diff --git a/conda/environments/all_cuda-122_arch-aarch64.yaml b/conda/environments/all_cuda-122_arch-aarch64.yaml index 7b50ceb0a..3c5319160 100644 --- a/conda/environments/all_cuda-122_arch-aarch64.yaml +++ b/conda/environments/all_cuda-122_arch-aarch64.yaml @@ -32,6 +32,7 @@ dependencies: - libcurand-dev - libcusolver-dev - libcusparse-dev +- libnvjitlink - make - nccl>=2.9.9 - ninja diff --git a/conda/environments/all_cuda-122_arch-x86_64.yaml b/conda/environments/all_cuda-122_arch-x86_64.yaml index 6c933e193..4e84731c7 100644 --- a/conda/environments/all_cuda-122_arch-x86_64.yaml +++ b/conda/environments/all_cuda-122_arch-x86_64.yaml @@ -32,6 +32,7 @@ dependencies: - libcurand-dev - libcusolver-dev - libcusparse-dev +- libnvjitlink - make - nccl>=2.9.9 - ninja diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index eb9de7a9d..8556e2941 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -188,6 +188,35 @@ include(cmake/thirdparty/get_cutlass.cmake) add_library( cuvs SHARED + src/distance/detail/pairwise_matrix/dispatch_canberra_double_double_double_int.cu + src/distance/detail/pairwise_matrix/dispatch_canberra_float_float_float_int.cu + src/distance/detail/pairwise_matrix/dispatch_correlation_double_double_double_int.cu + src/distance/detail/pairwise_matrix/dispatch_correlation_float_float_float_int.cu + src/distance/detail/pairwise_matrix/dispatch_cosine_double_double_double_int.cu + src/distance/detail/pairwise_matrix/dispatch_cosine_float_float_float_int.cu + src/distance/detail/pairwise_matrix/dispatch_hamming_unexpanded_double_double_double_int.cu + src/distance/detail/pairwise_matrix/dispatch_hamming_unexpanded_float_float_float_int.cu + src/distance/detail/pairwise_matrix/dispatch_hellinger_expanded_double_double_double_int.cu + src/distance/detail/pairwise_matrix/dispatch_hellinger_expanded_float_float_float_int.cu + src/distance/detail/pairwise_matrix/dispatch_jensen_shannon_double_double_double_int.cu + src/distance/detail/pairwise_matrix/dispatch_jensen_shannon_float_float_float_int.cu + src/distance/detail/pairwise_matrix/dispatch_kl_divergence_double_double_double_int.cu + src/distance/detail/pairwise_matrix/dispatch_kl_divergence_float_float_float_int.cu + src/distance/detail/pairwise_matrix/dispatch_l1_double_double_double_int.cu + src/distance/detail/pairwise_matrix/dispatch_l1_float_float_float_int.cu + src/distance/detail/pairwise_matrix/dispatch_l2_expanded_double_double_double_int.cu + src/distance/detail/pairwise_matrix/dispatch_l2_expanded_float_float_float_int.cu + src/distance/detail/pairwise_matrix/dispatch_l2_unexpanded_double_double_double_int.cu + src/distance/detail/pairwise_matrix/dispatch_l2_unexpanded_float_float_float_int.cu + src/distance/detail/pairwise_matrix/dispatch_l_inf_double_double_double_int.cu + src/distance/detail/pairwise_matrix/dispatch_l_inf_float_float_float_int.cu + src/distance/detail/pairwise_matrix/dispatch_lp_unexpanded_double_double_double_int.cu + src/distance/detail/pairwise_matrix/dispatch_lp_unexpanded_float_float_float_int.cu + src/distance/detail/pairwise_matrix/dispatch_rbf.cu + src/distance/detail/pairwise_matrix/dispatch_russel_rao_double_double_double_int.cu + src/distance/detail/pairwise_matrix/dispatch_russel_rao_float_float_float_int.cu + src/distance/distance.cu + src/distance/pairwise_distance.cu src/neighbors/brute_force_index.cu src/neighbors/brute_force.cu src/neighbors/cagra_build_float.cpp @@ -240,16 +269,13 @@ target_include_directories( if(NOT BUILD_CPU_ONLY) # Keep cuVS as lightweight as possible. Only CUDA libs and rmm should be used in global target. - target_link_libraries(cuvs - PUBLIC - rmm::rmm - $<$>:raft::raft> - $<$>:raft::compiled> - PRIVATE - $<$:raft::raft> - $<$:raft::compiled_static> - nvidia::cutlass::cutlass - ) + target_link_libraries( + cuvs + PUBLIC rmm::rmm $<$>:raft::raft> + $<$>:raft::compiled> + PRIVATE $<$:raft::raft> + $<$:raft::compiled_static> nvidia::cutlass::cutlass + ) endif() # Endian detection @@ -304,14 +330,14 @@ endif() set_target_properties( cuvs - PROPERTIES BUILD_RPATH "\$ORIGIN" - INSTALL_RPATH "\$ORIGIN" - CXX_STANDARD 17 - CXX_STANDARD_REQUIRED ON - CUDA_STANDARD 17 - CUDA_STANDARD_REQUIRED ON + PROPERTIES BUILD_RPATH "\$ORIGIN" + INSTALL_RPATH "\$ORIGIN" + CXX_STANDARD 17 + CXX_STANDARD_REQUIRED ON + CUDA_STANDARD 17 + CUDA_STANDARD_REQUIRED ON INTERFACE_POSITION_INDEPENDENT_CODE ON - POSITION_INDEPENDENT_CODE ON + POSITION_INDEPENDENT_CODE ON ) target_compile_options( @@ -325,25 +351,21 @@ target_link_options(cuvs PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/fatbin.ld") # * cuvs_c ------------------------------------------------------------------------------- if(BUILD_C_LIBRARY) add_library( - cuvs_c SHARED - src/core/c_api.cpp - src/neighbors/brute_force_c.cpp - src/neighbors/ivf_flat_c.cpp - src/neighbors/ivf_pq_c.cpp - src/neighbors/cagra_c.cpp + cuvs_c SHARED src/core/c_api.cpp src/neighbors/brute_force_c.cpp src/neighbors/ivf_flat_c.cpp + src/neighbors/ivf_pq_c.cpp src/neighbors/cagra_c.cpp ) add_library(cuvs::c_api ALIAS cuvs_c) set_target_properties( cuvs_c - PROPERTIES BUILD_RPATH "\$ORIGIN" - INSTALL_RPATH "\$ORIGIN" - CXX_STANDARD 17 - CXX_STANDARD_REQUIRED ON - POSITION_INDEPENDENT_CODE ON + PROPERTIES BUILD_RPATH "\$ORIGIN" + INSTALL_RPATH "\$ORIGIN" + CXX_STANDARD 17 + CXX_STANDARD_REQUIRED ON + POSITION_INDEPENDENT_CODE ON INTERFACE_POSITION_INDEPENDENT_CODE ON - EXPORT_NAME c_api + EXPORT_NAME c_api ) target_compile_options(cuvs_c PRIVATE "$<$:${CUVS_CXX_FLAGS}>") @@ -354,12 +376,11 @@ if(BUILD_C_LIBRARY) INTERFACE "$" ) - target_link_libraries(cuvs_c - PUBLIC - cuvs::cuvs - PRIVATE - $<$:raft::raft> - ) + target_link_libraries( + cuvs_c + PUBLIC cuvs::cuvs + PRIVATE $<$:raft::raft> + ) # ensure CUDA symbols aren't relocated to the middle of the debug build binaries target_link_options(cuvs_c PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/fatbin.ld") diff --git a/cpp/include/cuvs/distance/distance_types.h b/cpp/include/cuvs/distance/distance_types.h index 6cc2a993b..550221e8e 100644 --- a/cpp/include/cuvs/distance/distance_types.h +++ b/cpp/include/cuvs/distance/distance_types.h @@ -20,7 +20,7 @@ extern "C" { #endif /** enum to tell how to compute distance */ -enum DistanceType { +typedef enum { /** evaluate as dist_ij = sum(x_ik^2) + sum(y_ij)^2 - 2*sum(x_ik * y_jk) */ L2Expanded = 0, @@ -64,7 +64,7 @@ enum DistanceType { DiceExpanded = 19, /** Precomputed (special value) **/ Precomputed = 100 -}; +} cuvsDistanceType; #ifdef __cplusplus } diff --git a/cpp/include/cuvs/distance/distance_types.hpp b/cpp/include/cuvs/distance/distance_types.hpp index 0b2fa4c26..7b1864647 100644 --- a/cpp/include/cuvs/distance/distance_types.hpp +++ b/cpp/include/cuvs/distance/distance_types.hpp @@ -14,56 +14,12 @@ * limitations under the License. */ +#include "distance_types.h" #pragma once namespace cuvs::distance { -/** enum to tell how to compute distance */ -enum DistanceType : unsigned short { - - /** evaluate as dist_ij = sum(x_ik^2) + sum(y_ij)^2 - 2*sum(x_ik * y_jk) */ - L2Expanded = 0, - /** same as above, but inside the epilogue, perform square root operation */ - L2SqrtExpanded = 1, - /** cosine distance */ - CosineExpanded = 2, - /** L1 distance */ - L1 = 3, - /** evaluate as dist_ij += (x_ik - y-jk)^2 */ - L2Unexpanded = 4, - /** same as above, but inside the epilogue, perform square root operation */ - L2SqrtUnexpanded = 5, - /** basic inner product **/ - InnerProduct = 6, - /** Chebyshev (Linf) distance **/ - Linf = 7, - /** Canberra distance **/ - Canberra = 8, - /** Generalized Minkowski distance **/ - LpUnexpanded = 9, - /** Correlation distance **/ - CorrelationExpanded = 10, - /** Jaccard distance **/ - JaccardExpanded = 11, - /** Hellinger distance **/ - HellingerExpanded = 12, - /** Haversine distance **/ - Haversine = 13, - /** Bray-Curtis distance **/ - BrayCurtis = 14, - /** Jensen-Shannon distance**/ - JensenShannon = 15, - /** Hamming distance **/ - HammingUnexpanded = 16, - /** KLDivergence **/ - KLDivergence = 17, - /** RusselRao **/ - RusselRaoExpanded = 18, - /** Dice-Sorensen distance **/ - DiceExpanded = 19, - /** Precomputed (special value) **/ - Precomputed = 100 -}; +using DistanceType = cuvsDistanceType; /** * Whether minimal distance corresponds to similar elements (using the given metric). diff --git a/cpp/include/cuvs/distance/pairwise_distance.hpp b/cpp/include/cuvs/distance/pairwise_distance.hpp new file mode 100644 index 000000000..27d9af6c1 --- /dev/null +++ b/cpp/include/cuvs/distance/pairwise_distance.hpp @@ -0,0 +1,206 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include + +namespace cuvs::distance { + +/** + * @defgroup pairwise_distance Pairwise Distances API + * @{ + */ + +/** + * @brief Compute pairwise distances for two matrices + * + * Note: Only contiguous row- or column-major layouts supported currently. + * + * Usage example: + * @code{.cpp} + * #include + * #include + * #include + * + * raft::raft::resources handle; + * int n_samples = 5000; + * int n_features = 50; + * + * auto input = raft::make_device_matrix(handle, n_samples, n_features); + * + * // ... fill input with data ... + * + * auto output = raft::make_device_matrix(handle, n_samples, n_samples); + * + * auto metric = cuvs::distance::DistanceType::L2SqrtExpanded; + * cuvs::distance::pairwise_distance(handle, + * raft::make_const(input.view()), + * raft::make_const(input.view()), + * output.view(), + * metric); + * @endcode + * + * @param[in] handle raft handle for managing expensive resources + * @param[in] x first set of points (size n*k) + * @param[in] y second set of points (size m*k) + * @param[out] dist output distance matrix (size n*m) + * @param[in] metric distance to evaluate + * @param[in] metric_arg metric argument (used for Minkowski distance) + */ +void pairwise_distance( + raft::resources const& handle, + raft::device_matrix_view const x, + raft::device_matrix_view const y, + raft::device_matrix_view dist, + cuvs::distance::DistanceType metric, + float metric_arg = 2.0f); + +/** + * @brief Compute pairwise distances for two matrices + * + * Note: Only contiguous row- or column-major layouts supported currently. + * + * Usage example: + * @code{.cpp} + * #include + * #include + * #include + * + * raft::raft::resources handle; + * int n_samples = 5000; + * int n_features = 50; + * + * auto input = raft::make_device_matrix(handle, n_samples, n_features); + * + * // ... fill input with data ... + * + * auto output = raft::make_device_matrix(handle, n_samples, n_samples); + * + * auto metric = cuvs::distance::DistanceType::L2SqrtExpanded; + * cuvs::distance::pairwise_distance(handle, + * raft::make_const(input.view()), + * raft::make_const(input.view()), + * output.view(), + * metric); + * @endcode + * + * @param[in] handle raft handle for managing expensive resources + * @param[in] x first set of points (size n*k) + * @param[in] y second set of points (size m*k) + * @param[out] dist output distance matrix (size n*m) + * @param[in] metric distance to evaluate + * @param[in] metric_arg metric argument (used for Minkowski distance) + */ +void pairwise_distance( + raft::resources const& handle, + raft::device_matrix_view const x, + raft::device_matrix_view const y, + raft::device_matrix_view dist, + cuvs::distance::DistanceType metric, + double metric_arg = 2.0f); + +/** + * @brief Compute pairwise distances for two matrices + * + * Note: Only contiguous row- or column-major layouts supported currently. + * + * Usage example: + * @code{.cpp} + * #include + * #include + * #include + * + * raft::raft::resources handle; + * int n_samples = 5000; + * int n_features = 50; + * + * auto input = raft::make_device_matrix(handle, n_samples, n_features); + * + * // ... fill input with data ... + * + * auto output = raft::make_device_matrix(handle, n_samples, n_samples); + * + * auto metric = cuvs::distance::DistanceType::L2SqrtExpanded; + * cuvs::distance::pairwise_distance(handle, + * raft::make_const(input.view()), + * raft::make_const(input.view()), + * output.view(), + * metric); + * @endcode + * + * @param[in] handle raft handle for managing expensive resources + * @param[in] x first set of points (size n*k) + * @param[in] y second set of points (size m*k) + * @param[out] dist output distance matrix (size n*m) + * @param[in] metric distance to evaluate + * @param[in] metric_arg metric argument (used for Minkowski distance) + */ +void pairwise_distance( + raft::resources const& handle, + raft::device_matrix_view const x, + raft::device_matrix_view const y, + raft::device_matrix_view dist, + cuvs::distance::DistanceType metric, + float metric_arg = 2.0f); +/** + * @brief Compute pairwise distances for two matrices + * + * Note: Only contiguous row- or column-major layouts supported currently. + * + * Usage example: + * @code{.cpp} + * #include + * #include + * #include + * + * raft::raft::resources handle; + * int n_samples = 5000; + * int n_features = 50; + * + * auto input = raft::make_device_matrix(handle, n_samples, n_features); + * + * // ... fill input with data ... + * + * auto output = raft::make_device_matrix(handle, n_samples, n_samples); + * + * auto metric = cuvs::distance::DistanceType::L2SqrtExpanded; + * cuvs::distance::pairwise_distance(handle, + * raft::make_const(input.view()), + * raft::make_const(input.view()), + * output.view(), + * metric); + * @endcode + * + * @param[in] handle raft handle for managing expensive resources + * @param[in] x first set of points (size n*k) + * @param[in] y second set of points (size m*k) + * @param[out] dist output distance matrix (size n*m) + * @param[in] metric distance to evaluate + * @param[in] metric_arg metric argument (used for Minkowski distance) + */ +void pairwise_distance( + raft::resources const& handle, + raft::device_matrix_view const x, + raft::device_matrix_view const y, + raft::device_matrix_view dist, + cuvs::distance::DistanceType metric, + double metric_arg = 2.0f); + +/** @} */ // end group pairwise_distance_runtime + +} // namespace cuvs::distance diff --git a/cpp/include/cuvs/neighbors/brute_force.h b/cpp/include/cuvs/neighbors/brute_force.h index cb7ee2a30..145bb5555 100644 --- a/cpp/include/cuvs/neighbors/brute_force.h +++ b/cpp/include/cuvs/neighbors/brute_force.h @@ -102,7 +102,7 @@ cuvsError_t cuvsBruteForceIndexDestroy(cuvsBruteForceIndex_t index); */ cuvsError_t cuvsBruteForceBuild(cuvsResources_t res, DLManagedTensor* dataset, - enum DistanceType metric, + cuvsDistanceType metric, float metric_arg, cuvsBruteForceIndex_t index); /** diff --git a/cpp/include/cuvs/neighbors/ivf_flat.h b/cpp/include/cuvs/neighbors/ivf_flat.h index 08200ae7d..22c4d361c 100644 --- a/cpp/include/cuvs/neighbors/ivf_flat.h +++ b/cpp/include/cuvs/neighbors/ivf_flat.h @@ -36,7 +36,7 @@ extern "C" { */ struct ivfFlatIndexParams { /** Distance type. */ - enum DistanceType metric; + cuvsDistanceType metric; /** The argument used by some distance metrics. */ float metric_arg; /** diff --git a/cpp/include/cuvs/neighbors/ivf_pq.h b/cpp/include/cuvs/neighbors/ivf_pq.h index c1fcaed86..2a8269eb0 100644 --- a/cpp/include/cuvs/neighbors/ivf_pq.h +++ b/cpp/include/cuvs/neighbors/ivf_pq.h @@ -45,7 +45,7 @@ enum codebook_gen { // NOLINT */ struct ivfPqIndexParams { /** Distance type. */ - enum DistanceType metric; + cuvsDistanceType metric; /** The argument used by some distance metrics. */ float metric_arg; /** diff --git a/cpp/src/distance/detail/compress_to_bits.cuh b/cpp/src/distance/detail/compress_to_bits.cuh new file mode 100644 index 000000000..9ce47774a --- /dev/null +++ b/cpp/src/distance/detail/compress_to_bits.cuh @@ -0,0 +1,123 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +#include + +namespace cuvs::distance::detail { + +/** + * @brief Compress 2D boolean matrix to bitfield + * + * Utility kernel for masked_l2_nn. + * + * @tparam T + * + * @parameter[in] in An `m x n` boolean matrix. Row major. + * @parameter[out] out An `(m / bits_per_elem) x n` matrix with elements of + * type T, where T is of size `bits_per_elem` bits. + * Note: the division (`/`) is a ceilDiv. + */ +template ::value>> +RAFT_KERNEL compress_to_bits_kernel( + raft::device_matrix_view in, + raft::device_matrix_view out) +{ + constexpr int bits_per_element = 8 * sizeof(T); + constexpr int tile_dim_m = bits_per_element; + constexpr int nthreads = 128; + constexpr int tile_dim_n = nthreads; // read 128 bools at once = 1 sector + + // Tile in shared memory is transposed + __shared__ bool smem[tile_dim_n][tile_dim_m]; + + const int num_tiles_per_m = raft::ceildiv(in.extent(0), tile_dim_m); + const int num_tiles_per_n = raft::ceildiv(in.extent(1), tile_dim_n); + + for (int lin_tile_idx = blockIdx.x; true; lin_tile_idx += gridDim.x) { + const int tile_idx_n = tile_dim_n * (lin_tile_idx % num_tiles_per_n); + const int tile_idx_m = tile_dim_m * (lin_tile_idx / num_tiles_per_n); + + if (in.extent(0) <= tile_idx_m) { break; } + // Fill shared memory tile + bool reg_buf[tile_dim_m]; +#pragma unroll + for (int i = 0; i < tile_dim_m; ++i) { + const int in_m = tile_idx_m + i; + const int in_n = tile_idx_n + threadIdx.x; + bool in_bounds = in_m < in.extent(0) && in_n < in.extent(1); + reg_buf[i] = in_bounds ? in(in_m, in_n) : false; + smem[threadIdx.x][i] = reg_buf[i]; + } + __syncthreads(); + + // Drain memory tile into single output element out_elem. + T out_elem{0}; +#pragma unroll + for (int j = 0; j < tile_dim_n; ++j) { + if (smem[threadIdx.x][j]) { out_elem |= T(1) << j; } + } + __syncthreads(); + + // Write output. + int out_m = tile_idx_m / bits_per_element; + int out_n = tile_idx_n + threadIdx.x; + + if (out_m < out.extent(0) && out_n < out.extent(1)) { out(out_m, out_n) = out_elem; } + } +} + +/** + * @brief Compress 2D boolean matrix to bitfield + * + * Utility kernel for masked_l2_nn. + * + * @tparam T + * + * @parameter[in] in An `m x n` boolean matrix. Row major. + * @parameter[out] out An `(m / bits_per_elem) x n` matrix with elements of + * type T, where T is of size `bits_per_elem` bits. + * Note: the division (`/`) is a ceilDiv. + */ +template ::value>> +void compress_to_bits(raft::resources const& handle, + raft::device_matrix_view in, + raft::device_matrix_view out) +{ + auto stream = resource::get_cuda_stream(handle); + constexpr int bits_per_element = 8 * sizeof(T); + + RAFT_EXPECTS(raft::ceildiv(in.extent(0), bits_per_element) == out.extent(0), + "Number of output rows must be ceildiv(input rows, bits_per_elem)"); + RAFT_EXPECTS(in.extent(1) == out.extent(1), "Number of output columns must equal input columns."); + + const int num_SMs = raft::getMultiProcessorCount(); + int blocks_per_sm = 0; + constexpr int num_threads = 128; + constexpr int dyn_smem_size = 0; + RAFT_CUDA_TRY(cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &blocks_per_sm, compress_to_bits_kernel, num_threads, dyn_smem_size)); + + dim3 grid(num_SMs * blocks_per_sm); + dim3 block(128); + compress_to_bits_kernel<<>>(in, out); + RAFT_CUDA_TRY(cudaGetLastError()); +} + +}; // namespace cuvs::distance::detail diff --git a/cpp/src/distance/detail/distance.cuh b/cpp/src/distance/detail/distance.cuh new file mode 100644 index 000000000..7765bc672 --- /dev/null +++ b/cpp/src/distance/detail/distance.cuh @@ -0,0 +1,815 @@ +/* + * Copyright (c) 2018-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "distance_ops/all_ops.cuh" +#include "pairwise_matrix/dispatch.cuh" +#include "pairwise_matrix/dispatch_sm60.cuh" +#include "pairwise_matrix/dispatch_sm80.cuh" +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace cuvs { +namespace distance { +namespace detail { + +/** + * @brief: A tag type for overload resolution based on DistanceType + * + * It is not possible to partially specialize function templates on a single + * parameter. Instead, it is often easier to use a combination of conventional + * method overloading and a parameter with a specific tag type. The following + * type is used to help method overloading based on the DistanceType enum. + */ +template +using distance_tag = std::integral_constant; + +/** + * @brief Implement pairwise_matrix for specific distance + * + * There are multiple overloads for this function, one for each distance type. + * They are implemented below. The documentation of this function serves as + * documentation for all functions. The following overloads are defined: + * + * - DistanceType::Canberra: + * - DistanceType::CorrelationExpanded: + * - DistanceType::CosineExpanded: + * - DistanceType::HammingUnexpanded: + * - DistanceType::HellingerExpanded: + * - DistanceType::JensenShannon: + * - DistanceType::KLDivergence: + * - DistanceType::L1: + * - DistanceType::L2Expanded: + * - DistanceType::L2SqrtExpanded: + * - DistanceType::L2Unexpanded: + * - DistanceType::L2SqrtUnexpanded: + * - DistanceType::Linf: + * - DistanceType::LpUnexpanded: + * - DistanceType::RusselRaoExpanded: + * + * @tparam DataT Input data type + * @tparam AccT Accumulation data type + * @tparam OutT Output data type + * @tparam FinOpT Type of final operation + * @tparam IdxT Index type + * + * @param handle RAFT resources handle + * @param distance_type A tag type to indicate which distance is calculated. + * @param x First set of points + * @param y Second set of points + * @param out Output distance matrix + * @param m Number of points in x + * @param n Number of points in y + * @param k Dimensionality of points in x, y + * @param workspace Temporary workspace needed for computations + * @param worksize Number of bytes of the workspace + * @param is_row_major Whether the matrices are row-major or col-major + * @param metric_arg The `p` argument for Lp. + */ +template +void distance_impl(raft::resources const& handle, + distance_tag distance_type, + const DataT* x, + const DataT* y, + OutT* out, + IdxT m, + IdxT n, + IdxT k, + AccT* workspace, // unused + size_t worksize, // unused + FinOpT fin_op, + bool is_row_major, + DataT metric_arg) // unused +{ + ops::canberra_distance_op distance_op{}; + + const DataT* x_norm = nullptr; + const DataT* y_norm = nullptr; + + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + pairwise_matrix_dispatch( + distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); +} + +template +void distance_impl(raft::resources const& handle, + distance_tag distance_type, + const DataT* x, + const DataT* y, + OutT* out, + IdxT m, + IdxT n, + IdxT k, + AccT* workspace, + size_t worksize, + FinOpT fin_op, + bool is_row_major, + DataT) // unused +{ + ASSERT(!(worksize < 2 * (m + n) * sizeof(AccT)), "workspace size error"); + ASSERT(workspace != nullptr, "workspace is null"); + + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + + AccT* x_norm = workspace; + AccT* y_norm = workspace; + AccT* sq_x_norm = workspace; + AccT* sq_y_norm = workspace; + // TODO: Column major case looks to have lower accuracy for X == Y, + // perhaps the use of stridedSummationKernel could be causing this, + // need to investigate and fix. + if (x == y && is_row_major) { + raft::linalg::reduce(x_norm, + x, + k, + std::max(m, n), + (AccT)0, + is_row_major, + true, + stream, + false, + raft::identity_op(), + raft::add_op()); + sq_x_norm += std::max(m, n); + sq_y_norm = sq_x_norm; + raft::linalg::rowNorm( + sq_x_norm, x, k, std::max(m, n), raft::linalg::L2Norm, is_row_major, stream); + } else { + y_norm += m; + raft::linalg::reduce(x_norm, + x, + k, + m, + (AccT)0, + is_row_major, + true, + stream, + false, + raft::identity_op(), + raft::add_op()); + raft::linalg::reduce(y_norm, + y, + k, + n, + (AccT)0, + is_row_major, + true, + stream, + false, + raft::identity_op(), + raft::add_op()); + + sq_x_norm += (m + n); + sq_y_norm = sq_x_norm + m; + raft::linalg::rowNorm(sq_x_norm, x, k, m, raft::linalg::L2Norm, is_row_major, stream); + raft::linalg::rowNorm(sq_y_norm, y, k, n, raft::linalg::L2Norm, is_row_major, stream); + } + + using OpT = ops::correlation_distance_op; + OpT corr_op(is_row_major, sq_x_norm, sq_y_norm, m, n, k); + pairwise_matrix_dispatch( + corr_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); +} + +template +void distance_impl(raft::resources const& handle, + distance_tag distance_type, + const DataT* x, + const DataT* y, + OutT* out, + IdxT m, + IdxT n, + IdxT k, + AccT* workspace, + size_t worksize, + FinOpT fin_op, + bool is_row_major, + DataT) // unused +{ + // raft distance support inputs as float/double and output as uint8_t/float/double. + static_assert(!((sizeof(OutT) > 1) && (sizeof(AccT) != sizeof(OutT))), + "OutT can be uint8_t, float, double," + "if sizeof(OutT) > 1 then sizeof(AccT) == sizeof(OutT)."); + + ASSERT(!(worksize < (m + n) * sizeof(AccT)), "workspace size error"); + ASSERT(workspace != nullptr, "workspace is null"); + + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + + DataT* x_norm = workspace; + DataT* y_norm = workspace; + // TODO: Column major case looks to have lower accuracy for X == Y, + // perhaps the use of stridedSummationKernel could be causing this, + // need to investigate and fix. + if (x == y && is_row_major) { + raft::linalg::rowNorm( + x_norm, x, k, std::max(m, n), raft::linalg::L2Norm, is_row_major, stream, raft::sqrt_op{}); + } else { + y_norm += m; + raft::linalg::rowNorm( + x_norm, x, k, m, raft::linalg::L2Norm, is_row_major, stream, raft::sqrt_op{}); + raft::linalg::rowNorm( + y_norm, y, k, n, raft::linalg::L2Norm, is_row_major, stream, raft::sqrt_op{}); + } + + ops::cosine_distance_op distance_op{}; + pairwise_matrix_dispatch( + distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); +} + +template +void distance_impl(raft::resources const& handle, + distance_tag distance_type, + const DataT* x, + const DataT* y, + OutT* out, + IdxT m, + IdxT n, + IdxT k, + AccT*, // workspace unused + size_t, // worksize unused + FinOpT fin_op, + bool is_row_major, + DataT) // metric_arg unused +{ + ops::hamming_distance_op distance_op{k}; + + const DataT* x_norm = nullptr; + const DataT* y_norm = nullptr; + + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + + pairwise_matrix_dispatch( + distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); +} + +template +void distance_impl(raft::resources const& handle, + distance_tag distance_type, + const DataT* x, + const DataT* y, + OutT* out, + IdxT m, + IdxT n, + IdxT k, + AccT*, // workspace unused + size_t, // worksize unused + FinOpT fin_op, + bool is_row_major, + DataT) // metric_arg unused +{ + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + raft::linalg::gemm(handle, + out, + const_cast(x), + const_cast(y), + m, + n, + k, + !is_row_major, + !is_row_major, + is_row_major, + stream); +} + +template +void distance_impl(raft::resources const& handle, + distance_tag distance_type, + const DataT* x, + const DataT* y, + OutT* out, + IdxT m, + IdxT n, + IdxT k, + AccT*, // workspace unused + size_t, // worksize unused + FinOpT fin_op, + bool is_row_major, + DataT) // metric_arg unused +{ + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + + // First sqrt x and y + const auto raft_sqrt = raft::linalg::unaryOp; + + raft_sqrt((DataT*)x, x, m * k, raft::sqrt_op{}, stream); + if (x != y) { raft_sqrt((DataT*)y, y, n * k, raft::sqrt_op{}, stream); } + + // Then calculate Hellinger distance + ops::hellinger_distance_op distance_op{}; + + const DataT* x_norm = nullptr; + const DataT* y_norm = nullptr; + + pairwise_matrix_dispatch( + distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); + + // Finally revert sqrt of x and y + raft_sqrt((DataT*)x, x, m * k, raft::sqrt_op{}, stream); + if (x != y) { raft_sqrt((DataT*)y, y, n * k, raft::sqrt_op{}, stream); } + + RAFT_CUDA_TRY(cudaGetLastError()); +} + +template +void distance_impl(raft::resources const& handle, + distance_tag distance_type, + const DataT* x, + const DataT* y, + OutT* out, + IdxT m, + IdxT n, + IdxT k, + AccT*, // workspace unused + size_t, // worksize unused + FinOpT fin_op, + bool is_row_major, + DataT) // metric_arg unused +{ + ops::jensen_shannon_distance_op distance_op{}; + + const DataT* x_norm = nullptr; + const DataT* y_norm = nullptr; + + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + + pairwise_matrix_dispatch( + distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); +} + +template +void distance_impl(raft::resources const& handle, + distance_tag distance_type, + const DataT* x, + const DataT* y, + OutT* out, + IdxT m, + IdxT n, + IdxT k, + AccT*, // workspace unused + size_t, // worksize unused + FinOpT fin_op, + bool is_row_major, + DataT) // metric_arg unused +{ + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + + auto unaryOp_lambda = [] __device__(DataT input) { + const bool x_zero = (input == 0); + return (!x_zero) * raft::log(input + x_zero); + }; + + auto unaryOp_lambda_reverse = [] __device__(DataT input) { + // reverse previous log (x) back to x using (e ^ log(x)) + const bool x_zero = (input == 0); + return (!x_zero) * raft::exp(input); + }; + + if (x != y) { + raft::linalg::unaryOp( + (DataT*)y, y, n * k, unaryOp_lambda, stream); + } + + const DataT* x_norm = nullptr; + const DataT* y_norm = nullptr; + + // This op takes some shortcuts when x equals y. So its behavior changes based + // on this. + ops::kl_divergence_op distance_op{is_row_major, x == y}; + + pairwise_matrix_dispatch( + distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); + + if (x != y) { + // Now reverse previous log (x) back to x using (e ^ log(x)) + raft::linalg::unaryOp( + (DataT*)y, y, n * k, unaryOp_lambda_reverse, stream); + } +} + +template +void distance_impl(raft::resources const& handle, + distance_tag distance_type, + const DataT* x, + const DataT* y, + OutT* out, + IdxT m, + IdxT n, + IdxT k, + AccT*, // workspace unused + size_t, // worksize unused + FinOpT fin_op, + bool is_row_major, + DataT) // metric_arg unused +{ + ops::l1_distance_op distance_op{}; + + const DataT* x_norm = nullptr; + const DataT* y_norm = nullptr; + + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + pairwise_matrix_dispatch( + distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); +} + +template +void distance_impl_l2_expanded( // NOTE: different name + bool perform_sqrt, // dispatch on sqrt + const DataT* x, + const DataT* y, + OutT* out, + IdxT m, + IdxT n, + IdxT k, + AccT* workspace, + size_t worksize, + FinOpT fin_op, + cudaStream_t stream, + bool is_row_major) +{ + // raft distance support inputs as float/double and output as uint8_t/float/double. + static_assert(!((sizeof(OutT) > 1) && (sizeof(AccT) != sizeof(OutT))), + "OutT can be uint8_t, float, double," + "if sizeof(OutT) > 1 then sizeof(AccT) == sizeof(OutT)."); + + ASSERT(!(worksize < (m + n) * sizeof(AccT)), "workspace size error"); + ASSERT(workspace != nullptr, "workspace is null"); + + DataT* x_norm = workspace; + DataT* y_norm = workspace; + // TODO: Column major case looks to have lower accuracy for X == Y, + // perhaps the use of stridedSummationKernel could be causing this, + // need to investigate and fix. + if ((x == y) && is_row_major) { + raft::linalg::rowNorm(x_norm, + x, + k, + std::max(m, n), + raft::linalg::L2Norm, + is_row_major, + stream, + raft::identity_op{}); + } else { + y_norm += m; + raft::linalg::rowNorm( + x_norm, x, k, m, raft::linalg::L2Norm, is_row_major, stream, raft::identity_op{}); + raft::linalg::rowNorm( + y_norm, y, k, n, raft::linalg::L2Norm, is_row_major, stream, raft::identity_op{}); + } + + ops::l2_exp_distance_op distance_op{perform_sqrt}; + pairwise_matrix_dispatch( + distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); +} + +template +void distance_impl(raft::resources const& handle, + distance_tag distance_type, + const DataT* x, + const DataT* y, + OutT* out, + IdxT m, + IdxT n, + IdxT k, + AccT* workspace, + size_t worksize, + FinOpT fin_op, + bool is_row_major, + DataT) // metric_arg unused +{ + bool perform_sqrt = false; + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + distance_impl_l2_expanded( + perform_sqrt, x, y, out, m, n, k, workspace, worksize, fin_op, stream, is_row_major); +} + +template +void distance_impl(raft::resources const& handle, + distance_tag distance_type, + const DataT* x, + const DataT* y, + OutT* out, + IdxT m, + IdxT n, + IdxT k, + AccT* workspace, + size_t worksize, + FinOpT fin_op, + bool is_row_major, + DataT) // metric_arg unused +{ + bool perform_sqrt = true; + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + distance_impl_l2_expanded( + perform_sqrt, x, y, out, m, n, k, workspace, worksize, fin_op, stream, is_row_major); +} + +template +void distance_impl(raft::resources const& handle, + distance_tag distance_type, + const DataT* x, + const DataT* y, + OutT* out, + IdxT m, + IdxT n, + IdxT k, + AccT*, // workspace unused + size_t, // worksize unused + FinOpT fin_op, + bool is_row_major, + DataT) // metric_arg unused +{ + bool perform_sqrt = false; + ops::l2_unexp_distance_op l2_op(perform_sqrt); + + // The unexpanded L2 does not require the norms of a and b to be calculated. + const DataT* x_norm = nullptr; + const DataT* y_norm = nullptr; + + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + + pairwise_matrix_dispatch( + l2_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); +} + +template +void distance_impl(raft::resources const& handle, + distance_tag distance_type, + const DataT* x, + const DataT* y, + OutT* out, + IdxT m, + IdxT n, + IdxT k, + AccT*, // workspace unused + size_t, // worksize unused + FinOpT fin_op, + bool is_row_major, + DataT) // metric_arg unused +{ + bool perform_sqrt = true; + ops::l2_unexp_distance_op l2_op(perform_sqrt); + + // The unexpanded L2 does not require the norms of a and b to be calculated. + const DataT* x_norm = nullptr; + const DataT* y_norm = nullptr; + + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + + pairwise_matrix_dispatch( + l2_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); +} + +template +void distance_impl(raft::resources const& handle, + distance_tag distance_type, + const DataT* x, + const DataT* y, + OutT* out, + IdxT m, + IdxT n, + IdxT k, + AccT*, // workspace unused + size_t, // worksize unused + FinOpT fin_op, + bool is_row_major, + DataT) // metric_arg unused +{ + ops::l_inf_distance_op distance_op{}; + + const DataT* x_norm = nullptr; + const DataT* y_norm = nullptr; + + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + + pairwise_matrix_dispatch( + distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); +} + +template +void distance_impl(raft::resources const& handle, + distance_tag distance_type, + const DataT* x, + const DataT* y, + OutT* out, + IdxT m, + IdxT n, + IdxT k, + AccT*, // workspace unused + size_t, // worksize unused + FinOpT fin_op, + bool is_row_major, + DataT metric_arg) +{ + ops::lp_unexp_distance_op distance_op{metric_arg}; + + const DataT* x_norm = nullptr; + const DataT* y_norm = nullptr; + + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + + pairwise_matrix_dispatch( + distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); +} + +template +void distance_impl(raft::resources const& handle, + distance_tag distance_type, + const DataT* x, + const DataT* y, + OutT* out, + IdxT m, + IdxT n, + IdxT k, + AccT*, // workspace unused + size_t, // worksize unused + FinOpT fin_op, + bool is_row_major, + DataT) // metric_arg unused +{ + ops::russel_rao_distance_op distance_op{k}; + + const DataT* x_norm = nullptr; + const DataT* y_norm = nullptr; + + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + + pairwise_matrix_dispatch( + distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); +} + +/** + * @brief Evaluate pairwise distances with the user epilogue lamba allowed + * @tparam DistanceType which distance to evaluate + * @tparam InType input argument type + * @tparam AccType accumulation type + * @tparam OutType output type + * @tparam FinalLambda user-defined epilogue lamba + * @tparam Index_ Index type + * + * @param x first set of points + * @param y second set of points + * @param out output distance matrix + * @param m number of points in x + * @param n number of points in y + * @param k dimensionality + * @param workspace temporary workspace needed for computations + * @param worksize number of bytes of the workspace + * @param fin_op the final gemm epilogue lambda + * @param stream cuda stream + * @param isRowMajor whether the matrices are row-major or col-major + * + * @note fin_op: This is a device lambda which is supposed to operate upon the + * input which is AccType and returns the output in OutType. It's signature is + * as follows:
OutType fin_op(AccType in, int g_idx);
. If one needs + * any other parameters, feel free to pass them via closure. + */ +template +void distance(raft::resources const& handle, + const InType* x, + const InType* y, + OutType* out, + Index_ m, + Index_ n, + Index_ k, + void* workspace, + size_t worksize, + FinalLambda fin_op, + bool isRowMajor = true, + InType metric_arg = 2.0f) +{ + // raft distance support inputs as float/double and output as uint8_t/float/double. + static_assert(!((sizeof(OutType) > 1) && (sizeof(AccType) != sizeof(OutType))), + "OutType can be uint8_t, float, double," + "if sizeof(OutType) > 1 then sizeof(AccType) == sizeof(OutType)."); + + distance_impl( + handle, + distance_tag{}, + x, + y, + out, + m, + n, + k, + reinterpret_cast(workspace), + worksize, + fin_op, + isRowMajor, + metric_arg); + RAFT_CUDA_TRY(cudaPeekAtLastError()); +} + +/** + * @brief Evaluate pairwise distances for the simple use case + * @tparam DistanceType which distance to evaluate + * @tparam InType input argument type + * @tparam AccType accumulation type + * @tparam OutType output type + * @tparam Index_ Index type + * @param x first set of points + * @param y second set of points + * @param dist output distance matrix + * @param m number of points in x + * @param n number of points in y + * @param k dimensionality + * @param workspace temporary workspace needed for computations + * @param worksize number of bytes of the workspace + * @param stream cuda stream + * @param isRowMajor whether the matrices are row-major or col-major + */ +template +void distance(raft::resources const& handle, + const InType* x, + const InType* y, + OutType* out, + Index_ m, + Index_ n, + Index_ k, + void* workspace, + size_t worksize, + bool isRowMajor = true, + InType metric_arg = 2.0f) +{ + auto fin_op = raft::identity_op(); + + distance( + handle, x, y, out, m, n, k, workspace, worksize, fin_op, isRowMajor, metric_arg); +} + +/** + * @brief Return the exact workspace size to compute the distance + * @tparam DistanceType which distance to evaluate + * @tparam InType input argument type + * @tparam AccType accumulation type + * @tparam OutType output type + * @tparam Index_ Index type + * @param x first set of points + * @param y second set of points + * @param m number of points in x + * @param n number of points in y + * @param k dimensionality + * + * @note If the specified distanceType doesn't need the workspace at all, it + * returns 0. + */ +template +size_t getWorkspaceSize(const InType* x, const InType* y, Index_ m, Index_ n, Index_ k) +{ + size_t worksize = 0; + constexpr bool is_allocated = (distanceType <= cuvs::distance::DistanceType::CosineExpanded) || + (distanceType == cuvs::distance::DistanceType::CorrelationExpanded); + constexpr int numOfBuffers = + (distanceType == cuvs::distance::DistanceType::CorrelationExpanded) ? 2 : 1; + + if (is_allocated) { + // TODO : when X == Y allocate std::max(m, n) instead of m + n when column major input + // accuracy issue is resolved until then we allocate as m + n. + worksize += numOfBuffers * m * sizeof(AccType); + worksize += numOfBuffers * n * sizeof(AccType); + } + + return worksize; +} + +}; // namespace detail +}; // namespace distance +}; // namespace cuvs diff --git a/cpp/src/distance/detail/distance_ops/all_ops.cuh b/cpp/src/distance/detail/distance_ops/all_ops.cuh new file mode 100644 index 000000000..55989697f --- /dev/null +++ b/cpp/src/distance/detail/distance_ops/all_ops.cuh @@ -0,0 +1,35 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +// Defines a named requirement "has_cutlass_op" +#include "../distance_ops/cutlass.cuh" + +// The distance operations: +#include "../distance_ops/canberra.cuh" +#include "../distance_ops/correlation.cuh" +#include "../distance_ops/cosine.cuh" +#include "../distance_ops/hamming.cuh" +#include "../distance_ops/hellinger.cuh" +#include "../distance_ops/jensen_shannon.cuh" +#include "../distance_ops/kl_divergence.cuh" +#include "../distance_ops/l1.cuh" +#include "../distance_ops/l2_exp.cuh" +#include "../distance_ops/l2_unexp.cuh" +#include "../distance_ops/l_inf.cuh" +#include "../distance_ops/lp_unexp.cuh" +#include "../distance_ops/russel_rao.cuh" diff --git a/cpp/src/distance/detail/distance_ops/canberra.cuh b/cpp/src/distance/detail/distance_ops/canberra.cuh new file mode 100644 index 000000000..8bbdc9945 --- /dev/null +++ b/cpp/src/distance/detail/distance_ops/canberra.cuh @@ -0,0 +1,71 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include // raft::abs +#include // DI + +namespace cuvs::distance::detail::ops { + +/** + * @brief The canberra distance matrix calculation + * + * It computes the following equation: + * + * c_ij = sum_k |x_ik - y_kj| / ( |x_ik| + |y_kj| ) + */ +template +struct canberra_distance_op { + using DataT = DataType; + using AccT = AccType; + using IdxT = IdxType; + + // Load norms of input data + static constexpr bool use_norms = false; + // Whether the core function requires so many instructions that it makes sense + // to reduce loop unrolling, etc. We do this to keep compile times in check. + static constexpr bool expensive_inner_loop = true; + + // Size of shared memory. This is normally decided by the kernel policy, but + // some ops such as correlation_distance_op use more. + template + static constexpr size_t shared_mem_size() + { + return Policy::SmemSize; + } + + DI void core(AccT& acc, DataT& x, DataT& y) const + { + const auto diff = raft::abs(x - y); + const auto add = raft::abs(x) + raft::abs(y); + // deal with potential for 0 in denominator by + // forcing 0/1 instead + acc += ((add != 0) * diff / (add + (add == 0))); + }; + + template + DI void epilog(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], + DataT* regxn, + DataT* regyn, + IdxT gridStrideX, + IdxT gridStrideY) const + { + return; + } +}; + +} // namespace cuvs::distance::detail::ops diff --git a/cpp/src/distance/detail/distance_ops/correlation.cuh b/cpp/src/distance/detail/distance_ops/correlation.cuh new file mode 100644 index 000000000..f033f3dfa --- /dev/null +++ b/cpp/src/distance/detail/distance_ops/correlation.cuh @@ -0,0 +1,126 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include // DI + +namespace cuvs::distance::detail::ops { + +/** @brief The correlation distance + * + * It computes the following equation: + * + * d(x, y) = ((x - mean(x)) ⋅ (y - mean(y))) + * / + * (|| x - mean(x) ||_2 || y - mean(y) ||_2) + */ +template +struct correlation_distance_op { + using DataT = DataType; + using AccT = AccType; + using IdxT = IdxType; + + const DataT* x2n; + const DataT* y2n; + IdxT m; + IdxT n; + IdxT k; + + correlation_distance_op( + bool is_row_major, const DataT* x2n_, const DataT* y2n_, IdxT m_, IdxT n_, IdxT k_) noexcept + : x2n(x2n_), y2n(y2n_), m(m_), n(n_), k(k_) + { + // The distance op is typically created before the row-major/col-major + // swapping has been done. So we do it here. + if (!is_row_major) { + std::swap(x2n, y2n); + std::swap(m, n); + } + } + + // Load norms of input data + static constexpr bool use_norms = true; + // Whether the core function requires so many instructions that it makes sense + // to reduce loop unrolling, etc. We do this to keep compile times in check. + static constexpr bool expensive_inner_loop = false; + + // Size of shared memory. This is normally decided by the kernel policy, but + // some ops such as correlation_distance_op use more. + template + static constexpr size_t shared_mem_size() + { + return Policy::SmemSize + (2 * (Policy::Mblk + Policy::Nblk) * sizeof(DataT)); + } + + DI void core(AccT& acc, DataT& x, DataT& y) const { acc += x * y; }; + + template + DI void epilog(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], + DataT* regxn, + DataT* regyn, + IdxT gridStrideX, + IdxT gridStrideY) const + { + // Note how we can sneakily get a pointer to shared memory here, to store + // more data. If the implementation of PairwiseDistanceMatKernel ever + // changes, this will be where we find the bugs. + extern __shared__ char smem[]; + + DataT regx2n[Policy::AccRowsPerTh], regy2n[Policy::AccColsPerTh]; + + DataT* sx2Norm = + (DataT*)(&smem[Policy::SmemSize + (Policy::Mblk + Policy::Nblk) * sizeof(DataT)]); + DataT* sy2Norm = (&sx2Norm[Policy::Mblk]); + + // Load x & y norms required by this threadblock in shmem buffer + if (gridStrideX == blockIdx.x * Policy::Nblk) { + for (int i = threadIdx.x; i < Policy::Mblk; i += Policy::Nthreads) { + auto idx = gridStrideY + i; + sx2Norm[i] = idx < m ? x2n[idx] : 0; + } + } + + for (int i = threadIdx.x; i < Policy::Nblk; i += Policy::Nthreads) { + auto idx = gridStrideX + i; + sy2Norm[i] = idx < n ? y2n[idx] : 0; + } + __syncthreads(); + +#pragma unroll + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { + regx2n[i] = sx2Norm[i * Policy::AccThRows + (threadIdx.x / Policy::AccThCols)]; + } +#pragma unroll + for (int i = 0; i < Policy::AccColsPerTh; ++i) { + regy2n[i] = sy2Norm[i * Policy::AccThCols + (threadIdx.x % Policy::AccThCols)]; + } + +#pragma unroll + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { +#pragma unroll + for (int j = 0; j < Policy::AccColsPerTh; ++j) { + auto numer = k * acc[i][j] - (regxn[i] * regyn[j]); + auto Q_denom = k * regx2n[i] - (regxn[i] * regxn[i]); + auto R_denom = k * regy2n[j] - (regyn[j] * regyn[j]); + + acc[i][j] = 1 - (numer / raft::sqrt(Q_denom * R_denom)); + } + } + } +}; + +} // namespace cuvs::distance::detail::ops diff --git a/cpp/src/distance/detail/distance_ops/cosine.cuh b/cpp/src/distance/detail/distance_ops/cosine.cuh new file mode 100644 index 000000000..d48731651 --- /dev/null +++ b/cpp/src/distance/detail/distance_ops/cosine.cuh @@ -0,0 +1,85 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include // DI + +namespace cuvs::distance::detail::ops { + +// Epilogue operator for CUTLASS based kernel +template +struct cosine_cutlass_op { + __device__ cosine_cutlass_op() noexcept {} + __device__ AccT operator()(DataT& aNorm, const DataT& bNorm, DataT& accVal) const noexcept + { + return static_cast(1.0) - static_cast(accVal / (aNorm * bNorm)); + } + __device__ AccT operator()(DataT aData) const noexcept { return aData; } +}; + +/** + * @brief the expanded cosine distance matrix calculation + * + * It computes the following equation: + * + * d(x, y) = 1 - (x ⋅ y) / ( ||x||_2 ||y||_2) + */ +template +struct cosine_distance_op { + using DataT = DataType; + using AccT = AccType; + using IdxT = IdxType; + + // Load norms of input data + static constexpr bool use_norms = true; + // Whether the core function requires so many instructions that it makes sense + // to reduce loop unrolling, etc. We do this to keep compile times in check. + static constexpr bool expensive_inner_loop = false; + + // Size of shared memory. This is normally decided by the kernel policy, but + // some ops such as correlation_distance_op use more. + template + static constexpr size_t shared_mem_size() + { + return Policy::SmemSize + ((Policy::Mblk + Policy::Nblk) * sizeof(DataT)); + } + + DI void core(AccT& acc, DataT& x, DataT& y) const { acc += x * y; }; + + template + DI void epilog(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], + DataT* regxn, + DataT* regyn, + IdxT gridStrideX, + IdxT gridStrideY) const + { +#pragma unroll + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { +#pragma unroll + for (int j = 0; j < Policy::AccColsPerTh; ++j) { + acc[i][j] = 1.0 - (acc[i][j] / (regxn[i] * regyn[j])); + } + } + } + + constexpr cosine_cutlass_op get_cutlass_op() const + { + return cosine_cutlass_op(); + } +}; + +} // namespace cuvs::distance::detail::ops diff --git a/cpp/src/distance/detail/distance_ops/cutlass.cuh b/cpp/src/distance/detail/distance_ops/cutlass.cuh new file mode 100644 index 000000000..6d928314d --- /dev/null +++ b/cpp/src/distance/detail/distance_ops/cutlass.cuh @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include // std::false_type +#include // std::declval + +namespace cuvs::distance::detail::ops { + +// This file defines the named requirement "has_cutlass_op" that can be used to +// determine if a distance operation has a CUTLASS op that can be used to pass +// to CUTLASS. Examples of distance operations that satisfy this requirement are +// cosine_distance_op and l2_exp_distance_op. + +// Primary template handles types that do not support CUTLASS. +// This pattern is described in: +// https://en.cppreference.com/w/cpp/types/void_t +template +struct has_cutlass_op : std::false_type {}; + +// Specialization recognizes types that do support CUTLASS +template +struct has_cutlass_op().get_cutlass_op())>> + : std::true_type {}; + +} // namespace cuvs::distance::detail::ops diff --git a/cpp/src/distance/detail/distance_ops/hamming.cuh b/cpp/src/distance/detail/distance_ops/hamming.cuh new file mode 100644 index 000000000..7c6553f38 --- /dev/null +++ b/cpp/src/distance/detail/distance_ops/hamming.cuh @@ -0,0 +1,73 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include // DI + +namespace cuvs::distance::detail::ops { + +/** + * @brief the Hamming Unexpanded distance matrix calculation + * It computes the following equation: + * + * c_ij = sum_k (x_ik != y_kj) / k + */ +template +struct hamming_distance_op { + using DataT = DataType; + using AccT = AccType; + using IdxT = IdxType; + + IdxT k; + + hamming_distance_op(IdxT k_) noexcept : k(k_) {} + + // Load norms of input data + static constexpr bool use_norms = false; + // Whether the core function requires so many instructions that it makes sense + // to reduce loop unrolling, etc. We do this to keep compile times in check. + static constexpr bool expensive_inner_loop = false; + + // Size of shared memory. This is normally decided by the kernel policy, but + // some ops such as correlation_distance_op use more. + template + static constexpr size_t shared_mem_size() + { + return Policy::SmemSize; + } + + DI void core(AccT& acc, DataT& x, DataT& y) const { acc += (x != y); }; + + template + DI void epilog(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], + DataT* regxn, + DataT* regyn, + IdxT gridStrideX, + IdxT gridStrideY) const + { + const DataT one_over_k = DataT(1.0) / k; +#pragma unroll + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { +#pragma unroll + for (int j = 0; j < Policy::AccColsPerTh; ++j) { + acc[i][j] *= one_over_k; + } + } + } +}; + +} // namespace cuvs::distance::detail::ops diff --git a/cpp/src/distance/detail/distance_ops/hellinger.cuh b/cpp/src/distance/detail/distance_ops/hellinger.cuh new file mode 100644 index 000000000..ad5ca3156 --- /dev/null +++ b/cpp/src/distance/detail/distance_ops/hellinger.cuh @@ -0,0 +1,77 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include // DI + +namespace cuvs::distance::detail::ops { + +/** + * @brief the Hellinger distance matrix calculation + * + * It computes the following equation: + * + * c_ij = sqrt(1 - sum_k sqrt(x_ik * y_kj)) + * + */ +template +struct hellinger_distance_op { + using DataT = DataType; + using AccT = AccType; + using IdxT = IdxType; + + // Load norms of input data + static constexpr bool use_norms = false; + // Whether the core function requires so many instructions that it makes sense + // to reduce loop unrolling, etc. We do this to keep compile times in check. + static constexpr bool expensive_inner_loop = false; + + // Size of shared memory. This is normally decided by the kernel policy, but + // some ops such as correlation_distance_op use more. + template + static constexpr size_t shared_mem_size() + { + return Policy::SmemSize; + } + + DI void core(AccT& acc, DataT& x, DataT& y) const + { + // This is sqrt(x) * sqrt(y). + const auto product = x * y; + acc += product; + }; + + template + DI void epilog(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], + DataT* regxn, + DataT* regyn, + IdxT gridStrideX, + IdxT gridStrideY) const + { +#pragma unroll + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { +#pragma unroll + for (int j = 0; j < Policy::AccColsPerTh; ++j) { + // Adjust to replace NaN in sqrt with 0 if input to sqrt is negative + const auto finalVal = (1 - acc[i][j]); + const auto rectifier = (!signbit(finalVal)); + acc[i][j] = raft::sqrt(rectifier * finalVal); + } + } + } +}; + +} // namespace cuvs::distance::detail::ops diff --git a/cpp/src/distance/detail/distance_ops/jensen_shannon.cuh b/cpp/src/distance/detail/distance_ops/jensen_shannon.cuh new file mode 100644 index 000000000..216639494 --- /dev/null +++ b/cpp/src/distance/detail/distance_ops/jensen_shannon.cuh @@ -0,0 +1,81 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include // raft::log +#include // DI + +namespace cuvs::distance::detail::ops { + +// Describes the computation the jensen_shannon distance + +/** + * @brief the Jensen Shannon distance matrix calculation + * + * It computes the following equation: + * + * c_ij = sqrt(0.5 * sum( -x_i * (log(0.5 * (x_i + y_i)) - log(x_i)) + * + (-y_i * (log(0.5 * (x_i + y_i)) - log(y_i))))) + */ +template +struct jensen_shannon_distance_op { + using DataT = DataType; + using AccT = AccType; + using IdxT = IdxType; + + // Load norms of input data + static constexpr bool use_norms = false; + // Whether the core function requires so many instructions that it makes sense + // to reduce loop unrolling, etc. We do this to keep compile times in check. + static constexpr bool expensive_inner_loop = true; + + // Size of shared memory. This is normally decided by the kernel policy, but + // some ops such as correlation_distance_op use more. + template + static constexpr size_t shared_mem_size() + { + return Policy::SmemSize; + } + + DI void core(AccT& acc, DataT& x, DataT& y) const + { + const DataT m = 0.5f * (x + y); + const bool m_zero = (m == 0); + const auto logM = (!m_zero) * raft::log(m + m_zero); + + const bool x_zero = (x == 0); + const bool y_zero = (y == 0); + acc += (-x * (logM - raft::log(x + x_zero))) + (-y * (logM - raft::log(y + y_zero))); + }; + + template + DI void epilog(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], + DataT* regxn, + DataT* regyn, + IdxT gridStrideX, + IdxT gridStrideY) const + { +#pragma unroll + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { +#pragma unroll + for (int j = 0; j < Policy::AccColsPerTh; ++j) { + acc[i][j] = raft::sqrt(0.5 * acc[i][j]); + } + } + } +}; + +} // namespace cuvs::distance::detail::ops diff --git a/cpp/src/distance/detail/distance_ops/kl_divergence.cuh b/cpp/src/distance/detail/distance_ops/kl_divergence.cuh new file mode 100644 index 000000000..929c3a559 --- /dev/null +++ b/cpp/src/distance/detail/distance_ops/kl_divergence.cuh @@ -0,0 +1,99 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include // raft::log +#include // DI + +namespace cuvs::distance::detail::ops { + +/** + * @brief the KL Divergence distance matrix calculation + * + * It computes the following equation: + * + * c_ij = 0.5 * sum(x * log (x / y)); + */ +template +struct kl_divergence_op { + using DataT = DataType; + using AccT = AccType; + using IdxT = IdxType; + + const bool is_row_major; + const bool x_equal_y; + + kl_divergence_op(bool row_major_, bool x_equal_y_ = false) noexcept + : is_row_major(row_major_), x_equal_y(x_equal_y_) + { + } + + // Load norms of input data + static constexpr bool use_norms = false; + // Whether the core function requires so many instructions that it makes sense + // to reduce loop unrolling, etc. We do this to keep compile times in check. + static constexpr bool expensive_inner_loop = true; + + // Size of shared memory. This is normally decided by the kernel policy, but + // some ops such as correlation_distance_op use more. + template + static constexpr size_t shared_mem_size() + { + return Policy::SmemSize; + } + + DI void core(AccT& acc, DataT& x, DataT& y) const + { + // TODO: make sure that these branches get hoisted out of main loop.. Could + // be quite expensive otherwise. + if (x_equal_y) { + if (is_row_major) { + const bool x_zero = (x == 0); + const bool y_zero = (y == 0); + acc += x * (raft::log(x + x_zero) - (!y_zero) * raft::log(y + y_zero)); + } else { + const bool y_zero = (y == 0); + const bool x_zero = (x == 0); + acc += y * (raft::log(y + y_zero) - (!x_zero) * raft::log(x + x_zero)); + } + } else { + if (is_row_major) { + const bool x_zero = (x == 0); + acc += x * (raft::log(x + x_zero) - y); + } else { + const bool y_zero = (y == 0); + acc += y * (raft::log(y + y_zero) - x); + } + } + }; + + template + DI void epilog(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], + DataT* regxn, + DataT* regyn, + IdxT gridStrideX, + IdxT gridStrideY) const + { +#pragma unroll + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { +#pragma unroll + for (int j = 0; j < Policy::AccColsPerTh; ++j) { + acc[i][j] = (0.5f * acc[i][j]); + } + } + } +}; +} // namespace cuvs::distance::detail::ops diff --git a/cpp/src/distance/detail/distance_ops/l1.cuh b/cpp/src/distance/detail/distance_ops/l1.cuh new file mode 100644 index 000000000..76eaffaf3 --- /dev/null +++ b/cpp/src/distance/detail/distance_ops/l1.cuh @@ -0,0 +1,62 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include // DI + +namespace cuvs::distance::detail::ops { + +/** + * @brief the L1 distance matrix calculation + * + * It computes the following equation: + * + * c_ij = sum_k abs(x_ik - y_kj) + */ +template +struct l1_distance_op { + using DataT = DataType; + using AccT = AccType; + using IdxT = IdxType; + + // Do not load norms of data, the computation of L1 distance does not use them. + static constexpr bool use_norms = false; + // Whether the core function requires so many instructions that it makes sense + // to reduce loop unrolling, etc. We do this to keep compile times in check. + static constexpr bool expensive_inner_loop = false; + + // Size of shared memory. This is normally decided by the kernel policy, but + // some ops such as correlation_distance_op use more. + template + static constexpr size_t shared_mem_size() + { + return Policy::SmemSize; + } + + DI void core(AccT& acc, DataT& x, DataT& y) const { acc += raft::abs(x - y); }; + + template + DI void epilog(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], + DataT* regxn, + DataT* regyn, + IdxT gridStrideX, + IdxT gridStrideY) const + { + return; + }; +}; + +} // namespace cuvs::distance::detail::ops diff --git a/cpp/src/distance/detail/distance_ops/l2_exp.cuh b/cpp/src/distance/detail/distance_ops/l2_exp.cuh new file mode 100644 index 000000000..f45c41206 --- /dev/null +++ b/cpp/src/distance/detail/distance_ops/l2_exp.cuh @@ -0,0 +1,136 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include // DI + +namespace cuvs::distance::detail::ops { + +/** + * Reserve 1 digit of precision from each floating-point type + * for round-off error tolerance. + * @tparam DataT + */ +template +__device__ constexpr DataT get_clamp_precision() +{ + switch (sizeof(DataT)) { + case 2: return 1e-3; + case 4: return 1e-6; + case 8: return 1e-15; + default: return 0; + } +} + +// Epilogue operator for CUTLASS based kernel +template +struct l2_exp_cutlass_op { + bool sqrt; + + __device__ l2_exp_cutlass_op() noexcept : sqrt(false) {} + __device__ l2_exp_cutlass_op(bool isSqrt) noexcept : sqrt(isSqrt) {} + inline __device__ AccT operator()(DataT aNorm, DataT bNorm, DataT accVal) const noexcept + { + AccT outVal = aNorm + bNorm - DataT(2.0) * accVal; + + /** + * Self-neighboring points should have (aNorm == bNorm) == accVal and the dot product (accVal) + * can sometimes have round-off errors, which will cause (aNorm == bNorm) ~ accVal instead. + */ + outVal = outVal * !((outVal * outVal < get_clamp_precision()) * (aNorm == bNorm)); + return sqrt ? raft::sqrt(outVal * (outVal > 0)) : outVal; + } + + __device__ AccT operator()(DataT aData) const noexcept { return aData; } +}; + +/** + * @brief the expanded euclidean distance matrix calculation + * + * It computes the following equation: + * + * c_ij = - 2 sum_k x_ik * y_kj + ||x_i.||_2 + ||y_.j||_2 + * + */ +template +struct l2_exp_distance_op { + using DataT = DataType; + using AccT = AccType; + using IdxT = IdxType; + + const bool sqrt; + + l2_exp_distance_op(bool sqrt_) noexcept : sqrt(sqrt_) {} + + // Load norms of input data + static constexpr bool use_norms = true; + // Whether the core function requires so many instructions that it makes sense + // to reduce loop unrolling, etc. We do this to keep compile times in check. + static constexpr bool expensive_inner_loop = false; + + // Size of shared memory. This is normally decided by the kernel policy, but + // some ops such as correlation_distance_op use more. + template + static constexpr size_t shared_mem_size() + { + return Policy::SmemSize + ((Policy::Mblk + Policy::Nblk) * sizeof(DataT)); + } + + DI void core(AccT& acc, DataT& x, DataT& y) const { acc += x * y; }; + + template + DI void epilog(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], + DataT* regxn, + DataT* regyn, + IdxT gridStrideX, + IdxT gridStrideY) const + { +#pragma unroll + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { +#pragma unroll + for (int j = 0; j < Policy::AccColsPerTh; ++j) { + DataT accVal = acc[i][j]; + DataT val = regxn[i] + regyn[j] - (DataT)2.0 * accVal; + + /** + * Self-neighboring points should have (aNorm == bNorm) == accVal and the dot product + * (accVal) can sometimes have round-off errors, which will cause (aNorm == bNorm) ~ accVal + * instead. + */ + acc[i][j] = + val * (val > 0) * !((val * val < get_clamp_precision()) * (regxn[i] == regyn[j])); + } + } + if (sqrt) { +#pragma unroll + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { +#pragma unroll + for (int j = 0; j < Policy::AccColsPerTh; ++j) { + acc[i][j] = raft::sqrt(acc[i][j]); + } + } + } + } + + constexpr l2_exp_cutlass_op get_cutlass_op() const + { + return l2_exp_cutlass_op(sqrt); + } +}; + +} // namespace cuvs::distance::detail::ops diff --git a/cpp/src/distance/detail/distance_ops/l2_unexp.cuh b/cpp/src/distance/detail/distance_ops/l2_unexp.cuh new file mode 100644 index 000000000..aa6cc27f3 --- /dev/null +++ b/cpp/src/distance/detail/distance_ops/l2_unexp.cuh @@ -0,0 +1,79 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include // DI + +namespace cuvs::distance::detail::ops { + +/** + * @brief the unexpanded euclidean distance matrix calculation + * + * It computes the following equation: + * + * c_ij = optional_sqrt ( sum_k (x_ik - y_kj)^2 ) + */ +template +struct l2_unexp_distance_op { + using DataT = DataType; + using AccT = AccType; + using IdxT = IdxType; + + bool sqrt; + + l2_unexp_distance_op(bool sqrt_) noexcept : sqrt(sqrt_) {} + + // Do not load norms of data, the computation of L1 distance does not use them. + static constexpr bool use_norms = false; + // Whether the core function requires so many instructions that it makes sense + // to reduce loop unrolling, etc. We do this to keep compile times in check. + static constexpr bool expensive_inner_loop = false; + + // Size of shared memory. This is normally decided by the kernel policy, but + // some ops such as correlation_distance_op use more. + template + static constexpr size_t shared_mem_size() + { + return Policy::SmemSize; + } + + DI void core(AccT& acc, DataT& x, DataT& y) const + { + const auto diff = x - y; + acc += diff * diff; + }; + + template + DI void epilog(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], + DataT* regxn, + DataT* regyn, + IdxT gridStrideX, + IdxT gridStrideY) const + { + if (sqrt) { +#pragma unroll + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { +#pragma unroll + for (int j = 0; j < Policy::AccColsPerTh; ++j) { + acc[i][j] = raft::sqrt(acc[i][j]); + } + } + } + }; +}; + +} // namespace cuvs::distance::detail::ops diff --git a/cpp/src/distance/detail/distance_ops/l_inf.cuh b/cpp/src/distance/detail/distance_ops/l_inf.cuh new file mode 100644 index 000000000..d8f9384d7 --- /dev/null +++ b/cpp/src/distance/detail/distance_ops/l_inf.cuh @@ -0,0 +1,67 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include // DI + +namespace cuvs::distance::detail::ops { + +/** + * @brief the L_inf (Chebyshev) distance matrix calculation + * + * It computes the following equation: + * + * c_ij = max_k | x_ik - y_kj | + */ +template +struct l_inf_distance_op { + using DataT = DataType; + using AccT = AccType; + using IdxT = IdxType; + + // Load norms of input data + static constexpr bool use_norms = false; + // Whether the core function requires so many instructions that it makes sense + // to reduce loop unrolling, etc. We do this to keep compile times in check. + static constexpr bool expensive_inner_loop = false; + + // Size of shared memory. This is normally decided by the kernel policy, but + // some ops such as correlation_distance_op use more. + template + static constexpr size_t shared_mem_size() + { + return Policy::SmemSize; + } + + DI void core(AccT& acc, DataT& x, DataT& y) const + { + const auto diff = raft::abs(x - y); + acc = raft::max(acc, diff); + }; + + template + DI void epilog(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], + DataT* regxn, + DataT* regyn, + IdxT gridStrideX, + IdxT gridStrideY) const + { + return; + } +}; + +} // namespace cuvs::distance::detail::ops diff --git a/cpp/src/distance/detail/distance_ops/lp_unexp.cuh b/cpp/src/distance/detail/distance_ops/lp_unexp.cuh new file mode 100644 index 000000000..6136f9f3e --- /dev/null +++ b/cpp/src/distance/detail/distance_ops/lp_unexp.cuh @@ -0,0 +1,78 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include // raft::pow, raft::abs +#include // DI + +namespace cuvs::distance::detail::ops { + +/** + * @brief the unexpanded Lp (Minkowski) distance matrix calculation + * + * It computes the following equation: + * + * c_ij = (sum_k |x_ik - y_jk|^p)^(1/p) + */ +template +struct lp_unexp_distance_op { + using DataT = DataType; + using AccT = AccType; + using IdxT = IdxType; + + DataT p; + + lp_unexp_distance_op(DataT p_) noexcept : p(p_) {} + + // Load norms of input data + static constexpr bool use_norms = false; + // Whether the core function requires so many instructions that it makes sense + // to reduce loop unrolling, etc. We do this to keep compile times in check. + static constexpr bool expensive_inner_loop = true; + + // Size of shared memory. This is normally decided by the kernel policy, but + // some ops such as correlation_distance_op use more. + template + static constexpr size_t shared_mem_size() + { + return Policy::SmemSize; + } + + DI void core(AccT& acc, DataT& x, DataT& y) const + { + const auto diff = raft::abs(x - y); + acc += raft::pow(diff, p); + }; + + template + DI void epilog(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], + DataT* regxn, + DataT* regyn, + IdxT gridStrideX, + IdxT gridStrideY) const + { + const auto one_over_p = 1.0f / p; +#pragma unroll + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { +#pragma unroll + for (int j = 0; j < Policy::AccColsPerTh; ++j) { + acc[i][j] = raft::pow(acc[i][j], one_over_p); + } + } + } +}; + +} // namespace cuvs::distance::detail::ops diff --git a/cpp/src/distance/detail/distance_ops/russel_rao.cuh b/cpp/src/distance/detail/distance_ops/russel_rao.cuh new file mode 100644 index 000000000..5dffdcdb8 --- /dev/null +++ b/cpp/src/distance/detail/distance_ops/russel_rao.cuh @@ -0,0 +1,74 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include // DI + +namespace cuvs::distance::detail::ops { + +/** + * @brief the Russell Rao distance matrix calculation + * + * It computes the following equation: + * + * c_ij = (k - (sum_k x_ik * y_kj)) / k + */ +template +struct russel_rao_distance_op { + using DataT = DataType; + using AccT = AccType; + using IdxT = IdxType; + + IdxT k; + const float one_over_k; + + russel_rao_distance_op(IdxT k_) noexcept : k(k_), one_over_k(1.0f / k_) {} + + // Load norms of input data + static constexpr bool use_norms = false; + // Whether the core function requires so many instructions that it makes sense + // to reduce loop unrolling, etc. We do this to keep compile times in check. + static constexpr bool expensive_inner_loop = false; + + // Size of shared memory. This is normally decided by the kernel policy, but + // some ops such as correlation_distance_op use more. + template + static constexpr size_t shared_mem_size() + { + return Policy::SmemSize; + } + + DI void core(AccT& acc, DataT& x, DataT& y) const { acc += x * y; }; + + template + DI void epilog(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], + DataT* regxn, + DataT* regyn, + IdxT gridStrideX, + IdxT gridStrideY) const + { +#pragma unroll + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { +#pragma unroll + for (int j = 0; j < Policy::AccColsPerTh; ++j) { + acc[i][j] = (k - acc[i][j]) * one_over_k; + } + } + } +}; + +} // namespace cuvs::distance::detail::ops diff --git a/cpp/src/distance/detail/distance_ops/template.cuh b/cpp/src/distance/detail/distance_ops/template.cuh new file mode 100644 index 000000000..bdb933237 --- /dev/null +++ b/cpp/src/distance/detail/distance_ops/template.cuh @@ -0,0 +1,68 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include // DI + +namespace cuvs::distance::detail::ops { + +// Describes the computation the template distance +// +// Fill in the TODO items. + +template +struct template_distance_op { + using DataT = DataType; + using AccT = AccType; + using IdxT = IdxType; + + TODO member; + + template_distance_op(TODO member_) noexcept : member(member_) {} + + // Load norms of input data + static constexpr bool use_norms = TODO; + // Whether the core function requires so many instructions that it makes sense + // to reduce loop unrolling, etc. We do this to keep compile times in check. + static constexpr bool expensive_inner_loop = false; + + // Size of shared memory. This is normally decided by the kernel policy, but + // some ops such as correlation_distance_op use more. + template + static constexpr size_t shared_mem_size() + { + return Policy::SmemSize + TODO; + } + + DI void core(AccT& acc, DataT& x, DataT& y) const { TODO; }; + + template + DI void epilog(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], + DataT* regxn, + DataT* regyn, + IdxT gridStrideX, + IdxT gridStrideY) const + { + TODO; + } + + // If exist, returns a cutlass op that performs the same operation. + // See cosine and l2_exp distance ops for an example. + constexpr l2_exp_cutlass_op get_cutlass_op() const { TODO; } +}; + +} // namespace cuvs::distance::detail::ops diff --git a/cpp/src/distance/detail/fused_distance_nn.cuh b/cpp/src/distance/detail/fused_distance_nn.cuh new file mode 100644 index 000000000..1bf8793fd --- /dev/null +++ b/cpp/src/distance/detail/fused_distance_nn.cuh @@ -0,0 +1,97 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "distance_ops/l2_exp.cuh" // ops::l2_exp_distance_op +#include "fused_distance_nn/cutlass_base.cuh" +#include "fused_distance_nn/fused_cosine_nn.cuh" +#include "fused_distance_nn/fused_l2_nn.cuh" +#include "fused_distance_nn/helper_structs.cuh" +#include "fused_distance_nn/simt_kernel.cuh" +#include "pairwise_distance_base.cuh" // PairwiseDistances +#include +#include // raft::KeyValuePair +#include // raft::identity_op +#include // Policy +#include // raft::util::arch::SM_* +#include // raft::ceildiv, raft::shfl + +#include // size_t +#include // std::numeric_limits + +namespace cuvs { +namespace distance { + +namespace detail { + +template +void fusedDistanceNNImpl(OutT* min, + const DataT* x, + const DataT* y, + const DataT* xn, + const DataT* yn, + IdxT m, + IdxT n, + IdxT k, + int* workspace, + ReduceOpT redOp, + KVPReduceOpT pairRedOp, + bool sqrt, + bool initOutBuffer, + bool isRowMajor, + cuvs::distance::DistanceType metric, + float metric_arg, + cudaStream_t stream) +{ + // The kernel policy is determined by fusedDistanceNN. + typedef Policy P; + + dim3 blk(P::Nthreads); + auto nblks = raft::ceildiv(m, P::Nthreads); + constexpr auto maxVal = std::numeric_limits::max(); + typedef raft::KeyValuePair KVPair; + + RAFT_CUDA_TRY(cudaMemsetAsync(workspace, 0, sizeof(int) * m, stream)); + if (initOutBuffer) { + initKernel + <<>>(min, m, maxVal, redOp); + RAFT_CUDA_TRY(cudaGetLastError()); + } + + switch (metric) { + case cuvs::distance::DistanceType::CosineExpanded: + fusedCosineNN( + min, x, y, xn, yn, m, n, k, workspace, redOp, pairRedOp, sqrt, stream); + break; + case cuvs::distance::DistanceType::L2SqrtExpanded: + case cuvs::distance::DistanceType::L2Expanded: + // initOutBuffer is take care by fusedDistanceNNImpl() so we set it false to fusedL2NNImpl. + fusedL2NNImpl( + min, x, y, xn, yn, m, n, k, workspace, redOp, pairRedOp, sqrt, false, stream); + break; + default: assert("only cosine/l2 metric is supported with fusedDistanceNN\n"); break; + } +} + +} // namespace detail +} // namespace distance +} // namespace cuvs diff --git a/cpp/src/distance/detail/fused_distance_nn/custom_epilogue_with_broadcast.h b/cpp/src/distance/detail/fused_distance_nn/custom_epilogue_with_broadcast.h new file mode 100644 index 000000000..186715851 --- /dev/null +++ b/cpp/src/distance/detail/fused_distance_nn/custom_epilogue_with_broadcast.h @@ -0,0 +1,668 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/* + * Copyright (c) 2023-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! \file + + \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. + + The epilogue rearranges the result of a matrix product through shared memory to match canonical + tensor layouts in global memory. Epilogues support conversion and reduction operations. + +This file contains a customized version of EpilogueWithBroadcast from CUTLASS 2.9.1 +(https://github.com/NVIDIA/cutlass/blob/v2.9.1/include/cutlass/epilogue/threadblock/epilogue_with_broadcast.h) + +Changes: +- customized the compute_source_needed_() and apply_output_operator_() to suit the needs of per row +reduction +*/ + +#pragma once + +#if defined(__CUDACC_RTC__) +#include +#include +#else +#include + +#include +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// This base class is meant to define the concept required of the +/// EpilogueWithBroadcast::OutputOp +template +struct EpilogueWithBroadcastOpBaseCustom { + using ElementOutput = ElementC_; + using ElementAccumulator = ElementAccumulator_; + using ElementCompute = ElementCompute_; + using ElementZ = ElementZ_; + using ElementT = ElementT_; + static int const kElementsPerAccess = ElementsPerAccess; + + using FragmentAccumulator = Array; + using FragmentCompute = Array; + using FragmentC = Array; + using FragmentZ = Array; + using FragmentT = Array; + + /// If true, the 'Z' tensor is stored + static bool const kStoreZ = StoreZ; + + /// If true, the 'T' tensor is stored + static bool const kStoreT = StoreT; + + /// Parameters structure - required + struct Params {}; + + // + // Methods + // + + /// Constructor from Params + EpilogueWithBroadcastOpBaseCustom(Params const& params_) {} + + /// Determine if the source is needed. May return false if + bool is_source_needed() const { return true; } + + CUTLASS_HOST_DEVICE + void set_k_partition(int k_partition, int k_partition_count) {} + + /// Applies the operation when is_source_needed() is true + CUTLASS_HOST_DEVICE + void operator()(FragmentZ& frag_Z, + FragmentT& frag_T, + FragmentAccumulator const& AB, + FragmentC const& frag_C, + FragmentCompute const& V) const + { + } + + /// Applies the operation when is_source_needed() is false + CUTLASS_HOST_DEVICE + void operator()(FragmentZ& frag_Z, + FragmentT& frag_T, + FragmentAccumulator const& AB, + FragmentCompute const& V) const + { + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Epilogue operator with bias vector broadcast over columns. +/// +/// Computes the following: +/// +/// +/// Z, T = OutputOp(AB, C, Broadcast) +/// +/// if (ElementwiseOp::kStoreZ) { +/// store(converted_u); +/// } +/// +/// if (ElementwiseOp::kStoreT) { +/// store(v); +/// } +/// +template < + typename Shape_, ///< Shape of threadblock tile (concept: GemmShape) + typename WarpMmaOperator_, ///< Warp-level MMA operator (concept: gemm::warp::MmaTensorOp) + int PartitionsK, ///< Number of partitions of the K dimension + typename OutputTileIterator_, ///< Tile iterator reading and writing output tensors (z) + typename TensorTileIterator_, ///< Additional tile iterator for tensor-valued operands (t) + typename ElementVector_, ///< Pointer to broadcast vector + typename AccumulatorFragmentIterator_, ///< Fragment iterator selecting accumulators + typename WarpTileIterator_, ///< Warp-scoped tile iterator writing accumulators to SMEM + typename SharedLoadIterator_, ///< Threadblock-scoped tile iterator loading from SMEM + typename OutputOp_, ///< Output operator - concept is EpilogueWithBroadcastOp + typename Padding_, ///< Padding added to SMEM allocation to avoid bank conflicts (concept: + ///< MatrixShape) + int FragmentsPerPartition = 1, ///< Used to coarsten the epilogue granularity + int IterationsUnroll = ///< Used to reduce binary size when epilogue op is large + (!IsEpilogueFunctorHeavy::value)> +class EpilogueWithBroadcastCustom : public EpilogueBase { + public: + using Base = EpilogueBase; + + using Shape = Shape_; + using WarpMmaOperator = WarpMmaOperator_; + static int const kPartitionsK = PartitionsK; + using OutputTileIterator = OutputTileIterator_; + using TensorTileIterator = TensorTileIterator_; + using ElementVector = ElementVector_; + using AccumulatorFragmentIterator = AccumulatorFragmentIterator_; + using WarpTileIterator = WarpTileIterator_; + using SharedLoadIterator = SharedLoadIterator_; + using OutputOp = OutputOp_; + using Padding = Padding_; + + using Layout = layout::RowMajor; + using LongIndex = typename Layout::LongIndex; + + /// The complete warp-level accumulator tile + using AccumulatorTile = typename Base::AccumulatorTile; + + /// Accumulator element + using ElementAccumulator = typename WarpTileIterator::Element; + + /// Compute data type produced by the output op + using ElementCompute = typename OutputOp::ElementCompute; + + /// Compute fragment + using FragmentCompute = Array; + + /// Thread map used by output tile iterators + using ThreadMap = typename OutputTileIterator::ThreadMap; + + /// Fragment object used to store the broadcast values + using BroadcastFragment = + Array; + + /// Output element + using ElementOutput = typename OutputTileIterator::Element; + + /// Data type of additional tensor + using ElementTensor = typename TensorTileIterator::Element; + + /// Output access size + static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess; + + /// Tensor reference to destination tensor + using TensorRef = typename OutputTileIterator::TensorRef; + + /// Tensor reference to sync tensor + using SyncTensorRef = typename cutlass::TensorRef; + + /// Const tensor reference to source tensor + using ConstTensorRef = typename OutputTileIterator::ConstTensorRef; + + /// Array type used to output + using OutputAccessType = + Array; + + /// Array type used by output functor + using AccumulatorAccessType = + Array; + + /// Array type used by output functor + using ComputeAccessType = Array; + + /// Tensor access type + using TensorAccessType = Array; + + /// Number of warps + using WarpCount = typename Base::WarpCount; + + /// Shared memory allocation from epilogue base class + using BaseSharedStorage = typename Base::SharedStorage; + + static int constexpr kSmemTiles = + Base::kFragmentsPerIteration > 1 ? Base::kFragmentsPerIteration : kPartitionsK; + static int constexpr kSmemPointerOffset = Base::SharedStorage::StorageShape::kCount / kSmemTiles; + + /// Used for the broadcast + struct BroadcastDetail { + /// Number of threads per warp + static int const kWarpSize = 32; + + static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; + + /// Number of distinct scalar column indices handled by each thread + static int const kColumnsPerThread = + ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess; + + /// Number of distinct scalar row indices handled by each thread + static int const kRowsPerThread = + ThreadMap::Iterations::kCount / ThreadMap::Iterations::kColumn; + + /// Number of threads per threadblock + static int const kThreadCount = kWarpSize * WarpCount::kCount; + + /// Number of distinct threads per row of output tile + static int const kThreadsPerRow = (Shape::kN / kColumnsPerThread); + + /// Number of distinct threads which must be reduced during the final reduction phase within the + /// threadblock. + static int const kThreadRows = kThreadCount / kThreadsPerRow; + + /// I'm not sure what I meant here. + static int const kThreadAccessesPerRow = + const_max(1, (Shape::kN + kThreadCount - 1) / kThreadCount); + + /// Shape of the shared memory allocation for the epilogue + using StorageShape = MatrixShape; + + /// Debug printing + CUTLASS_DEVICE + static void print() + { +#if 0 + printf("BroadcastDetail {\n"); + printf( + " kColumnsPerThread: %d\nkRowsPerThread: %d\n,kThreadCount: %d\nkThreadsPerRow: %d\n" + "kThreadRows: %d\nThreadAccessesPerRow: %d\nStorageShape: %d x %d (count: %d)\n", + kColumnsPerThread, + kRowsPerThread, + kThreadCount, + kThreadsPerRow, + kThreadRows, + kThreadAccessesPerRow, + StorageShape::kRow, + StorageShape::kColumn, + StorageShape::kCount + ); + printf("};\n"); +#endif + } + }; + + /// Shared storage structure (shadows base) with additional SMEM buffer for reduction + struct SharedStorage { + union { + BaseSharedStorage base; + }; + + CUTLASS_HOST_DEVICE + SharedStorage() {} + }; + + public: + static_assert(SharedLoadIterator::Fragment::kElements == TensorTileIterator::Fragment::kElements, + "Mismatch between shared load iterator and output tile iterator."); + + static_assert(OutputTileIterator::kElementsPerAccess, + "OutputTileIterator::kElementsPerAccess must not be zero."); + + static_assert(!(OutputTileIterator::Fragment::kElements % OutputTileIterator::kElementsPerAccess), + "Divisibility"); + + private: + /// Loads fragment from shared memory aligned with output tensor + SharedLoadIterator shared_load_iterator_; + + /// Thread index within the threadblock + int thread_idx_; + + public: + /// Constructor + CUTLASS_DEVICE + EpilogueWithBroadcastCustom(SharedStorage& shared_storage, ///< Shared storage object + int thread_idx, ///< ID of a thread within the threadblock + int warp_idx, ///< ID of warp within threadblock + int lane_idx ///< Id of thread within warp + ) + : Base(shared_storage.base, thread_idx, warp_idx, lane_idx), + shared_load_iterator_(shared_storage.base.reference(), thread_idx), + thread_idx_(thread_idx) + { + } + + /// Streams the result to global memory + CUTLASS_DEVICE + void operator()( + OutputOp const& output_op, ///< Output operator + ElementVector const* broadcast_ptr, ///< Broadcast vector + AccumulatorTile const& accumulators, ///< Complete warp-level accumulator tile + OutputTileIterator source_iterator, ///< Tile iterator for source accumulator matrix + TensorTileIterator + tensor_iterator, ///< Threadblock tile iterator for additional tensor operand + MatrixCoord const& + problem_size = ///< Problem size needed to guard against out-of-bounds accesses + MatrixCoord(Shape::kM, Shape::kN), + MatrixCoord const& + threadblock_offset = ///< Threadblock's initial offset within the problem size space + MatrixCoord()) + { + BroadcastFragment broadcast_fragment; + + load_broadcast_fragment_(broadcast_fragment, broadcast_ptr, problem_size, threadblock_offset); + + compute_source_needed_( + output_op, broadcast_fragment, accumulators, source_iterator, tensor_iterator); + } + + private: + CUTLASS_DEVICE + void load_broadcast_fragment_( + BroadcastFragment& + broadcast_fragment, ///< Fragment containing the accumulated partial reduction over columns + ElementVector const* broadcast_ptr, ///< Broadcast vector + MatrixCoord const& + problem_size, ///< Problem size needed to guard against out-of-bounds accesses + MatrixCoord const& + threadblock_offset ///< Threadblock's initial offset within the problem size space + ) + { + broadcast_fragment.clear(); + + // If no pointer is supplied, set with all zeros and avoid memory accesses + if (!broadcast_ptr) { return; } + + int thread_initial_column = ThreadMap::initial_offset(thread_idx_).column(); + + int thread_column_idx = threadblock_offset.column() + thread_initial_column; + broadcast_ptr += thread_initial_column; + + NumericArrayConverter + converter; + using AccessType = AlignedArray; + using ComputeFragmentType = Array; + + ComputeFragmentType* frag_ptr = reinterpret_cast(&broadcast_fragment); + + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < ThreadMap::Iterations::kColumn; ++j) { + AccessType loaded; + + loaded.clear(); + + if (thread_column_idx < problem_size.column()) { + loaded = *reinterpret_cast(broadcast_ptr); + } + + ComputeFragmentType cvt = converter(loaded); + frag_ptr[j] = cvt; + + thread_column_idx += ThreadMap::Delta::kColumn; + broadcast_ptr += ThreadMap::Delta::kColumn; + } + } + + template + struct acc2smem_source_not_needed; + + template + struct acc2smem_source_not_needed> { + template + CUTLASS_DEVICE static void helper(AccumulatorFragmentIterator accum_fragment_iterator, + WarpTileIterator& warp_tile_iterator) + { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Advance; i++) { + ++accum_fragment_iterator; + } + + CUTLASS_PRAGMA_UNROLL + for (int p = 0; p < Base::kFragmentsPerIteration; ++p) { + typename AccumulatorFragmentIterator::Fragment accum_fragment; + + accum_fragment_iterator.load(accum_fragment); + ++accum_fragment_iterator; + + warp_tile_iterator.store(accum_fragment); + if (p < Base::kFragmentsPerIteration - 1) { + warp_tile_iterator.add_pointer_offset(kSmemPointerOffset); + } + } + + if (Base::kFragmentsPerIteration > 1) { + warp_tile_iterator.add_pointer_offset(kSmemPointerOffset * + (1 - Base::kFragmentsPerIteration)); + } + } + + CUTLASS_DEVICE + static void push(size_t pos, + AccumulatorFragmentIterator const& iterator_begin, + WarpTileIterator& warp_tile_iterator) + { + int dummy[] = { + (pos == (Seq * Base::kFragmentsPerIteration)) && + (helper(iterator_begin, warp_tile_iterator), 0)...}; + + CUTLASS_UNUSED(dummy[0]); + } + }; + + /// Streams the result to global memory + CUTLASS_DEVICE + void compute_source_not_needed_( + OutputOp const& output_op, ///< Output operator + BroadcastFragment const& + broadcast_fragment, ///< Fragment containing the accumulated partial reduction over columns + OutputTileIterator destination_iterator, ///< Tile iterator for destination + AccumulatorTile const& accumulators, ///< Complete warp-level accumulator tile + TensorTileIterator tensor_iterator ///< Threadblock tile iterator for additioanl tensor operand + ) + { + } + + template + struct acc2smem_source_needed; + + template + struct acc2smem_source_needed> { + template + CUTLASS_DEVICE static void helper(AccumulatorFragmentIterator accum_fragment_iterator, + WarpTileIterator& warp_tile_iterator) + { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Advance; i++) { + ++accum_fragment_iterator; + } + + typename AccumulatorFragmentIterator::Fragment accum_fragment; + accum_fragment_iterator.load(accum_fragment); + warp_tile_iterator.store(accum_fragment); + } + + CUTLASS_DEVICE + static void push(size_t pos, + AccumulatorFragmentIterator const& iterator_begin, + WarpTileIterator& warp_tile_iterator) + { + int dummy[] = {(pos == Seq) && (helper(iterator_begin, warp_tile_iterator), 0)...}; + } + }; + + /// Streams the result to global memory + CUTLASS_DEVICE + void compute_source_needed_( + OutputOp const& output_op, ///< Output operator + BroadcastFragment const& + broadcast_fragment, ///< Fragment containing the accumulated partial reduction over columns + AccumulatorTile const& accumulators, ///< Complete warp-level accumulator tile + OutputTileIterator + source_iterator, ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles) + TensorTileIterator tensor_iterator ///< Threadblock tile iterator for additioanl tensor operand + ) + { + typename OutputTileIterator::Fragment source_fragment; + source_fragment.clear(); + + // + // Iterator over warp-level accumulator fragment + // + + AccumulatorFragmentIterator accum_fragment_iterator(accumulators); + + // + // Iterate over accumulator tile + // + +#pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations : 1) + for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) { + // + // Convert and store fragment + // + + //__syncthreads(); + + acc2smem_source_needed>::push( + iter, accum_fragment_iterator, this->warp_tile_iterator_); + + __syncthreads(); + + // + // Load fragments from shared memory + // + + typename SharedLoadIterator::Fragment aligned_accum_fragment[kPartitionsK]; + + shared_load_iterator_.load(aligned_accum_fragment[0]); + + // + // Apply output operation + // + + typename TensorTileIterator::Fragment frag_T; + + // + // Load the source + // + + source_iterator.load(source_fragment); + ++source_iterator; + + apply_output_operator_( + frag_T, output_op, aligned_accum_fragment[0], source_fragment, broadcast_fragment); + + // + // Conditionally store fragments + // + if (OutputOp::kStoreT) { + tensor_iterator.store(frag_T); + ++tensor_iterator; + } + } + tensor_iterator.dumpToGmem(); + } + + /// Helper to invoke the output functor over each vector of output + CUTLASS_DEVICE + void apply_output_operator_(typename TensorTileIterator::Fragment& frag_T, + OutputOp const& output_op, + typename SharedLoadIterator::Fragment const& frag_AB, + typename OutputTileIterator::Fragment const& frag_C, + BroadcastFragment const& frag_Broadcast) + { + using AccessTypeT = Array; + using AccessTypeBroadcast = Array; + + AccessTypeT* frag_T_ptr = reinterpret_cast(&frag_T); + + AccumulatorAccessType const* frag_AB_ptr = + reinterpret_cast(&frag_AB); + + OutputAccessType const* frag_C_ptr = reinterpret_cast(&frag_C); + + AccessTypeBroadcast const* frag_Broadcast_ptr = + reinterpret_cast(&frag_Broadcast); + + int const kOutputOpIterations = + TensorTileIterator::Fragment::kElements / TensorTileIterator::kElementsPerAccess; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kOutputOpIterations; ++i) { + output_op(frag_T_ptr[i], + frag_AB_ptr[i], + frag_C_ptr[(i / ThreadMap::Iterations::kColumn)], + frag_Broadcast_ptr[i % ThreadMap::Iterations::kColumn]); + } + } + + /// Helper to invoke the output functor over each vector of output + CUTLASS_DEVICE + void apply_output_operator_source_not_needed_( + typename OutputTileIterator::Fragment& frag_Z, + typename TensorTileIterator::Fragment& frag_T, + OutputOp const& output_op, + typename SharedLoadIterator::Fragment const& frag_AB, + BroadcastFragment const& frag_Broadcast) + { + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/cpp/src/distance/detail/fused_distance_nn/cutlass_base.cuh b/cpp/src/distance/detail/fused_distance_nn/cutlass_base.cuh new file mode 100644 index 000000000..f58c48aff --- /dev/null +++ b/cpp/src/distance/detail/fused_distance_nn/cutlass_base.cuh @@ -0,0 +1,177 @@ +/* + * Copyright (c) 2023-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wstrict-aliasing" +#pragma GCC diagnostic ignored "-Wtautological-compare" + +// We define CUTLASS_NAMESPACE in case +// RAFT cmake is not used +#ifndef CUTLASS_NAMESPACE +#define cutlass raft_cutlass +#endif + +#include "epilogue_elementwise.cuh" // FusedDistanceNNEpilogueElementwise +#include "gemm.h" // FusedDistanceNNGemm +#include // getMultiProcessorCount +#include // RAFT_CUTLASS_TRY + +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace cuvs { +namespace distance { +namespace detail { + +template +RAFT_KERNEL initBinMutexKernel(cuda::binary_semaphore* mut, IdxT m) +{ + auto tid = IdxT(blockIdx.x) * blockDim.x + threadIdx.x; + + if (tid < m) { mut[tid].release(); } +} + +template +void cutlassFusedDistanceNN(const DataT* x, + const DataT* y, + const DataT* xn, + const DataT* yn, + IdxT m, + IdxT n, + IdxT k, + IdxT lda, + IdxT ldb, + IdxT ldd, + OutT* dOutput, + int* mutexes, + CGReduceOpT cg_reduce_op, + DistanceFn dist_op, + ReduceOpT redOp, + KVPReduceOpT pairRedOp, + cudaStream_t stream) +{ + using EpilogueOutputOp = cutlass::epilogue::thread::FusedDistanceNNEpilogueElementwise< + DataT, // ElementC_ + AccT, // ElementAccumulator_ + DataT, // ElementCompute_ + AccT, // ElementZ_ + OutT, // ElementT_ + // 128 / cutlass::sizeof_bits::value, + 1, // Elements per access 1 + DistanceFn, + CGReduceOpT, + ReduceOpT, + KVPReduceOpT>; + constexpr int batch_count = 1; + + rmm::device_uvector> bin_mutex(m, stream); + + int blks_ = (m / 256) + 1; + + initBinMutexKernel<<>>(bin_mutex.data(), m); + + typename EpilogueOutputOp::Params epilog_op_param( + dist_op, cg_reduce_op, redOp, pairRedOp, mutexes, bin_mutex.data()); + + // Number of pipelines you want to use + constexpr int NumStages = 3; + // Alignment + constexpr int Alignment = VecLen; + + // default initialize problem size with row major inputs + auto problem_size = cutlass::gemm::GemmCoord(m, n, k); + + constexpr bool isRowMajor = true; + + using fusedDistanceNNKernel = + typename cutlass::gemm::kernel::FusedDistanceNNGemm::GemmKernel; + + using fusedDistanceNN = cutlass::gemm::device::GemmGrouped; + + int num_blocks_per_sm = fusedDistanceNN::maximum_active_blocks(); + int num_sms = raft::getMultiProcessorCount(); + int full_wave = num_blocks_per_sm * num_sms; + constexpr int mmaShapeM = fusedDistanceNNKernel::Mma::Shape::kM; + constexpr int mmaShapeN = fusedDistanceNNKernel::Mma::Shape::kN; + int columnTiles = (problem_size.n() - 1 + mmaShapeN) / mmaShapeN; + int rowTiles = (problem_size.m() - 1 + mmaShapeM) / mmaShapeM; + int totalTiles = columnTiles * rowTiles; + int thread_blocks = + rowTiles < full_wave ? (totalTiles < full_wave ? totalTiles : full_wave) : rowTiles; + + typename fusedDistanceNN::Arguments arguments{ + problem_size, + batch_count, // num of problems. + thread_blocks, + epilog_op_param, + x, + y, + xn, // C matrix eq vector param, which here is A norm + (DataT*)yn, // this is broadcast vec, which is required to be non-const param + dOutput, // Output distance matrix + (int64_t)lda, // stride A + (int64_t)ldb, // stride B + (int64_t)1, // stride A norm + (int64_t)ldd // stride Output matrix + }; + + // Using the arguments, query for extra workspace required for matrix multiplication computation + size_t workspace_size = fusedDistanceNN::get_workspace_size(arguments); + // Allocate workspace memory + rmm::device_uvector workspace(workspace_size, stream); + // Instantiate CUTLASS kernel depending on templates + fusedDistanceNN fusedDistanceNN_op; + // Check the problem size is supported or not + RAFT_CUTLASS_TRY(fusedDistanceNN_op.can_implement(arguments)); + // Initialize CUTLASS kernel with arguments and workspace pointer + RAFT_CUTLASS_TRY(fusedDistanceNN_op.initialize(arguments, workspace.data(), stream)); + // Launch initialized CUTLASS kernel + RAFT_CUTLASS_TRY(fusedDistanceNN_op.run(stream)); +} + +}; // namespace detail +}; // namespace distance +}; // namespace cuvs + +#pragma GCC diagnostic pop diff --git a/cpp/src/distance/detail/fused_distance_nn/epilogue.cuh b/cpp/src/distance/detail/fused_distance_nn/epilogue.cuh new file mode 100644 index 000000000..06939e2bf --- /dev/null +++ b/cpp/src/distance/detail/fused_distance_nn/epilogue.cuh @@ -0,0 +1,134 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* + * Copyright (c) 2023-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! \file + \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. + +This is adapted from DefaultEpilogueWithBroadcastTensorOp from CUTLASS 2.9.0 +(https://github.com/NVIDIA/cutlass/blob/master/include/cutlass/epilogue/threadblock/default_epilogue_with_broadcast.h#L75) + +This epilogue allows us to load norm buffers using PredicatedTileIteratorNormVec +and EpilogueWithBroadcast used for distances L2/cosine as well as applies user-define elementwise +operation. +-- A norm load is provided PredicatedTileIteratorNormVec +-- B norm load is provided by EpilogueWithBroadcast +-- elementwise operation is provided by OutputOp +*/ + +#pragma once + +#include "custom_epilogue_with_broadcast.h" +#include "predicated_tile_iterator_normvec_smem.h" +#include "predicated_tile_iterator_reduced_vec.h" + +#include +#include +#include +#include +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Defines sensible defaults for epilogues for TensorOps. +template +struct FusedDistanceNNEpilogue { + /// Use defaults related to the existing epilogue + using Base = + DefaultEpilogueTensorOp; + + // + // Stores the result z = (y = GEMM(A, B, C), broadcast) + // + using RowNormTileIterator = cutlass::epilogue::threadblock:: + PredicatedTileIteratorNormVecSmem; + + // + // Additional tensor tile iterator - stores t = Elementwise(z) + // + using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorReducedVec< + typename Base::OutputTileThreadMap, + ElementTensor, + LayoutT, + typename OutputOp::Params>; + + /// Define the epilogue + using Epilogue = cutlass::epilogue::threadblock::EpilogueWithBroadcastCustom< + Shape, + WarpMmaTensorOp, + PartitionsK, + RowNormTileIterator, + OutputTileIterator, + ElementVector, + typename Base::AccumulatorFragmentIterator, + typename Base::WarpTileIterator, + typename Base::SharedLoadIterator, + OutputOp, + typename Base::Padding, + Base::kFragmentsPerIteration>; +}; + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/cpp/src/distance/detail/fused_distance_nn/epilogue_elementwise.cuh b/cpp/src/distance/detail/fused_distance_nn/epilogue_elementwise.cuh new file mode 100644 index 000000000..e69b2486d --- /dev/null +++ b/cpp/src/distance/detail/fused_distance_nn/epilogue_elementwise.cuh @@ -0,0 +1,220 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* + * Copyright (c) 2023-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// +/*! \file + \brief Functor performing distance operations used by epilogues of pairwise distance + * kernels. +* This is adapted from LinearCombinationBiasElementwise from CUTLASS 2.9.0 +* customized for applying elementwise distance formula on accumulated GEMM value +* and applying user-defined operation which can convert distance values to key-value pair. +* . +*/ + +#pragma once + +#include + +#include +#include +#include +#include +#include +#include + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace thread { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// This base class is meant to define the concept required of the +/// EpilogueWithBroadcast::OutputOp +template +class FusedDistanceNNEpilogueElementwise { + public: + using ElementOutput = ElementC_; + using ElementC = ElementC_; + using ElementAccumulator = ElementAccumulator_; + using ElementCompute = ElementCompute_; + using ElementZ = ElementZ_; + using ElementT = ElementT_; + static int const kElementsPerAccess = ElementsPerAccess; + static int const kCount = kElementsPerAccess; + + using DistanceOp = DistanceOp_; + using CGReduceOp = CGReduceOp_; + + using FragmentAccumulator = Array; + using FragmentCompute = Array; + using FragmentC = Array; + using FragmentZ = Array; + using OutValT = typename CGReduceOp::AccTypeT; + using FragmentT = Array; + + using FragmentOutput = FragmentZ; + + static bool const kIsHeavy = true; // ElementwiseOp::kIsHeavy; + + /// If true, the 'Z' tensor is stored + static bool const kStoreZ = false; // We don't store anything in Z, + + /// If true, the 'T' tensor is stored + static bool const kStoreT = true; // this is our final output storage. + + /// Host-constructable parameters structure + struct Params { + CGReduceOp_ cg_reduce_op; + DistanceOp_ dist_op_; + KVPReduceOpT_ pair_redop_; + ReduceOpT_ red_op_; + int* mutexes_; + cuda::binary_semaphore* bin_mutex_; + using CGReduceT = CGReduceOp_; + // + // Methods + // + CUTLASS_HOST_DEVICE + Params(DistanceOp_ dist_op, + CGReduceOp cg_reduce_op, + ReduceOpT_ red_op, + KVPReduceOpT_ pair_redop, + int* mutexes, + cuda::binary_semaphore* bin_mutex) + : cg_reduce_op(cg_reduce_op), + dist_op_(dist_op), + pair_redop_(pair_redop), + red_op_(red_op), + mutexes_(mutexes), + bin_mutex_(bin_mutex) + { + } + + CUTLASS_HOST_DEVICE + Params() {} + }; + + private: + // + // Data members + // + DistanceOp_ elementwise_op; + KVPReduceOpT_ pair_redop; + + public: + ReduceOpT_ red_op; + + // + // Methods + // + + /// Constructor from Params + CUTLASS_HOST_DEVICE + FusedDistanceNNEpilogueElementwise(Params const& params) + : elementwise_op(params.dist_op_), pair_redop(params.pair_redop_), red_op(params.red_op_) + { + } + + /// Returns true if source is needed + CUTLASS_HOST_DEVICE + bool is_source_needed() const + { + // we use for making sure C matrix is used for A mat norm. + return true; + } + + /// Functionally required for serial reduction in the epilogue + CUTLASS_HOST_DEVICE + void set_k_partition(int k_partition, int k_partition_count) {} + + /// Applies the operation when is_source_needed() is true + CUTLASS_HOST_DEVICE + void operator()(FragmentT& frag_T, + FragmentAccumulator const& AB, + FragmentC const& frag_C, + FragmentCompute const& V) const + { + FragmentCompute tmp_Accum = + NumericArrayConverter()(AB); + FragmentCompute tmp_C = + NumericArrayConverter()(frag_C); + FragmentCompute result_Z; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kElementsPerAccess; ++i) { + ElementCompute res_Z = elementwise_op(tmp_C[i], V[i], tmp_Accum[i]); + frag_T[i] = res_Z; + } + } + + /// Applies the operation when is_source_needed() is false + CUTLASS_HOST_DEVICE + void operator()(FragmentZ& frag_Z, + FragmentT& frag_T, + FragmentAccumulator const& AB, + FragmentCompute const& V) const + { + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace thread +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/cpp/src/distance/detail/fused_distance_nn/fused_cosine_nn.cuh b/cpp/src/distance/detail/fused_distance_nn/fused_cosine_nn.cuh new file mode 100644 index 000000000..bcbc6689c --- /dev/null +++ b/cpp/src/distance/detail/fused_distance_nn/fused_cosine_nn.cuh @@ -0,0 +1,136 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "../distance_ops/cosine.cuh" // ops::l2_exp_distance_op +#include "../pairwise_distance_base.cuh" // PairwiseDistances +#include "cutlass_base.cuh" +#include "helper_structs.cuh" +#include "simt_kernel.cuh" +#include // raft::KeyValuePair +#include // raft::identity_op +#include // Policy +#include // raft::util::arch::SM_* +#include // raft::ceildiv, raft::shfl + +#include // size_t +#include // std::numeric_limits + +namespace cuvs { +namespace distance { + +namespace detail { + +template +void fusedCosineNN(OutT* min, + const DataT* x, + const DataT* y, + const DataT* xn, + const DataT* yn, + IdxT m, + IdxT n, + IdxT k, + int* workspace, + ReduceOpT redOp, + KVPReduceOpT pairRedOp, + bool sqrt, + cudaStream_t stream) +{ + // The kernel policy is determined by fusedL2NN. + typedef Policy P; + + dim3 blk(P::Nthreads); + constexpr auto maxVal = std::numeric_limits::max(); + typedef raft::KeyValuePair KVPair; + + namespace arch = raft::util::arch; + using AccT = DataT; + ops::cosine_distance_op distance_op{}; + + raft::identity_op fin_op{}; + + auto kernel = fusedDistanceNNkernel; + + // Get pointer to fp32 SIMT kernel to determine the runtime architecture of the + // current system. Other methods to determine the architecture (that do not + // require a pointer) can be error prone. See: + // https://github.com/NVIDIA/cub/issues/545 + void* kernel_ptr = reinterpret_cast(kernel); + auto runtime_arch = arch::kernel_virtual_arch(kernel_ptr); + auto cutlass_range = arch::SM_range(arch::SM_80(), arch::SM_future()); + + if (cutlass_range.contains(runtime_arch)) { + // If device is SM_80 or later, use CUTLASS-based kernel. + using cosineOp = cuvs::distance::detail::ops::cosine_cutlass_op; + using kvp_cg_min_reduce_op_ = kvp_cg_min_reduce_op; + kvp_cg_min_reduce_op_ cg_reduce_op; + cosineOp cosine_dist_op; + + IdxT lda, ldb, ldd; + lda = k, ldb = k, ldd = n; + + cutlassFusedDistanceNN(x, + y, + xn, + yn, + m, + n, + k, + lda, + ldb, + ldd, + min, + workspace, + cg_reduce_op, + cosine_dist_op, + redOp, + pairRedOp, + stream); + } else { + // If device less than SM_80, use fp32 SIMT kernel. + constexpr size_t shmemSize = P::SmemSize + ((P::Mblk + P::Nblk) * sizeof(DataT)); + dim3 grid = launchConfigGenerator

(m, n, shmemSize, kernel); + + kernel<<>>( + min, x, y, xn, yn, m, n, k, maxVal, workspace, redOp, pairRedOp, distance_op, fin_op); + RAFT_CUDA_TRY(cudaGetLastError()); + } +} + +} // namespace detail +} // namespace distance +} // namespace cuvs diff --git a/cpp/src/distance/detail/fused_distance_nn/fused_l2_nn.cuh b/cpp/src/distance/detail/fused_distance_nn/fused_l2_nn.cuh new file mode 100644 index 000000000..59fe10ea0 --- /dev/null +++ b/cpp/src/distance/detail/fused_distance_nn/fused_l2_nn.cuh @@ -0,0 +1,146 @@ +/* + * Copyright (c) 2021-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "../distance_ops/l2_exp.cuh" // ops::l2_exp_distance_op +#include "../pairwise_distance_base.cuh" // PairwiseDistances +#include "cutlass_base.cuh" +#include "helper_structs.cuh" +#include "simt_kernel.cuh" +#include // raft::KeyValuePair +#include // raft::identity_op +#include // Policy +#include // raft::util::arch::SM_* +#include // raft::ceildiv, raft::shfl + +#include // size_t +#include // std::numeric_limits + +namespace cuvs { +namespace distance { + +namespace detail { + +template +void fusedL2NNImpl(OutT* min, + const DataT* x, + const DataT* y, + const DataT* xn, + const DataT* yn, + IdxT m, + IdxT n, + IdxT k, + int* workspace, + ReduceOpT redOp, + KVPReduceOpT pairRedOp, + bool sqrt, + bool initOutBuffer, + cudaStream_t stream) +{ + // The kernel policy is determined by fusedL2NN. + typedef Policy P; + + dim3 blk(P::Nthreads); + auto nblks = raft::ceildiv(m, P::Nthreads); + constexpr auto maxVal = std::numeric_limits::max(); + typedef raft::KeyValuePair KVPair; + + RAFT_CUDA_TRY(cudaMemsetAsync(workspace, 0, sizeof(int) * m, stream)); + if (initOutBuffer) { + initKernel + <<>>(min, m, maxVal, redOp); + RAFT_CUDA_TRY(cudaGetLastError()); + } + + namespace arch = raft::util::arch; + using AccT = DataT; + ops::l2_exp_distance_op distance_op{sqrt}; + + raft::identity_op fin_op{}; + + auto kernel = fusedDistanceNNkernel; + + // Get pointer to fp32 SIMT kernel to determine the best compute architecture + // out of all for which the kernel was compiled for that matches closely + // to the current device. Other methods to determine the architecture (that do not + // require a pointer) can be error prone. See: + // https://github.com/NVIDIA/cub/issues/545 + void* kernel_ptr = reinterpret_cast(kernel); + auto runtime_arch = arch::kernel_virtual_arch(kernel_ptr); + auto cutlass_range = arch::SM_range(arch::SM_80(), arch::SM_future()); + + if (cutlass_range.contains(runtime_arch)) { + // If device is SM_80 or later, use CUTLASS-based kernel. + using L2Op = cuvs::distance::detail::ops::l2_exp_cutlass_op; + using kvp_cg_min_reduce_op_ = kvp_cg_min_reduce_op; + kvp_cg_min_reduce_op_ cg_reduce_op; + L2Op L2_dist_op(sqrt); + + IdxT lda, ldb, ldd; + lda = k, ldb = k, ldd = n; + + cutlassFusedDistanceNN(x, + y, + xn, + yn, + m, + n, + k, + lda, + ldb, + ldd, + min, + workspace, + cg_reduce_op, + L2_dist_op, + redOp, + pairRedOp, + stream); + } else { + // If device less than SM_80, use fp32 SIMT kernel. + constexpr size_t shmemSize = P::SmemSize + ((P::Mblk + P::Nblk) * sizeof(DataT)); + dim3 grid = launchConfigGenerator

(m, n, shmemSize, kernel); + + kernel<<>>( + min, x, y, xn, yn, m, n, k, maxVal, workspace, redOp, pairRedOp, distance_op, fin_op); + RAFT_CUDA_TRY(cudaGetLastError()); + } +} + +} // namespace detail +} // namespace distance +} // namespace cuvs diff --git a/cpp/src/distance/detail/fused_distance_nn/gemm.h b/cpp/src/distance/detail/fused_distance_nn/gemm.h new file mode 100644 index 000000000..0385b95cd --- /dev/null +++ b/cpp/src/distance/detail/fused_distance_nn/gemm.h @@ -0,0 +1,409 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* + * Copyright (c) 2023-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "epilogue.cuh" +#include "persistent_gemm.h" + +#include +#include +#include +#include + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// +/* + * This configuration is used for float inputs with veclen(kAlignmentA/B) = 2 or 4, + * ideal threadblock tile shape is 32x256x16 for such cases as there is no + * registers spills for it. + * + */ +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + int kAlignmentB, + /// Element type for C and D matrix operands + typename ElementC_, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Epilogue output operator - must satisfy concept of 'EpilogueWithBroadcastOp' + typename EpilogueOutputOp, + /// Number of stages used in the pipelined mainloop + int Stages, + /// data layout row/column major of inputs + bool isRowMajor> +struct FusedDistanceNNGemm { + // This struct is specialized for fp32/3xTF32 + + /// Threadblock-level tile size (concept: GemmShape) + // <- threadblock tile M = 32, N = 256, K = 16 + // this is more performant but note that for veclen = 1 + // this shape has register spills + using ThreadblockShape = cutlass::gemm::GemmShape<32, 256, 16>; + + // <- threadblock tile M = 32, N = 128, K = 16 + // this shape has high occupancy but less perf + // this is less performant but this shape has *no* register spills + // for any veclens(1, 2, 4) + // using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 16>; + + /// Warp-level tile size (concept: GemmShape) + // This code section describes tile size a warp will compute + // <- warp tile M = 64, N = 64, K = 16 + // this is more performant for veclen 2,4. + using WarpShape = cutlass::gemm::GemmShape<32, 64, 16>; + + // this shape has high occupancy but less perf used for 32x128x16 + // using WarpShape = cutlass::gemm::GemmShape<32, 32, 16>; + + /// Warp-level tile size (concept: GemmShape) + // This code section describes the size of MMA op + // <- MMA Op tile M = 16, N = 8, K = 4 + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; + + /// Operation performed by GEMM + using Operator = cutlass::arch::OpMultiplyAddFastF32; + // using Operator = cutlass::arch::OpMultiplyAdd; // this runs only 1xTF32 for float inputs + + // This code section describes whether you want to use tensor cores or regular SIMT cores on GPU + // SM + using OperatorClass = cutlass::arch::OpClassTensorOp; + + // This code section describes CUDA SM architecture number + using ArchTag = cutlass::arch::Sm80; + + // This code section describes how threadblocks are scheduled on GPU + /// Threadblock-level swizzling operator + using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; + + /// data layout for final output matrix. + // we keep this same layout even for column major inputs + using LayoutOutput = cutlass::layout::RowMajor; + + typedef typename std::conditional::type NormXLayout; + + typedef typename std:: + conditional::type LayoutA_; + + typedef typename std:: + conditional::type LayoutB_; + + using GemmBase = typename DefaultGemmUniversal::GemmKernel; + + // Replace epilogue + using Epilogue = typename cutlass::epilogue::threadblock::FusedDistanceNNEpilogue< + typename GemmBase::Epilogue::Shape, + typename GemmBase::Epilogue::WarpMmaOperator, + GemmBase::Epilogue::kPartitionsK, + ElementAccumulator, + typename EpilogueOutputOp::ElementT, + ElementAccumulator, + EpilogueOutputOp, + NormXLayout, + GemmBase::Epilogue::kElementsPerAccess>::Epilogue; + + // Compose the GEMM kernel + using GemmKernel = FusedDistanceNNPersistent; +}; + +/* + * This configuration is used for float inputs with veclen(kAlignmentA/B) = 1, + * ideal threadblock tile shape is 32x128x16 for such cases as there is no + * registers spills for it. + * + */ +template < + /// Element type for C and D matrix operands + typename ElementC_, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Epilogue output operator - must satisfy concept of 'EpilogueWithBroadcastOp' + typename EpilogueOutputOp, + /// Number of stages used in the pipelined mainloop + int Stages, + /// data layout row/column major of inputs + bool isRowMajor> +struct FusedDistanceNNGemm { + // This struct is specialized for fp32/3xTF32 + using ElementA_ = float; + using ElementB_ = float; + + /// Threadblock-level tile size (concept: GemmShape) + // <- threadblock tile M = 32, N = 128, K = 16 + // this shape has high occupancy and no register spills for veclen = 1. + using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 16>; + + /// Warp-level tile size (concept: GemmShape) + // This code section describes tile size a warp will compute + // <- warp tile M = 32, N = 32, K = 16 + using WarpShape = cutlass::gemm::GemmShape<32, 32, 16>; + + /// Warp-level tile size (concept: GemmShape) + // This code section describes the size of MMA op + // <- MMA Op tile M = 16, N = 8, K = 4 + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; + + /// Operation performed by GEMM + using Operator = cutlass::arch::OpMultiplyAddFastF32; + // using Operator = cutlass::arch::OpMultiplyAdd; // this runs only 1xTF32 for float inputs + + // This code section describes whether you want to use tensor cores or regular SIMT cores on GPU + // SM + using OperatorClass = cutlass::arch::OpClassTensorOp; + + // This code section describes CUDA SM architecture number + using ArchTag = cutlass::arch::Sm80; + + // This code section describes how threadblocks are scheduled on GPU + /// Threadblock-level swizzling operator + using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; + + /// data layout for final output matrix. + // we keep this same layout even for column major inputs + using LayoutOutput = cutlass::layout::RowMajor; + + typedef typename std::conditional::type NormXLayout; + + typedef typename std:: + conditional::type LayoutA_; + + typedef typename std:: + conditional::type LayoutB_; + + using GemmBase = typename DefaultGemmUniversal::GemmKernel; + + // Replace epilogue + using Epilogue = typename cutlass::epilogue::threadblock::FusedDistanceNNEpilogue< + typename GemmBase::Epilogue::Shape, + typename GemmBase::Epilogue::WarpMmaOperator, + GemmBase::Epilogue::kPartitionsK, + ElementAccumulator, + typename EpilogueOutputOp::ElementT, + ElementAccumulator, + EpilogueOutputOp, + NormXLayout, + GemmBase::Epilogue::kElementsPerAccess>::Epilogue; + + // Compose the GEMM kernel + using GemmKernel = FusedDistanceNNPersistent; +}; + +template < + /// Layout type for A matrix operand + int kAlignmentA, + /// Layout type for B matrix operand + int kAlignmentB, + /// Element type for C and D matrix operands + typename ElementC_, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Epilogue output operator - must satisfy concept of 'EpilogueWithBroadcastOp' + typename EpilogueOutputOp, + /// Number of stages used in the pipelined mainloop + int Stages, + /// data layout row/column major of inputs + bool isRowMajor> +struct FusedDistanceNNGemm { + // Threadblock-level tile size (concept: GemmShape) + // <- threadblock tile M = 64, N = 64, K = 16 + using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 16>; + // using ThreadblockShape = cutlass::gemm::GemmShape<16, 128, 16>; + /// Warp-level tile size (concept: GemmShape) + // This code section describes tile size a warp will compute + // <- warp tile M = 32, N = 32, K = 16 + using WarpShape = cutlass::gemm::GemmShape<32, 32, 16>; + // using WarpShape = cutlass::gemm::GemmShape<16, 32, 16>; + /// Warp-level tile size (concept: GemmShape) + // This code section describes the size of MMA op + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; + + // Operation performed by GEMM + using Operator = cutlass::arch::OpMultiplyAdd; + // This code section describes whether you want to use tensor cores or regular SIMT cores on GPU + // SM + using OperatorClass = cutlass::arch::OpClassTensorOp; + + // This code section describes CUDA SM architecture number + using ArchTag = cutlass::arch::Sm80; + + // This code section describes how threadblocks are scheduled on GPU + /// Threadblock-level swizzling operator + using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; + + /// data layout for final output matrix. + // we keep this same layout even for column major inputs + using LayoutOutput = cutlass::layout::RowMajor; + + typedef typename std::conditional::type NormXLayout; + + typedef typename std:: + conditional::type LayoutA_; + + typedef typename std:: + conditional::type LayoutB_; + + using GemmBase = typename DefaultGemmUniversal::GemmKernel; + + // Replace epilogue + using Epilogue = typename cutlass::epilogue::threadblock::FusedDistanceNNEpilogue< + typename GemmBase::Epilogue::Shape, + typename GemmBase::Epilogue::WarpMmaOperator, + GemmBase::Epilogue::kPartitionsK, + ElementC_, + typename EpilogueOutputOp::ElementT, + ElementC_, + EpilogueOutputOp, + NormXLayout, + GemmBase::Epilogue::kElementsPerAccess>::Epilogue; + + // Compose the GEMM kernel + using GemmKernel = FusedDistanceNNPersistent; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass \ No newline at end of file diff --git a/cpp/src/distance/detail/fused_distance_nn/helper_structs.cuh b/cpp/src/distance/detail/fused_distance_nn/helper_structs.cuh new file mode 100644 index 000000000..bd439e0a7 --- /dev/null +++ b/cpp/src/distance/detail/fused_distance_nn/helper_structs.cuh @@ -0,0 +1,146 @@ +/* + * Copyright (c) 2021-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "../distance_ops/l2_exp.cuh" // ops::l2_exp_distance_op +#include "../pairwise_distance_base.cuh" // PairwiseDistances +#include "cutlass_base.cuh" +#include "simt_kernel.cuh" +#include // raft::KeyValuePair +#include // raft::identity_op +#include // Policy +#include // raft::util::arch::SM_* +#include // raft::ceildiv, raft::shfl +#include + +#include // size_t +#include // std::numeric_limits + +namespace cuvs { +namespace distance { + +namespace detail { + +template +struct KVPMinReduceImpl { + typedef raft::KeyValuePair KVP; + DI KVP operator()(LabelT rit, const KVP& a, const KVP& b) { return b.value < a.value ? b : a; } + DI KVP operator()(const KVP& a, const KVP& b) { return b.value < a.value ? b : a; } + +}; // KVPMinReduce + +template +struct MinAndDistanceReduceOpImpl { + typedef typename raft::KeyValuePair KVP; + + DI void operator()(LabelT rid, KVP* out, const KVP& other) const + { + if (other.value < out->value) { + out->key = other.key; + out->value = other.value; + } + } + DI void operator()(LabelT rid, volatile KVP* out, const KVP& other) const + { + if (other.value < out->value) { + out->key = other.key; + out->value = other.value; + } + } + + DI void operator()(LabelT rid, DataT* out, const KVP& other) const + { + if (other.value < *out) { *out = other.value; } + } + + DI void operator()(LabelT rid, volatile DataT* out, const KVP& other) const + { + if (other.value < *out) { *out = other.value; } + } + + DI void operator()(LabelT rid, DataT* out, const DataT& other) const + { + if (other < *out) { *out = other; } + } + + DI void operator()(LabelT rid, volatile DataT* out, const DataT& other) const + { + if (other < *out) { *out = other; } + } + + DI void init(DataT* out, DataT maxVal) const { *out = maxVal; } + DI void init(KVP* out, DataT maxVal) const + { + out->value = maxVal; + out->key = 0xfffffff0; + } + + DI void init_key(DataT& out, LabelT idx) const { return; } + DI void init_key(KVP& out, LabelT idx) const { out.key = idx; } + + DI DataT get_value(KVP& out) const { return out.value; } + DI DataT get_value(DataT& out) const { return out; } +}; + +template +struct MinReduceOpImpl { + typedef typename raft::KeyValuePair KVP; + DI void operator()(LabelT rid, DataT* out, const KVP& other) + { + if (other.value < *out) { *out = other.value; } + } + + DI void init(DataT* out, DataT maxVal) { *out = maxVal; } +}; + +template +RAFT_KERNEL initKernel(OutT* min, IdxT m, DataT maxVal, ReduceOpT redOp) +{ + auto tid = IdxT(blockIdx.x) * blockDim.x + threadIdx.x; + if (tid < m) { redOp.init(min + tid, maxVal); } +} + +template +void initialize(OutT* min, IdxT m, DataT maxVal, ReduceOpT redOp, cudaStream_t stream) +{ + auto blks = raft::ceildiv(m, 256); + initKernel<<>>(min, m, maxVal, redOp); +} + +// cg::reduce functor for FusedDistanceNN used in its cutlass version +// to output the min distance value & key(loc id). +// This is used in fused_distance_nn/predicated_tile_iterator_reduced_vec.h +// store_with_byte_offset() passed to cg::reduce() & select_reduce. +template +struct kvp_cg_min_reduce_op { + typedef typename raft::KeyValuePair KVP; + + __host__ __device__ kvp_cg_min_reduce_op() noexcept {}; + + using AccTypeT = AccType; + using IndexT = Index; + // functor signature. + __host__ __device__ KVP operator()(KVP a, KVP b) const { return a.value < b.value ? a : b; } + + __host__ __device__ AccType operator()(AccType a, AccType b) const { return min(a, b); } + + __host__ __device__ bool isAmin(AccType a, AccType b) const { return a < b ? true : false; } +}; + +} // namespace detail +} // namespace distance +} // namespace cuvs diff --git a/cpp/src/distance/detail/fused_distance_nn/persistent_gemm.h b/cpp/src/distance/detail/fused_distance_nn/persistent_gemm.h new file mode 100644 index 000000000..f1a7c728e --- /dev/null +++ b/cpp/src/distance/detail/fused_distance_nn/persistent_gemm.h @@ -0,0 +1,512 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* + * Copyright (c) 2023-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! \file + \brief Problem visitor for grouped GEMMs +This file contains heavily customized version of GemmGrouped from CUTLASS 2.10.0 +(https://github.com/NVIDIA/cutlass/blob/v2.10.0/include/cutlass/gemm/kernel/gemm_grouped.h) + +Changes: +- adds support for only single problem size to be launched persistently + where each threablock processes more than one tile of the same problem. +*/ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct FusedDistanceNNPersistent { + public: + using Mma = Mma_; + using Epilogue = Epilogue_; + using EpilogueOutputOp = typename Epilogue::OutputOp; + using ThreadblockSwizzle = ThreadblockSwizzle_; + static GroupScheduleMode const kGroupScheduleMode = GroupScheduleMode_; + static bool const kTransposed = Transposed; + + // Optional transpose + using MapArguments = kernel::detail::MapArguments; + + // Public-facing type definitions related to operand element type, layout, and complex conjugate + // operation. Must interact with the 'kTransposed' notion. + using ElementA = typename MapArguments::ElementA; + using LayoutA = typename MapArguments::LayoutA; + using ElementB = typename MapArguments::ElementB; + using LayoutB = typename MapArguments::LayoutB; + using ElementC = typename Epilogue::OutputTileIterator::Element; + using LayoutC = typename MapArguments::LayoutC; + + static ComplexTransform const kTransformA = MapArguments::kTransformA; + static ComplexTransform const kTransformB = MapArguments::kTransformB; + + // Type definitions about the mainloop. + using Operator = typename Mma::Operator; + using OperatorClass = typename Mma::Operator::OperatorClass; + using ThreadblockShape = typename Mma::Shape; + using WarpShape = typename Mma::Operator::Shape; + using InstructionShape = typename Mma::Policy::Operator::InstructionShape; + using ArchTag = typename Mma::ArchTag; + + static int const kStages = Mma::kStages; + static int const kAlignmentA = MapArguments::kAlignmentA; + static int const kAlignmentB = MapArguments::kAlignmentB; + static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; + + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; + + using ProblemVisitor = GemmGroupedProblemVisitor; + + // + // Structures + // + + struct temp_problem_visitor { + int problem_count; + + CUTLASS_HOST_DEVICE temp_problem_visitor() : problem_count(0){}; + CUTLASS_HOST_DEVICE temp_problem_visitor(int problem_count_) : problem_count(problem_count_){}; + }; + + /// Argument structure + struct Arguments { + // + // Data members + // + GemmCoord problem_sizes; + temp_problem_visitor problem_visitor; + int problem_count; + int threadblock_count; + + typename EpilogueOutputOp::Params output_op; + + void const* ptr_A; + void const* ptr_B; + void const* ptr_C; + void* ptr_Vector; + void* ptr_Tensor; + + typename LayoutA::Stride::Index lda; + typename LayoutB::Stride::Index ldb; + typename LayoutC::Stride::Index ldc; + typename LayoutC::Stride::Index ldt; + + // Only used by device-level operator + GemmCoord* host_problem_sizes; + + // + // Methods + // + + /// Default ctor + CUTLASS_HOST_DEVICE + Arguments() + : threadblock_count(0), + ptr_A(nullptr), + ptr_B(nullptr), + ptr_C(nullptr), + ptr_Vector(nullptr), + ptr_Tensor(nullptr), + lda(0), + ldb(0), + ldc(0), + ldt(0), + host_problem_sizes(nullptr) + { + } + + /// Ctor + CUTLASS_HOST_DEVICE + Arguments(GemmCoord problem_sizes, + int problem_count, + int threadblock_count, + typename EpilogueOutputOp::Params output_op, + void const* ptr_A, + void const* ptr_B, + void const* ptr_C, + void* ptr_Vector, + void* ptr_Tensor, + typename LayoutA::Stride::Index lda, + typename LayoutB::Stride::Index ldb, + typename LayoutC::Stride::Index ldc, + typename LayoutC::Stride::Index ldt, + GemmCoord* host_problem_sizes = nullptr) + : problem_sizes(problem_sizes), + threadblock_count(threadblock_count), + output_op(output_op), + ptr_A(ptr_A), + ptr_B(ptr_B), + ptr_C(ptr_C), + ptr_Vector(ptr_Vector), + ptr_Tensor(ptr_Tensor), + lda(lda), + ldb(ldb), + ldc(ldc), + ldt(ldt), + host_problem_sizes(host_problem_sizes) + { + problem_visitor.problem_count = problem_count; + } + }; + + // + // Structure for precomputing values in host memory and passing to kernels + // + + /// Parameters structure + struct Params { + temp_problem_visitor problem_visitor; + int threadblock_count; + + typename Mma::IteratorA::Params params_A; + typename Mma::IteratorB::Params params_B; + typename Epilogue::OutputTileIterator::Params params_C; + typename Epilogue::TensorTileIterator::Params params_Tensor; + + typename EpilogueOutputOp::Params output_op; + + void* ptr_A; + void* ptr_B; + void* ptr_C; + void* ptr_Vector; + void* ptr_Tensor; + + GemmCoord problem_size; + typename LayoutA::Stride::Index lda; + typename LayoutB::Stride::Index ldb; + typename LayoutC::Stride::Index ldc; + typename LayoutC::Stride::Index ldt; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params() + : params_A(0), + params_B(0), + params_C(0), + ptr_A(nullptr), + ptr_B(nullptr), + ptr_C(nullptr), + ptr_Vector(nullptr), + ptr_Tensor(nullptr), + lda(0), + ldb(0), + ldc(0), + ldt(0) + { + } + + CUTLASS_HOST_DEVICE + Params(Arguments const& args, void* workspace = nullptr, int tile_count = 0) + : problem_size(args.problem_sizes), + threadblock_count(args.threadblock_count), + output_op(args.output_op), + params_A(args.lda), + params_B(args.ldb), + params_C(args.ldc), + // Here we pass additional user args via args.output_op + // to the reduction output tile iterator + params_Tensor(args.ldt, args.output_op), + ptr_A(const_cast(args.ptr_A)), + ptr_B(const_cast(args.ptr_B)), + ptr_C(const_cast(args.ptr_C)), + ptr_Vector(args.ptr_Vector), + ptr_Tensor(args.ptr_Tensor), + lda(args.lda), + ldb(args.ldb), + ldc(args.ldc), + ldt(args.ldt) + { + problem_visitor.problem_count = args.problem_visitor.problem_count; + } + + CUTLASS_HOST_DEVICE + void update(Arguments const& args, void* workspace = nullptr, int tile_count = 0) + { + threadblock_count = args.threadblock_count; + output_op = args.output_op; + ptr_A = const_cast(args.ptr_A); + ptr_B = const_cast(args.ptr_B); + ptr_C = const_cast(args.ptr_C); + ptr_Vector = args.ptr_Vector; + ptr_Tensor = args.ptr_Tensor; + lda = args.lda; + ldb = args.ldb; + ldc = args.ldc; + ldt = args.ldt; + + problem_size = args.problem_sizes; + } + }; + + /// Shared memory storage structure + struct SharedStorage { + union { + typename Mma::SharedStorage main_loop; + typename Epilogue::SharedStorage epilogue; + } kernel; + + typename Epilogue::TensorTileIterator::SharedStorage reduced_store; + typename Epilogue::OutputTileIterator::SharedStorage rownorm_store; + }; + + public: + // + // Methods + // + + CUTLASS_DEVICE + FusedDistanceNNPersistent() {} + + /// Determines whether kernel satisfies alignment + static Status can_implement(cutlass::gemm::GemmCoord const& problem_size) + { + return Status::kSuccess; + } + + static Status can_implement(Arguments const& args) { return Status::kSuccess; } + + static size_t get_extra_workspace_size(Arguments const& args, + cutlass::gemm::GemmCoord const& grid_tiled_shape) + { + return 0; + } + + CUTLASS_DEVICE + static uint32_t tile_count(const cutlass::MatrixCoord& grid) + { + return grid.row() * grid.column(); + } + + /// Get the grid shape + CUTLASS_DEVICE + static cutlass::MatrixCoord grid_shape(const cutlass::gemm::GemmCoord& problem) + { + return cutlass::MatrixCoord(((problem.m() - 1 + ThreadblockShape::kM) / ThreadblockShape::kM), + ((problem.n() - 1 + ThreadblockShape::kN) / ThreadblockShape::kN)); + } + + /// Executes one GEMM + CUTLASS_DEVICE + void operator()(Params const& params, SharedStorage& shared_storage) + { +#if __CUDA_ARCH__ >= 800 + // + // These types shadow the type-level definitions and support the ability to implement + // a 'transposed' GEMM that computes the transposed problems. + // + using ElementA = typename Mma::IteratorA::Element; + using LayoutA = typename Mma::IteratorA::Layout; + using ElementB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Layout; + using ElementC = typename Epilogue::OutputTileIterator::Element; + using LayoutC = typename Epilogue::OutputTileIterator::Layout; + + const GemmCoord& problem_size = params.problem_size; + const auto grid_shape_ = grid_shape(problem_size); + const uint32_t problem_chunk = (tile_count(grid_shape_) - 1 + gridDim.x) / gridDim.x; + const uint32_t problem_chunk_end = blockIdx.x * problem_chunk + problem_chunk; + typename LayoutB::Index column = + ((blockIdx.x * problem_chunk) % grid_shape_.column()) * Mma::Shape::kN; + + typename LayoutB::Index row = + ((blockIdx.x * problem_chunk) / grid_shape_.column()) * Mma::Shape::kM; + if (column) { + shared_storage.reduced_store.initSmem(params.output_op); + shared_storage.rownorm_store.initSmem(params.ptr_C, problem_size.m(), row, sizeof(ElementC)); + } + + // Outer 'persistent' loop to iterate over tiles + for (uint32_t tile_idx = blockIdx.x * problem_chunk; tile_idx < problem_chunk_end; tile_idx++) { + const auto grid_shape_ = grid_shape(problem_size); + cutlass::MatrixCoord threadblock_offset( + int(tile_idx / grid_shape_.column()) * Mma::Shape::kM, + int(tile_idx % grid_shape_.column()) * Mma::Shape::kN); + + const bool isNextTile = ((tile_idx + 1) < problem_chunk_end); + const bool doesRowChange = + ((threadblock_offset.column() + Mma::Shape::kN) >= problem_size.n()); + const bool do_gmem_reduce = (doesRowChange || !isNextTile) ? true : false; + + ElementA* ptr_A = static_cast(params.ptr_A); + ElementB* ptr_B = static_cast(params.ptr_B); + + // Compute initial location in logical coordinates + cutlass::MatrixCoord tb_offset_A{threadblock_offset.row(), 0}; + cutlass::MatrixCoord tb_offset_B{0, threadblock_offset.column()}; + + // Compute position within threadblock + int thread_idx = threadIdx.x; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A( + params.params_A, ptr_A, {problem_size.m(), problem_size.k()}, thread_idx, tb_offset_A); + + typename Mma::IteratorB iterator_B( + params.params_B, ptr_B, {problem_size.k(), problem_size.n()}, thread_idx, tb_offset_B); + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + + int lane_idx = threadIdx.x % 32; + + // + // Matrix multiply phase + // + + // Construct thread-scoped matrix multiply + Mma mma(shared_storage.kernel.main_loop, thread_idx, warp_idx, lane_idx); + + typename Mma::FragmentC accumulators; + + accumulators.clear(); + // Compute threadblock-scoped matrix multiply-add + int gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Wait for all threads to finish their epilogue phases from the previous tile. + //__syncthreads(); + + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators); + + // + // Epilogue + // + + EpilogueOutputOp output_op(params.output_op); + + ElementC* ptr_C = static_cast(params.ptr_C); + typename Epilogue::ElementTensor* ptr_Tensor = + static_cast(params.ptr_Tensor); + + // Define the reduction output pointer and move to the appropriate place + typename Epilogue::ElementVector* ptr_Vector = + static_cast(params.ptr_Vector); + + // Tile iterator loading from source tensor. + typename Epilogue::OutputTileIterator iterator_rownorm(shared_storage.rownorm_store, + params.params_C, + ptr_C, + problem_size.mn(), + thread_idx, + threadblock_offset); + + // Additional tensor to load from + typename Epilogue::TensorTileIterator tensor_iterator(shared_storage.reduced_store, + params.params_Tensor, + // Only the final block outputs Tensor + ptr_Tensor, + problem_size.mn(), + thread_idx, + do_gmem_reduce, + threadblock_offset); + + Epilogue epilogue(shared_storage.kernel.epilogue, thread_idx, warp_idx, lane_idx); + + // Execute the epilogue operator to update the destination tensor. + // Move to appropriate location for this output tile + if (ptr_Vector) { ptr_Vector += threadblock_offset.column(); } + + // Execute the epilogue operator to update the destination tensor. + epilogue(output_op, + ptr_Vector, + // iterator_D, + accumulators, + iterator_rownorm, + tensor_iterator, + problem_size.mn(), + threadblock_offset); + } +#endif + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/cpp/src/distance/detail/fused_distance_nn/predicated_tile_iterator_normvec_smem.h b/cpp/src/distance/detail/fused_distance_nn/predicated_tile_iterator_normvec_smem.h new file mode 100644 index 000000000..794cd5eb6 --- /dev/null +++ b/cpp/src/distance/detail/fused_distance_nn/predicated_tile_iterator_normvec_smem.h @@ -0,0 +1,448 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* + * Copyright (c) 2023-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! \file + \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. + +This file contains a customized version of PredicatedTileIterator from CUTLASS 2.9.0 +(https://github.com/NVIDIA/cutlass/blob/v2.9.0/include/cutlass/epilogue/threadblock/predicated_tile_iterator.h#L75) + +Changes: +- added `Layout_` template param +- Only the row index is used to load the data in load_with_byte_offset(). + This way the same normalization data is used across all columns in a row. + +*/ +#pragma once + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { + +//////////////////////////////////////////////////////////////////////////////// + +namespace epilogue { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile iterator used to load and store output tile from global memory in epilogue. +/// +/// Satisfies: ReadableTileIterator | PredicatedTileIterator | ForwardTileIterator +/// +template +class PredicatedTileIteratorNormVecSmem { + public: + using ThreadMap = ThreadMap_; + using Shape = typename ThreadMap::Shape; + + using Element = Element_; + + using Layout = Layout_; + using TensorRef = TensorRef; + using ConstTensorRef = typename TensorRef::ConstTensorRef; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + using TensorCoord = MatrixCoord; + + static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; + static int const kThreads = ThreadMap::kThreads; + static int const kIterations = ThreadMap::Count::kTile; + + static int const total_rows = ThreadMap::kWarpCount * ThreadMap::Iterations::kRow * + ThreadMap::Iterations::kGroup * ThreadMap::Iterations::kCluster * + ThreadMap::Count::kTile * ThreadMap::Delta::kRow; + + static_assert(ThreadMap::Iterations::kRow > 0, "ThreadMap::Iterations::kRow must be > 0"); + static_assert(ThreadMap::Iterations::kGroup > 0, "ThreadMap::Iterations::kGroup must be > 0"); + static_assert(ThreadMap::Iterations::kCluster > 0, "ThreadMap::Iterations::kCluster must be > 0"); + static_assert(ThreadMap::Iterations::kColumn > 0, "ThreadMap::Iterations::kColumn must be > 0"); + + using Fragment = Array; + + /// Memory access size + using AccessType = AlignedArray; + + // + // Parameters struct + // + + /// Uses a non-template class + struct Params : PredicatedTileIteratorParams { + using Base = PredicatedTileIteratorParams; + + CUTLASS_HOST_DEVICE + Params() {} + + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : PredicatedTileIteratorParams( + layout.stride(0) * int(sizeof(AccessType)) / kElementsPerAccess, + make_OutputTileThreadMapDesc()) + { + } + + CUTLASS_HOST_DEVICE + Params(Base const& base) : Base(base) {} + }; + + /// Mask object + struct Mask { + static int const kCount = ThreadMap::Iterations::kColumn; + + /// Predicate state + bool predicates[kCount]; + + // + // Mask + // + CUTLASS_HOST_DEVICE + Mask() { enable(); } + + ///< Efficiently disables all accesses guarded by mask + CUTLASS_HOST_DEVICE void clear() + { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kCount; ++i) { + predicates[i] = false; + } + } + + ///< CUTLASS_HOST_DEVICE enables all accesses guarded by mask + CUTLASS_DEVICE void enable() + { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kCount; ++i) { + predicates[i] = true; + } + } + }; + + /// Shared storage allocation needed by the predicated tile + // iterator for storing rowNorm chunk. + struct SharedStorage { + // + // Type definitions + // + using Shape = MatrixShape; + + /// Shape of the shared memory allocation + using StorageShape = MatrixShape; + + // + // Data members + // + // Methods + // + AlignedBuffer storage; + + CUTLASS_DEVICE + Element* data() { return storage.data(); } + + SharedStorage() {} + + CUTLASS_DEVICE + void initSmem(void* pointer, + const Index& num_rows, + const Index& tb_row_offset, + const LongIndex& stride) + { + Element* shared_elem_arr = data(); + uint8_t* first_tile_byte_pointer_ = + reinterpret_cast(pointer) + LongIndex(tb_row_offset) * LongIndex(stride); + const auto gmem_ptr = reinterpret_cast(first_tile_byte_pointer_); + + for (int row = threadIdx.x; row < total_rows; row += blockDim.x) { + bool guard = (tb_row_offset + row) < num_rows; + cutlass::arch::cp_async(shared_elem_arr + row, gmem_ptr + row, guard); + cutlass::arch::cp_async_wait<0>(); + } + } + }; + + private: + // + // Data members + // + + /// Parameters structure containing reference and precomputed state. + PredicatedTileIteratorParams params_; + + /// Byte-level pointer + uint8_t* byte_pointer_; + + /// Array of boolean values to contain steady-state predicates + Mask mask_; + + /// Extent of the matrix tile in rows + Index extent_row_; + + /// Extent of the matrix tile in rows + Index extent_column_; + + /// A thread's starting row position (assuming steady-state predicates have been computed) + Index thread_start_row_; + + /// A thread's starting column + Index thread_start_column_; + + /// Internal state counter + int state_[3]; + + /// Scatter indices + int const* indices_; + + // + // Static asserts about internal strides + // + + static_assert(sizeof(extent_row_) == 4, "Expected 32b extents"); + static_assert(sizeof(thread_start_row_) == 4, "Expected 32b extents"); + static_assert(sizeof(PredicatedTileIteratorParams::stride) == 8, "Expected 64b strides"); + + private: + // + // Methods + // + + protected: + SharedStorage& shared_storage_; + + public: + // + // Methods + // + + /// Constructor + CUTLASS_DEVICE + PredicatedTileIteratorNormVecSmem(SharedStorage& shared_storage, + PredicatedTileIteratorParams const& params, + Element* pointer, + TensorCoord extent, + int thread_idx, + TensorCoord& threadblock_offset, + int const* indices = nullptr) + : params_(params), indices_(indices), shared_storage_(shared_storage) + { + TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx) + threadblock_offset; + + extent_row_ = extent.row(); + extent_column_ = extent.column(); + + thread_start_row_ = thread_offset.row(); + thread_start_column_ = thread_offset.column(); + + // Initialize predicates + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kColumn; ++c) { + mask_.predicates[c] = + ((thread_offset.column() + ThreadMap::Delta::kColumn * c) < extent.column()); + } + + // Null pointer performs no accesses + if (!pointer) { + mask_.clear(); + return; + } + + if (ScatterD && !indices) { mask_.clear(); } + + // Initialize pointer + byte_pointer_ = reinterpret_cast(pointer) + + LongIndex(thread_offset.row()) * LongIndex(params_.stride); + + if (ScatterD) { + byte_pointer_ = reinterpret_cast(pointer) + + LongIndex(thread_offset.column()) * sizeof(AccessType) / kElementsPerAccess; + } + + if (threadblock_offset.column() == 0) { + shared_storage_.initSmem(pointer, extent_row_, threadblock_offset.row(), params_.stride); + } + + // Initialize internal state counter + state_[0] = state_[1] = state_[2] = 0; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) + { + byte_pointer_ += pointer_offset * sizeof_bits::value / 8; + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_byte_offset(Fragment& frag, int64_t byte_offset) const + { + AccessType* frag_ptr = reinterpret_cast(&frag); + + Element* shared_elem_arr = shared_storage_.data(); + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + int frag_row_idx = + (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); + + int row_offset = row * ThreadMap::Delta::kRow + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + int iter_row = ((row_offset + thread_start_row_) % total_rows); + Element val = shared_elem_arr[iter_row]; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kElementsPerAccess; ++i) { + (*frag_ptr)[frag_row_idx + i] = val; + } + } + } + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment& frag) const { load_with_byte_offset(frag, 0); } + + CUTLASS_DEVICE + MatrixCoord thread_start() const { return MatrixCoord(thread_start_row_, thread_start_column_); } + + /// Need to get the thread start row from the tile iterator + CUTLASS_DEVICE + int32_t thread_start_row() const { return thread_start_row_; } + + /// Need to get the thread start row from the tile iterator + CUTLASS_DEVICE + int32_t thread_start_column() const { return thread_start_column_; } + + /// Extent of the matrix in rows + CUTLASS_DEVICE + Index extent_row() const { return extent_row_; } + + /// Extent of the matrix in columns + CUTLASS_DEVICE + Index extent_column() const { return extent_column_; } + + /// Advances to the next position to load or store + CUTLASS_HOST_DEVICE + PredicatedTileIteratorNormVecSmem& operator++() + { + ++state_[0]; + + if (!ScatterD) { byte_pointer_ += params_.advance_row; } + + thread_start_row_ += ThreadMap::Shape::kRow; + + if (state_[0] == ThreadMap::Count::kRow) { + state_[0] = 0; + ++state_[1]; + byte_pointer_ += params_.advance_group; + + thread_start_row_ += + (ThreadMap::Shape::kGroup - 1) * ThreadMap::Shape::kRow * ThreadMap::Count::kRow; + + if (state_[1] == ThreadMap::Count::kGroup) { + state_[1] = 0; + ++state_[2]; + byte_pointer_ += params_.advance_cluster; + + thread_start_row_ += ThreadMap::Count::kGroup * ThreadMap::Shape::kGroup * + ThreadMap::Count::kRow * ThreadMap::Shape::kRow; + + if (state_[2] == ThreadMap::Count::kCluster) { + state_[2] = 0; + byte_pointer_ += params_.advance_tile; + } + } + } + + return *this; + } + + ///< Efficiently disables all accesses guarded by mask + CUTLASS_DEVICE void clear_mask() { mask_.clear(); } + + ///< Efficiently enables all accesses guarded by mask + CUTLASS_DEVICE void enable_mask() { mask_.enable(); } + + ///< Sets the mask + CUTLASS_DEVICE void get_mask(Mask& mask) const { mask = mask_; } + + ///< Sets the mask + CUTLASS_DEVICE void set_mask(Mask const& mask) { mask_ = mask; } +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/cpp/src/distance/detail/fused_distance_nn/predicated_tile_iterator_reduced_vec.h b/cpp/src/distance/detail/fused_distance_nn/predicated_tile_iterator_reduced_vec.h new file mode 100644 index 000000000..5ffb74e9c --- /dev/null +++ b/cpp/src/distance/detail/fused_distance_nn/predicated_tile_iterator_reduced_vec.h @@ -0,0 +1,610 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* + * Copyright (c) 2023-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! \file + \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. + +This file contains a customized version of PredicatedTileIterator from CUTLASS 2.9.0 +(https://github.com/NVIDIA/cutlass/blob/v2.9.0/include/cutlass/epilogue/threadblock/predicated_tile_iterator.h#L75) + +Changes: +- added `Layout_` template param +- PredicatedTileIteratorParams() is customized to not stride by layout.stride(0). +- makes use of `SharedStorage` to store reduced values across warps to gmem in coalesced manner. +- customized the store_with_byte_offset() to perform reduction per row and write final value to +gmem. +- customized the Params() struct to take user inputs from epilogueOp params. + +*/ + +#pragma once + +#include +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace cg = cooperative_groups; + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { + +//////////////////////////////////////////////////////////////////////////////// + +namespace epilogue { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile iterator used to load and store output tile from global memory in epilogue. +/// +/// Satisfies: ReadableTileIterator | PredicatedTileIterator | ForwardTileIterator +/// +template +class PredicatedTileIteratorReducedVec { + public: + using ThreadMap = ThreadMap_; + using Shape = typename ThreadMap::Shape; + + using Element = Element_; + + using Layout = Layout_; + using TensorRef = TensorRef; + using ConstTensorRef = typename TensorRef::ConstTensorRef; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + using TensorCoord = MatrixCoord; + using EpilogueOpParams = EpilogueOpParams_; + using OutIdxT = typename EpilogueOpParams::CGReduceT::IndexT; + using OutValT = typename EpilogueOpParams::CGReduceT::AccTypeT; + + static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; + static int const kThreads = ThreadMap::kThreads; + static int const kIterations = ThreadMap::Count::kTile; + + static_assert(ThreadMap::Iterations::kRow > 0, "ThreadMap::Iterations::kRow must be > 0"); + static_assert(ThreadMap::Iterations::kGroup > 0, "ThreadMap::Iterations::kGroup must be > 0"); + static_assert(ThreadMap::Iterations::kCluster > 0, "ThreadMap::Iterations::kCluster must be > 0"); + static_assert(ThreadMap::Iterations::kColumn > 0, "ThreadMap::Iterations::kColumn must be > 0"); + static_assert(!UseCUDAStore, "UseCUDAStore path is not supported"); + + static int const total_rows = ThreadMap::kWarpCount * ThreadMap::Iterations::kRow * + ThreadMap::Iterations::kGroup * ThreadMap::Iterations::kCluster * + ThreadMap::Count::kTile * ThreadMap::Delta::kRow; + /// Fragment object + using Fragment = + Array; + + // Memory access size + using AccessType = AlignedArray; + using AccessTypeValT = AlignedArray; + + // + // Parameters struct + // + + /// Uses a non-template class + struct Params : PredicatedTileIteratorParams { + using Base = PredicatedTileIteratorParams; + + EpilogueOpParams user_param; + CUTLASS_HOST_DEVICE + Params() {} + + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : PredicatedTileIteratorParams( + layout.stride(0) * int(sizeof(AccessType)) / kElementsPerAccess, + make_OutputTileThreadMapDesc()) + { + } + + CUTLASS_HOST_DEVICE + Params(Layout const& layout, EpilogueOpParams const& user_param_) + : PredicatedTileIteratorParams(int(sizeof(AccessType)) / kElementsPerAccess, + make_OutputTileThreadMapDesc()), + user_param(user_param_) + { + } + + CUTLASS_HOST_DEVICE + Params(Base const& base) : Base(base) {} + }; + + /// Mask object + struct Mask { + // static int const kCount = ThreadMap::Iterations::kColumn; + static int const kCount = ThreadMap::Iterations::kColumn * kElementsPerAccess; + + /// Predicate state + bool predicates[kCount]; + + // + // Mask + // + CUTLASS_HOST_DEVICE + Mask() { enable(); } + + ///< Efficiently disables all accesses guarded by mask + CUTLASS_HOST_DEVICE void clear() + { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kCount; ++i) { + predicates[i] = false; + } + } + + ///< CUTLASS_HOST_DEVICE enables all accesses guarded by mask + CUTLASS_DEVICE void enable() + { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kCount; ++i) { + predicates[i] = true; + } + } + }; + + /// Shared storage allocation needed by the predicated tile + // iterator for reduction. + struct SharedStorage { + // + // Type definitions + // + using Shape = MatrixShape; + + /// Shape of the shared memory allocation for the reduced values store + using StorageShape = MatrixShape; + + // + // Data members + + // + // Methods + // + AlignedBuffer storage; + + CUTLASS_DEVICE + Element* data() { return storage.data(); } + + SharedStorage() {} + + CUTLASS_DEVICE + void initSmem(EpilogueOpParams const& user_params) + { + Element* shared_elem_arr = data(); + constexpr auto maxVal = std::numeric_limits::max(); + + for (int row = threadIdx.x; row < total_rows; row += blockDim.x) { + user_params.red_op_.init(&shared_elem_arr[row], maxVal); + } + } + }; + + template + struct select_reduce { + /// Performs warp level reduction and stores a reduced output to memory + CUTLASS_DEVICE + select_reduce(OutT value, + ValT prev_red_val, + cg_reduce_op_t reduce_op, + cg_group_t cg_warp_group, + OutT& shmem_ptr) + { + if (cg_warp_group.any(reduce_op.isAmin(value, prev_red_val))) { + OutT reduced_val = cg::reduce(cg_warp_group, value, reduce_op); + if (cg_warp_group.thread_rank() == 0) { shmem_ptr = reduced_val; } + } + } + }; + + template + struct select_reduce> { + using ValT = float; + using Ty = raft::KeyValuePair; + /// Performs warp level reduction of key value pair and stores a reduced output to memory + CUTLASS_DEVICE + select_reduce(Ty val_to_red, + float prev_red_val, + cg_reduce_op_t cg_reduce_op, + cg_group_t cg_warp_group, + Ty& shmem_ptr) + { + ValT val = val_to_red.value; + + if (cg_warp_group.any(cg_reduce_op.isAmin(val, prev_red_val))) { + ValT reduced_val = cg::reduce(cg_warp_group, val, cg_reduce_op); + bool pred = (reduced_val == val); + auto subTile = cg::binary_partition(cg_warp_group, pred); + if (pred) { + if (subTile.thread_rank() == 0) { shmem_ptr = val_to_red; } + } + } + } + }; + + template + struct select_reduce> { + using ValT = double; + using Ty = raft::KeyValuePair; + /// Performs warp level reduction of key value pair and stores a reduced output to memory + CUTLASS_DEVICE + select_reduce(Ty val_to_red, + double prev_red_val, + cg_reduce_op_t cg_reduce_op, + cg_group_t cg_warp_group, + Ty& shmem_ptr) + { + ValT val = val_to_red.value; + + if (cg_warp_group.any(cg_reduce_op.isAmin(val, prev_red_val))) { + ValT reduced_val = cg::reduce(cg_warp_group, val, cg_reduce_op); + bool pred = (reduced_val == val); + auto subTile = cg::binary_partition(cg_warp_group, pred); + if (pred) { + if (subTile.thread_rank() == 0) { shmem_ptr = val_to_red; } + } + } + } + }; + + private: + // + // Data members + // + + /// Parameters structure containing reference and precomputed state. + Params params_; + + /// Byte-level pointer first tile offset of this threadblock. + volatile uint8_t* first_tile_byte_pointer_; + + /// Array of boolean values to contain steady-state predicates + Mask mask_; + + /// Extent of the matrix tile in rows + Index extent_row_; + + /// Extent of the matrix tile in rows + Index extent_column_; + + /// A thread's starting row position (assuming steady-state predicates have been computed) + Index thread_start_row_; + Index block_start_row_first_tile_; + + /// A thread's starting column + Index thread_start_column_; + + /// Internal state counter + int state_[3]; + // mutable int shared_tile_id; + + /// Scatter indices + int const* indices_; + + const int do_gmem_reduction_; + + // + // Static asserts about internal strides + // + + static_assert(sizeof(extent_row_) == 4, "Expected 32b extents"); + static_assert(sizeof(thread_start_row_) == 4, "Expected 32b extents"); + static_assert(sizeof(Params::stride) == 8, "Expected 64b strides"); + + protected: + SharedStorage& shared_storage_; + + private: + // + // Methods + // + public: + // + // Methods + // + /// Constructor + CUTLASS_DEVICE + PredicatedTileIteratorReducedVec(SharedStorage& shared_storage, + Params const& params, + volatile Element* pointer, + TensorCoord extent, + int thread_idx, + const bool do_gmem_reduction, + TensorCoord threadblock_offset = TensorCoord(), + int const* indices = nullptr) + : params_(params), + indices_(indices), + shared_storage_(shared_storage), + do_gmem_reduction_(do_gmem_reduction) + { + TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx) + threadblock_offset; + + extent_row_ = extent.row(); + extent_column_ = extent.column(); + + thread_start_row_ = thread_offset.row(); + thread_start_column_ = thread_offset.column(); + + TensorCoord block_offset = ThreadMap::initial_offset(0) + threadblock_offset; + block_start_row_first_tile_ = block_offset.row(); + + // Initialize predicates + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kColumn * kElementsPerAccess; ++c) { + int columnPerAccess = (c / kElementsPerAccess); + int columnWithinPerAccess = c % kElementsPerAccess; + mask_.predicates[c] = ((thread_offset.column() + ThreadMap::Delta::kColumn * columnPerAccess + + columnWithinPerAccess) < extent.column()); + } + + if (threadblock_offset.column() == 0) { + EpilogueOpParams const& user_params = params_.user_param; + shared_storage_.initSmem(user_params); + } + __syncthreads(); + + // Null pointer performs no accesses + if (!pointer) { mask_.clear(); } + + if (ScatterD && !indices) { mask_.clear(); } + + // Initialize pointer + first_tile_byte_pointer_ = reinterpret_cast(pointer) + + LongIndex(block_offset.row()) * LongIndex(params_.stride); + + // Initialize internal state counter + state_[0] = state_[1] = state_[2] = 0; + } + + CUTLASS_DEVICE void dumpToGmem() + { + if (block_start_row_first_tile_ >= extent_row_) { return; } + + if (do_gmem_reduction_) { + EpilogueOpParams const& user_params = params_.user_param; + const uint32_t mutex_id = (block_start_row_first_tile_ / total_rows); + const bool useGmemMutex = (gridDim.x != ((extent_row_ - 1 + total_rows) / total_rows)); + int row = threadIdx.x; + Element* shared_elem_arr = shared_storage_.data(); + Element row_local_min; + if (row < total_rows) { row_local_min = shared_elem_arr[row]; } + + // single lock per block for multiple rows + if (useGmemMutex && threadIdx.x == 0) { user_params.bin_mutex_[mutex_id].acquire(); } + __syncthreads(); + + if (row < total_rows) { + volatile Element* gmem_ptr = reinterpret_cast(first_tile_byte_pointer_); + + if ((block_start_row_first_tile_ + row) < extent_row_) { + user_params.red_op_(block_start_row_first_tile_ + row, (gmem_ptr + row), row_local_min); + } + } + + __syncthreads(); + __threadfence(); + + if (useGmemMutex && (threadIdx.x == 0)) { + // release mutex lock. + user_params.bin_mutex_[mutex_id].release(); + } + shared_storage_.initSmem(user_params); + __syncthreads(); + } + } + + /// Destructor + CUTLASS_DEVICE + ~PredicatedTileIteratorReducedVec() {} + + /// Performs reduction and Stores a reduced output to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment& frag, int64_t byte_offset) const + { + AccessTypeValT* frag_ptr = reinterpret_cast(&frag); + + cg::thread_block cta = cg::this_thread_block(); + // tile_width 16 is required if kElementPerAccess > 1 + constexpr int tile_width = (32 / ThreadMap::Delta::kColumn) ? 32 : 16; + cg::thread_block_tile tile32 = cg::tiled_partition(cta); + EpilogueOpParams const& user_params = params_.user_param; + + using cg_reduce_t = decltype(user_params.cg_reduce_op); + using tile32_t = decltype(tile32); + + Element* shared_elem_arr = shared_storage_.data(); + constexpr auto maxVal = std::numeric_limits::max(); + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + int frag_row_idx = + (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); + + int row_offset = row * ThreadMap::Delta::kRow + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + + const OutIdxT row_id = row_offset + thread_start_row_; + bool row_guard = (row_id < extent_row_); + + const int frag_idx = frag_row_idx * ThreadMap::Iterations::kColumn * kElementsPerAccess; + Element red_val; + user_params.red_op_.init(&red_val, maxVal); + + if (row_guard) { + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn * kElementsPerAccess; + ++column) { + int columnPerAccess = column / kElementsPerAccess; + int columnWithPerAccess = column % kElementsPerAccess; + bool guard = mask_.predicates[column]; + if (guard) { + const OutIdxT key_id = thread_start_column_ + + ThreadMap::Delta::kColumn * columnPerAccess + + columnWithPerAccess; + const int frag_col_idx = frag_idx + column; + + Element this_val; + user_params.red_op_.init(&this_val, (*frag_ptr)[frag_col_idx]); + user_params.red_op_.init_key(this_val, key_id); + user_params.red_op_(row_id, &red_val, this_val); + } + } + } + const int iter_row = (row_id % total_rows); + const auto prev_red_val = user_params.red_op_.get_value(shared_elem_arr[iter_row]); + if (row_guard) { + // select_reduce doesn't need to use `red_op_` as at the warp level we use cg_reduce_op, + // this satisfies the requirement of mst/single linkage of checking colors buffer. + select_reduce red_obj( + red_val, prev_red_val, user_params.cg_reduce_op, tile32, shared_elem_arr[iter_row]); + } + } + } + } + __syncthreads(); + } + + /// Stores a fragment to memory + CUTLASS_DEVICE + void store(Fragment& frag) const { store_with_byte_offset(frag, 0); } + + CUTLASS_DEVICE + MatrixCoord thread_start() const { return MatrixCoord(thread_start_row_, thread_start_column_); } + + /// Need to get the thread start row from the tile iterator + CUTLASS_DEVICE + int32_t thread_start_row() const { return thread_start_row_; } + + /// Need to get the thread start row from the tile iterator + CUTLASS_DEVICE + int32_t thread_start_column() const { return thread_start_column_; } + + /// Extent of the matrix in rows + CUTLASS_DEVICE + Index extent_row() const { return extent_row_; } + + /// Extent of the matrix in columns + CUTLASS_DEVICE + Index extent_column() const { return extent_column_; } + + /// Advances to the next position to load or store + CUTLASS_HOST_DEVICE + PredicatedTileIteratorReducedVec& operator++() + { + ++state_[0]; + + thread_start_row_ += ThreadMap::Shape::kRow; + + if (state_[0] == ThreadMap::Count::kRow) { + state_[0] = 0; + ++state_[1]; + + thread_start_row_ += + (ThreadMap::Shape::kGroup - 1) * ThreadMap::Shape::kRow * ThreadMap::Count::kRow; + + if (state_[1] == ThreadMap::Count::kGroup) { + state_[1] = 0; + ++state_[2]; + + thread_start_row_ += ThreadMap::Count::kGroup * ThreadMap::Shape::kGroup * + ThreadMap::Count::kRow * ThreadMap::Shape::kRow; + + if (state_[2] == ThreadMap::Count::kCluster) { state_[2] = 0; } + } + } + return *this; + } + + ///< Efficiently disables all accesses guarded by mask + CUTLASS_DEVICE void clear_mask() { mask_.clear(); } + + ///< Efficiently enables all accesses guarded by mask + CUTLASS_DEVICE void enable_mask() { mask_.enable(); } + + ///< Sets the mask + CUTLASS_DEVICE void get_mask(Mask& mask) const { mask = mask_; } + + ///< Sets the mask + CUTLASS_DEVICE void set_mask(Mask const& mask) { mask_ = mask; } +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/cpp/src/distance/detail/fused_distance_nn/simt_kernel.cuh b/cpp/src/distance/detail/fused_distance_nn/simt_kernel.cuh new file mode 100644 index 000000000..184063c8b --- /dev/null +++ b/cpp/src/distance/detail/fused_distance_nn/simt_kernel.cuh @@ -0,0 +1,187 @@ +/* + * Copyright (c) 2021-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "../distance_ops/l2_exp.cuh" // ops::l2_exp_distance_op +#include "../pairwise_distance_base.cuh" // PairwiseDistances +#include // raft::KeyValuePair +#include // Policy + +#include // size_t +#include // std::numeric_limits + +namespace cuvs { +namespace distance { +namespace detail { + +// TODO: specialize this function for MinAndDistanceReduceOp +// with atomicCAS of 64 bit which will eliminate mutex and shfls +template +DI void updateReducedVal( + int* mutex, OutT* min, KVPair* val, ReduceOpT red_op, IdxT m, IdxT gridStrideY) +{ + const auto lid = threadIdx.x % raft::WarpSize; + const auto accrowid = threadIdx.x / P::AccThCols; + + // Update each output row in order within a warp. This will resolve hang + // issues with pre-Volta architectures +#pragma unroll + for (int j = 0; j < (raft::WarpSize / P::AccThCols); j++) { + if (lid == j * P::AccThCols) { +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { + auto rid = gridStrideY + accrowid + i * P::AccThRows; + if (rid < m) { + auto value = val[i]; + while (atomicCAS(mutex + rid, 0, 1) == 1) + ; + __threadfence(); + red_op(rid, min + rid, value); + __threadfence(); + atomicCAS(mutex + rid, 1, 0); + } + } + } + } +} + +template +__launch_bounds__(P::Nthreads, 2) RAFT_KERNEL fusedDistanceNNkernel(OutT* min, + const DataT* x, + const DataT* y, + const DataT* xn, + const DataT* yn, + IdxT m, + IdxT n, + IdxT k, + DataT maxVal, + int* mutex, + ReduceOpT redOp, + KVPReduceOpT pairRedOp, + OpT distance_op, + FinalLambda fin_op) +{ +// compile only if below non-ampere arch. +#if __CUDA_ARCH__ < 800 + extern __shared__ char smem[]; + + typedef raft::KeyValuePair KVPair; + KVPair val[P::AccRowsPerTh]; +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { + val[i] = {0, maxVal}; + } + + // epilogue operation lambda for final value calculation + auto epilog_lambda = [n, pairRedOp, &val, maxVal] __device__( + DataT acc[P::AccRowsPerTh][P::AccColsPerTh], + DataT * regxn, + DataT * regyn, + IdxT gridStrideX, + IdxT gridStrideY) { + KVPReduceOpT pairRed_op(pairRedOp); + + // intra thread reduce + const auto acccolid = threadIdx.x % P::AccThCols; + const auto accrowid = threadIdx.x / P::AccThCols; +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { +#pragma unroll + for (int j = 0; j < P::AccColsPerTh; ++j) { + auto tmpkey = acccolid + j * P::AccThCols + gridStrideX; + KVPair tmp = {tmpkey, acc[i][j]}; + if (tmpkey < n) { + val[i] = pairRed_op(accrowid + i * P::AccThRows + gridStrideY, tmp, val[i]); + } + } + } + }; + + auto rowEpilog_lambda = + [m, mutex, min, pairRedOp, redOp, &val, maxVal] __device__(IdxT gridStrideY) { + KVPReduceOpT pairRed_op(pairRedOp); + ReduceOpT red_op(redOp); + + const auto accrowid = threadIdx.x / P::AccThCols; + const auto lid = raft::laneId(); + + // reduce +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { +#pragma unroll + for (int j = P::AccThCols / 2; j > 0; j >>= 1) { + // Actually, the srcLane (lid +j) should be (lid +j) % P:AccThCols, + // but the shfl op applies the modulo internally. + auto tmpkey = raft::shfl(val[i].key, lid + j, P::AccThCols); + auto tmpvalue = raft::shfl(val[i].value, lid + j, P::AccThCols); + KVPair tmp = {tmpkey, tmpvalue}; + val[i] = pairRed_op(accrowid + i * P::AccThRows + gridStrideY, tmp, val[i]); + } + } + + updateReducedVal(mutex, min, val, red_op, m, gridStrideY); + + // reset the val array. +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { + val[i] = {0, maxVal}; + } + }; + + IdxT lda = k, ldb = k, ldd = n; + constexpr bool row_major = true; + constexpr bool write_out = false; + PairwiseDistances + obj(x, + y, + m, + n, + k, + lda, + ldb, + ldd, + xn, + yn, + nullptr, // Output pointer + smem, + distance_op, + epilog_lambda, + fin_op, + rowEpilog_lambda); + obj.run(); +#endif +} + +} // namespace detail +} // namespace distance +} // namespace cuvs diff --git a/cpp/src/distance/detail/fused_l2_nn.cuh b/cpp/src/distance/detail/fused_l2_nn.cuh new file mode 100644 index 000000000..3b8c426ea --- /dev/null +++ b/cpp/src/distance/detail/fused_l2_nn.cuh @@ -0,0 +1,386 @@ +/* + * Copyright (c) 2021-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "../distance_ops/l2_exp.cuh" // ops::l2_exp_distance_op +#include "../fused_distance_nn/cutlass_base.cuh" +#include "../pairwise_distance_base.cuh" // PairwiseDistances +#include // raft::KeyValuePair +#include // raft::identity_op +#include // Policy +#include // raft::util::arch::SM_* +#include // raft::ceildiv, raft::shfl + +#include // size_t +#include // std::numeric_limits + +namespace cuvs { +namespace distance { + +namespace detail { + +template +struct KVPMinReduceImpl { + typedef raft::KeyValuePair KVP; + DI KVP operator()(LabelT rit, const KVP& a, const KVP& b) { return b.value < a.value ? b : a; } + DI KVP operator()(const KVP& a, const KVP& b) { return b.value < a.value ? b : a; } + +}; // KVPMinReduce + +template +struct MinAndDistanceReduceOpImpl { + typedef typename raft::KeyValuePair KVP; + DI void operator()(LabelT rid, KVP* out, const KVP& other) const + { + if (other.value < out->value) { + out->key = other.key; + out->value = other.value; + } + } + + DI void operator()(LabelT rid, DataT* out, const KVP& other) const + { + if (other.value < *out) { *out = other.value; } + } + + DI void operator()(LabelT rid, DataT* out, const DataT& other) const + { + if (other < *out) { *out = other; } + } + + DI void init(DataT* out, DataT maxVal) const { *out = maxVal; } + DI void init(KVP* out, DataT maxVal) const { out->value = maxVal; } + + DI void init_key(DataT& out, LabelT idx) const { return; } + DI void init_key(KVP& out, LabelT idx) const { out.key = idx; } + + DI DataT get_value(KVP& out) const + { + return out.value; + ; + } + DI DataT get_value(DataT& out) const { return out; } +}; + +template +struct MinReduceOpImpl { + typedef typename raft::KeyValuePair KVP; + DI void operator()(LabelT rid, DataT* out, const KVP& other) + { + if (other.value < *out) { *out = other.value; } + } + + DI void init(DataT* out, DataT maxVal) { *out = maxVal; } +}; + +template +RAFT_KERNEL initKernel(OutT* min, IdxT m, DataT maxVal, ReduceOpT redOp) +{ + auto tid = IdxT(blockIdx.x) * blockDim.x + threadIdx.x; + if (tid < m) { redOp.init(min + tid, maxVal); } +} + +template +void initialize(OutT* min, IdxT m, DataT maxVal, ReduceOpT redOp, cudaStream_t stream) +{ + auto blks = raft::ceildiv(m, 256); + initKernel<<>>(min, m, maxVal, redOp); +} + +// TODO: specialize this function for MinAndDistanceReduceOp +// with atomicCAS of 64 bit which will eliminate mutex and shfls +template +DI void updateReducedVal( + int* mutex, OutT* min, KVPair* val, ReduceOpT red_op, IdxT m, IdxT gridStrideY) +{ + const auto lid = threadIdx.x % raft::WarpSize; + const auto accrowid = threadIdx.x / P::AccThCols; + + // Update each output row in order within a warp. This will resolve hang + // issues with pre-Volta architectures +#pragma unroll + for (int j = 0; j < (raft::WarpSize / P::AccThCols); j++) { + if (lid == j * P::AccThCols) { +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { + auto rid = gridStrideY + accrowid + i * P::AccThRows; + if (rid < m) { + auto value = val[i]; + while (atomicCAS(mutex + rid, 0, 1) == 1) + ; + __threadfence(); + red_op(rid, min + rid, value); + __threadfence(); + atomicCAS(mutex + rid, 1, 0); + } + } + } + } +} + +template +__launch_bounds__(P::Nthreads, 2) RAFT_KERNEL fusedL2NNkernel(OutT* min, + const DataT* x, + const DataT* y, + const DataT* xn, + const DataT* yn, + IdxT m, + IdxT n, + IdxT k, + DataT maxVal, + int* mutex, + ReduceOpT redOp, + KVPReduceOpT pairRedOp, + OpT distance_op, + FinalLambda fin_op) +{ +// compile only if below non-ampere arch. +#if __CUDA_ARCH__ < 800 + extern __shared__ char smem[]; + + typedef raft::KeyValuePair KVPair; + KVPair val[P::AccRowsPerTh]; +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { + val[i] = {0, maxVal}; + } + + // epilogue operation lambda for final value calculation + auto epilog_lambda = [n, pairRedOp, &val, maxVal] __device__( + DataT acc[P::AccRowsPerTh][P::AccColsPerTh], + DataT * regxn, + DataT * regyn, + IdxT gridStrideX, + IdxT gridStrideY) { + KVPReduceOpT pairRed_op(pairRedOp); + + // intra thread reduce + const auto acccolid = threadIdx.x % P::AccThCols; + const auto accrowid = threadIdx.x / P::AccThCols; +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { +#pragma unroll + for (int j = 0; j < P::AccColsPerTh; ++j) { + auto tmpkey = acccolid + j * P::AccThCols + gridStrideX; + KVPair tmp = {tmpkey, acc[i][j]}; + if (tmpkey < n) { + val[i] = pairRed_op(accrowid + i * P::AccThRows + gridStrideY, tmp, val[i]); + } + } + } + }; + + auto rowEpilog_lambda = + [m, mutex, min, pairRedOp, redOp, &val, maxVal] __device__(IdxT gridStrideY) { + KVPReduceOpT pairRed_op(pairRedOp); + ReduceOpT red_op(redOp); + + const auto accrowid = threadIdx.x / P::AccThCols; + const auto lid = raft::laneId(); + + // reduce +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { +#pragma unroll + for (int j = P::AccThCols / 2; j > 0; j >>= 1) { + // Actually, the srcLane (lid +j) should be (lid +j) % P:AccThCols, + // but the shfl op applies the modulo internally. + auto tmpkey = raft::shfl(val[i].key, lid + j, P::AccThCols); + auto tmpvalue = raft::shfl(val[i].value, lid + j, P::AccThCols); + KVPair tmp = {tmpkey, tmpvalue}; + val[i] = pairRed_op(accrowid + i * P::AccThRows + gridStrideY, tmp, val[i]); + } + } + + updateReducedVal(mutex, min, val, red_op, m, gridStrideY); + + // reset the val array. +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { + val[i] = {0, maxVal}; + } + }; + + IdxT lda = k, ldb = k, ldd = n; + constexpr bool row_major = true; + constexpr bool write_out = false; + PairwiseDistances + obj(x, + y, + m, + n, + k, + lda, + ldb, + ldd, + xn, + yn, + nullptr, // Output pointer + smem, + distance_op, + epilog_lambda, + fin_op, + rowEpilog_lambda); + obj.run(); +#endif +} + +// cg::reduce functor for FusedDistanceNN used in its cutlass version +// to output the min distance value & key(loc id). +// This is used in fused_distance_nn/predicated_tile_iterator_reduced_vec.h +// store_with_byte_offset() passed to cg::reduce() & select_reduce. +template +struct kvp_cg_min_reduce_op { + typedef typename raft::KeyValuePair KVP; + + __host__ __device__ kvp_cg_min_reduce_op() noexcept {}; + + using AccTypeT = AccType; + using IndexT = Index; + // functor signature. + __host__ __device__ KVP operator()(KVP a, KVP b) const { return a.value < b.value ? a : b; } + + __host__ __device__ AccType operator()(AccType a, AccType b) const { return min(a, b); } + + __host__ __device__ bool isAmin(AccType a, AccType b) const { return a < b ? true : false; } +}; + +template +void fusedL2NNImpl(OutT* min, + const DataT* x, + const DataT* y, + const DataT* xn, + const DataT* yn, + IdxT m, + IdxT n, + IdxT k, + int* workspace, + ReduceOpT redOp, + KVPReduceOpT pairRedOp, + bool sqrt, + bool initOutBuffer, + cudaStream_t stream) +{ + // The kernel policy is determined by fusedL2NN. + typedef Policy P; + + dim3 blk(P::Nthreads); + auto nblks = raft::ceildiv(m, P::Nthreads); + constexpr auto maxVal = std::numeric_limits::max(); + typedef raft::KeyValuePair KVPair; + + RAFT_CUDA_TRY(cudaMemsetAsync(workspace, 0, sizeof(int) * m, stream)); + if (initOutBuffer) { + initKernel + <<>>(min, m, maxVal, redOp); + RAFT_CUDA_TRY(cudaGetLastError()); + } + + namespace arch = raft::util::arch; + using AccT = DataT; + ops::l2_exp_distance_op distance_op{sqrt}; + + raft::identity_op fin_op{}; + + auto kernel = fusedL2NNkernel; + + // Get pointer to fp32 SIMT kernel to determine the best compute architecture + // out of all for which the kernel was compiled for that matches closely + // to the current device. Other methods to determine the architecture (that do not + // require a pointer) can be error prone. See: + // https://github.com/NVIDIA/cub/issues/545 + void* kernel_ptr = reinterpret_cast(kernel); + auto runtime_arch = arch::kernel_virtual_arch(kernel_ptr); + auto cutlass_range = arch::SM_range(arch::SM_80(), arch::SM_future()); + + if (cutlass_range.contains(runtime_arch)) { + // If device is SM_80 or later, use CUTLASS-based kernel. + using L2Op = cuvs::distance::detail::ops::l2_exp_cutlass_op; + using kvp_cg_min_reduce_op_ = kvp_cg_min_reduce_op; + kvp_cg_min_reduce_op_ cg_reduce_op; + L2Op L2_dist_op(sqrt); + + IdxT lda, ldb, ldd; + lda = k, ldb = k, ldd = n; + + cutlassFusedDistanceNN(x, + y, + xn, + yn, + m, + n, + k, + lda, + ldb, + ldd, + min, + workspace, + cg_reduce_op, + L2_dist_op, + redOp, + pairRedOp, + stream); + } else { + // If device less than SM_80, use fp32 SIMT kernel. + constexpr size_t shmemSize = P::SmemSize + ((P::Mblk + P::Nblk) * sizeof(DataT)); + dim3 grid = launchConfigGenerator

(m, n, shmemSize, kernel); + + kernel<<>>( + min, x, y, xn, yn, m, n, k, maxVal, workspace, redOp, pairRedOp, distance_op, fin_op); + RAFT_CUDA_TRY(cudaGetLastError()); + } +} + +} // namespace detail +} // namespace distance +} // namespace cuvs diff --git a/cpp/src/distance/detail/kernels/gram_matrix.cuh b/cpp/src/distance/detail/kernels/gram_matrix.cuh new file mode 100644 index 000000000..009941fa7 --- /dev/null +++ b/cpp/src/distance/detail/kernels/gram_matrix.cuh @@ -0,0 +1,488 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "../../distance.cuh" +#include +#include +#include +#include +// #include +#include +#include +#include +#include + +namespace cuvs::distance::kernels::detail { + +template +using dense_input_matrix_view_t = raft::device_matrix_view; +template +using dense_output_matrix_view_t = raft::device_matrix_view; +template +using csr_input_matrix_view_t = raft::device_csr_matrix_view; + +/** + * Base class for general Gram matrices + * A Gram matrix is the Hermitian matrix of inner probucts G_ik = + * Here, the inner product is evaluated for all elements from vectors sets X1, + * and X2. + * + * To be more precise, on exit the output buffer will store: + * - if is_row_major == true: out[j+k*n1] = , + * - if is_row_major == false: out[j*n2 + k] = , + * where x1_j is the j-th vector from the x1 set and x2_k is the k-th vector + * from the x2 set. + */ +template +class GramMatrixBase { + protected: + cublasHandle_t cublas_handle; + bool legacy_interface; + + public: + GramMatrixBase() : legacy_interface(false){}; + [[deprecated]] GramMatrixBase(cublasHandle_t cublas_handle) + : cublas_handle(cublas_handle), legacy_interface(true){}; + + virtual ~GramMatrixBase(){}; + + /** Convenience function to evaluate the Gram matrix for two vector sets. + * Vector sets are provided in Matrix format + * + * @param [in] handle raft handle + * @param [in] x1 dense device matrix view, size [n1*n_cols] + * @param [in] x2 dense device matrix view, size [n2*n_cols] + * @param [out] out dense device matrix view for the Gram matrix, size [n1*n2] + * @param norm_x1 optional L2-norm of x1's rows for computation within RBF. + * @param norm_x2 optional L2-norm of x2's rows for computation within RBF. + */ + void operator()(raft::resources const& handle, + dense_input_matrix_view_t x1, + dense_input_matrix_view_t x2, + dense_output_matrix_view_t out, + math_t* norm_x1 = nullptr, + math_t* norm_x2 = nullptr) + { + evaluate(handle, x1, x2, out, norm_x1, norm_x2); + } + + /** Convenience function to evaluate the Gram matrix for two vector sets. + * Vector sets are provided in Matrix format + * + * @param [in] handle raft handle + * @param [in] x1 csr device matrix view, size [n1*n_cols] + * @param [in] x2 dense device matrix view, size [n2*n_cols] + * @param [out] out dense device matrix view for the Gram matrix, size [n1*n2] + * @param norm_x1 optional L2-norm of x1's rows for computation within RBF. + * @param norm_x2 optional L2-norm of x2's rows for computation within RBF. + */ + void operator()(raft::resources const& handle, + csr_input_matrix_view_t x1, + dense_input_matrix_view_t x2, + dense_output_matrix_view_t out, + math_t* norm_x1 = nullptr, + math_t* norm_x2 = nullptr) + { + evaluate(handle, x1, x2, out, norm_x1, norm_x2); + } + + /** Convenience function to evaluate the Gram matrix for two vector sets. + * Vector sets are provided in Matrix format + * + * @param [in] handle raft handle + * @param [in] x1 csr device matrix view, size [n1*n_cols] + * @param [in] x2 csr device matrix view, size [n2*n_cols] + * @param [out] out dense device matrix view for the Gram matrix, size [n1*n2] + * @param norm_x1 optional L2-norm of x1's rows for computation within RBF. + * @param norm_x2 optional L2-norm of x2's rows for computation within RBF. + */ + void operator()(raft::resources const& handle, + csr_input_matrix_view_t x1, + csr_input_matrix_view_t x2, + dense_output_matrix_view_t out, + math_t* norm_x1 = nullptr, + math_t* norm_x2 = nullptr) + { + evaluate(handle, x1, x2, out, norm_x1, norm_x2); + } + + // unfortunately, 'evaluate' cannot be templatized as it needs to be virtual + + /** Evaluate the Gram matrix for two vector sets using simple dot product. + * + * @param [in] handle raft handle + * @param [in] x1 dense device matrix view, size [n1*n_cols] + * @param [in] x2 dense device matrix view, size [n2*n_cols] + * @param [out] out dense device matrix view for the Gram matrix, size [n1*n2] + * @param norm_x1 unused. + * @param norm_x2 unused. + */ + virtual void evaluate(raft::resources const& handle, + dense_input_matrix_view_t x1, + dense_input_matrix_view_t x2, + dense_output_matrix_view_t out, + math_t* norm_x1, + math_t* norm_x2) + { + linear(handle, x1, x2, out); + } + /** Evaluate the Gram matrix for two vector sets using simple dot product. + * + * @param [in] handle raft handle + * @param [in] x1 csr device matrix view, size [n1*n_cols] + * @param [in] x2 dense device matrix view, size [n2*n_cols] + * @param [out] out dense device matrix view for the Gram matrix, size [n1*n2] + * @param norm_x1 unused. + * @param norm_x2 unused. + */ + virtual void evaluate(raft::resources const& handle, + csr_input_matrix_view_t x1, + dense_input_matrix_view_t x2, + dense_output_matrix_view_t out, + math_t* norm_x1, + math_t* norm_x2) + { + linear(handle, x1, x2, out); + } + /** Evaluate the Gram matrix for two vector sets using simple dot product. + * + * @param [in] handle raft handle + * @param [in] x1 csr device matrix view, size [n1*n_cols] + * @param [in] x2 csr device matrix view, size [n2*n_cols] + * @param [out] out dense device matrix view for the Gram matrix, size [n1*n2] + * @param norm_x1 unused. + * @param norm_x2 unused. + */ + virtual void evaluate(raft::resources const& handle, + csr_input_matrix_view_t x1, + csr_input_matrix_view_t x2, + dense_output_matrix_view_t out, + math_t* norm_x1, + math_t* norm_x2) + { + linear(handle, x1, x2, out); + } + + /** Evaluate the Gram matrix for two vector sets using simple dot product. + * + * @param [in] x1 device array of vectors, size [n1*n_cols] + * @param [in] n1 number vectors in x1 + * @param [in] n_cols number of columns (features) in x1 and x2 + * @param [in] x2 device array of vectors, size [n2*n_cols] + * @param [in] n2 number vectors in x2 + * @param [out] out device buffer to store the Gram matrix, size [n1*n2] + * @param [in] is_row_major whether the input and output matrices are in row + * major format + * @param [in] stream cuda stream + * @param ld1 leading dimension of x1 (usually it is n1) + * @param ld2 leading dimension of x2 (usually it is n2) + * @param ld_out leading dimension of out (usually it is n1) + */ + [[deprecated]] virtual void evaluate(const math_t* x1, + int n1, + int n_cols, + const math_t* x2, + int n2, + math_t* out, + bool is_row_major, + cudaStream_t stream, + int ld1, + int ld2, + int ld_out) + { + linear(x1, n1, n_cols, x2, n2, out, is_row_major, stream, ld1, ld2, ld_out); + } + + /** Convenience function to evaluate the Gram matrix for two vector sets. + * + * @param [in] x1 device array of vectors, size [n1*n_cols] + * @param [in] n1 number vectors in x1 + * @param [in] n_cols number of columns (features) in x1 and x2 + * @param [in] x2 device array of vectors, size [n2*n_cols] + * @param [in] n2 number vectors in x2 + * @param [out] out device buffer to store the Gram matrix, size [n1*n2] + * @param [in] is_row_major whether the input and output matrices are in row + * major format + * @param [in] stream cuda stream + * @param ld1 leading dimension of x1 + * @param ld2 leading dimension of x2 + * @param ld_out leading dimension of out + */ + [[deprecated]] void operator()(const math_t* x1, + int n1, + int n_cols, + const math_t* x2, + int n2, + math_t* out, + bool is_row_major, + cudaStream_t stream, + int ld1 = 0, + int ld2 = 0, + int ld_out = 0) + { + ASSERT(legacy_interface, "Legacy interface can only be used with legacy ctor."); + if (ld1 <= 0) { ld1 = is_row_major ? n_cols : n1; } + if (ld2 <= 0) { ld2 = is_row_major ? n_cols : n2; } + if (ld_out <= 0) { ld_out = is_row_major ? n2 : n1; } + evaluate(x1, n1, n_cols, x2, n2, out, is_row_major, stream, ld1, ld2, ld_out); + } + + protected: + /** Calculates the Gram matrix using simple dot product between vector sets. + * + * out = x1 * x2 + * + * Can be used as a building block for more complex kernel functions. + * + * @param [in] x1 device array of vectors, size [n1*n_cols] + * @param [in] n1 number vectors in x1 + * @param [in] n_cols number of columns (features) in x1 and x2 + * @param [in] x2 device array of vectors, size [n2*n_cols] + * @param [in] n2 number vectors in x2 + * @param [out] out device buffer to store the Gram matrix, size [n1*n2] + * @param [in] is_row_major whether the input and output matrices are in row + * major format + * @param [in] stream cuda stream + * @param ld1 leading dimension of x1 + * @param ld2 leading dimension of x2 + * @param ld_out leading dimension of out + */ + [[deprecated]] void linear(const math_t* x1, + int n1, + int n_cols, + const math_t* x2, + int n2, + math_t* out, + bool is_row_major, + cudaStream_t stream, + int ld1, + int ld2, + int ld_out) + { + math_t alpha = 1.0; + math_t beta = 0.0; + if (is_row_major) { + // #TODO: Call from public API when ready + RAFT_CUBLAS_TRY(raft::linalg::detail::cublasgemm(cublas_handle, + CUBLAS_OP_T, + CUBLAS_OP_N, + n2, + n1, + n_cols, + &alpha, + x2, + ld2, + x1, + ld1, + &beta, + out, + ld_out, + stream)); + } else { + // #TODO: Call from public API when ready + RAFT_CUBLAS_TRY(raft::linalg::detail::cublasgemm(cublas_handle, + CUBLAS_OP_N, + CUBLAS_OP_T, + n1, + n2, + n_cols, + &alpha, + x1, + ld1, + x2, + ld2, + &beta, + out, + ld_out, + stream)); + } + } + + protected: + bool get_is_row_major(dense_output_matrix_view_t matrix) + { + return (matrix.stride(1) == 1); + } + + bool get_is_row_major(dense_input_matrix_view_t matrix) + { + return (matrix.stride(1) == 1); + } + + bool get_is_col_major(dense_output_matrix_view_t matrix) + { + return (matrix.stride(0) == 1); + } + + bool get_is_col_major(dense_input_matrix_view_t matrix) + { + return (matrix.stride(0) == 1); + } + + /** Calculates the Gram matrix using simple dot product between vector sets. + * + * out = x1 * x2 + * + * Can be used as a building block for more complex kernel functions. + * + * @param [in] handle raft handle + * @param [in] x1 dense device matrix view, size [n1*n_cols] + * @param [in] x2 dense device matrix view, size [n2*n_cols] + * @param [out] out dense device matrix view for the Gram matrix, size [n1*n2] + */ + void linear(raft::resources const& handle, + dense_input_matrix_view_t x1, + dense_input_matrix_view_t x2, + dense_output_matrix_view_t out) + { + // check is_row_major consistency + bool is_row_major = get_is_row_major(x1) && get_is_row_major(x2) && get_is_row_major(out); + bool is_col_major = get_is_col_major(x1) && get_is_col_major(x2) && get_is_col_major(out); + ASSERT(is_row_major || is_col_major, + "GramMatrix leading dimensions for x1, x2 and out do not match"); + + // check dimensions + int n1 = out.extent(0); + int n2 = out.extent(1); + int n_cols = x1.extent(1); + ASSERT(x1.extent(0) == n1, "GramMatrix input matrix dimensions for x1 and out do not match"); + ASSERT(x2.extent(0) == n2, "GramMatrix input matrix dimensions for x2 and out do not match"); + ASSERT(x2.extent(1) == n_cols, "GramMatrix input matrix dimensions for x1 and x2 do not match"); + + // extract major stride + int ld1 = is_row_major ? x1.stride(0) : x1.stride(1); + int ld2 = is_row_major ? x2.stride(0) : x2.stride(1); + int ld_out = is_row_major ? out.stride(0) : out.stride(1); + + math_t alpha = 1.0; + math_t beta = 0.0; + if (is_row_major) { + // #TODO: Use mdspan-based API when stride-capable + // https://github.com/rapidsai/raft/issues/875 + raft::linalg::gemm(handle, + true, + false, + n2, + n1, + n_cols, + &alpha, + x2.data_handle(), + ld2, + x1.data_handle(), + ld1, + &beta, + out.data_handle(), + ld_out, + resource::get_cuda_stream(handle)); + } else { + // #TODO: Use mdspan-based API when stride-capable + // https://github.com/rapidsai/raft/issues/875 + raft::linalg::gemm(handle, + false, + true, + n1, + n2, + n_cols, + &alpha, + x1.data_handle(), + ld1, + x2.data_handle(), + ld2, + &beta, + out.data_handle(), + ld_out, + resource::get_cuda_stream(handle)); + } + } + + /** Calculates the Gram matrix using simple dot product between vector sets. + * + * out = x1 * x2 + * + * Can be used as a building block for more complex kernel functions. + * + * @param [in] handle raft handle + * @param [in] x1 csr device matrix view, size [n1*n_cols] + * @param [in] x2 dense device matrix view, size [n2*n_cols] + * @param [out] out dense device matrix view for the Gram matrix, size [n1*n2] + */ + void linear(raft::resources const& handle, + csr_input_matrix_view_t x1, + dense_input_matrix_view_t x2, + dense_output_matrix_view_t out) + { + // check is_row_major consistency + bool is_row_major = get_is_row_major(x2) && get_is_row_major(out); + bool is_col_major = get_is_col_major(x2) && get_is_col_major(out); + ASSERT(is_row_major || is_col_major, + "GramMatrix leading dimensions for x2 and out do not match"); + + // check dimensions + auto x1_structure = x1.structure_view(); + ASSERT(x1_structure.get_n_rows() == out.extent(0), + "GramMatrix input matrix dimensions for x1 and out do not match"); + ASSERT(x2.extent(0) == out.extent(1), + "GramMatrix input matrix dimensions for x2 and out do not match"); + ASSERT(x2.extent(1) == x1_structure.get_n_cols(), + "GramMatrix input matrix dimensions for x1 and x2 do not match"); + + math_t alpha = 1.0; + math_t beta = 0.0; + + raft::sparse::linalg::spmm(handle, false, true, &alpha, x1, x2, &beta, out); + } + + /** Calculates the Gram matrix using simple dot product between vector sets. + * + * out = x1 * x2 + * + * Can be used as a building block for more complex kernel functions. + * + * @param [in] handle raft handle + * @param [in] x1 csr device matrix view, size [n1*n_cols] + * @param [in] x2 csr device matrix view, size [n2*n_cols] + * @param [out] out dense device matrix view for the Gram matrix, size [n1*n2] + */ + void linear(raft::resources const& handle, + csr_input_matrix_view_t x1, + csr_input_matrix_view_t x2, + dense_output_matrix_view_t out) + { + // check layout consistency (w.r.t. strides a matrix might be both row & col major) + bool is_row_major_nopad = get_is_row_major(out) && out.stride(0) == out.extent(1); + bool is_col_major_nopad = get_is_col_major(out) && out.stride(1) == out.extent(0); + + ASSERT(is_row_major_nopad || is_col_major_nopad, + "Sparse linear Kernel distance does not support ld_out parameter"); + + // switch a,b based on is_row_major + if (is_col_major_nopad) { + auto out_row_major = raft::make_device_matrix_view( + out.data_handle(), out.extent(1), out.extent(0)); + raft::sparse::distance::pairwise_distance( + handle, x2, x1, out_row_major, cuvs::distance::DistanceType::InnerProduct, 0.0); + } else { + auto out_row_major = raft::make_device_matrix_view( + out.data_handle(), out.extent(0), out.extent(1)); + raft::sparse::distance::pairwise_distance( + handle, x1, x2, out_row_major, cuvs::distance::DistanceType::InnerProduct, 0.0); + } + } +}; + +}; // end namespace cuvs::distance::kernels::detail diff --git a/cpp/src/distance/detail/kernels/kernel_factory.cuh b/cpp/src/distance/detail/kernels/kernel_factory.cuh new file mode 100644 index 000000000..534339a15 --- /dev/null +++ b/cpp/src/distance/detail/kernels/kernel_factory.cuh @@ -0,0 +1,65 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "gram_matrix.cuh" +#include "kernel_matrices.cuh" + +#include +#include + +namespace cuvs::distance::kernels::detail { + +template +class KernelFactory { + public: + static GramMatrixBase* create(KernelParams params) + { + GramMatrixBase* res; + // KernelParams is not templated, we convert the parameters to math_t here: + math_t coef0 = params.coef0; + math_t gamma = params.gamma; + switch (params.kernel) { + case LINEAR: res = new GramMatrixBase(); break; + case POLYNOMIAL: res = new PolynomialKernel(params.degree, gamma, coef0); break; + case TANH: res = new TanhKernel(gamma, coef0); break; + case RBF: res = new RBFKernel(gamma); break; + default: throw raft::exception("Kernel not implemented"); + } + return res; + } + + [[deprecated]] static GramMatrixBase* create(KernelParams params, cublasHandle_t handle) + { + GramMatrixBase* res; + // KernelParams is not templated, we convert the parameters to math_t here: + math_t coef0 = params.coef0; + math_t gamma = params.gamma; + switch (params.kernel) { + case LINEAR: res = new GramMatrixBase(handle); break; + case POLYNOMIAL: + res = new PolynomialKernel(params.degree, gamma, coef0, handle); + break; + case TANH: res = new TanhKernel(gamma, coef0, handle); break; + case RBF: res = new RBFKernel(gamma, handle); break; + default: throw raft::exception("Kernel not implemented"); + } + return res; + } +}; + +}; // end namespace cuvs::distance::kernels::detail diff --git a/cpp/src/distance/detail/kernels/kernel_matrices.cuh b/cpp/src/distance/detail/kernels/kernel_matrices.cuh new file mode 100644 index 000000000..bff5bda92 --- /dev/null +++ b/cpp/src/distance/detail/kernels/kernel_matrices.cuh @@ -0,0 +1,777 @@ +/* + * Copyright (c) 2019-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "gram_matrix.cuh" + +#include "../detail/kernels/rbf_fin_op.cuh" +#include +#include +#include +#include +#include + +namespace cuvs::distance::kernels::detail { + +/** Epiloge function for polynomial kernel without padding. + * Calculates output = (gain*in + offset)^exponent + * @param inout device vector in column major format, size [len] + * @param len array length + * @param exponent + * @param gain + * @param offset + */ +template +RAFT_KERNEL polynomial_kernel_nopad( + math_t* inout, size_t len, exp_t exponent, math_t gain, math_t offset) +{ + for (size_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < len; + tid += blockDim.x * gridDim.x) { + inout[tid] = pow(gain * inout[tid] + offset, exponent); + } +} + +/** Epiloge function for polynomial kernel with padding. + * Calculates output = (gain*input + offset)^exponent + * @param inout device vector in column major format, size [ld * cols] + * @param ld leading dimension of the inout buffer + * @param rows number of rows (rows <= ld) + * @param cols number of columns + * @param exponent + * @param gain + * @param offset + */ +template +RAFT_KERNEL polynomial_kernel( + math_t* inout, int ld, int rows, int cols, exp_t exponent, math_t gain, math_t offset) +{ + for (size_t tidy = threadIdx.y + blockIdx.y * blockDim.y; tidy < cols; + tidy += blockDim.y * gridDim.y) + for (size_t tidx = threadIdx.x + blockIdx.x * blockDim.x; tidx < rows; + tidx += blockDim.x * gridDim.x) { + inout[tidx + tidy * ld] = pow(gain * inout[tidx + tidy * ld] + offset, exponent); + } +} + +/** Epiloge function for tanh kernel without padding. + * Calculates output = tanh(gain*input + offset) + * @param inout device vector, size [len] + * @param len length of the input vector + * @param gain + * @param offset + */ +template +RAFT_KERNEL tanh_kernel_nopad(math_t* inout, size_t len, math_t gain, math_t offset) +{ + for (size_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < len; + tid += blockDim.x * gridDim.x) { + inout[tid] = tanh(gain * inout[tid] + offset); + } +} + +/** Epiloge function for tanh kernel without padding. + * Calculates output = tanh(gain*input + offset) + * @param inout device vector in column major format, size [ld * cols] + * @param ld leading dimension of the inout buffer + * @param rows number of rows (rows <= ld) + * @param cols number of columns + * @param gain + * @param offset + */ +template +RAFT_KERNEL tanh_kernel(math_t* inout, int ld, int rows, int cols, math_t gain, math_t offset) +{ + for (size_t tidy = threadIdx.y + blockIdx.y * blockDim.y; tidy < cols; + tidy += blockDim.y * gridDim.y) + for (size_t tidx = threadIdx.x + blockIdx.x * blockDim.x; tidx < rows; + tidx += blockDim.x * gridDim.x) { + inout[tidx + tidy * ld] = tanh(gain * inout[tidx + tidy * ld] + offset); + } +} + +/** Epiloge function for rbf kernel using expansion. + * + * Calculates output_ij = exp(-gain * (norm_x_i + norm_y_j - 2*input_ij)); + * + * Intended usage + * - input is the product of two matrices X and Y input_ij = sum_k X_ik * Y_jk + * - norm_x_i = l2_norm(x_i), where x_i is the i-th row of matrix X + * - norm_y_j = l2_norm(y_j), where y_j is the j-th row of matrix Y + * + * @param inout device vector in column major format, size [ld * cols] + * @param ld leading dimension of the inout buffer + * @param rows number of rows (rows <= ld) + * @param cols number of columns + * @param norm_x l2-norm of X's rows + * @param norm_y l2-norm of Y's rows + * @param gain + */ +template +RAFT_KERNEL rbf_kernel_expanded( + math_t* inout, int ld, int rows, int cols, math_t* norm_x, math_t* norm_y, math_t gain) +{ + for (size_t tidy = threadIdx.y + blockIdx.y * blockDim.y; tidy < cols; + tidy += blockDim.y * gridDim.y) { + math_t norm_y_val = norm_y[tidy]; + for (size_t tidx = threadIdx.x + blockIdx.x * blockDim.x; tidx < rows; + tidx += blockDim.x * gridDim.x) { + inout[tidx + tidy * ld] = + exp(-1.0 * gain * (norm_x[tidx] + norm_y_val - inout[tidx + tidy * ld] * 2)); + } + } +} + +namespace { +std::tuple generateLaunchConfig2dElementwiseOp(int n1, int n2) +{ + dim3 block_shape = dim3(32, 4); + const int num_blocks_x = raft::ceildiv(n1, 32); + const int num_blocks_y = std::min(raft::ceildiv(n2, 32), (1 << 16) - 1); + dim3 grid_shape = dim3(num_blocks_x, num_blocks_y); + return std::make_tuple(grid_shape, block_shape); +} +} // namespace + +/** + * Create a kernel matrix using polynomial kernel function. + */ +template +class PolynomialKernel : public GramMatrixBase { + exp_t exponent; + math_t gain; + math_t offset; + + void applyKernel( + math_t* inout, int ld, int rows, int cols, bool is_row_major, cudaStream_t stream) + { + const int n_minor = is_row_major ? cols : rows; + if (ld == n_minor) { + polynomial_kernel_nopad<<((size_t)rows * cols, 128), 128, 0, stream>>>( + inout, rows * cols, exponent, gain, offset); + } else { + int n1 = is_row_major ? cols : rows; + int n2 = is_row_major ? rows : cols; + auto [grid_shape, block_shape] = generateLaunchConfig2dElementwiseOp(n1, n2); + polynomial_kernel<<>>( + inout, ld, n1, n2, exponent, gain, offset); + } + RAFT_CUDA_TRY(cudaPeekAtLastError()); + } + + public: + /** + * Constructs a polynomial kernel object. + * It evaluates the kernel matrix using the following formula: + * K_ij = (gain* + offset)^exponent + * + * @tparam math_t floating point type + * @tparam exp_t type of exponent + * @param exponent + * @param gain + * @param offset + */ + PolynomialKernel(exp_t exponent, math_t gain, math_t offset) + : GramMatrixBase(), exponent(exponent), gain(gain), offset(offset) + { + } + + [[deprecated]] PolynomialKernel(exp_t exponent, math_t gain, math_t offset, cublasHandle_t handle) + : GramMatrixBase(handle), exponent(exponent), gain(gain), offset(offset) + { + } + + /** Evaluate kernel matrix using polynomial kernel. + * + * output[i,k] = (gain* + offset)^exponent, + * where x1_i is the i-th vector from the x1 set, and x2_k is k-th vector + * in the x2 set, and < , > denotes dot product. + * + * @param [in] handle raft handle + * @param [in] x1 dense device matrix view, size [n1*n_cols] + * @param [in] x2 dense device matrix view, size [n2*n_cols] + * @param [out] out dense device matrix view for the Gram matrix, size [n1*n2] + * @param norm_x1 unused. + * @param norm_x2 unused. + */ + void evaluate(raft::resources const& handle, + dense_input_matrix_view_t x1, + dense_input_matrix_view_t x2, + dense_output_matrix_view_t out, + math_t* norm_x1, + math_t* norm_x2) + { + bool is_row_major = GramMatrixBase::get_is_row_major(out); + int ld_out = is_row_major ? out.stride(0) : out.stride(1); + GramMatrixBase::linear(handle, x1, x2, out); + applyKernel(out.data_handle(), + ld_out, + out.extent(0), + out.extent(1), + is_row_major, + resource::get_cuda_stream(handle)); + } + + /** Evaluate kernel matrix using polynomial kernel. + * + * output[i,k] = (gain* + offset)^exponent, + * where x1_i is the i-th vector from the x1 set, and x2_k is k-th vector + * in the x2 set, and < , > denotes dot product. + * + * @param [in] handle raft handle + * @param [in] x1 csr device matrix view, size [n1*n_cols] + * @param [in] x2 dense device matrix view, size [n2*n_cols] + * @param [out] out dense device matrix view for the Gram matrix, size [n1*n2] + * @param norm_x1 unused. + * @param norm_x2 unused. + */ + void evaluate(raft::resources const& handle, + csr_input_matrix_view_t x1, + dense_input_matrix_view_t x2, + dense_output_matrix_view_t out, + math_t* norm_x1, + math_t* norm_x2) + { + bool is_row_major = GramMatrixBase::get_is_row_major(out); + int ld_out = is_row_major ? out.stride(0) : out.stride(1); + GramMatrixBase::linear(handle, x1, x2, out); + applyKernel(out.data_handle(), + ld_out, + out.extent(0), + out.extent(1), + is_row_major, + resource::get_cuda_stream(handle)); + } + + /** Evaluate kernel matrix using polynomial kernel. + * + * output[i,k] = (gain* + offset)^exponent, + * where x1_i is the i-th vector from the x1 set, and x2_k is k-th vector + * in the x2 set, and < , > denotes dot product. + * + * @param [in] handle raft handle + * @param [in] x1 csr device matrix view, size [n1*n_cols] + * @param [in] x2 csr device matrix view, size [n2*n_cols] + * @param [out] out dense device matrix view for the Gram matrix, size [n1*n2] + * @param norm_x1 unused. + * @param norm_x2 unused. + */ + void evaluate(raft::resources const& handle, + csr_input_matrix_view_t x1, + csr_input_matrix_view_t x2, + dense_output_matrix_view_t out, + math_t* norm_x1, + math_t* norm_x2) + { + bool is_row_major = GramMatrixBase::get_is_row_major(out); + int ld_out = is_row_major ? out.stride(0) : out.stride(1); + GramMatrixBase::linear(handle, x1, x2, out); + applyKernel(out.data_handle(), + ld_out, + out.extent(0), + out.extent(1), + is_row_major, + resource::get_cuda_stream(handle)); + } + + /** Evaluate the Gram matrix using the legacy interface. + * + * @param [in] x1 device array of vectors, size [n1*n_cols] + * @param [in] n1 number vectors in x1 + * @param [in] n_cols number of columns (features) in x1 and x2 + * @param [in] x2 device array of vectors, size [n2*n_cols] + * @param [in] n2 number vectors in x2 + * @param [out] out device buffer to store the Gram matrix, size [n1*n2] + * @param [in] is_row_major whether the input and output matrices are in row + * major format + * @param [in] stream cuda stream + * @param ld1 leading dimension of x1 (usually it is n1) + * @param ld2 leading dimension of x2 (usually it is n2) + * @param ld_out leading dimension of out (usually it is n1) + */ + [[deprecated]] void evaluate(const math_t* x1, + int n1, + int n_cols, + const math_t* x2, + int n2, + math_t* out, + bool is_row_major, + cudaStream_t stream, + int ld1, + int ld2, + int ld_out) + { + ASSERT(GramMatrixBase::legacy_interface, + "Legacy interface can only be used with legacy ctor."); + GramMatrixBase::linear( + x1, n1, n_cols, x2, n2, out, is_row_major, stream, ld1, ld2, ld_out); + applyKernel(out, ld_out, n1, n2, is_row_major, stream); + } +}; + +/** + * Create a kernel matrix using tanh kernel function. + */ +template +class TanhKernel : public GramMatrixBase { + math_t gain, offset; + + void applyKernel( + math_t* inout, int ld, int rows, int cols, bool is_row_major, cudaStream_t stream) + { + const int n_minor = is_row_major ? cols : rows; + if (ld == n_minor) { + tanh_kernel_nopad<<((size_t)rows * cols, 128), 128, 0, stream>>>( + inout, rows * cols, gain, offset); + } else { + int n1 = is_row_major ? cols : rows; + int n2 = is_row_major ? rows : cols; + auto [grid_shape, block_shape] = generateLaunchConfig2dElementwiseOp(n1, n2); + tanh_kernel<<>>(inout, ld, n1, n2, gain, offset); + } + RAFT_CUDA_TRY(cudaPeekAtLastError()); + } + + public: + /** + * Constructs a tanh kernel object. + * It evaluates the kernel matrix using the following formula: + * K_ij = tanh(gain* + offset) + * + * @tparam math_t floating point type + * @param gain + * @param offset + */ + TanhKernel(math_t gain, math_t offset) : GramMatrixBase(), gain(gain), offset(offset) {} + + [[deprecated]] TanhKernel(math_t gain, math_t offset, cublasHandle_t handle) + : GramMatrixBase(handle), gain(gain), offset(offset) + { + } + + /** Evaluate kernel matrix using tanh kernel. + * + * output_[i + k*n1] = (gain* + offset)^exponent, + * where x1_i is the i-th vector from the x1 set, and x2_k is k-th vector + * in the x2 set, and < , > denotes dot product. + * + * @param [in] handle raft handle + * @param [in] x1 dense device matrix view, size [n1*n_cols] + * @param [in] x2 dense device matrix view, size [n2*n_cols] + * @param [out] out dense device matrix view for the Gram matrix, size [n1*n2] + * @param norm_x1 unused. + * @param norm_x2 unused. + */ + void evaluate(raft::resources const& handle, + dense_input_matrix_view_t x1, + dense_input_matrix_view_t x2, + dense_output_matrix_view_t out, + math_t* norm_x1, + math_t* norm_x2) + { + bool is_row_major = GramMatrixBase::get_is_row_major(out); + int ld_out = is_row_major ? out.stride(0) : out.stride(1); + GramMatrixBase::linear(handle, x1, x2, out); + applyKernel(out.data_handle(), + ld_out, + out.extent(0), + out.extent(1), + is_row_major, + resource::get_cuda_stream(handle)); + } + + /** Evaluate kernel matrix using tanh kernel. + * + * output_[i + k*n1] = (gain* + offset)^exponent, + * where x1_i is the i-th vector from the x1 set, and x2_k is k-th vector + * in the x2 set, and < , > denotes dot product. + * + * @param [in] handle raft handle + * @param [in] x1 csr device matrix view, size [n1*n_cols] + * @param [in] x2 dense device matrix view, size [n2*n_cols] + * @param [out] out dense device matrix view for the Gram matrix, size [n1*n2] + * @param norm_x1 unused. + * @param norm_x2 unused. + */ + void evaluate(raft::resources const& handle, + csr_input_matrix_view_t x1, + dense_input_matrix_view_t x2, + dense_output_matrix_view_t out, + math_t* norm_x1, + math_t* norm_x2) + { + bool is_row_major = GramMatrixBase::get_is_row_major(out); + int ld_out = is_row_major ? out.stride(0) : out.stride(1); + GramMatrixBase::linear(handle, x1, x2, out); + applyKernel(out.data_handle(), + ld_out, + out.extent(0), + out.extent(1), + is_row_major, + resource::get_cuda_stream(handle)); + } + + /** Evaluate kernel matrix using tanh kernel. + * + * output_[i + k*n1] = (gain* + offset)^exponent, + * where x1_i is the i-th vector from the x1 set, and x2_k is k-th vector + * in the x2 set, and < , > denotes dot product. + * + * @param [in] handle raft handle + * @param [in] x1 csr device matrix view, size [n1*n_cols] + * @param [in] x2 csr device matrix view, size [n2*n_cols] + * @param [out] out dense device matrix view for the Gram matrix, size [n1*n2] + * @param norm_x1 unused. + * @param norm_x2 unused. + */ + void evaluate(raft::resources const& handle, + csr_input_matrix_view_t x1, + csr_input_matrix_view_t x2, + dense_output_matrix_view_t out, + math_t* norm_x1, + math_t* norm_x2) + { + bool is_row_major = GramMatrixBase::get_is_row_major(out); + int ld_out = is_row_major ? out.stride(0) : out.stride(1); + GramMatrixBase::linear(handle, x1, x2, out); + applyKernel(out.data_handle(), + ld_out, + out.extent(0), + out.extent(1), + is_row_major, + resource::get_cuda_stream(handle)); + } + + /** Evaluate the Gram matrix using the legacy interface. + * + * @param [in] x1 device array of vectors, size [n1*n_cols] + * @param [in] n1 number vectors in x1 + * @param [in] n_cols number of columns (features) in x1 and x2 + * @param [in] x2 device array of vectors, size [n2*n_cols] + * @param [in] n2 number vectors in x2 + * @param [out] out device buffer to store the Gram matrix, size [n1*n2] + * @param [in] is_row_major whether the input and output matrices are in row + * major format + * @param [in] stream cuda stream + * @param ld1 leading dimension of x1 (usually it is n1) + * @param ld2 leading dimension of x2 (usually it is n2) + * @param ld_out leading dimension of out (usually it is n1) + */ + [[deprecated]] void evaluate(const math_t* x1, + int n1, + int n_cols, + const math_t* x2, + int n2, + math_t* out, + bool is_row_major, + cudaStream_t stream, + int ld1, + int ld2, + int ld_out) + { + ASSERT(GramMatrixBase::legacy_interface, + "Legacy interface can only be used with legacy ctor."); + GramMatrixBase::linear( + x1, n1, n_cols, x2, n2, out, is_row_major, stream, ld1, ld2, ld_out); + applyKernel(out, ld_out, n1, n2, is_row_major, stream); + } +}; + +/** + * Create a kernel matrix using RBF kernel function. + */ +template +class RBFKernel : public GramMatrixBase { + math_t gain; + + void applyKernel(math_t* inout, + int ld, + int rows, + int cols, + math_t* norm_x1, + math_t* norm_x2, + bool is_row_major, + cudaStream_t stream) + { + int n1 = is_row_major ? cols : rows; + int n2 = is_row_major ? rows : cols; + math_t* norm_n1 = is_row_major ? norm_x2 : norm_x1; + math_t* norm_n2 = is_row_major ? norm_x1 : norm_x2; + auto [grid_shape, block_shape] = generateLaunchConfig2dElementwiseOp(n1, n2); + rbf_kernel_expanded<<>>( + inout, ld, n1, n2, norm_n1, norm_n2, gain); + } + + public: + /** + * Constructs a RBF kernel object. + * It evaluates the kernel matrix using the following formula: + * K_ij = exp(-gain*|x1_i- x2_k|^2) + * + * @tparam math_t floating point type + * @param gain + */ + RBFKernel(math_t gain) : GramMatrixBase(), gain(gain) {} + + [[deprecated]] RBFKernel(math_t gain, cublasHandle_t handle) + : GramMatrixBase(handle), gain(gain) + { + } + + void matrixRowNormL2(raft::resources const& handle, + dense_input_matrix_view_t matrix, + math_t* target) + { + bool is_row_major = GramMatrixBase::get_is_row_major(matrix); + int minor = is_row_major ? matrix.extent(1) : matrix.extent(0); + int ld = is_row_major ? matrix.stride(0) : matrix.stride(1); + ASSERT(ld == minor, "RBF Kernel lazy rowNorm compute does not support ld parameter"); + raft::linalg::rowNorm(target, + matrix.data_handle(), + matrix.extent(1), + matrix.extent(0), + raft::linalg::NormType::L2Norm, + is_row_major, + resource::get_cuda_stream(handle)); + } + + void matrixRowNormL2(raft::resources const& handle, + csr_input_matrix_view_t matrix, + math_t* target) + { + auto matrix_structure = matrix.structure_view(); + raft::sparse::linalg::rowNormCsr(handle, + matrix_structure.get_indptr().data(), + matrix.get_elements().data(), + matrix_structure.get_nnz(), + matrix_structure.get_n_rows(), + target, + raft::linalg::NormType::L2Norm); + } + + /** Evaluate kernel matrix using RBF kernel. + * + * output_[i + k*n1] = exp(-gain*|x1_i - x2_k|^2), + * where x1_i is the i-th vector from the x1 set, and x2_k is k-th vector + * in the x2 set, and | | euclidean distance. + * + * @param [in] handle raft handle + * @param [in] x1 dense device matrix view, size [n1*n_cols] + * @param [in] x2 dense device matrix view, size [n2*n_cols] + * @param [out] out dense device matrix view for the Gram matrix, size [n1*n2] + * @param norm_x1 optional L2-norm of x1's rows for computation within RBF. + * @param norm_x2 optional L2-norm of x2's rows for computation within RBF. + */ + void evaluate(raft::resources const& handle, + dense_input_matrix_view_t x1, + dense_input_matrix_view_t x2, + dense_output_matrix_view_t out, + math_t* norm_x1, + math_t* norm_x2) + { + cudaStream_t stream = resource::get_cuda_stream(handle); + // lazy compute norms if not given + rmm::device_uvector tmp_norm_x1(0, stream); + rmm::device_uvector tmp_norm_x2(0, stream); + if (norm_x1 == nullptr) { + tmp_norm_x1.reserve(x1.extent(0), stream); + norm_x1 = tmp_norm_x1.data(); + matrixRowNormL2(handle, x1, norm_x1); + } + if (norm_x2 == nullptr) { + tmp_norm_x2.reserve(x2.extent(0), stream); + norm_x2 = tmp_norm_x2.data(); + matrixRowNormL2(handle, x2, norm_x2); + } + + // compute L2expanded + bool is_row_major = GramMatrixBase::get_is_row_major(out); + int ld_out = is_row_major ? out.stride(0) : out.stride(1); + GramMatrixBase::linear(handle, x1, x2, out); + applyKernel(out.data_handle(), + ld_out, + out.extent(0), + out.extent(1), + norm_x1, + norm_x2, + is_row_major, + resource::get_cuda_stream(handle)); + } + + /** Evaluate kernel matrix using RBF kernel. + * + * output_[i + k*n1] = exp(-gain*|x1_i - x2_k|^2), + * where x1_i is the i-th vector from the x1 set, and x2_k is k-th vector + * in the x2 set, and | | euclidean distance. + * + * @param [in] handle raft handle + * @param [in] x1 csr device matrix view, size [n1*n_cols] + * @param [in] x2 dense device matrix view, size [n2*n_cols] + * @param [out] out dense device matrix view for the Gram matrix, size [n1*n2] + * @param norm_x1 optional L2-norm of x1's rows for computation within RBF. + * @param norm_x2 optional L2-norm of x2's rows for computation within RBF. + */ + void evaluate(raft::resources const& handle, + csr_input_matrix_view_t x1, + dense_input_matrix_view_t x2, + dense_output_matrix_view_t out, + math_t* norm_x1, + math_t* norm_x2) + { + cudaStream_t stream = resource::get_cuda_stream(handle); + + // lazy compute norms if not given + rmm::device_uvector tmp_norm_x1(0, stream); + rmm::device_uvector tmp_norm_x2(0, stream); + if (norm_x1 == nullptr) { + tmp_norm_x1.reserve(x1.structure_view().get_n_rows(), stream); + norm_x1 = tmp_norm_x1.data(); + matrixRowNormL2(handle, x1, norm_x1); + } + if (norm_x2 == nullptr) { + tmp_norm_x2.reserve(x2.extent(0), stream); + norm_x2 = tmp_norm_x2.data(); + matrixRowNormL2(handle, x2, norm_x2); + } + + // compute L2expanded + bool is_row_major = GramMatrixBase::get_is_row_major(out); + int ld_out = is_row_major ? out.stride(0) : out.stride(1); + GramMatrixBase::linear(handle, x1, x2, out); + applyKernel(out.data_handle(), + ld_out, + out.extent(0), + out.extent(1), + norm_x1, + norm_x2, + is_row_major, + resource::get_cuda_stream(handle)); + } + + /** Evaluate kernel matrix using RBF kernel. + * + * output_[i + k*n1] = exp(-gain*|x1_i - x2_k|^2), + * where x1_i is the i-th vector from the x1 set, and x2_k is k-th vector + * in the x2 set, and | | euclidean distance. + * + * @param [in] handle raft handle + * @param [in] x1 csr device matrix view, size [n1*n_cols] + * @param [in] x2 csr device matrix view, size [n2*n_cols] + * @param [out] out dense device matrix view for the Gram matrix, size [n1*n2] + * @param norm_x1 optional L2-norm of x1's rows for computation within RBF. + * @param norm_x2 optional L2-norm of x2's rows for computation within RBF. + */ + void evaluate(raft::resources const& handle, + csr_input_matrix_view_t x1, + csr_input_matrix_view_t x2, + dense_output_matrix_view_t out, + math_t* norm_x1, + math_t* norm_x2) + { + cudaStream_t stream = resource::get_cuda_stream(handle); + + // lazy compute norms if not given + rmm::device_uvector tmp_norm_x1(0, stream); + rmm::device_uvector tmp_norm_x2(0, stream); + if (norm_x1 == nullptr) { + tmp_norm_x1.reserve(x1.structure_view().get_n_rows(), stream); + norm_x1 = tmp_norm_x1.data(); + matrixRowNormL2(handle, x1, norm_x1); + } + if (norm_x2 == nullptr) { + tmp_norm_x2.reserve(x2.structure_view().get_n_rows(), stream); + norm_x2 = tmp_norm_x2.data(); + matrixRowNormL2(handle, x2, norm_x2); + } + + // compute L2expanded + bool is_row_major = GramMatrixBase::get_is_row_major(out); + int ld_out = is_row_major ? out.stride(0) : out.stride(1); + GramMatrixBase::linear(handle, x1, x2, out); + applyKernel(out.data_handle(), + ld_out, + out.extent(0), + out.extent(1), + norm_x1, + norm_x2, + is_row_major, + resource::get_cuda_stream(handle)); + } + + /** Evaluate the Gram matrix using the legacy interface. + * + * @param [in] x1 device array of vectors, size [n1*n_cols] + * @param [in] n1 number vectors in x1 + * @param [in] n_cols number of columns (features) in x1 and x2 + * @param [in] x2 device array of vectors, size [n2*n_cols] + * @param [in] n2 number vectors in x2 + * @param [out] out device buffer to store the Gram matrix, size [n1*n2] + * @param [in] is_row_major whether the input and output matrices are in row + * major format + * @param [in] stream cuda stream + * @param ld1 leading dimension of x1 (usually it is n1) + * @param ld2 leading dimension of x2 (usually it is n2) + * @param ld_out leading dimension of out (usually it is n1) + */ + [[deprecated]] void evaluate(const math_t* x1, + int n1, + int n_cols, + const math_t* x2, + int n2, + math_t* out, + bool is_row_major, + cudaStream_t stream, + int ld1, + int ld2, + int ld_out) + { + ASSERT(GramMatrixBase::legacy_interface, + "Legacy interface can only be used with legacy ctor."); + int minor1 = is_row_major ? n_cols : n1; + int minor2 = is_row_major ? n_cols : n2; + int minor_out = is_row_major ? n2 : n1; + ASSERT(ld1 == minor1, "RBF Kernel distance does not support ld1 parameter"); + ASSERT(ld2 == minor2, "RBF Kernel distance does not support ld2 parameter"); + ASSERT(ld_out == minor_out, "RBF Kernel distance does not support ld_out parameter"); + + math_t gain = this->gain; + using index_t = int64_t; + + rbf_fin_op fin_op{gain}; + + raft::resources handle; + resource::set_cuda_stream(handle, stream); + + cuvs::distance::distance(handle, + const_cast(x1), + const_cast(x2), + out, + n1, + n2, + n_cols, + NULL, + 0, + fin_op, + is_row_major); + } +}; + +}; // end namespace cuvs::distance::kernels::detail diff --git a/cpp/src/distance/detail/kernels/rbf_fin_op.cuh b/cpp/src/distance/detail/kernels/rbf_fin_op.cuh new file mode 100644 index 000000000..73588baea --- /dev/null +++ b/cpp/src/distance/detail/kernels/rbf_fin_op.cuh @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +/* + * This file defines rbf_fin_op, which is used in GramMatrixBase. + * + * This struct has been moved to a separate file, so that it is cheap to include + * in distance/distance-ext.cuh, where an instance of cuvs::distance::distance + * with the rbf_fin_op is instantiated. + * + */ + +#include // raft::exp +#include // HD + +namespace cuvs::distance::kernels::detail { + +/** @brief: Final op for Gram matrix with RBF kernel. + * + * Calculates output = e^(-gain * in) + * + */ +template +struct rbf_fin_op { + OutT gain; + + explicit HD rbf_fin_op(OutT gain_) noexcept : gain(gain_) {} + + template + HDI OutT operator()(OutT d_val, Args... unused_args) + { + return raft::exp(-gain * d_val); + } +}; // struct rbf_fin_op + +} // namespace cuvs::distance::kernels::detail diff --git a/cpp/src/distance/detail/masked_distance_base.cuh b/cpp/src/distance/detail/masked_distance_base.cuh new file mode 100644 index 000000000..d92052c84 --- /dev/null +++ b/cpp/src/distance/detail/masked_distance_base.cuh @@ -0,0 +1,326 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include "../pairwise_distance_base.cuh" +#include +#include + +#include + +namespace cuvs { +namespace distance { +namespace detail { + +/** + * @brief Device class for masked nearest neighbor computations. + * + * @tparam useNorms whether norms are needed + * @tparam DataT input data-type (for x and y matrices) + * @tparam AccT accumulation data-type + * @tparam IdxT index data-type + * @tparam Policy struct which tunes the Contraction kernel + * @tparam CoreLambda tells how to accumulate an x and y into + acc. its signature: + template void core_lambda(AccT& acc, + const DataT& x, const DataT& y) + * @tparam EpilogueLambda applies an elementwise function to compute final + values. Its signature is: + template void epilogue_lambda + (AccT acc[][], DataT* regxn, DataT* regyn); + * @tparam FinalLambda the final lambda called on final distance value + * @tparam rowEpilogueLambda epilog lambda that executes when a full row has + * been processed. + * + * @param[in] x input matrix + * @param[in] y input matrix + * @param[in] m number of rows of x + * @param[in] n number of columns of y + * @param[in] k number of cols of x and y + * @param[in] lda leading dimension of x + * @param[in] ldb leading dimension of y + * @param[in] ldd parameter to keep Contractions_NT happy.. + * @param[in] xn row norms of input matrix A. Required for expanded L2, cosine + * @param[in] yn row norms of input matrix B. Required for expanded L2, cosine + * @param[in] adj An adjacency matrix encoded as a bitfield indicating for each + * row of `x` and each group in `y` whether to compute the + * distance. Dim = `(m / 64) x num_groups`. + * @param[in] group_idxs An array containing the *end* indices of each group + * in `y`. The value of group_idxs[j] indicates the + * start of group j + 1, i.e., it is the inclusive + * scan of the group lengths. The first group is + * always assumed to start at index 0 and the last + * group typically ends at index `n`. Length = + * `num_groups`. + * @param[in] num_groups The number of groups in group_idxs. + * @param[in] smem shared mem buffer for intermediate storage of x, y, xn & yn. + * @param core_op the core accumulation operation lambda + * @param epilog_op the epilog operation lambda + * @param fin_op the final gemm epilogue lambda + * @param rowEpilog_op epilog lambda that executes when a full row has been processed. + */ +template > +struct MaskedDistances : public BaseClass { + private: + typedef Policy P; + const DataT* xn; + const DataT* yn; + const DataT* const yBase; + const uint64_t* adj; + const IdxT* group_idxs; + IdxT num_groups; + char* smem; + CoreLambda core_op; + EpilogueLambda epilog_op; + FinalLambda fin_op; + rowEpilogueLambda rowEpilog_op; + + AccT acc[P::AccRowsPerTh][P::AccColsPerTh]; + + public: + // Constructor + DI MaskedDistances(const DataT* _x, + const DataT* _y, + IdxT _m, + IdxT _n, + IdxT _k, + IdxT _lda, + IdxT _ldb, + IdxT _ldd, + const DataT* _xn, + const DataT* _yn, + const uint64_t* _adj, + const IdxT* _group_idxs, + IdxT _num_groups, + char* _smem, + CoreLambda _core_op, + EpilogueLambda _epilog_op, + FinalLambda _fin_op, + rowEpilogueLambda _rowEpilog_op) + : BaseClass(_x, _y, _m, _n, _k, _lda, _ldb, _ldd, _smem), + xn(_xn), + yn(_yn), + yBase(_y), + adj(_adj), + group_idxs(_group_idxs), + num_groups(_num_groups), + smem(_smem), + core_op(_core_op), + epilog_op(_epilog_op), + fin_op(_fin_op), + rowEpilog_op(_rowEpilog_op) + { + } + + DI void run() + { + const auto grid_stride_m = (P::Mblk * gridDim.y); + const auto grid_offset_m = (P::Mblk * blockIdx.y); + + const auto grid_stride_g = gridDim.x; + const auto grid_offset_g = blockIdx.x; + + for (auto tile_idx_m = grid_offset_m; tile_idx_m < this->m; tile_idx_m += grid_stride_m) { + // Start loop over groups + for (auto idx_g = grid_offset_g; idx_g < this->num_groups; idx_g += grid_stride_g) { + const uint64_t block_adj = get_block_adjacency(adj, tile_idx_m, idx_g); + // block_adj is a bitfield that contains a 1 if a row is adjacent to the + // current group. All zero means we can skip this group. + if (block_adj == 0) { continue; } + + // thread_adj is a bitfield that contains a 1 at location i iff we must + // compute row i of acc (the accumulator register tile). That is, + // for i = 0,.., AccRowsPerTh and j = 0,.., AccColsPerTh: + // + // ((1 << i) & thread_adj) > 0 <=> acc[i][j] must be computed. + // + // We precompute this information because it is used in various + // locations to skip thread-local computations, specifically: + // + // 1. To skip computations if thread_adj == 0, i.e., none of the values + // of `acc` have to be computed. + // + // 2. In epilog_op, to consider only values of `acc` to be reduced that + // are not masked of. + // + // Note 1: Even when the computation can be skipped for a specific thread, + // the thread still participates in synchronization operations. + // + // Note 2: In theory, it should be possible to skip computations for + // specific rows of `acc`. In practice, however, this does not improve + // performance. + int thread_adj = compute_thread_adjacency(block_adj); + + auto tile_idx_n = idx_g == 0 ? 0 : group_idxs[idx_g - 1]; + const auto group_end_n = group_idxs[idx_g]; + for (; tile_idx_n < group_end_n; tile_idx_n += P::Nblk) { + // We provide group_end_n to limit the number of unnecessary data + // points that are loaded from y. + this->ldgXY(tile_idx_m, tile_idx_n, 0, group_end_n); + + reset_accumulator(); + this->stsXY(); + __syncthreads(); + this->switch_write_buffer(); + + for (int kidx = P::Kblk; kidx < this->k; kidx += P::Kblk) { + this->ldgXY(tile_idx_m, tile_idx_n, kidx, group_end_n); + // Process all data in shared memory (previous k-block) and + // accumulate in registers. + if (thread_adj != 0) { accumulate(); } + this->stsXY(); + __syncthreads(); + this->switch_write_buffer(); + this->switch_read_buffer(); + } + if (thread_adj != 0) { + accumulate(); // last iteration + } + // The pre-condition for the loop over tile_idx_n is that write_buffer + // and read_buffer point to the same buffer. This flips read_buffer + // back so that it satisfies the pre-condition of this loop. + this->switch_read_buffer(); + + if (useNorms) { + DataT regxn[P::AccRowsPerTh], regyn[P::AccColsPerTh]; + load_norms(tile_idx_m, tile_idx_n, group_end_n, regxn, regyn); + if (thread_adj != 0) { + epilog_op(acc, thread_adj, regxn, regyn, tile_idx_n, tile_idx_m, group_end_n); + } + } else { + if (thread_adj != 0) { + epilog_op(acc, thread_adj, nullptr, nullptr, tile_idx_n, tile_idx_m, group_end_n); + } + } + } // tile_idx_n + } // idx_g + rowEpilog_op(tile_idx_m); + } // tile_idx_m + } + + private: + DI uint64_t get_block_adjacency(const uint64_t* adj, IdxT tile_idx_m, IdxT idx_group) + { + // A single element of `adj` contains exactly enough bits to indicate which + // rows in the current tile to skip and which to compute. + static_assert(P::Mblk == 8 * sizeof(adj[0]), + "masked_l2_nn only supports a policy with 64 rows per block."); + IdxT block_flag_idx = tile_idx_m / P::Mblk; + // Index into adj at row tile_idx_m / 64 and column idx_group. + return adj[block_flag_idx * this->num_groups + idx_group]; + } + + DI uint32_t compute_thread_adjacency(const uint64_t block_adj) + { + // thread_adj is a bitfield that contains a 1 at location i iff we must + // compute row i of acc (the accumulator register tile). It is described in + // more detail in the run() method. + uint32_t thread_adj = 0; +#pragma unroll + for (int thread_row_idx = 0; thread_row_idx < P::AccRowsPerTh; ++thread_row_idx) { + // Index `thread_row_idx` refers to a row of the current threads' register + // tile `acc`, i.e., acc[i][:]. Index `block_row_idx` refers to the + // corresponding row of the current block tile in shared memory. + const int block_row_idx = this->accrowid + thread_row_idx * P::AccThRows; + + // block_row_is_adjacent is true if the current block_row_idx is adjacent + // to the current group. + const uint64_t block_mask = 1ull << block_row_idx; + const bool block_row_is_adjacent = (block_adj & block_mask) != 0; + if (block_row_is_adjacent) { + // If block row is adjacent, write a 1 bit to thread_adj at location + // `thread_row_idx`. + const uint32_t thread_mask = 1 << thread_row_idx; + thread_adj |= thread_mask; + } + } + return thread_adj; + } + + DI void reset_accumulator() + { + // Reset accumulator registers to zero. +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { +#pragma unroll + for (int j = 0; j < P::AccColsPerTh; ++j) { + acc[i][j] = BaseClass::Zero; + } + } + } + + DI void accumulate() + { +#pragma unroll + for (int ki = 0; ki < P::Kblk; ki += P::Veclen) { + this->ldsXY(ki); +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { +#pragma unroll + for (int j = 0; j < P::AccColsPerTh; ++j) { +#pragma unroll + for (int v = 0; v < P::Veclen; ++v) { + core_op(acc[i][j], this->regx[i][v], this->regy[j][v]); + } + } + } + } + } + + DI void load_norms(IdxT tile_idx_m, + IdxT tile_idx_n, + IdxT end_n, + DataT (®xn)[P::AccRowsPerTh], + DataT (®yn)[P::AccColsPerTh]) + { + DataT* sxNorm = (DataT*)(&smem[P::SmemSize]); + DataT* syNorm = (&sxNorm[P::Mblk]); + + // Load x & y norms required by this threadblock in shmem buffer + for (int i = threadIdx.x; i < P::Mblk; i += P::Nthreads) { + auto idx = tile_idx_m + i; + sxNorm[i] = idx < this->m ? xn[idx] : 0; + } + + for (int i = threadIdx.x; i < P::Nblk; i += P::Nthreads) { + auto idx = tile_idx_n + i; + syNorm[i] = idx < end_n ? yn[idx] : 0; + } + __syncthreads(); + +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { + regxn[i] = sxNorm[i * P::AccThRows + (threadIdx.x / P::AccThCols)]; + } +#pragma unroll + for (int i = 0; i < P::AccColsPerTh; ++i) { + regyn[i] = syNorm[i * P::AccThCols + (threadIdx.x % P::AccThCols)]; + } + } +}; // struct MaskedDistances + +}; // namespace detail +}; // namespace distance +}; // namespace cuvs diff --git a/cpp/src/distance/detail/masked_nn.cuh b/cpp/src/distance/detail/masked_nn.cuh new file mode 100644 index 000000000..6520b1e2e --- /dev/null +++ b/cpp/src/distance/detail/masked_nn.cuh @@ -0,0 +1,329 @@ +/* + * Copyright (c) 2023-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "../compress_to_bits.cuh" +#include "../fused_distance_nn/fused_l2_nn.cuh" +#include "../masked_distance_base.cuh" +#include +#include +#include +#include + +#include + +#include + +#include + +namespace cuvs { +namespace distance { +namespace detail { + +template +__launch_bounds__(P::Nthreads, 2) RAFT_KERNEL masked_l2_nn_kernel(OutT* min, + const DataT* x, + const DataT* y, + const DataT* xn, + const DataT* yn, + const uint64_t* adj, + const IdxT* group_idxs, + IdxT num_groups, + IdxT m, + IdxT n, + IdxT k, + bool sqrt, + DataT maxVal, + int* mutex, + ReduceOpT redOp, + KVPReduceOpT pairRedOp, + CoreLambda core_op, + FinalLambda fin_op) +{ + extern __shared__ char smem[]; + + typedef raft::KeyValuePair KVPair; + KVPair val[P::AccRowsPerTh]; +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { + val[i] = {-1, maxVal}; + } + + // epilogue operation lambda for final value calculation + auto epilog_lambda = [pairRedOp, &val, maxVal, sqrt] __device__( + DataT acc[P::AccRowsPerTh][P::AccColsPerTh], + int thread_adj, + DataT* regxn, + DataT* regyn, + IdxT tile_idx_n, + IdxT tile_idx_m, + IdxT tile_end_n) { + KVPReduceOpT pairRed_op(pairRedOp); + +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { +#pragma unroll + for (int j = 0; j < P::AccColsPerTh; ++j) { + acc[i][j] = regxn[i] + regyn[j] - (DataT)2.0 * acc[i][j]; + } + } + if (sqrt) { +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { +#pragma unroll + for (int j = 0; j < P::AccColsPerTh; ++j) { + acc[i][j] = raft::sqrt(acc[i][j]); + } + } + } + + // intra thread reduce + const auto acccolid = threadIdx.x % P::AccThCols; + const auto accrowid = threadIdx.x / P::AccThCols; + +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { + // thread_adj is a bitfield that contains a 1 at location i iff we must + // compute row i of acc (the accumulator register tile). It is described in + // more detail in the maskedDistances.run() method. + const bool ignore = (thread_adj & (1 << i)) == 0; + if (ignore) { continue; } +#pragma unroll + for (int j = 0; j < P::AccColsPerTh; ++j) { + auto tmpkey = acccolid + j * P::AccThCols + tile_idx_n; + if (tile_end_n <= tmpkey) { + // Do not process beyond end of tile. + continue; + } + KVPair tmp = {tmpkey, acc[i][j]}; + if (tmpkey < tile_end_n) { + val[i] = pairRed_op(accrowid + i * P::AccThRows + tile_idx_m, tmp, val[i]); + } + } + } + }; + + auto rowEpilog_lambda = + [m, mutex, min, pairRedOp, redOp, &val, maxVal] __device__(IdxT tile_idx_m) { + KVPReduceOpT pairRed_op(pairRedOp); + ReduceOpT red_op(redOp); + + const auto accrowid = threadIdx.x / P::AccThCols; + const auto lid = raft::laneId(); + // reduce +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { +#pragma unroll + for (int j = P::AccThCols / 2; j > 0; j >>= 1) { + auto tmpkey = raft::shfl(val[i].key, lid + j); + auto tmpvalue = raft::shfl(val[i].value, lid + j); + KVPair tmp = {tmpkey, tmpvalue}; + val[i] = pairRed_op(accrowid + i * P::AccThRows + tile_idx_m, tmp, val[i]); + } + } + + updateReducedVal(mutex, min, val, red_op, m, tile_idx_m); + + // reset the val array. +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { + val[i] = {-1, maxVal}; + } + }; + + IdxT lda = k, ldb = k, ldd = n; + MaskedDistances + obj(x, + y, + m, + n, + k, + lda, + ldb, + ldd, + xn, + yn, + adj, + group_idxs, + num_groups, + smem, + core_op, + epilog_lambda, + fin_op, + rowEpilog_lambda); + obj.run(); +} + +/** + * @brief Wrapper for masked_l2_nn_kernel + * + * Responsibilities: + * - Allocate (and initialize) workspace memory for: + * - mutexes used in nearest neighbor update step + * - adjacency matrix bitfield + * - Compress adjacency matrix to bitfield + * - Initialize output buffer (conditional on `initOutBuffer`) + * - Specify core and final operations for the L2 norm + * - Determine optimal launch configuration for kernel. + * - Launch kernel and check for errors. + * + * @tparam DataT Input data-type (for x and y matrices). + * @tparam OutT Output data-type (for key-value pairs). + * @tparam IdxT Index data-type. + * @tparam ReduceOpT A struct to perform the final needed reduction + * operation and also to initialize the output array + * elements with the appropriate initial value needed for + * reduction. + * @tparam KVPReduceOpT Type of Reduction operation on key value pairs. + * + * @param handle RAFT handle for managing expensive resources + * @param[out] out Will contain reduced output (nn key-value pairs) + * @param[in] x First matrix. Row major. Dim = `m x k`. (on device) + * @param[in] y Second matrix. Row major. Dim = `n x k`. (on device) + * @param[in] xn L2 squared norm of `x`. Length = `m`. + * @param[in] yn L2 squared norm of `y`. Length = `n`. + * @param[in] adj A boolean adjacency matrix indicating for each + * row of `x` and each group in `y` whether to compute the + * distance. Dim = `m x num_groups`. + * @param[in] group_idxs An array containing the *end* indices of each group + * in `y`. The value of group_idxs[j] indicates the + * start of group j + 1, i.e., it is the inclusive + * scan of the group lengths. The first group is + * always assumed to start at index 0 and the last + * group typically ends at index `n`. Length = + * `num_groups`. + * @param[in] num_groups Length of `group_idxs`. + * @param m Rows of `x`. + * @param n Rows of `y`. + * @param k Cols of `x` and `y`. + * @param redOp Reduction operator in the epilogue + * @param pairRedOp Reduction operation on key value pairs + * @param sqrt Whether to compute the squared or actual (i.e. sqrt) L2 norm. + * @param initOutBuffer Whether to initialize the output buffer + * + * + */ +template +void masked_l2_nn_impl(raft::resources const& handle, + OutT* out, + const DataT* x, + const DataT* y, + const DataT* xn, + const DataT* yn, + const bool* adj, + const IdxT* group_idxs, + IdxT num_groups, + IdxT m, + IdxT n, + IdxT k, + ReduceOpT redOp, + KVPReduceOpT pairRedOp, + bool sqrt, + bool initOutBuffer) +{ + typedef typename linalg::Policy4x4::Policy P; + + static_assert(P::Mblk == 64, "masked_l2_nn_impl only supports a policy with 64 rows per block."); + + // Get stream and workspace memory resource + rmm::mr::device_memory_resource* ws_mr = + dynamic_cast(resource::get_workspace_resource(handle)); + auto stream = resource::get_cuda_stream(handle); + + // Acquire temporary buffers and initialize to zero: + // 1) Adjacency matrix bitfield + // 2) Workspace for fused nearest neighbor operation + size_t m_div_64 = raft::ceildiv(m, IdxT(64)); + rmm::device_uvector ws_adj64{m_div_64 * num_groups, stream, ws_mr}; + rmm::device_uvector ws_fused_nn{size_t(m), stream, ws_mr}; + RAFT_CUDA_TRY(cudaMemsetAsync(ws_adj64.data(), 0, ws_adj64.size() * sizeof(uint64_t), stream)); + RAFT_CUDA_TRY(cudaMemsetAsync(ws_fused_nn.data(), 0, ws_fused_nn.size() * sizeof(int), stream)); + + // Compress boolean adjacency matrix to bitfield. + auto adj_view = raft::make_device_matrix_view(adj, m, num_groups); + auto adj64_view = + raft::make_device_matrix_view(ws_adj64.data(), m_div_64, num_groups); + compress_to_bits(handle, adj_view, adj64_view); + + // Initialize output buffer with keyvalue pairs as determined by the reduction + // operator (it will be called with maxVal). + constexpr auto maxVal = std::numeric_limits::max(); + if (initOutBuffer) { + dim3 grid(raft::ceildiv(m, P::Nthreads)); + dim3 block(P::Nthreads); + + initKernel<<>>(out, m, maxVal, redOp); + RAFT_CUDA_TRY(cudaGetLastError()); + } + + // Accumulation operation lambda + auto core_lambda = [] __device__(DataT & acc, DataT & x, DataT & y) { acc += x * y; }; + auto fin_op = raft::identity_op{}; + + auto kernel = masked_l2_nn_kernel; + constexpr size_t smemSize = P::SmemSize + ((P::Mblk + P::Nblk) * sizeof(DataT)); + dim3 block(P::Nthreads); + dim3 grid = launchConfigGenerator

(m, n, smemSize, kernel); + + kernel<<>>(out, + x, + y, + xn, + yn, + ws_adj64.data(), + group_idxs, + num_groups, + m, + n, + k, + sqrt, + maxVal, + ws_fused_nn.data(), + redOp, + pairRedOp, + core_lambda, + fin_op); + + RAFT_CUDA_TRY(cudaGetLastError()); +} + +} // namespace detail +} // namespace distance +} // namespace cuvs diff --git a/cpp/src/distance/detail/pairwise_distance_base.cuh b/cpp/src/distance/detail/pairwise_distance_base.cuh new file mode 100644 index 000000000..990f845fd --- /dev/null +++ b/cpp/src/distance/detail/pairwise_distance_base.cuh @@ -0,0 +1,326 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include // raft::linalg::Contractions_NT +#include // ceildiv +#include // RAFT_CUDA_TRY + +#include // size_t + +namespace cuvs { +namespace distance { +namespace detail { + +/** + * @brief Device class for L1, L2 and cosine distance metrics. + * @tparam DataT input data-type (for A and B matrices) + * @tparam AccT accumulation data-type + * @tparam OutT output data-type (for C and D matrices) + * @tparam IdxT index data-type + * @tparam Policy struct which tunes the Contraction kernel + * @tparam OpT A distance operation, e.g., cosine_distance_op. + * @tparam EpilogueLambda applies an elementwise function to compute final + values. Its signature is: + template void epilogue_lambda + (AccT acc[][], DataT* regxn, DataT* regyn); + * @tparam FinalLambda the final lambda called on final distance value + * @param[in] x input matrix + * @param[in] y input matrix + * @param[in] m number of rows of A and C/D + * @param[in] n number of columns of B and C/D + * @param[in] k number of cols of A and rows of B + * @param[in] lda leading dimension of A + * @param[in] ldb leading dimension of B + * @param[in] ldd leading dimension of C/D + * @param[in] xn row norms of input matrix A. Required for expanded L2, cosine + * @param[in] yn row norms of input matrix B. Required for expanded L2, cosine + * @param[output] pD output matrix + * @param[in] smem shared mem buffer for intermediate storage of A, B, xn & yn. + * @param distance_op the distance operation, e.g. cosine_distance_op + * @param epilog_op the epilog operation lambda + * @param fin_op the final gemm epilogue lambda + * @param rowEpilog_op epilog lambda that executes when a full row has been processed + */ + +template > +struct PairwiseDistances : public BaseClass { + // Get accumulation type from distance_op + using AccT = typename OpT::AccT; + + private: + typedef Policy P; + const DataT* xn; + const DataT* yn; + const DataT* const yBase; + OutT* dOutput; + char* smem; + OpT distance_op; + EpilogueLambda epilog_op; + FinalLambda fin_op; + rowEpilogueLambda rowEpilog_op; + + const IdxT grid_stride_m; + const IdxT grid_stride_n; + const IdxT grid_offset_m; + const IdxT grid_offset_n; + + AccT acc[P::AccRowsPerTh][P::AccColsPerTh]; + + public: + // Constructor + DI PairwiseDistances(const DataT* _x, + const DataT* _y, + IdxT _m, + IdxT _n, + IdxT _k, + IdxT _lda, + IdxT _ldb, + IdxT _ldd, + const DataT* _xn, + const DataT* _yn, + OutT* _dOutput, + char* _smem, + OpT _distance_op, + EpilogueLambda _epilog_op, + FinalLambda _fin_op, + rowEpilogueLambda _rowEpilog_op) + : BaseClass(_x, _y, _m, _n, _k, _lda, _ldb, _ldd, _smem), + xn(_xn), + yn(_yn), + yBase(_y), + dOutput(_dOutput), + smem(_smem), + distance_op(_distance_op), + epilog_op(_epilog_op), + fin_op(_fin_op), + rowEpilog_op(_rowEpilog_op), + grid_stride_m(P::Mblk * gridDim.y), + grid_stride_n(P::Nblk * gridDim.x), + grid_offset_m(P::Mblk * blockIdx.y), + grid_offset_n(P::Nblk * blockIdx.x) + { + } + + DI void run() + { + for (auto tile_idx_m = grid_offset_m; tile_idx_m < this->m; tile_idx_m += grid_stride_m) { + this->ldgXY(tile_idx_m, grid_offset_n, 0); + for (auto tile_idx_n = grid_offset_n; tile_idx_n < this->n; tile_idx_n += grid_stride_n) { + // Prolog: + reset_accumulator(); + this->stsXY(); + __syncthreads(); + this->switch_write_buffer(); + + // Main loop: + for (int kidx = P::Kblk; kidx < this->k; kidx += P::Kblk) { + this->ldgXY(tile_idx_m, tile_idx_n, kidx); + // Process all data in shared memory (previous k-block) and + // accumulate in registers. + accumulate(); + this->stsXY(); + __syncthreads(); + this->switch_write_buffer(); + this->switch_read_buffer(); + } + accumulate(); // last iteration + // The pre-condition for the loop over tile_idx_n is that write_buffer + // and read_buffer point to the same buffer. This flips read_buffer back + // so that it satisfies the pre-condition of this loop. + this->switch_read_buffer(); + + // Epilog: + if (distance_op.use_norms) { + DataT regxn[P::AccRowsPerTh], regyn[P::AccColsPerTh]; + load_norms(tile_idx_m, tile_idx_n, regxn, regyn); + // Overlap ldg with epilog computation + ldgNextGridStride(tile_idx_m, tile_idx_n); + // Calculate distance_op epilog. + // Use .template to disambiguate (See: + // https://en.cppreference.com/w/cpp/language/dependent_name) + distance_op.template epilog(acc, regxn, regyn, tile_idx_n, tile_idx_m); + // And any possible additional epilogs + epilog_op(acc, regxn, regyn, tile_idx_n, tile_idx_m); + } else { + // Overlap ldg with epilog computation + ldgNextGridStride(tile_idx_m, tile_idx_n); + // Calculate distance_op epilog. + // Use .template to disambiguate (See: + // https://en.cppreference.com/w/cpp/language/dependent_name) + distance_op.template epilog(acc, nullptr, nullptr, tile_idx_n, tile_idx_m); + // And any possible additional epilogs + epilog_op(acc, nullptr, nullptr, tile_idx_n, tile_idx_m); + } + if (writeOut) { store_output(tile_idx_m, tile_idx_n); } + } + rowEpilog_op(tile_idx_m); + } + } + + private: + DI void ldgNextGridStride(IdxT tile_idx_m, IdxT tile_idx_n) + { + // Fetch next grid stride ldg if within range + const auto next_tile_tile_idx_n = tile_idx_n + grid_stride_n; + const auto next_tile_tile_idx_m = tile_idx_m + grid_stride_m; + if ((next_tile_tile_idx_n) < this->n) { + this->ldgXY(tile_idx_m, next_tile_tile_idx_n, 0); + } else if ((next_tile_tile_idx_m) < this->m) { + this->ldgXY(next_tile_tile_idx_m, grid_offset_n, 0); + } + } + + DI void reset_accumulator() + { + // Reset accumulator registers to zero. +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { +#pragma unroll + for (int j = 0; j < P::AccColsPerTh; ++j) { + acc[i][j] = BaseClass::Zero; + } + } + } + + DI void accumulate_reg_tile(DataT (®_x)[P::AccRowsPerTh][P::Veclen], + DataT (®_y)[P::AccColsPerTh][P::Veclen]) + { +#pragma unroll + for (int v = 0; v < P::Veclen; ++v) { +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { +#pragma unroll + for (int j = 0; j < P::AccColsPerTh; ++j) { + distance_op.core(acc[i][j], reg_x[i][v], reg_y[j][v]); + } + } + } + } + + DI void accumulate() + { + // We have a separate ldsXY and accumulate_reg_tile outside the loop body, + // so that these separated calls can be interspersed with preceding and + // following instructions, thereby hiding latency. + this->ldsXY(0); + + // If expensive inner loop, do not unroll loop. + constexpr int num_iterations = P::Kblk / P::Veclen - 1; + constexpr int unroll_count = decltype(distance_op)::expensive_inner_loop ? 1 : num_iterations; +#pragma unroll unroll_count + for (int ki = P::Veclen; ki < P::Kblk; ki += P::Veclen) { + accumulate_reg_tile(this->regx, this->regy); + this->ldsXY(ki); + } + + // Accumulate last loaded tile. + accumulate_reg_tile(this->regx, this->regy); + } + + DI void load_norms(IdxT tile_idx_m, + IdxT tile_idx_n, + DataT (®xn)[P::AccRowsPerTh], + DataT (®yn)[P::AccColsPerTh]) + { + DataT* sxNorm = (DataT*)(&smem[P::SmemSize]); + DataT* syNorm = (&sxNorm[P::Mblk]); + + // Load x & y norms required by this threadblock in shmem buffer + if (tile_idx_n == blockIdx.x * P::Nblk) { + for (int i = threadIdx.x; i < P::Mblk; i += P::Nthreads) { + auto idx = tile_idx_m + i; + sxNorm[i] = idx < this->m ? xn[idx] : 0; + } + } + + for (int i = threadIdx.x; i < P::Nblk; i += P::Nthreads) { + auto idx = tile_idx_n + i; + syNorm[i] = idx < this->n ? yn[idx] : 0; + } + __syncthreads(); + +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { + regxn[i] = sxNorm[i * P::AccThRows + (threadIdx.x / P::AccThCols)]; + } +#pragma unroll + for (int i = 0; i < P::AccColsPerTh; ++i) { + regyn[i] = syNorm[i * P::AccThCols + (threadIdx.x % P::AccThCols)]; + } + } + + DI void store_output(IdxT tile_idx_m, IdxT tile_idx_n) + { + IdxT starty = tile_idx_m + this->accrowid; + IdxT startx = tile_idx_n + this->acccolid; + +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { + auto rowId = starty + i * P::AccThRows; +#pragma unroll + for (int j = 0; j < P::AccColsPerTh; ++j) { + auto colId = startx + j * P::AccThCols; + if (rowId < this->m && colId < this->n) { + // Promote to 64 bit index for final write, as output array can be > 2^31 + dOutput[std::size_t(rowId) * this->n + colId] = fin_op(acc[i][j], 0); + } + } + } + } +}; // struct PairwiseDistances + +template +dim3 launchConfigGenerator(IdxT m, IdxT n, std::size_t sMemSize, T func) +{ + int devId; + RAFT_CUDA_TRY(cudaGetDevice(&devId)); + int numSMs; + RAFT_CUDA_TRY(cudaDeviceGetAttribute(&numSMs, cudaDevAttrMultiProcessorCount, devId)); + + int numBlocksPerSm = 0; + dim3 grid; + + RAFT_CUDA_TRY( + cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numBlocksPerSm, func, P::Nthreads, sMemSize)); + std::size_t minGridSize = numSMs * numBlocksPerSm; + std::size_t yChunks = raft::ceildiv(m, P::Mblk); + std::size_t xChunks = raft::ceildiv(n, P::Nblk); + grid.y = yChunks > minGridSize ? minGridSize : yChunks; + grid.x = (minGridSize - grid.y) <= 0 ? 1 : xChunks; + if (grid.x != 1) { + std::size_t i = 1; + while (grid.y * i < minGridSize) { + i++; + } + grid.x = i >= xChunks ? xChunks : i; + } + + return grid; +} + +}; // namespace detail +}; // namespace distance +}; // namespace cuvs diff --git a/cpp/src/distance/detail/pairwise_distance_cutlass_base.cuh b/cpp/src/distance/detail/pairwise_distance_cutlass_base.cuh new file mode 100644 index 000000000..da9a1ac4e --- /dev/null +++ b/cpp/src/distance/detail/pairwise_distance_cutlass_base.cuh @@ -0,0 +1,185 @@ +/* + * Copyright (c) 2018-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wstrict-aliasing" +#pragma GCC diagnostic ignored "-Wtautological-compare" + +// We define CUTLASS_NAMESPACE in case +// RAFT cmake is not used +#ifndef CUTLASS_NAMESPACE +#define cutlass raft_cutlass +#endif + +#include "pairwise_distance_epilogue_elementwise.h" +#include "pairwise_distance_gemm.h" + +#include "distance_ops/cutlass.cuh" +#include + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace cuvs { +namespace distance { +namespace detail { + +template +std::enable_if_t::value> cutlassDistanceKernel(const DataT* x, + const DataT* y, + const DataT* xn, + const DataT* yn, + IdxT m, + IdxT n, + IdxT k, + IdxT lda, + IdxT ldb, + IdxT ldd, + OutT* dOutput, + FinalLambda fin_op, + OpT distance_op, + cudaStream_t stream) +{ + static_assert(!(std::is_same::value), + "OutType bool is not supported use uint8_t instead"); + + auto dist_op = distance_op.get_cutlass_op(); + using DistanceFn = decltype(dist_op); + using EpilogueOutputOp = + cutlass::epilogue::thread::PairwiseDistanceEpilogueElementwise; + constexpr int batch_count = 1; + + constexpr auto mode = cutlass::gemm::GemmUniversalMode::kGemm; + + typename EpilogueOutputOp::Params epilog_op_param(dist_op, fin_op); + + // Number of pipelines you want to use + constexpr int NumStages = 3; + // Alignment + constexpr int Alignment = VecLen; + + using cutlassDistKernel = + typename cutlass::gemm::kernel::PairwiseDistanceGemm::GemmKernel; + + using cutlassDist = cutlass::gemm::device::GemmUniversalAdapter; + + constexpr uint32_t gridYZMax = ((1 << (sizeof(uint16_t) * 8)) - 1); + constexpr uint32_t max_batch_size = gridYZMax * cutlassDistKernel::ThreadblockShape::kN; + IdxT numNbatches = (n - 1 + max_batch_size) / max_batch_size; + + for (IdxT i = 0; i < numNbatches; i++) { + const DataT *a, *b; + IdxT gemm_lda, gemm_ldb; + size_t offsetN = i * max_batch_size; + + if constexpr (isRowMajor) { + gemm_lda = ldb; + gemm_ldb = lda; + a = y + offsetN * gemm_lda; + b = x; + } else { + gemm_lda = lda; + gemm_ldb = ldb; + a = x; + b = y + offsetN; + } + IdxT chunkN = (i + 1) * max_batch_size; + IdxT currentN = (chunkN < n) ? max_batch_size : (n - offsetN); + + // default initialize problem size with row major inputs + auto problem_size = isRowMajor ? cutlass::gemm::GemmCoord(currentN, m, k) + : cutlass::gemm::GemmCoord(m, currentN, k); + + typename cutlassDist::Arguments arguments{ + mode, + problem_size, + batch_count, + epilog_op_param, + a, + b, + xn, // C matrix eq vector param, which here is A norm + nullptr, // tensor_Z, + (DataT*)yn + offsetN, // this is broadcast vec, which is required to be non-const param + dOutput + offsetN, // Output distance matrix + (int64_t)0, // batch stride A + (int64_t)0, // batch stride B + (int64_t)0, // batch stride Norm A + (int64_t)0, + (int64_t)0, // batch stride Norm B + (int64_t)0, // batch stride Output + gemm_lda, // stride A + gemm_ldb, // stride B + 1, // stride A norm + 0, // this is no-op for Z + 0, // This must be zero + ldd // stride Output matrix + }; + + // Using the arguments, query for extra workspace required for matrix multiplication computation + size_t workspace_size = cutlassDist::get_workspace_size(arguments); + // Allocate workspace memory + rmm::device_uvector workspace(workspace_size, stream); + // Instantiate CUTLASS kernel depending on templates + cutlassDist cutlassDist_op; + // Check the problem size is supported or not + RAFT_CUTLASS_TRY(cutlassDist_op.can_implement(arguments)); + + // Initialize CUTLASS kernel with arguments and workspace pointer + RAFT_CUTLASS_TRY(cutlassDist_op.initialize(arguments, workspace.data(), stream)); + + // Launch initialized CUTLASS kernel + RAFT_CUTLASS_TRY(cutlassDist_op(stream)); + } +} + +}; // namespace detail +}; // namespace distance +}; // namespace cuvs + +#pragma GCC diagnostic pop diff --git a/cpp/src/distance/detail/pairwise_distance_epilogue.h b/cpp/src/distance/detail/pairwise_distance_epilogue.h new file mode 100644 index 000000000..6ead09ed1 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_distance_epilogue.h @@ -0,0 +1,100 @@ +/* + * Copyright (c) 2018-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! \file + \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. + +This is adapted from DefaultEpilogueWithBroadcastTensorOp from CUTLASS 2.9.0 +(https://github.com/NVIDIA/cutlass/blob/master/include/cutlass/epilogue/threadblock/default_epilogue_with_broadcast.h#L75) + +This epilogue allows us to load norm buffers using PredicatedTileIteratorNormVec +and EpilogueWithBroadcast used for distances L2/cosine as well as applies user-define elementwise +operation. +-- A norm load is provided PredicatedTileIteratorNormVec +-- B norm load is provided by EpilogueWithBroadcast +-- elementwise operation is provided by OutputOp +*/ + +#pragma once + +#include "./predicated_tile_iterator_normvec.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Defines sensible defaults for epilogues for TensorOps. +template +struct PairwiseDistanceEpilogue { + /// Use defaults related to the existing epilogue + using Base = + DefaultEpilogueTensorOp; + + // + // Stores the result z = (y = GEMM(A, B, C), broadcast) + // + using OutputTileIterator = cutlass::epilogue::threadblock:: + PredicatedTileIteratorNormVec; + + // + // Additional tensor tile iterator - stores t = Elementwise(z) + // + using TensorTileIterator = + cutlass::epilogue::threadblock::PredicatedTileIterator; + + /// Define the epilogue + using Epilogue = EpilogueWithBroadcast; +}; + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/cpp/src/distance/detail/pairwise_distance_epilogue_elementwise.h b/cpp/src/distance/detail/pairwise_distance_epilogue_elementwise.h new file mode 100644 index 000000000..2b2c04b9d --- /dev/null +++ b/cpp/src/distance/detail/pairwise_distance_epilogue_elementwise.h @@ -0,0 +1,171 @@ +/* + * Copyright (c) 2018-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// +/*! \file + \brief Functor performing distance operations used by epilogues of pairwise distance + * kernels. +* This is adapted from LinearCombinationBiasElementwise from CUTLASS 2.9.0 +* customized for applying elementwise distance formula on accumulated GEMM value +* and applying user-defined final custom operation on the distance value. +*/ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace thread { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// This base class is meant to define the concept required of the +/// EpilogueWithBroadcast::OutputOp +template +class PairwiseDistanceEpilogueElementwise { + public: + using ElementOutput = ElementC_; + using ElementC = ElementC_; + using ElementAccumulator = ElementAccumulator_; + using ElementCompute = ElementCompute_; + using ElementZ = ElementZ_; + using ElementT = ElementT_; + static int const kElementsPerAccess = ElementsPerAccess; + static int const kCount = kElementsPerAccess; + + using DistanceOp = DistanceOp_; + using FinalOp = FinalOp_; + + using FragmentAccumulator = Array; + using FragmentCompute = Array; + using FragmentC = Array; + using FragmentZ = Array; + using FragmentT = Array; + + using FragmentOutput = FragmentZ; + + static bool const kIsHeavy = false; // ElementwiseOp::kIsHeavy; + + /// If true, the 'Z' tensor is stored + static bool const kStoreZ = false; // We don't store anything in Z, + + /// If true, the 'T' tensor is stored + static bool const kStoreT = true; // this is our final output storage. + + /// Host-constructable parameters structure + struct Params { + FinalOp_ final_op_; + DistanceOp_ dist_op_; + + // + // Methods + // + CUTLASS_HOST_DEVICE + Params(DistanceOp_ dist_op, FinalOp final_op) : final_op_(final_op), dist_op_(dist_op) {} + + CUTLASS_HOST_DEVICE + Params() {} + }; + + private: + // + // Data members + // + FinalOp_ final_op; + DistanceOp_ elementwise_op; + + public: + // + // Methods + // + + /// Constructor from Params + CUTLASS_HOST_DEVICE + PairwiseDistanceEpilogueElementwise(Params const& params) + : final_op(params.final_op_), elementwise_op(params.dist_op_) + { + } + + /// Returns true if source is needed + CUTLASS_HOST_DEVICE + bool is_source_needed() const + { + // we use for making sure C matrix path is used for A mat norm. + return true; + } + + /// Functionally required for serial reduction in the epilogue + CUTLASS_HOST_DEVICE + void set_k_partition(int k_partition, int k_partition_count) {} + + /// Applies the operation when is_source_needed() is true + CUTLASS_HOST_DEVICE + void operator()(FragmentZ& frag_Z, + FragmentT& frag_T, + FragmentAccumulator const& AB, + FragmentC const& frag_C, + FragmentCompute const& V) const + { + FragmentCompute tmp_Accum = + NumericArrayConverter()(AB); + FragmentCompute tmp_C = + NumericArrayConverter()(frag_C); + FragmentCompute result_Z; + FragmentCompute result_T; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kElementsPerAccess; ++i) { + result_Z[i] = elementwise_op(tmp_C[i], V[i], tmp_Accum[i]); + result_T[i] = final_op(result_Z[i], 0); + } + + NumericArrayConverter convert_t; + frag_T = convert_t(result_T); + } + + /// Applies the operation when is_source_needed() is false + CUTLASS_HOST_DEVICE + void operator()(FragmentZ& frag_Z, + FragmentT& frag_T, + FragmentAccumulator const& AB, + FragmentCompute const& V) const + { + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace thread +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/cpp/src/distance/detail/pairwise_distance_gemm.h b/cpp/src/distance/detail/pairwise_distance_gemm.h new file mode 100644 index 000000000..aaf2689da --- /dev/null +++ b/cpp/src/distance/detail/pairwise_distance_gemm.h @@ -0,0 +1,238 @@ +/* + * Copyright (c) 2018-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "./pairwise_distance_epilogue.h" + +#include +#include +#include +#include +#include + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + int kAlignmentB, + /// Element type for C and D matrix operands + typename ElementC_, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Element type for final output + // typename ElementOutT, + /// Epilogue output operator - must satisfy concept of 'EpilogueWithBroadcastOp' + typename EpilogueOutputOp, + /// Number of stages used in the pipelined mainloop + int Stages, + /// data layout row/column major of inputs + bool isRowMajor> +struct PairwiseDistanceGemm { + // This struct is specialized for fp32/3xTF32 + + /// Threadblock-level tile size (concept: GemmShape) + using ThreadblockShape = + cutlass::gemm::GemmShape<128, 128, 16>; // <- threadblock tile M = 128, N = 128, K = 16 + /// Warp-level tile size (concept: GemmShape) + // This code section describes tile size a warp will compute + using WarpShape = cutlass::gemm::GemmShape<64, 64, 16>; // <- warp tile M = 64, N = 64, K = 16 + /// Warp-level tile size (concept: GemmShape) + // This code section describes the size of MMA op + using InstructionShape = + cutlass::gemm::GemmShape<16, 8, 4>; // <- MMA Op tile M = 16, N = 8, K = 4 + + /// Operation performed by GEMM + using Operator = cutlass::arch::OpMultiplyAddFastF32; + + // This code section describes whether you want to use tensor cores or regular SIMT cores on GPU + // SM + using OperatorClass = cutlass::arch::OpClassTensorOp; + + // This code section describes CUDA SM architecture number + using ArchTag = cutlass::arch::Sm80; + + // This code section describes how threadblocks are scheduled on GPU + /// Threadblock-level swizzling operator + using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; + + /// data layout for final output matrix. + // we keep this same layout even for column major inputs + using LayoutOutput = cutlass::layout::RowMajor; + + typedef typename std::conditional::type NormXLayout; + + typedef typename std:: + conditional::type LayoutA_; + + typedef typename std:: + conditional::type LayoutB_; + + using GemmBase = typename DefaultGemmUniversal::GemmKernel; + + // Replace epilogue + using Epilogue = typename cutlass::epilogue::threadblock::PairwiseDistanceEpilogue< + typename GemmBase::Epilogue::Shape, + typename GemmBase::Epilogue::WarpMmaOperator, + GemmBase::Epilogue::kPartitionsK, + ElementAccumulator, + typename EpilogueOutputOp::ElementT, + ElementAccumulator, + EpilogueOutputOp, + NormXLayout, + GemmBase::Epilogue::kElementsPerAccess>::Epilogue; + + // Compose the GEMM kernel + using GemmKernel = GemmWithFusedEpilogue; +}; + +template < + /// Layout type for A matrix operand + int kAlignmentA, + /// Layout type for B matrix operand + int kAlignmentB, + /// Element type for C and D matrix operands + typename ElementC_, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Epilogue output operator - must satisfy concept of 'EpilogueWithBroadcastOp' + typename EpilogueOutputOp, + /// Number of stages used in the pipelined mainloop + int Stages, + /// data layout row/column major of inputs + bool isRowMajor> +struct PairwiseDistanceGemm { + // using Transform = cutlass::ComplexTransform::kNone; + // Threadblock-level tile size (concept: GemmShape) + using ThreadblockShape = + cutlass::gemm::GemmShape<64, 64, 16>; // <- threadblock tile M = 64, N = 64, K = 16 + /// Warp-level tile size (concept: GemmShape) + // This code section describes tile size a warp will compute + using WarpShape = cutlass::gemm::GemmShape<32, 32, 16>; // <- warp tile M = 32, N = 32, K = 16 + /// Warp-level tile size (concept: GemmShape) + // This code section describes the size of MMA op + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; + + // Operation performed by GEMM + using Operator = cutlass::arch::OpMultiplyAdd; + // This code section describes whether you want to use tensor cores or regular SIMT cores on GPU + // SM + using OperatorClass = cutlass::arch::OpClassTensorOp; + + // This code section describes CUDA SM architecture number + using ArchTag = cutlass::arch::Sm80; + + // This code section describes how threadblocks are scheduled on GPU + /// Threadblock-level swizzling operator + using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; + + /// data layout for final output matrix. + // we keep this same layout even for column major inputs + using LayoutOutput = cutlass::layout::RowMajor; + + typedef typename std::conditional::type NormXLayout; + + typedef typename std:: + conditional::type LayoutA_; + + typedef typename std:: + conditional::type LayoutB_; + + using GemmBase = typename DefaultGemmUniversal::GemmKernel; + + // Replace epilogue + using Epilogue = typename cutlass::epilogue::threadblock::PairwiseDistanceEpilogue< + typename GemmBase::Epilogue::Shape, + typename GemmBase::Epilogue::WarpMmaOperator, + GemmBase::Epilogue::kPartitionsK, + ElementC_, + typename EpilogueOutputOp::ElementT, + ElementC_, + EpilogueOutputOp, + NormXLayout, + GemmBase::Epilogue::kElementsPerAccess>::Epilogue; + + // Compose the GEMM kernel + using GemmKernel = GemmWithFusedEpilogue; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass \ No newline at end of file diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch-ext.cuh b/cpp/src/distance/detail/pairwise_matrix/dispatch-ext.cuh new file mode 100644 index 000000000..3e8402f1f --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch-ext.cuh @@ -0,0 +1,194 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "../distance_ops/all_ops.cuh" // ops::* +#include "../distance_ops/cutlass.cuh" // ops::has_cutlass_op +#include "../kernels/rbf_fin_op.cuh" // rbf_fin_op +#include "../pairwise_matrix/params.cuh" // pairwise_matrix_params +#include // raft::identity_op +#include // RAFT_EXPLICIT + +#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY + +namespace cuvs::distance::detail { + +template +void pairwise_matrix_dispatch(OpT distance_op, + IdxT m, + IdxT n, + IdxT k, + const DataT* x, + const DataT* y, + const DataT* x_norm, + const DataT* y_norm, + OutT* out, + FinOpT fin_op, + cudaStream_t stream, + bool is_row_major) RAFT_EXPLICIT; + +}; // namespace cuvs::distance::detail + +#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY + +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + extern template void cuvs::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +/* + * Hierarchy of instantiations: + * + * This file defines extern template instantiations of the distance kernels. The + * instantiation of the public API is handled in raft/distance/distance-ext.cuh. + * + * After adding an instance here, make sure to also add the instance there. + */ + +// The following two instances are used in the RBF kernel object. Note the use of int64_t for the +// index type. +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + cuvs::distance::detail::ops::l2_unexp_distance_op, + float, + float, + float, + cuvs::distance::kernels::detail::rbf_fin_op, + int64_t); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + cuvs::distance::detail::ops::l2_unexp_distance_op, + double, + double, + double, + cuvs::distance::kernels::detail::rbf_fin_op, + int64_t); + +// Rest of instances +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + cuvs::distance::detail::ops::canberra_distance_op, float, float, float, raft::identity_op, int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + cuvs::distance::detail::ops::canberra_distance_op, + double, + double, + double, + raft::identity_op, + int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + cuvs::distance::detail::ops::correlation_distance_op, + float, + float, + float, + raft::identity_op, + int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + cuvs::distance::detail::ops::correlation_distance_op, + double, + double, + double, + raft::identity_op, + int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + cuvs::distance::detail::ops::cosine_distance_op, float, float, float, raft::identity_op, int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + cuvs::distance::detail::ops::cosine_distance_op, double, double, double, raft::identity_op, int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + cuvs::distance::detail::ops::hamming_distance_op, float, float, float, raft::identity_op, int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + cuvs::distance::detail::ops::hamming_distance_op, double, double, double, raft::identity_op, int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + cuvs::distance::detail::ops::hellinger_distance_op, float, float, float, raft::identity_op, int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + cuvs::distance::detail::ops::hellinger_distance_op, + double, + double, + double, + raft::identity_op, + int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + cuvs::distance::detail::ops::jensen_shannon_distance_op, + float, + float, + float, + raft::identity_op, + int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + cuvs::distance::detail::ops::jensen_shannon_distance_op, + double, + double, + double, + raft::identity_op, + int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + cuvs::distance::detail::ops::kl_divergence_op, float, float, float, raft::identity_op, int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + cuvs::distance::detail::ops::kl_divergence_op, double, double, double, raft::identity_op, int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + cuvs::distance::detail::ops::l1_distance_op, float, float, float, raft::identity_op, int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + cuvs::distance::detail::ops::l1_distance_op, double, double, double, raft::identity_op, int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + cuvs::distance::detail::ops::l2_exp_distance_op, float, float, float, raft::identity_op, int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + cuvs::distance::detail::ops::l2_exp_distance_op, double, double, double, raft::identity_op, int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + cuvs::distance::detail::ops::l2_unexp_distance_op, float, float, float, raft::identity_op, int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + cuvs::distance::detail::ops::l2_unexp_distance_op, + double, + double, + double, + raft::identity_op, + int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + cuvs::distance::detail::ops::l_inf_distance_op, float, float, float, raft::identity_op, int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + cuvs::distance::detail::ops::l_inf_distance_op, double, double, double, raft::identity_op, int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + cuvs::distance::detail::ops::lp_unexp_distance_op, float, float, float, raft::identity_op, int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + cuvs::distance::detail::ops::lp_unexp_distance_op, + double, + double, + double, + raft::identity_op, + int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + cuvs::distance::detail::ops::russel_rao_distance_op, float, float, float, raft::identity_op, int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + cuvs::distance::detail::ops::russel_rao_distance_op, + double, + double, + double, + raft::identity_op, + int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch-inl.cuh b/cpp/src/distance/detail/pairwise_matrix/dispatch-inl.cuh new file mode 100644 index 000000000..e64e9e5d7 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch-inl.cuh @@ -0,0 +1,127 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +/* This file has two responsibilities: + * + * 1. Dispatch to the correct implementation of a kernel based on the + * architecture of the device on which the kernel will be launched. For + * instance, the cosine distance has a CUTLASS-based implementation that can + * be used on SM80+ and the normal implementation that is used on older + * architectures. + * + * 2. Provide concise function templates that can be instantiated in + * src/distance/detail/pairwise_matrix/. Previously, + * cuvs::distance::detail::distance was instantiated. The function + * necessarily required a large set of include files, which slowed down the + * build. The cuvs::distance::detail::pairwise_matrix_arch_dispatch functions + * do not require as large an include files set, which speeds up the build. + */ + +#include "../distance_ops/cutlass.cuh" // ops::has_cutlass_op +#include "../pairwise_matrix/dispatch_sm60.cuh" // dispatch_sm60 +#include "../pairwise_matrix/params.cuh" // pairwise_matrix_params +#include // raft::util::arch::SM_* + +// NOTE: to minimize compile times, we do not include dispatch_sm80.cuh. +// Including dispatch_sm80.cuh can slow down compile times (due to CUTLASS). +// Therefore, it is the including file's responsibility to include the correct +// dispatch_smXX.cuh headers, as is done in raft/distance/detail/distance.cuh +// and src/distance/detail/pairwise_matrix/dispatch_*.cu. + +namespace cuvs::distance::detail { + +// This forward-declaration ensures that we do not need to include +// dispatch_sm80.cuh if we are not calling it in practice. This makes compiling +// all the non-CUTLASS based distance instantiations faster. For CUTLASS-based +// distances, dispatch_sm80.cuh has to be included by the file including this +// file. +template +void pairwise_matrix_sm80_dispatch(OpT, + pairwise_matrix_params, + SM_compat_t, + cudaStream_t); + +template +void pairwise_matrix_dispatch(OpT distance_op, + IdxT m, + IdxT n, + IdxT k, + const DataT* x, + const DataT* y, + const DataT* x_norm, + const DataT* y_norm, + OutT* out, + FinOpT fin_op, + cudaStream_t stream, + bool is_row_major) +{ + // Create kernel parameter struct. Flip x and y if column major. + IdxT ldx = is_row_major ? k : m; + IdxT ldy = is_row_major ? k : n; + IdxT ld_out = is_row_major ? n : m; + + pairwise_matrix_params params{ + m, n, k, ldx, ldy, ld_out, x, y, x_norm, y_norm, out, fin_op, is_row_major}; + + if (!params.is_row_major) { params.flip_x_and_y(); } + + // Dispatch rule: + // - execute CUTLASS-based kernel on SM_80 and above + // - execute normal kernel below SM_80 + namespace arch = raft::util::arch; + + constexpr bool cutlass_op_unavailable = !ops::has_cutlass_op(); + + if constexpr (cutlass_op_unavailable) { + // Always execute legacy kernels when no cutlass op is available + auto any_range = arch::SM_range(arch::SM_min(), arch::SM_future()); + pairwise_matrix_sm60_dispatch(distance_op, params, any_range, stream); + } else { + auto cutlass_range = arch::SM_range(arch::SM_80(), arch::SM_future()); + auto legacy_range = arch::SM_range(arch::SM_min(), arch::SM_80()); + + // Get pointer to SM60 kernel to determine the best compute architecture + // out of all for which the kernel was compiled for that matches closely + // to the current device. Other methods to determine the architecture (that do not + // require a pointer) can be error prone. See: + // https://github.com/NVIDIA/cub/issues/545 + auto sm60_wrapper = pairwise_matrix_sm60_get_wrapper(distance_op, params, legacy_range); + void* kernel_ptr = reinterpret_cast(sm60_wrapper.kernel_ptr); + auto runtime_arch = arch::kernel_virtual_arch(kernel_ptr); + + if (cutlass_range.contains(runtime_arch)) { + // If device is SM_80 or later, use CUTLASS-based kernel. + pairwise_matrix_sm80_dispatch(distance_op, params, cutlass_range, stream); + } else { + // Reuse kernel wrapper that we obtained above. This avoids performing the + // dispatch twice. + sm60_wrapper.launch(distance_op, params, stream); + } + } +} + +}; // namespace cuvs::distance::detail diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch.cuh b/cpp/src/distance/detail/pairwise_matrix/dispatch.cuh new file mode 100644 index 000000000..4a52b7ebe --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch.cuh @@ -0,0 +1,24 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#ifndef RAFT_EXPLICIT_INSTANTIATE_ONLY +#include "dispatch-inl.cuh" +#endif + +#ifdef RAFT_COMPILED +#include "dispatch-ext.cuh" +#endif diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_00_generate.py b/cpp/src/distance/detail/pairwise_matrix/dispatch_00_generate.py new file mode 100644 index 000000000..accd8de9b --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_00_generate.py @@ -0,0 +1,194 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# NOTE: this template is not perfectly formatted. Use pre-commit to get +# everything in shape again. +header = """/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include // raft::identity_op +#include // ops::* +#include // dispatch +""" + + +macro = """ +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \\ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \\ + template void cuvs::distance::detail:: \\ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \\ + OpT distance_op, \\ + IdxT m, \\ + IdxT n, \\ + IdxT k, \\ + const DataT* x, \\ + const DataT* y, \\ + const DataT* x_norm, \\ + const DataT* y_norm, \\ + OutT* out, \\ + FinOpT fin_op, \\ + cudaStream_t stream, \\ + bool is_row_major) +""" + +data_type_instances = [ + dict( + DataT="float", + AccT="float", + OutT="float", + IdxT="int", + ), + dict( + DataT="double", + AccT="double", + OutT="double", + IdxT="int", + ), +] + +op_instances = [ + dict( + path_prefix="canberra", + OpT="cuvs::distance::detail::ops::canberra_distance_op", + archs = [60], + ), + dict( + path_prefix="correlation", + OpT="cuvs::distance::detail::ops::correlation_distance_op", + archs = [60], + ), + dict( + path_prefix="cosine", + OpT="cuvs::distance::detail::ops::cosine_distance_op", + archs = [60, 80], + ), + dict( + path_prefix="hamming_unexpanded", + OpT="cuvs::distance::detail::ops::hamming_distance_op", + archs = [60], + ), + dict( + path_prefix="hellinger_expanded", + OpT="cuvs::distance::detail::ops::hellinger_distance_op", + archs = [60], + ), + # inner product is handled by cublas. + dict( + path_prefix="jensen_shannon", + OpT="cuvs::distance::detail::ops::jensen_shannon_distance_op", + archs = [60], + ), + dict( + path_prefix="kl_divergence", + OpT="cuvs::distance::detail::ops::kl_divergence_op", + archs = [60], + ), + dict( + path_prefix="l1", + OpT="cuvs::distance::detail::ops::l1_distance_op", + archs = [60], + ), + dict( + path_prefix="l2_expanded", + OpT="cuvs::distance::detail::ops::l2_exp_distance_op", + archs = [60, 80], + ), + dict( + path_prefix="l2_unexpanded", + OpT="cuvs::distance::detail::ops::l2_unexp_distance_op", + archs = [60], + ), + dict( + path_prefix="l_inf", + OpT="cuvs::distance::detail::ops::l_inf_distance_op", + archs = [60], + ), + dict( + path_prefix="lp_unexpanded", + OpT="cuvs::distance::detail::ops::lp_unexp_distance_op", + archs = [60], + ), + dict( + path_prefix="russel_rao", + OpT="cuvs::distance::detail::ops::russel_rao_distance_op", + archs = [60], + ), +] + +def arch_headers(archs): + include_headers ="\n".join([ + f"#include " + for arch in archs + ]) + return include_headers + + + +for op in op_instances: + for dt in data_type_instances: + DataT, AccT, OutT, IdxT = (dt[k] for k in ["DataT", "AccT", "OutT", "IdxT"]); + path = f"dispatch_{op['path_prefix']}_{DataT}_{AccT}_{OutT}_{IdxT}.cu" + with open(path, "w") as f: + f.write(header) + f.write(arch_headers(op["archs"])) + f.write(macro) + + OpT = op['OpT'] + FinOpT = "raft::identity_op" + f.write(f"\ninstantiate_raft_distance_detail_pairwise_matrix_dispatch({OpT}, {DataT}, {AccT}, {OutT}, {FinOpT}, {IdxT});\n") + f.write("\n#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch\n") + print(f"src/distance/detail/pairwise_matrix/{path}") + +# Dispatch kernels for with the RBF fin op. +with open("dispatch_rbf.cu", "w") as f: + OpT="cuvs::distance::detail::ops::l2_unexp_distance_op" + archs = [60] + + f.write(header) + f.write("#include // rbf_fin_op\n") + f.write(arch_headers(archs)) + f.write(macro) + + for dt in data_type_instances: + DataT, AccT, OutT, IdxT = (dt[k] for k in ["DataT", "AccT", "OutT", "IdxT"]); + IdxT = "int64_t" # overwrite IdxT + + FinOpT = f"cuvs::distance::kernels::detail::rbf_fin_op<{DataT}>" + f.write(f"\ninstantiate_raft_distance_detail_pairwise_matrix_dispatch({OpT}, {DataT}, {AccT}, {OutT}, {FinOpT}, {IdxT});\n") + + f.write("\n#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch\n") + +print("src/distance/detail/pairwise_matrix/dispatch_rbf.cu") diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_canberra_double_double_double_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_canberra_double_double_double_int.cu new file mode 100644 index 000000000..f82df6cc0 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_canberra_double_double_double_int.cu @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include "../distance_ops/all_ops.cuh" // ops::* +#include "dispatch-inl.cuh" // dispatch +#include "dispatch_sm60.cuh" +#include // raft::identity_op +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void cuvs::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + cuvs::distance::detail::ops::canberra_distance_op, + double, + double, + double, + raft::identity_op, + int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_canberra_float_float_float_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_canberra_float_float_float_int.cu new file mode 100644 index 000000000..a20ca5f47 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_canberra_float_float_float_int.cu @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include "../distance_ops/all_ops.cuh" // ops::* +#include "dispatch-inl.cuh" // dispatch +#include "dispatch_sm60.cuh" +#include // raft::identity_op +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void cuvs::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + cuvs::distance::detail::ops::canberra_distance_op, float, float, float, raft::identity_op, int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_correlation_double_double_double_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_correlation_double_double_double_int.cu new file mode 100644 index 000000000..7bb7e4a96 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_correlation_double_double_double_int.cu @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include "../distance_ops/all_ops.cuh" // ops::* +#include "dispatch-inl.cuh" // dispatch +#include "dispatch_sm60.cuh" +#include // raft::identity_op +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void cuvs::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + cuvs::distance::detail::ops::correlation_distance_op, + double, + double, + double, + raft::identity_op, + int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_correlation_float_float_float_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_correlation_float_float_float_int.cu new file mode 100644 index 000000000..34fcc4be4 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_correlation_float_float_float_int.cu @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include "../distance_ops/all_ops.cuh" // ops::* +#include "dispatch-inl.cuh" // dispatch +#include "dispatch_sm60.cuh" +#include // raft::identity_op +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void cuvs::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + cuvs::distance::detail::ops::correlation_distance_op, + float, + float, + float, + raft::identity_op, + int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_cosine_double_double_double_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_cosine_double_double_double_int.cu new file mode 100644 index 000000000..cb23743c1 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_cosine_double_double_double_int.cu @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include "../distance_ops/all_ops.cuh" // ops::* +#include "dispatch-inl.cuh" // dispatch +#include "dispatch_sm60.cuh" +#include "dispatch_sm80.cuh" +#include // raft::identity_op +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void cuvs::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + cuvs::distance::detail::ops::cosine_distance_op, double, double, double, raft::identity_op, int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_cosine_float_float_float_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_cosine_float_float_float_int.cu new file mode 100644 index 000000000..ad71ff295 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_cosine_float_float_float_int.cu @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include "../distance_ops/all_ops.cuh" // ops::* +#include "dispatch-inl.cuh" // dispatch +#include "dispatch_sm60.cuh" +#include "dispatch_sm80.cuh" +#include // raft::identity_op +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void cuvs::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + cuvs::distance::detail::ops::cosine_distance_op, float, float, float, raft::identity_op, int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_hamming_unexpanded_double_double_double_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_hamming_unexpanded_double_double_double_int.cu new file mode 100644 index 000000000..e81d54411 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_hamming_unexpanded_double_double_double_int.cu @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include "../distance_ops/all_ops.cuh" // ops::* +#include "dispatch-inl.cuh" // dispatch +#include "dispatch_sm60.cuh" +#include // raft::identity_op +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void cuvs::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + cuvs::distance::detail::ops::hamming_distance_op, double, double, double, raft::identity_op, int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_hamming_unexpanded_float_float_float_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_hamming_unexpanded_float_float_float_int.cu new file mode 100644 index 000000000..ddbdab602 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_hamming_unexpanded_float_float_float_int.cu @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include "../distance_ops/all_ops.cuh" // ops::* +#include "dispatch-inl.cuh" // dispatch +#include "dispatch_sm60.cuh" +#include // raft::identity_op +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void cuvs::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + cuvs::distance::detail::ops::hamming_distance_op, float, float, float, raft::identity_op, int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_hellinger_expanded_double_double_double_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_hellinger_expanded_double_double_double_int.cu new file mode 100644 index 000000000..d2acecaf0 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_hellinger_expanded_double_double_double_int.cu @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include "../distance_ops/all_ops.cuh" // ops::* +#include "dispatch-inl.cuh" // dispatch +#include "dispatch_sm60.cuh" +#include // raft::identity_op +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void cuvs::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + cuvs::distance::detail::ops::hellinger_distance_op, + double, + double, + double, + raft::identity_op, + int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_hellinger_expanded_float_float_float_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_hellinger_expanded_float_float_float_int.cu new file mode 100644 index 000000000..034d76679 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_hellinger_expanded_float_float_float_int.cu @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include "../distance_ops/all_ops.cuh" // ops::* +#include "dispatch-inl.cuh" // dispatch +#include "dispatch_sm60.cuh" +#include // raft::identity_op +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void cuvs::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + cuvs::distance::detail::ops::hellinger_distance_op, float, float, float, raft::identity_op, int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_jensen_shannon_double_double_double_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_jensen_shannon_double_double_double_int.cu new file mode 100644 index 000000000..030faeecd --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_jensen_shannon_double_double_double_int.cu @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include "../distance_ops/all_ops.cuh" // ops::* +#include "dispatch-inl.cuh" // dispatch +#include "dispatch_sm60.cuh" +#include // raft::identity_op +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void cuvs::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + cuvs::distance::detail::ops::jensen_shannon_distance_op, + double, + double, + double, + raft::identity_op, + int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_jensen_shannon_float_float_float_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_jensen_shannon_float_float_float_int.cu new file mode 100644 index 000000000..f7551a566 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_jensen_shannon_float_float_float_int.cu @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include "../distance_ops/all_ops.cuh" // ops::* +#include "dispatch-inl.cuh" // dispatch +#include "dispatch_sm60.cuh" +#include // raft::identity_op +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void cuvs::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + cuvs::distance::detail::ops::jensen_shannon_distance_op, + float, + float, + float, + raft::identity_op, + int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_kl_divergence_double_double_double_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_kl_divergence_double_double_double_int.cu new file mode 100644 index 000000000..6640d3949 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_kl_divergence_double_double_double_int.cu @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include "../distance_ops/all_ops.cuh" // ops::* +#include "dispatch-inl.cuh" // dispatch +#include "dispatch_sm60.cuh" +#include // raft::identity_op +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void cuvs::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + cuvs::distance::detail::ops::kl_divergence_op, double, double, double, raft::identity_op, int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_kl_divergence_float_float_float_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_kl_divergence_float_float_float_int.cu new file mode 100644 index 000000000..60cafa474 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_kl_divergence_float_float_float_int.cu @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include "../distance_ops/all_ops.cuh" // ops::* +#include "dispatch-inl.cuh" // dispatch +#include "dispatch_sm60.cuh" +#include // raft::identity_op +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void cuvs::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + cuvs::distance::detail::ops::kl_divergence_op, float, float, float, raft::identity_op, int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_l1_double_double_double_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_l1_double_double_double_int.cu new file mode 100644 index 000000000..b5e8a2f68 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_l1_double_double_double_int.cu @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include "../distance_ops/all_ops.cuh" // ops::* +#include "dispatch-inl.cuh" // dispatch +#include "dispatch_sm60.cuh" +#include // raft::identity_op + +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void cuvs::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + cuvs::distance::detail::ops::l1_distance_op, double, double, double, raft::identity_op, int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_l1_float_float_float_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_l1_float_float_float_int.cu new file mode 100644 index 000000000..73868a486 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_l1_float_float_float_int.cu @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include "../distance_ops/all_ops.cuh" // ops::* +#include "dispatch-inl.cuh" // dispatch +#include "dispatch_sm60.cuh" +#include // raft::identity_op +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void cuvs::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + cuvs::distance::detail::ops::l1_distance_op, float, float, float, raft::identity_op, int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_l2_expanded_double_double_double_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_l2_expanded_double_double_double_int.cu new file mode 100644 index 000000000..8ac80b77d --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_l2_expanded_double_double_double_int.cu @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include "../distance_ops/all_ops.cuh" // ops::* +#include "dispatch-inl.cuh" // dispatch +#include "dispatch_sm60.cuh" +#include "dispatch_sm80.cuh" +#include // raft::identity_op +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void cuvs::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + cuvs::distance::detail::ops::l2_exp_distance_op, double, double, double, raft::identity_op, int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_l2_expanded_float_float_float_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_l2_expanded_float_float_float_int.cu new file mode 100644 index 000000000..abebb9121 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_l2_expanded_float_float_float_int.cu @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include "../distance_ops/all_ops.cuh" // ops::* +#include "dispatch-inl.cuh" // dispatch +#include "dispatch_sm60.cuh" +#include "dispatch_sm80.cuh" +#include // raft::identity_op +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void cuvs::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + cuvs::distance::detail::ops::l2_exp_distance_op, float, float, float, raft::identity_op, int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_l2_unexpanded_double_double_double_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_l2_unexpanded_double_double_double_int.cu new file mode 100644 index 000000000..ffa6bf02b --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_l2_unexpanded_double_double_double_int.cu @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include "../distance_ops/all_ops.cuh" // ops::* +#include "dispatch-inl.cuh" // dispatch +#include "dispatch_sm60.cuh" +#include // raft::identity_op +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void cuvs::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + cuvs::distance::detail::ops::l2_unexp_distance_op, + double, + double, + double, + raft::identity_op, + int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_l2_unexpanded_float_float_float_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_l2_unexpanded_float_float_float_int.cu new file mode 100644 index 000000000..acef42a4e --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_l2_unexpanded_float_float_float_int.cu @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include "../distance_ops/all_ops.cuh" // ops::* +#include "dispatch-inl.cuh" // dispatch +#include "dispatch_sm60.cuh" +#include // raft::identity_op +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void cuvs::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + cuvs::distance::detail::ops::l2_unexp_distance_op, float, float, float, raft::identity_op, int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_l_inf_double_double_double_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_l_inf_double_double_double_int.cu new file mode 100644 index 000000000..c2bbbf06b --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_l_inf_double_double_double_int.cu @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include "../distance_ops/all_ops.cuh" // ops::* +#include "dispatch-inl.cuh" // dispatch +#include "dispatch_sm60.cuh" +#include // raft::identity_op +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void cuvs::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + cuvs::distance::detail::ops::l_inf_distance_op, double, double, double, raft::identity_op, int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_l_inf_float_float_float_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_l_inf_float_float_float_int.cu new file mode 100644 index 000000000..163b9f37b --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_l_inf_float_float_float_int.cu @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include "../distance_ops/all_ops.cuh" // ops::* +#include "dispatch-inl.cuh" // dispatch +#include "dispatch_sm60.cuh" +#include // raft::identity_op +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void cuvs::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + cuvs::distance::detail::ops::l_inf_distance_op, float, float, float, raft::identity_op, int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_layout.cuh b/cpp/src/distance/detail/pairwise_matrix/dispatch_layout.cuh new file mode 100644 index 000000000..1f95e8e41 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_layout.cuh @@ -0,0 +1,117 @@ +/* + * Copyright (c) 2023-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "params.cuh" // pairwise_matrix_params +#include // RAFT_EXPECTS + +#include // std::min +#include // size_t +#include // std::integral_constant +namespace cuvs::distance::detail { + +/** + * @brief: Computes minimal common alignment of the rows in a 2D array in bytes + * + * The 2D matrix `x` is assumed to be row-major. This function computes the + * minimal alignment in bytes of the first elements of each row. + * Output can be 16, 8, 4, 2, 1. + * + * @param x Base pointer of row-major input matrix + * @param stride Stride in number of element between consecutive rows. + */ +template +size_t alignment_of_2d_array(const DataT* x, size_t stride) +{ + auto base = reinterpret_cast(x); + size_t stride_bytes = sizeof(DataT) * stride; + + for (int align = 16; align >= 0; align /= 2) { + bool base_aligned = base % align == 0; + bool stride_aligned = stride_bytes % align == 0; + if (base_aligned && stride_aligned) { return align; } + } + return 1; +} + +/** + * @brief: Computes the vec_len parameter kernel policy parameter + * + * @param params Kernel parameters + */ +template +int determine_vec_len(pairwise_matrix_params params) +{ + size_t align_x = alignment_of_2d_array(params.x, params.ldx); + size_t align_y = alignment_of_2d_array(params.y, params.ldy); + size_t byte_alignment = min(align_x, align_y); + + // Since alignment is in bytes, it could be smaller than sizeof(DataT). + // Handle this (unlikely) case here. + RAFT_EXPECTS(sizeof(DataT) <= byte_alignment, + "Input matrix must be aligned to size of elements."); + + // Compute number of elements that can be loaded in one instruction + // without causing misalignent errors. + int vec_len_aligned = (byte_alignment % sizeof(DataT) == 0) ? byte_alignment / sizeof(DataT) : 1; + + // In the future, pairwise_matrix might support `int8_t` input. In that case, + // byte_alignment / sizeof(DataT) might exceed 4. We maximize at 4 here, to + // prevent adding more cases in dispatch_layout below (which are expensive to + // compile). + vec_len_aligned = std::min(vec_len_aligned, 4); + + return vec_len_aligned; +} + +template +using vec_len_constant = std::integral_constant; + +/** + * @brief: Converts run-time arguments to compile-time arguments + * + * Converts run-time arguments row_major and vec_len to compile-time arguments + * and dispatches a lambda f with these compile-time arguments. + * + * This is equivalent to copying and pasting the lambda function `f` in each of + * the switch case statements. + * + * @tparam F Type of lambda f. + * @param row_major Boolean indicating whether input arrays have row-major layout. + * @param vec_len Integer value 1, 2, or 4 specifying the Veclen template parameter of + * the KernelPolicy. + * @param f Lambda that takes two std::integral_constant parameters representing + * row_major and vec_len. + */ +template +auto dispatch_layout(bool row_major, int vec_len, F&& f) +{ + if (row_major) { + switch (vec_len) { + case 4: return f(std::true_type(), vec_len_constant<4>()); + case 2: return f(std::true_type(), vec_len_constant<2>()); + default: return f(std::true_type(), vec_len_constant<1>()); + } + } else { + switch (vec_len) { + case 4: return f(std::false_type(), vec_len_constant<4>()); + case 2: return f(std::false_type(), vec_len_constant<2>()); + default: return f(std::false_type(), vec_len_constant<1>()); + } + } +} + +}; // namespace cuvs::distance::detail diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_lp_unexpanded_double_double_double_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_lp_unexpanded_double_double_double_int.cu new file mode 100644 index 000000000..d13532ac6 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_lp_unexpanded_double_double_double_int.cu @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include "../distance_ops/all_ops.cuh" // ops::* +#include "dispatch-inl.cuh" // dispatch +#include "dispatch_sm60.cuh" +#include // raft::identity_op +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void cuvs::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + cuvs::distance::detail::ops::lp_unexp_distance_op, + double, + double, + double, + raft::identity_op, + int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_lp_unexpanded_float_float_float_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_lp_unexpanded_float_float_float_int.cu new file mode 100644 index 000000000..65e0163d7 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_lp_unexpanded_float_float_float_int.cu @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include "../distance_ops/all_ops.cuh" // ops::* +#include "dispatch-inl.cuh" // dispatch +#include "dispatch_sm60.cuh" +#include // raft::identity_op +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void cuvs::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + cuvs::distance::detail::ops::lp_unexp_distance_op, float, float, float, raft::identity_op, int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_rbf.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_rbf.cu new file mode 100644 index 000000000..23f2b34e8 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_rbf.cu @@ -0,0 +1,64 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include "../distance_ops/all_ops.cuh" // ops::* +#include "../kernels/rbf_fin_op.cuh" // rbf_fin_op +#include "dispatch-inl.cuh" // dispatch +#include "dispatch_sm60.cuh" +#include // raft::identity_op +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void cuvs::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + cuvs::distance::detail::ops::l2_unexp_distance_op, + float, + float, + float, + cuvs::distance::kernels::detail::rbf_fin_op, + int64_t); + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + cuvs::distance::detail::ops::l2_unexp_distance_op, + double, + double, + double, + cuvs::distance::kernels::detail::rbf_fin_op, + int64_t); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_russel_rao_double_double_double_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_russel_rao_double_double_double_int.cu new file mode 100644 index 000000000..1a5e5cf98 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_russel_rao_double_double_double_int.cu @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include "../distance_ops/all_ops.cuh" // ops::* +#include "dispatch-inl.cuh" // dispatch +#include "dispatch_sm60.cuh" +#include // raft::identity_op +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void cuvs::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + cuvs::distance::detail::ops::russel_rao_distance_op, + double, + double, + double, + raft::identity_op, + int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_russel_rao_float_float_float_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_russel_rao_float_float_float_int.cu new file mode 100644 index 000000000..a9b1f6bb4 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_russel_rao_float_float_float_int.cu @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include "../distance_ops/all_ops.cuh" // ops::* +#include "dispatch-inl.cuh" // dispatch +#include "dispatch_sm60.cuh" +#include // raft::identity_op +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void cuvs::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + cuvs::distance::detail::ops::russel_rao_distance_op, float, float, float, raft::identity_op, int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_sm60.cuh b/cpp/src/distance/detail/pairwise_matrix/dispatch_sm60.cuh new file mode 100644 index 000000000..2b0ed01ef --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_sm60.cuh @@ -0,0 +1,85 @@ +/* + * Copyright (c) 2023-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "dispatch_layout.cuh" // dispatch_layout +#include "kernel_sm60.cuh" // pairwise_matrix_sm60_wrapper +#include // raft::linalg::Policy4x4 + +#include // std::min + +namespace cuvs::distance::detail { + +template +pairwise_matrix_sm60_wrapper pairwise_matrix_sm60_get_wrapper( + OpT distance_op, + pairwise_matrix_params params, + SM_compat_t sm_compat_range) +{ + int vec_len = determine_vec_len(params); + + // f takes compile-time constants row_major and vec_len aligned and returns + // the corresponding kernel wrapper. The wrapper contains the launch + // parameters of the kernel: a pointer to the kernel function, grid size, + // block size, and shared memory size. + auto f = [&](auto row_major, auto vec_len_aligned) { + // row_major and vec_len are std::integral_constants of type bool and int + // respectively. + + // To keep compile times in check, we only specialize on veclen > 1 when + // the inner loop is relatively cheap (< 5 flops). + constexpr int vec_len_op = distance_op.expensive_inner_loop ? 1 : vec_len_aligned(); + + // Prevent double, vec_len=4 combination (this is not supported) + constexpr int vec_len = std::min(vec_len_op, static_cast(16 / sizeof(DataT))); + + using RowPolicy = typename raft::linalg::Policy4x4::Policy; + using ColPolicy = typename raft::linalg::Policy4x4::ColPolicy; + using Policy = typename std::conditional::type; + + auto wrapper = + make_pairwise_matrix_sm60_wrapper(distance_op, params, sm_compat_range); + + return wrapper; + }; + + // Dispatch_layout calls f with appropriate compile time constants based on + // the runtime values of params.is_row_major and vec_len. + return dispatch_layout(params.is_row_major, vec_len, f); +} + +template +void pairwise_matrix_sm60_dispatch(OpT distance_op, + pairwise_matrix_params params, + SM_compat_t sm_compat_range, + cudaStream_t stream) +{ + auto wrapper = pairwise_matrix_sm60_get_wrapper(distance_op, params, sm_compat_range); + + wrapper.launch(distance_op, params, stream); +} + +} // namespace cuvs::distance::detail diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_sm80.cuh b/cpp/src/distance/detail/pairwise_matrix/dispatch_sm80.cuh new file mode 100644 index 000000000..d9761545e --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_sm80.cuh @@ -0,0 +1,69 @@ +/* + * Copyright (c) 2023-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "../pairwise_distance_cutlass_base.cuh" // cutlassDistanceKernel +#include "dispatch_layout.cuh" // dispatch_layout + +#include // std::min + +namespace cuvs::distance::detail { + +template +void pairwise_matrix_sm80_dispatch(OpT distance_op, + pairwise_matrix_params params, + SM_compat_t sm_compat_range, + cudaStream_t stream) +{ + int vec_len = determine_vec_len(params); + + // f takes compile-time constants row_major and vec_len aligned and runs the + // corresponding cutlass launch code. + auto f = [&](auto row_major, auto vec_len_aligned) { + // row_major and vec_len are std::integral_constants of type bool and int + // respectively. + + // Prevent double, vec_len=4 combination (this is not supported) + constexpr int vec_len = std::min(vec_len_aligned(), static_cast(16 / sizeof(DataT))); + + using AccT = typename OpT::AccT; + cutlassDistanceKernel(params.x, + params.y, + params.x_norm, + params.y_norm, + params.m, + params.n, + params.k, + params.ldx, + params.ldy, + params.ld_out, + params.out, + params.fin_op, + distance_op, + stream); + }; + + // Dispatch_layout calls f with appropriate compile time constants based on + // the runtime values of params.is_row_major and vec_len. + dispatch_layout(params.is_row_major, vec_len, f); +} + +}; // namespace cuvs::distance::detail diff --git a/cpp/src/distance/detail/pairwise_matrix/kernel_sm60.cuh b/cpp/src/distance/detail/pairwise_matrix/kernel_sm60.cuh new file mode 100644 index 000000000..b63955422 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/kernel_sm60.cuh @@ -0,0 +1,156 @@ +/* + * Copyright (c) 2023-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "../pairwise_distance_base.cuh" // PairwiseDistances +#include "params.cuh" // pairwise_matrix_params +#include // raft::void_op +#include // raft::util::arch::SM_compute_arch + +#include // assert + +namespace cuvs::distance::detail { + +template +__launch_bounds__(Policy::Nthreads, 2) RAFT_KERNEL + pairwise_matrix_kernel(OpT distance_op, pairwise_matrix_params params) +{ + // Early exit to minimize the size of the kernel when it is not supposed to be compiled. + constexpr SM_compat_t sm_compat_range{}; + if constexpr (!sm_compat_range.contains(raft::util::arch::SM_compute_arch())) { + assert(false); + return; + } + + extern __shared__ char smem[]; + + // The epilog is already provided by distance_op. Do not provide additional + // epilogs. + auto epilog_op = raft::void_op(); + // No support for row_epilog_op. + auto row_epilog_op = raft::void_op(); + + // Always write output + constexpr bool write_out = true; + constexpr bool use_norms = distance_op.use_norms; + PairwiseDistances + obj(params.x, + params.y, + params.m, + params.n, + params.k, + params.ldx, + params.ldy, + params.ld_out, + params.x_norm, + params.y_norm, + params.out, + smem, + distance_op, + epilog_op, + params.fin_op, + row_epilog_op); + obj.run(); +} + +// The type of a pointer to the pairwise matrix kernel. The following template +// arguments are type-erased: +// +// - The kernel policy +// - row_major +// - SM_compat_t +template +using pairwise_matrix_kernel_t = void (*)(OpT, pairwise_matrix_params); + +// A wrapper for the pairwise matrix kernel launch. Includes kernel launch +// parameters. +template +struct pairwise_matrix_sm60_wrapper { + dim3 grid; + dim3 block; + int smem_size; + pairwise_matrix_kernel_t kernel_ptr; + + void launch(OpT distance_op, + pairwise_matrix_params params, + cudaStream_t stream) + { + kernel_ptr<<>>(distance_op, params); + RAFT_CUDA_TRY(cudaGetLastError()); + } +}; + +/** @brief: Create kernel launch wrapper for pairwise matrix kernel + * + * This can be used to type-erase the kernel execution policy, row_major, and SM + * compatibility range. + * + * @tparam Policy: Kernel execution policy + * @tparam row_major: Indicates whether input matrices are row major + * @tparam OpT: Type of distance operation + * @tparam IdxT: Index type + * @tparam DataT: Data type + * @tparam OutT: Output data type + * @tparam FinOpT: Final operation type + * @tparam SM_compat_t: Type of the SM architecture compatibility + * + * @param distance_op: Distance operation + * @param params: Parameters + * @param sm_compat_range: Which SM architectures to compile for. + */ +template +pairwise_matrix_sm60_wrapper make_pairwise_matrix_sm60_wrapper( + OpT distance_op, + pairwise_matrix_params params, + SM_compat_t sm_compat_range) +{ + dim3 block(Policy::Nthreads); + // Use ::template to disambiguate (See: + // https://en.cppreference.com/w/cpp/language/dependent_name) + int smem_size = OpT::template shared_mem_size(); + // Obtain function pointer to kernel + auto kernel = + pairwise_matrix_kernel; + dim3 grid = launchConfigGenerator(params.m, params.n, smem_size, kernel); + + return pairwise_matrix_sm60_wrapper{ + grid, block, smem_size, kernel}; +} + +}; // namespace cuvs::distance::detail diff --git a/cpp/src/distance/detail/pairwise_matrix/params.cuh b/cpp/src/distance/detail/pairwise_matrix/params.cuh new file mode 100644 index 000000000..aa419aca0 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/params.cuh @@ -0,0 +1,47 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +namespace cuvs::distance::detail { + +template +struct pairwise_matrix_params { + IdxT m; + IdxT n; + IdxT k; + IdxT ldx; + IdxT ldy; + IdxT ld_out; + const DataT* x; + const DataT* y; + const DataT* x_norm; + const DataT* y_norm; + OutT* out; + FinOpT fin_op; + bool is_row_major; + + /// @brief: Flips the x and y input and corresponding sizes + void flip_x_and_y() + { + // Flip m, n; ldx, ldy; x, y; x_norm, y_norm. + std::swap(m, n); + std::swap(ldx, ldy); + std::swap(x, y); + std::swap(x_norm, y_norm); + } +}; + +} // namespace cuvs::distance::detail diff --git a/cpp/src/distance/detail/predicated_tile_iterator_normvec.h b/cpp/src/distance/detail/predicated_tile_iterator_normvec.h new file mode 100644 index 000000000..951f8a013 --- /dev/null +++ b/cpp/src/distance/detail/predicated_tile_iterator_normvec.h @@ -0,0 +1,585 @@ +/* + * Copyright (c) 2018-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! \file + \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. + +This file contains a customized version of PredicatedTileIterator from CUTLASS 2.9.0 +(https://github.com/NVIDIA/cutlass/blob/v2.9.0/include/cutlass/epilogue/threadblock/predicated_tile_iterator.h#L75) + +Changes: +- added `Layout_` template param +- Only the row index is used to load the data in load_with_byte_offset(). + This way the same normalization data is used across all columns in a row. + +*/ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { + +//////////////////////////////////////////////////////////////////////////////// + +namespace epilogue { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile iterator used to load and store output tile from global memory in epilogue. +/// +/// Satisfies: ReadableTileIterator | PredicatedTileIterator | ForwardTileIterator +/// +template +class PredicatedTileIteratorNormVec { + public: + using ThreadMap = ThreadMap_; + using Shape = typename ThreadMap::Shape; + + using Element = Element_; + + using Layout = Layout_; + using TensorRef = TensorRef; + using ConstTensorRef = typename TensorRef::ConstTensorRef; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + using TensorCoord = MatrixCoord; + + static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; + static int const kThreads = ThreadMap::kThreads; + static int const kIterations = ThreadMap::Count::kTile; + + static_assert(ThreadMap::Iterations::kRow > 0, "ThreadMap::Iterations::kRow must be > 0"); + static_assert(ThreadMap::Iterations::kGroup > 0, "ThreadMap::Iterations::kGroup must be > 0"); + static_assert(ThreadMap::Iterations::kCluster > 0, "ThreadMap::Iterations::kCluster must be > 0"); + static_assert(ThreadMap::Iterations::kColumn > 0, "ThreadMap::Iterations::kColumn must be > 0"); + + /// Fragment object + using Fragment = Array; + + /// Memory access size + using AccessType = AlignedArray; + + // + // Parameters struct + // + + /// Uses a non-template class + struct Params : PredicatedTileIteratorParams { + using Base = PredicatedTileIteratorParams; + + CUTLASS_HOST_DEVICE + Params() {} + + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : PredicatedTileIteratorParams( + layout.stride(0) * int(sizeof(AccessType)) / kElementsPerAccess, + make_OutputTileThreadMapDesc()) + { + } + + CUTLASS_HOST_DEVICE + Params(Base const& base) : Base(base) {} + }; + + /// Mask object + struct Mask { + static int const kCount = ThreadMap::Iterations::kColumn; + + /// Predicate state + bool predicates[kCount]; + + // + // Mask + // + CUTLASS_HOST_DEVICE + Mask() { enable(); } + + ///< Efficiently disables all accesses guarded by mask + CUTLASS_HOST_DEVICE void clear() + { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kCount; ++i) { + predicates[i] = false; + } + } + + ///< CUTLASS_HOST_DEVICE enables all accesses guarded by mask + CUTLASS_DEVICE void enable() + { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kCount; ++i) { + predicates[i] = true; + } + } + }; + + private: + // + // Data members + // + + /// Parameters structure containing reference and precomputed state. + PredicatedTileIteratorParams params_; + + /// Byte-level pointer + uint8_t* byte_pointer_; + + /// Array of boolean values to contain steady-state predicates + Mask mask_; + + /// Extent of the matrix tile in rows + Index extent_row_; + + /// Extent of the matrix tile in rows + Index extent_column_; + + /// A thread's starting row position (assuming steady-state predicates have been computed) + Index thread_start_row_; + + /// A thread's starting column + Index thread_start_column_; + + /// Internal state counter + int state_[3]; + + /// Scatter indices + int const* indices_; + + // + // Static asserts about internal strides + // + + static_assert(sizeof(extent_row_) == 4, "Expected 32b extents"); + static_assert(sizeof(thread_start_row_) == 4, "Expected 32b extents"); + static_assert(sizeof(PredicatedTileIteratorParams::stride) == 8, "Expected 64b strides"); + + private: + // + // Methods + // + + public: + // + // Methods + // + + /// Constructor + CUTLASS_DEVICE + PredicatedTileIteratorNormVec(PredicatedTileIteratorParams const& params, + Element* pointer, + TensorCoord extent, + int thread_idx, + TensorCoord threadblock_offset = TensorCoord(), + int const* indices = nullptr) + : params_(params), indices_(indices) + { + TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx) + threadblock_offset; + + extent_row_ = extent.row(); + extent_column_ = extent.column(); + + thread_start_row_ = thread_offset.row(); + thread_start_column_ = thread_offset.column(); + + // Initialize predicates + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kColumn; ++c) { + mask_.predicates[c] = + ((thread_offset.column() + ThreadMap::Delta::kColumn * c) < extent.column()); + } + + // Null pointer performs no accesses + if (!pointer) { mask_.clear(); } + + if (ScatterD && !indices) { mask_.clear(); } + + // Initialize pointer + byte_pointer_ = reinterpret_cast(pointer) + + LongIndex(thread_offset.row()) * LongIndex(params_.stride); + + if (ScatterD) { + byte_pointer_ = reinterpret_cast(pointer) + + LongIndex(thread_offset.column()) * sizeof(AccessType) / kElementsPerAccess; + } + + // Initialize internal state counter + state_[0] = state_[1] = state_[2] = 0; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) + { + byte_pointer_ += pointer_offset * sizeof_bits::value / 8; + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_byte_offset(Fragment& frag, int64_t byte_offset) const + { + uint8_t* byte_pointer = byte_pointer_; + AccessType* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + int frag_row_idx = + (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); + + int row_offset = row * ThreadMap::Delta::kRow + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + + bool row_guard = ((row_offset + thread_start_row_) < extent_row_); + + AccessType* memory_pointer = reinterpret_cast(byte_pointer + byte_offset); + + if (ScatterD && row_guard) { + assert(indices_); + + memory_pointer = reinterpret_cast( + byte_pointer + byte_offset + + LongIndex(indices_[row_offset + thread_start_row_]) * LongIndex(params_.stride)); + } + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { + bool guard = row_guard && mask_.predicates[column]; + if (column == 0) { + cutlass::arch::global_load( + frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column], + (void*)&memory_pointer[0], + guard); + } else { + frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column] = + frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn]; + } + } + + if (row + 1 < ThreadMap::Iterations::kRow) { + if (!ScatterD) { byte_pointer += params_.increment_row; } + } + } + + if (group + 1 < ThreadMap::Iterations::kGroup) { byte_pointer += params_.increment_group; } + } + + if (cluster + 1 < ThreadMap::Iterations::kCluster) { + byte_pointer += params_.increment_cluster; + } + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment& frag) const { load_with_byte_offset(frag, 0); } + + /// Stores a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const& frag, int64_t byte_offset) const + { + uint8_t* byte_pointer = byte_pointer_; + AccessType const* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + int frag_row_idx = + (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); + + int row_offset = row * ThreadMap::Delta::kRow + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + + bool row_guard = ((row_offset + thread_start_row_) < extent_row_); + + AccessType* memory_pointer = reinterpret_cast(byte_pointer + byte_offset); + + if (ScatterD && row_guard) { + assert(indices_); + + memory_pointer = reinterpret_cast( + byte_pointer + byte_offset + + LongIndex(indices_[row_offset + thread_start_row_]) * LongIndex(params_.stride)); + } + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { + bool guard = row_guard && mask_.predicates[column]; + + if (UseCUDAStore) { + if (guard) { + memory_pointer[column * ThreadMap::Delta::kColumn / kElementsPerAccess] = + frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column]; + } + } else { + cutlass::arch::global_store( + frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column], + (void*)&memory_pointer[column * ThreadMap::Delta::kColumn / kElementsPerAccess], + guard); + } + } + + if (row + 1 < ThreadMap::Iterations::kRow) { + if (!ScatterD) { byte_pointer += params_.increment_row; } + } + } + + if (group + 1 < ThreadMap::Iterations::kGroup) { byte_pointer += params_.increment_group; } + } + + if (cluster + 1 < ThreadMap::Iterations::kCluster) { + byte_pointer += params_.increment_cluster; + } + } + } + + /// Stores a fragment to memory + CUTLASS_DEVICE + void store(Fragment const& frag) const { store_with_byte_offset(frag, 0); } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void downsample_load_with_byte_offset(Fragment& frag, + int64_t byte_offset, + int convolution_P, + int convolution_Q, + int add_P, + int add_Q, + int problem_N) const + { + uint8_t* byte_pointer = byte_pointer_; + AccessType* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + int frag_row_idx = + (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); + + int row_offset = row * ThreadMap::Delta::kRow + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + + bool row_guard = ((row_offset + thread_start_row_) < extent_row_); + + int output_row = row_offset + thread_start_row_; + int output_N = output_row / (convolution_P * convolution_Q); + int output_PQ = output_row % (convolution_P * convolution_Q); + int output_P = output_PQ / convolution_Q; + int output_Q = output_PQ % convolution_Q; + + int input_row = output_N * 2 * convolution_P * 2 * convolution_Q + + (2 * output_P + add_P) * 2 * convolution_Q + 2 * output_Q + add_Q; + + int64_t byte_offset = (input_row - output_row) * problem_N * sizeof(float); + + AccessType* memory_pointer = reinterpret_cast(byte_pointer + byte_offset); + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { + bool guard = row_guard && mask_.predicates[column]; + + cutlass::arch::global_load( + frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column], + (void*)&memory_pointer[column * ThreadMap::Delta::kColumn / kElementsPerAccess], + guard); + } + + if (row + 1 < ThreadMap::Iterations::kRow) { byte_pointer += params_.increment_row; } + } + + if (group + 1 < ThreadMap::Iterations::kGroup) { byte_pointer += params_.increment_group; } + } + + if (cluster + 1 < ThreadMap::Iterations::kCluster) { + byte_pointer += params_.increment_cluster; + } + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void upsample_load_with_byte_offset(Fragment& frag, + int64_t byte_offset, + int convolution_P, + int convolution_Q, + int add_P, + int add_Q, + int problem_N) const + { + uint8_t* byte_pointer = byte_pointer_; + AccessType* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + int frag_row_idx = + (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); + + int row_offset = row * ThreadMap::Delta::kRow + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + + bool row_guard = ((row_offset + thread_start_row_) < extent_row_); + + int output_row = row_offset + thread_start_row_; + int output_N = output_row / (convolution_P * convolution_Q); + int output_PQ = output_row % (convolution_P * convolution_Q); + int output_P = output_PQ / convolution_Q; + int output_Q = output_PQ % convolution_Q; + int row_add_P = add_P; + int row_add_Q = add_Q; + if (output_P > convolution_P - 2) row_add_P = 0; + if (output_Q > convolution_Q - 2) row_add_Q = 0; + + int input_row = output_N * (convolution_P / 2) * (convolution_Q / 2) + + ((output_P + row_add_P) / 2) * (convolution_Q / 2) + + (output_Q + row_add_Q) / 2; + + int64_t byte_offset = (input_row - output_row) * problem_N * sizeof(float); + + AccessType* memory_pointer = reinterpret_cast(byte_pointer + byte_offset); + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { + bool guard = row_guard && mask_.predicates[column]; + + cutlass::arch::global_load( + frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column], + (void*)&memory_pointer[column * ThreadMap::Delta::kColumn / kElementsPerAccess], + guard); + } + + if (row + 1 < ThreadMap::Iterations::kRow) { byte_pointer += params_.increment_row; } + } + + if (group + 1 < ThreadMap::Iterations::kGroup) { byte_pointer += params_.increment_group; } + } + + if (cluster + 1 < ThreadMap::Iterations::kCluster) { + byte_pointer += params_.increment_cluster; + } + } + } + + CUTLASS_DEVICE + MatrixCoord thread_start() const { return MatrixCoord(thread_start_row_, thread_start_column_); } + + /// Need to get the thread start row from the tile iterator + CUTLASS_DEVICE + int32_t thread_start_row() const { return thread_start_row_; } + + /// Need to get the thread start row from the tile iterator + CUTLASS_DEVICE + int32_t thread_start_column() const { return thread_start_column_; } + + /// Extent of the matrix in rows + CUTLASS_DEVICE + Index extent_row() const { return extent_row_; } + + /// Extent of the matrix in columns + CUTLASS_DEVICE + Index extent_column() const { return extent_column_; } + + /// Advances to the next position to load or store + CUTLASS_HOST_DEVICE + PredicatedTileIteratorNormVec& operator++() + { + ++state_[0]; + + if (!ScatterD) { byte_pointer_ += params_.advance_row; } + + thread_start_row_ += ThreadMap::Shape::kRow; + + if (state_[0] == ThreadMap::Count::kRow) { + state_[0] = 0; + ++state_[1]; + byte_pointer_ += params_.advance_group; + + thread_start_row_ += + (ThreadMap::Shape::kGroup - 1) * ThreadMap::Shape::kRow * ThreadMap::Count::kRow; + + if (state_[1] == ThreadMap::Count::kGroup) { + state_[1] = 0; + ++state_[2]; + byte_pointer_ += params_.advance_cluster; + + thread_start_row_ += ThreadMap::Count::kGroup * ThreadMap::Shape::kGroup * + ThreadMap::Count::kRow * ThreadMap::Shape::kRow; + + if (state_[2] == ThreadMap::Count::kCluster) { + state_[2] = 0; + byte_pointer_ += params_.advance_tile; + } + } + } + + return *this; + } + + ///< Efficiently disables all accesses guarded by mask + CUTLASS_DEVICE void clear_mask() { mask_.clear(); } + + ///< Efficiently enables all accesses guarded by mask + CUTLASS_DEVICE void enable_mask() { mask_.enable(); } + + ///< Sets the mask + CUTLASS_DEVICE void get_mask(Mask& mask) const { mask = mask_; } + + ///< Sets the mask + CUTLASS_DEVICE void set_mask(Mask const& mask) { mask_ = mask; } +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/cpp/src/distance/distance-ext.cuh b/cpp/src/distance/distance-ext.cuh new file mode 100644 index 000000000..ad45e8405 --- /dev/null +++ b/cpp/src/distance/distance-ext.cuh @@ -0,0 +1,1066 @@ +/* + * Copyright (c) 2018-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "detail/kernels/rbf_fin_op.cuh" // rbf_fin_op +#include // cuvs::distance::DistanceType +#include // raft::device_matrix_view +#include // raft::identity_op +#include // raft::resources +#include // RAFT_EXPLICIT + +#include // rmm::device_uvector + +#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY + +namespace cuvs { +namespace distance { + +template +void distance(raft::resources const& handle, + const DataT* x, + const DataT* y, + OutT* dist, + IdxT m, + IdxT n, + IdxT k, + void* workspace, + size_t worksize, + FinalLambda fin_op, + bool isRowMajor = true, + DataT metric_arg = 2.0f) RAFT_EXPLICIT; + +template +void distance(raft::resources const& handle, + const DataT* x, + const DataT* y, + OutT* dist, + IdxT m, + IdxT n, + IdxT k, + void* workspace, + size_t worksize, + bool isRowMajor = true, + DataT metric_arg = 2.0f) RAFT_EXPLICIT; + +template +size_t getWorkspaceSize(const DataT* x, const DataT* y, IdxT m, IdxT n, IdxT k) RAFT_EXPLICIT; + +template +size_t getWorkspaceSize(raft::device_matrix_view const& x, + raft::device_matrix_view const& y) RAFT_EXPLICIT; + +template +void distance(raft::resources const& handle, + const DataT* x, + const DataT* y, + OutT* dist, + IdxT m, + IdxT n, + IdxT k, + bool isRowMajor = true, + DataT metric_arg = 2.0f) RAFT_EXPLICIT; + +template +void pairwise_distance(raft::resources const& handle, + const Type* x, + const Type* y, + Type* dist, + IdxT m, + IdxT n, + IdxT k, + rmm::device_uvector& workspace, + cuvs::distance::DistanceType metric, + bool isRowMajor = true, + Type metric_arg = 2.0f) RAFT_EXPLICIT; + +template +void pairwise_distance(raft::resources const& handle, + const Type* x, + const Type* y, + Type* dist, + IdxT m, + IdxT n, + IdxT k, + cuvs::distance::DistanceType metric, + bool isRowMajor = true, + Type metric_arg = 2.0f) RAFT_EXPLICIT; + +template +void distance(raft::resources const& handle, + raft::device_matrix_view const x, + raft::device_matrix_view const y, + raft::device_matrix_view dist, + DataT metric_arg = 2.0f) RAFT_EXPLICIT; + +template +void pairwise_distance(raft::resources const& handle, + device_matrix_view const x, + device_matrix_view const y, + device_matrix_view dist, + cuvs::distance::DistanceType metric, + Type metric_arg = 2.0f) RAFT_EXPLICIT; + +}; // namespace distance +}; // namespace cuvs + +#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY + +/* + * Hierarchy of instantiations: + * + * This file defines the extern template instantiations for the public API of + * cuvs::distance. To improve compile times, the extern template instantiation + * of the distance kernels is handled in + * distance/detail/pairwise_matrix/dispatch-ext.cuh. + * + * After adding an instance here, make sure to also add the instance to + * dispatch-ext.cuh and the corresponding .cu files. + */ + +#define instantiate_raft_distance_distance(DT, DataT, AccT, OutT, FinalLambda, IdxT) \ + extern template void cuvs::distance::distance( \ + raft::resources const& handle, \ + const DataT* x, \ + const DataT* y, \ + OutT* dist, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + void* workspace, \ + size_t worksize, \ + FinalLambda fin_op, \ + bool isRowMajor, \ + DataT metric_arg) + +// The following two instances are used in test/distance/gram.cu. Note the use +// of int64_t for the index type. +instantiate_raft_distance_distance(cuvs::distance::DistanceType::L2Unexpanded, + float, + float, + float, + cuvs::distance::kernels::detail::rbf_fin_op, + int64_t); +instantiate_raft_distance_distance(cuvs::distance::DistanceType::L2Unexpanded, + double, + double, + double, + cuvs::distance::kernels::detail::rbf_fin_op, + int64_t); + +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::Canberra, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::Canberra, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::CorrelationExpanded, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance(cuvs::distance::DistanceType::CorrelationExpanded, + double, + double, + double, + raft::identity_op, + int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::CosineExpanded, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::CosineExpanded, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::HammingUnexpanded, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::HammingUnexpanded, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::HellingerExpanded, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::HellingerExpanded, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::InnerProduct, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::InnerProduct, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::JensenShannon, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::JensenShannon, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::KLDivergence, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::KLDivergence, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::L1, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::L1, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::L2Expanded, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::L2Expanded, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::L2SqrtExpanded, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::L2SqrtExpanded, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::L2SqrtUnexpanded, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::L2SqrtUnexpanded, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::L2Unexpanded, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::L2Unexpanded, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::Linf, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::Linf, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::LpUnexpanded, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::LpUnexpanded, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::RusselRaoExpanded, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::RusselRaoExpanded, double, double, double, raft::identity_op, int); + +#undef instantiate_raft_distance_distance + +// Same, but without raft::identity_op +#define instantiate_raft_distance_distance(DT, DataT, AccT, OutT, IdxT) \ + extern template void cuvs::distance::distance( \ + raft::resources const& handle, \ + const DataT* x, \ + const DataT* y, \ + OutT* dist, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + void* workspace, \ + size_t worksize, \ + bool isRowMajor, \ + DataT metric_arg) + +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::Canberra, float, float, float, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::Canberra, double, double, double, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::CorrelationExpanded, float, float, float, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::CorrelationExpanded, double, double, double, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::CosineExpanded, float, float, float, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::CosineExpanded, double, double, double, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::HammingUnexpanded, float, float, float, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::HammingUnexpanded, double, double, double, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::HellingerExpanded, float, float, float, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::HellingerExpanded, double, double, double, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::InnerProduct, float, float, float, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::InnerProduct, double, double, double, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::JensenShannon, float, float, float, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::JensenShannon, double, double, double, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::KLDivergence, float, float, float, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::KLDivergence, double, double, double, int); +instantiate_raft_distance_distance(cuvs::distance::DistanceType::L1, float, float, float, int); +instantiate_raft_distance_distance(cuvs::distance::DistanceType::L1, double, double, double, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::L2Expanded, float, float, float, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::L2Expanded, double, double, double, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::L2SqrtExpanded, float, float, float, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::L2SqrtExpanded, double, double, double, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::L2SqrtUnexpanded, float, float, float, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::L2SqrtUnexpanded, double, double, double, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::L2Unexpanded, float, float, float, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::L2Unexpanded, double, double, double, int); +instantiate_raft_distance_distance(cuvs::distance::DistanceType::Linf, float, float, float, int); +instantiate_raft_distance_distance(cuvs::distance::DistanceType::Linf, double, double, double, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::LpUnexpanded, float, float, float, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::LpUnexpanded, double, double, double, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::RusselRaoExpanded, float, float, float, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::RusselRaoExpanded, double, double, double, int); + +#undef instantiate_raft_distance_distance + +// Same, but without workspace +#define instantiate_raft_distance_distance(DT, DataT, AccT, OutT, IdxT) \ + extern template void cuvs::distance::distance( \ + raft::resources const& handle, \ + const DataT* x, \ + const DataT* y, \ + OutT* dist, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + bool isRowMajor, \ + DataT metric_arg) + +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::Canberra, float, float, float, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::Canberra, double, double, double, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::CorrelationExpanded, float, float, float, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::CorrelationExpanded, double, double, double, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::CosineExpanded, float, float, float, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::CosineExpanded, double, double, double, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::HammingUnexpanded, float, float, float, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::HammingUnexpanded, double, double, double, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::HellingerExpanded, float, float, float, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::HellingerExpanded, double, double, double, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::InnerProduct, float, float, float, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::InnerProduct, double, double, double, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::JensenShannon, float, float, float, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::JensenShannon, double, double, double, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::KLDivergence, float, float, float, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::KLDivergence, double, double, double, int); +instantiate_raft_distance_distance(cuvs::distance::DistanceType::L1, float, float, float, int); +instantiate_raft_distance_distance(cuvs::distance::DistanceType::L1, double, double, double, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::L2Expanded, float, float, float, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::L2Expanded, double, double, double, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::L2SqrtExpanded, float, float, float, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::L2SqrtExpanded, double, double, double, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::L2SqrtUnexpanded, float, float, float, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::L2SqrtUnexpanded, double, double, double, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::L2Unexpanded, float, float, float, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::L2Unexpanded, double, double, double, int); +instantiate_raft_distance_distance(cuvs::distance::DistanceType::Linf, float, float, float, int); +instantiate_raft_distance_distance(cuvs::distance::DistanceType::Linf, double, double, double, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::LpUnexpanded, float, float, float, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::LpUnexpanded, double, double, double, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::RusselRaoExpanded, float, float, float, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::RusselRaoExpanded, double, double, double, int); + +#undef instantiate_raft_distance_distance + +#define instantiate_raft_distance_getWorkspaceSize(DistT, DataT, AccT, OutT, IdxT) \ + extern template size_t cuvs::distance::getWorkspaceSize( \ + const DataT* x, const DataT* y, IdxT m, IdxT n, IdxT k) + +instantiate_raft_distance_getWorkspaceSize( + cuvs::distance::DistanceType::Canberra, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + cuvs::distance::DistanceType::Canberra, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + cuvs::distance::DistanceType::CorrelationExpanded, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + cuvs::distance::DistanceType::CorrelationExpanded, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + cuvs::distance::DistanceType::CosineExpanded, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + cuvs::distance::DistanceType::CosineExpanded, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + cuvs::distance::DistanceType::HammingUnexpanded, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + cuvs::distance::DistanceType::HammingUnexpanded, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + cuvs::distance::DistanceType::HellingerExpanded, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + cuvs::distance::DistanceType::HellingerExpanded, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + cuvs::distance::DistanceType::InnerProduct, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + cuvs::distance::DistanceType::InnerProduct, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + cuvs::distance::DistanceType::JensenShannon, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + cuvs::distance::DistanceType::JensenShannon, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + cuvs::distance::DistanceType::KLDivergence, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + cuvs::distance::DistanceType::KLDivergence, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + cuvs::distance::DistanceType::L1, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + cuvs::distance::DistanceType::L1, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + cuvs::distance::DistanceType::L2Expanded, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + cuvs::distance::DistanceType::L2Expanded, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + cuvs::distance::DistanceType::L2SqrtExpanded, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + cuvs::distance::DistanceType::L2SqrtExpanded, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + cuvs::distance::DistanceType::L2SqrtUnexpanded, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + cuvs::distance::DistanceType::L2SqrtUnexpanded, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + cuvs::distance::DistanceType::L2Unexpanded, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + cuvs::distance::DistanceType::L2Unexpanded, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + cuvs::distance::DistanceType::Linf, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + cuvs::distance::DistanceType::Linf, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + cuvs::distance::DistanceType::LpUnexpanded, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + cuvs::distance::DistanceType::LpUnexpanded, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + cuvs::distance::DistanceType::RusselRaoExpanded, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + cuvs::distance::DistanceType::RusselRaoExpanded, double, double, double, int); + +#undef instantiate_raft_distance_getWorkspaceSize + +#define instantiate_raft_distance_getWorkspaceSize(DistT, DataT, AccT, OutT, IdxT, layout) \ + extern template size_t cuvs::distance::getWorkspaceSize( \ + raft::device_matrix_view const& x, \ + raft::device_matrix_view const& y) + +// We could consider not taking template parameters for this function. The +// number of instantiations seems a bit excessive.. +instantiate_raft_distance_getWorkspaceSize( + cuvs::distance::DistanceType::Canberra, float, float, float, int, raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize( + cuvs::distance::DistanceType::Canberra, double, double, double, int, raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize( + cuvs::distance::DistanceType::Canberra, float, float, float, int, raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize( + cuvs::distance::DistanceType::Canberra, double, double, double, int, raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::CorrelationExpanded, + float, + float, + float, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::CorrelationExpanded, + double, + double, + double, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::CorrelationExpanded, + float, + float, + float, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::CorrelationExpanded, + double, + double, + double, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::CosineExpanded, + float, + float, + float, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::CosineExpanded, + double, + double, + double, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::CosineExpanded, + float, + float, + float, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::CosineExpanded, + double, + double, + double, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::HammingUnexpanded, + float, + float, + float, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::HammingUnexpanded, + double, + double, + double, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::HammingUnexpanded, + float, + float, + float, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::HammingUnexpanded, + double, + double, + double, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::HellingerExpanded, + float, + float, + float, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::HellingerExpanded, + double, + double, + double, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::HellingerExpanded, + float, + float, + float, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::HellingerExpanded, + double, + double, + double, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize( + cuvs::distance::DistanceType::InnerProduct, float, float, float, int, raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::InnerProduct, + double, + double, + double, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize( + cuvs::distance::DistanceType::InnerProduct, float, float, float, int, raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::InnerProduct, + double, + double, + double, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize( + cuvs::distance::DistanceType::JensenShannon, float, float, float, int, raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::JensenShannon, + double, + double, + double, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize( + cuvs::distance::DistanceType::JensenShannon, float, float, float, int, raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::JensenShannon, + double, + double, + double, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize( + cuvs::distance::DistanceType::KLDivergence, float, float, float, int, raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::KLDivergence, + double, + double, + double, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize( + cuvs::distance::DistanceType::KLDivergence, float, float, float, int, raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::KLDivergence, + double, + double, + double, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize( + cuvs::distance::DistanceType::L1, float, float, float, int, raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize( + cuvs::distance::DistanceType::L1, double, double, double, int, raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize( + cuvs::distance::DistanceType::L1, float, float, float, int, raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize( + cuvs::distance::DistanceType::L1, double, double, double, int, raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize( + cuvs::distance::DistanceType::L2Expanded, float, float, float, int, raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize( + cuvs::distance::DistanceType::L2Expanded, double, double, double, int, raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize( + cuvs::distance::DistanceType::L2Expanded, float, float, float, int, raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize( + cuvs::distance::DistanceType::L2Expanded, double, double, double, int, raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::L2SqrtExpanded, + float, + float, + float, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::L2SqrtExpanded, + double, + double, + double, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::L2SqrtExpanded, + float, + float, + float, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::L2SqrtExpanded, + double, + double, + double, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::L2SqrtUnexpanded, + float, + float, + float, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::L2SqrtUnexpanded, + double, + double, + double, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::L2SqrtUnexpanded, + float, + float, + float, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::L2SqrtUnexpanded, + double, + double, + double, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize( + cuvs::distance::DistanceType::L2Unexpanded, float, float, float, int, raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::L2Unexpanded, + double, + double, + double, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize( + cuvs::distance::DistanceType::L2Unexpanded, float, float, float, int, raft::layout_f_contiguous); + +#undef instantiate_raft_distance_getWorkspaceSize + +#define instantiate_raft_distance_pairwise_distance(DataT, IdxT) \ + extern template void cuvs::distance::pairwise_distance(raft::resources const& handle, \ + const DataT* x, \ + const DataT* y, \ + DataT* dist, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + rmm::device_uvector& workspace, \ + cuvs::distance::DistanceType metric, \ + bool isRowMajor, \ + DataT metric_arg) + +instantiate_raft_distance_pairwise_distance(float, int); +instantiate_raft_distance_pairwise_distance(double, int); + +#undef instantiate_raft_distance_pairwise_distance + +// Same, but without workspace +#define instantiate_raft_distance_pairwise_distance(DataT, IdxT) \ + extern template void cuvs::distance::pairwise_distance(raft::resources const& handle, \ + const DataT* x, \ + const DataT* y, \ + DataT* dist, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + cuvs::distance::DistanceType metric, \ + bool isRowMajor, \ + DataT metric_arg) + +instantiate_raft_distance_pairwise_distance(float, int); +instantiate_raft_distance_pairwise_distance(double, int); + +#undef instantiate_raft_distance_pairwise_distance + +// Version with mdspan +#define instantiate_raft_distance_distance(DistT, DataT, AccT, OutT, layout, IdxT) \ + extern template void cuvs::distance::distance( \ + raft::resources const& handle, \ + raft::device_matrix_view const x, \ + raft::device_matrix_view const y, \ + raft::device_matrix_view dist, \ + DataT metric_arg) + +// Again, we might want to consider reigning in the number of instantiations... +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::Canberra, float, float, float, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::Canberra, double, double, double, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::Canberra, float, float, float, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::Canberra, double, double, double, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance(cuvs::distance::DistanceType::CorrelationExpanded, + float, + float, + float, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(cuvs::distance::DistanceType::CorrelationExpanded, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(cuvs::distance::DistanceType::CorrelationExpanded, + float, + float, + float, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance(cuvs::distance::DistanceType::CorrelationExpanded, + double, + double, + double, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance(cuvs::distance::DistanceType::CosineExpanded, + float, + float, + float, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(cuvs::distance::DistanceType::CosineExpanded, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(cuvs::distance::DistanceType::CosineExpanded, + float, + float, + float, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance(cuvs::distance::DistanceType::CosineExpanded, + double, + double, + double, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance(cuvs::distance::DistanceType::HammingUnexpanded, + float, + float, + float, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(cuvs::distance::DistanceType::HammingUnexpanded, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(cuvs::distance::DistanceType::HammingUnexpanded, + float, + float, + float, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance(cuvs::distance::DistanceType::HammingUnexpanded, + double, + double, + double, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance(cuvs::distance::DistanceType::HellingerExpanded, + float, + float, + float, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(cuvs::distance::DistanceType::HellingerExpanded, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(cuvs::distance::DistanceType::HellingerExpanded, + float, + float, + float, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance(cuvs::distance::DistanceType::HellingerExpanded, + double, + double, + double, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::InnerProduct, float, float, float, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance(cuvs::distance::DistanceType::InnerProduct, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::InnerProduct, float, float, float, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance(cuvs::distance::DistanceType::InnerProduct, + double, + double, + double, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::JensenShannon, float, float, float, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance(cuvs::distance::DistanceType::JensenShannon, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::JensenShannon, float, float, float, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance(cuvs::distance::DistanceType::JensenShannon, + double, + double, + double, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::KLDivergence, float, float, float, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance(cuvs::distance::DistanceType::KLDivergence, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::KLDivergence, float, float, float, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance(cuvs::distance::DistanceType::KLDivergence, + double, + double, + double, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::L1, float, float, float, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::L1, double, double, double, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::L1, float, float, float, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::L1, double, double, double, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::L2Expanded, float, float, float, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::L2Expanded, double, double, double, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::L2Expanded, float, float, float, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::L2Expanded, double, double, double, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance(cuvs::distance::DistanceType::L2SqrtExpanded, + float, + float, + float, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(cuvs::distance::DistanceType::L2SqrtExpanded, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(cuvs::distance::DistanceType::L2SqrtExpanded, + float, + float, + float, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance(cuvs::distance::DistanceType::L2SqrtExpanded, + double, + double, + double, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance(cuvs::distance::DistanceType::L2SqrtUnexpanded, + float, + float, + float, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(cuvs::distance::DistanceType::L2SqrtUnexpanded, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(cuvs::distance::DistanceType::L2SqrtUnexpanded, + float, + float, + float, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance(cuvs::distance::DistanceType::L2SqrtUnexpanded, + double, + double, + double, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::L2Unexpanded, float, float, float, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance(cuvs::distance::DistanceType::L2Unexpanded, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::L2Unexpanded, float, float, float, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance(cuvs::distance::DistanceType::L2Unexpanded, + double, + double, + double, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::Linf, float, float, float, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::Linf, double, double, double, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::Linf, float, float, float, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::Linf, double, double, double, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::LpUnexpanded, float, float, float, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance(cuvs::distance::DistanceType::LpUnexpanded, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::LpUnexpanded, float, float, float, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance(cuvs::distance::DistanceType::LpUnexpanded, + double, + double, + double, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance(cuvs::distance::DistanceType::RusselRaoExpanded, + float, + float, + float, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(cuvs::distance::DistanceType::RusselRaoExpanded, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(cuvs::distance::DistanceType::RusselRaoExpanded, + float, + float, + float, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance(cuvs::distance::DistanceType::RusselRaoExpanded, + double, + double, + double, + raft::layout_f_contiguous, + int); + +#undef instantiate_raft_distance_distance + +#define instantiate_raft_distance_pairwise_distance(DataT, layout, IdxT) \ + extern template void cuvs::distance::pairwise_distance( \ + raft::resources const& handle, \ + raft::device_matrix_view const x, \ + raft::device_matrix_view const y, \ + raft::device_matrix_view dist, \ + cuvs::distance::DistanceType metric, \ + DataT metric_arg) + +instantiate_raft_distance_pairwise_distance(float, raft::layout_c_contiguous, int); +instantiate_raft_distance_pairwise_distance(float, raft::layout_f_contiguous, int); +instantiate_raft_distance_pairwise_distance(double, raft::layout_c_contiguous, int); +instantiate_raft_distance_pairwise_distance(double, raft::layout_f_contiguous, int); + +#undef instantiate_raft_distance_pairwise_distance diff --git a/cpp/src/distance/distance-inl.cuh b/cpp/src/distance/distance-inl.cuh new file mode 100644 index 000000000..5b82f5438 --- /dev/null +++ b/cpp/src/distance/distance-inl.cuh @@ -0,0 +1,478 @@ +/* + * Copyright (c) 2018-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "detail/distance.cuh" +#include +#include +#include +#include + +#include + +#include + +namespace cuvs { +namespace distance { + +/** + * @defgroup pairwise_distance pointer-based pairwise distance prims + * @{ + */ + +/** + * @brief Evaluate pairwise distances with the user epilogue lamba allowed + * @tparam DistanceType which distance to evaluate + * @tparam DataT input argument type + * @tparam AccT accumulation type + * @tparam OutT output type + * @tparam FinalLambda user-defined epilogue lamba + * @tparam IdxT Index type + * @param handle raft handle for managing expensive resources + * @param x first set of points + * @param y second set of points + * @param dist output distance matrix + * @param m number of points in x + * @param n number of points in y + * @param k dimensionality + * @param workspace temporary workspace needed for computations + * @param worksize number of bytes of the workspace + * @param fin_op the final gemm epilogue lambda + * @param isRowMajor whether the matrices are row-major or col-major + * @param metric_arg metric argument (used for Minkowski distance) + * + * @note fin_op: This is a device lambda which is supposed to operate upon the + * input which is AccT and returns the output in OutT. It's signature is + * as follows:

OutT fin_op(AccT in, int g_idx);
. If one needs + * any other parameters, feel free to pass them via closure. + */ +template +void distance(raft::resources const& handle, + const DataT* x, + const DataT* y, + OutT* dist, + IdxT m, + IdxT n, + IdxT k, + void* workspace, + size_t worksize, + FinalLambda fin_op, + bool isRowMajor = true, + DataT metric_arg = 2.0f) +{ + detail::distance( + handle, x, y, dist, m, n, k, workspace, worksize, fin_op, isRowMajor, metric_arg); +} + +/** + * @brief Evaluate pairwise distances for the simple use case + * @tparam DistanceType which distance to evaluate + * @tparam DataT input argument type + * @tparam AccT accumulation type + * @tparam OutT output type + * @tparam IdxT Index type + * @param handle raft handle for managing expensive resources + * @param x first set of points + * @param y second set of points + * @param dist output distance matrix + * @param m number of points in x + * @param n number of points in y + * @param k dimensionality + * @param workspace temporary workspace needed for computations + * @param worksize number of bytes of the workspace + * @param isRowMajor whether the matrices are row-major or col-major + * @param metric_arg metric argument (used for Minkowski distance) + */ +template +void distance(raft::resources const& handle, + const DataT* x, + const DataT* y, + OutT* dist, + IdxT m, + IdxT n, + IdxT k, + void* workspace, + size_t worksize, + bool isRowMajor = true, + DataT metric_arg = 2.0f) +{ + detail::distance( + handle, x, y, dist, m, n, k, workspace, worksize, isRowMajor, metric_arg); +} + +/** + * @brief Return the exact workspace size to compute the distance + * @tparam DistanceType which distance to evaluate + * @tparam DataT input argument type + * @tparam AccT accumulation type + * @tparam OutT output type + * @tparam IdxT Index type + * @param x first set of points + * @param y second set of points + * @param m number of points in x + * @param n number of points in y + * @param k dimensionality + * + * @note If the specified DistT doesn't need the workspace at all, it + * returns 0. + */ +template +size_t getWorkspaceSize(const DataT* x, const DataT* y, IdxT m, IdxT n, IdxT k) +{ + return detail::getWorkspaceSize(x, y, m, n, k); +} + +/** + * @brief Return the exact workspace size to compute the distance + * @tparam DistanceType which distance to evaluate + * @tparam DataT input argument type + * @tparam AccT accumulation type + * @tparam OutT output type + * @tparam IdxT Index type + * @param x first set of points (size m*k) + * @param y second set of points (size n*k) + * @return number of bytes needed in workspace + * + * @note If the specified DistT doesn't need the workspace at all, it + * returns 0. + */ +template +size_t getWorkspaceSize(raft::device_matrix_view const& x, + raft::device_matrix_view const& y) +{ + RAFT_EXPECTS(x.extent(1) == y.extent(1), "Number of columns must be equal."); + + return getWorkspaceSize( + x.data_handle(), y.data_handle(), x.extent(0), y.extent(0), x.extent(1)); +} + +/** + * @brief Evaluate pairwise distances for the simple use case + * @tparam DistanceType which distance to evaluate + * @tparam DataT input argument type + * @tparam AccT accumulation type + * @tparam OutT output type + * @tparam IdxT Index type + * @param handle raft handle for managing expensive resources + * @param x first set of points + * @param y second set of points + * @param dist output distance matrix + * @param m number of points in x + * @param n number of points in y + * @param k dimensionality + * @param isRowMajor whether the matrices are row-major or col-major + * @param metric_arg metric argument (used for Minkowski distance) + */ +template +void distance(raft::resources const& handle, + const DataT* x, + const DataT* y, + OutT* dist, + IdxT m, + IdxT n, + IdxT k, + bool isRowMajor = true, + DataT metric_arg = 2.0f) +{ + auto stream = raft::resource::get_cuda_stream(handle); + rmm::device_uvector workspace(0, stream); + auto worksize = getWorkspaceSize(x, y, m, n, k); + workspace.resize(worksize, stream); + detail::distance( + handle, x, y, dist, m, n, k, workspace.data(), worksize, isRowMajor, metric_arg); +} + +/** + * @brief Convenience wrapper around 'distance' prim to convert runtime metric + * into compile time for the purpose of dispatch + * @tparam Type input/accumulation/output data-type + * @tparam IdxT indexing type + * @param handle raft handle for managing expensive resources + * @param x first set of points + * @param y second set of points + * @param dist output distance matrix + * @param m number of points in x + * @param n number of points in y + * @param k dimensionality + * @param workspace temporary workspace buffer which can get resized as per the + * needed workspace size + * @param metric distance metric + * @param isRowMajor whether the matrices are row-major or col-major + * @param metric_arg metric argument (used for Minkowski distance) + */ +template +void pairwise_distance(raft::resources const& handle, + const Type* x, + const Type* y, + Type* dist, + IdxT m, + IdxT n, + IdxT k, + rmm::device_uvector& workspace, + cuvs::distance::DistanceType metric, + bool isRowMajor = true, + Type metric_arg = 2.0f) +{ + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + + auto dispatch = [&](auto distance_type) { + auto worksize = getWorkspaceSize(x, y, m, n, k); + workspace.resize(worksize, stream); + detail::distance( + handle, x, y, dist, m, n, k, workspace.data(), worksize, isRowMajor, metric_arg); + }; + + switch (metric) { + case DistanceType::Canberra: + dispatch(std::integral_constant{}); + break; + case DistanceType::CorrelationExpanded: + dispatch(std::integral_constant{}); + break; + case DistanceType::CosineExpanded: + dispatch(std::integral_constant{}); + break; + case DistanceType::HammingUnexpanded: + dispatch(std::integral_constant{}); + break; + case DistanceType::HellingerExpanded: + dispatch(std::integral_constant{}); + break; + case cuvs::distance::DistanceType::InnerProduct: + dispatch(std::integral_constant{}); + break; + case DistanceType::JensenShannon: + dispatch(std::integral_constant{}); + break; + case DistanceType::KLDivergence: + dispatch(std::integral_constant{}); + break; + case DistanceType::L1: + dispatch(std::integral_constant{}); + break; + case DistanceType::L2Expanded: + dispatch(std::integral_constant{}); + break; + case DistanceType::L2SqrtExpanded: + dispatch(std::integral_constant{}); + break; + case DistanceType::L2SqrtUnexpanded: + dispatch(std::integral_constant{}); + break; + case DistanceType::L2Unexpanded: + dispatch(std::integral_constant{}); + break; + case DistanceType::Linf: + dispatch(std::integral_constant{}); + break; + case DistanceType::LpUnexpanded: + dispatch(std::integral_constant{}); + break; + case DistanceType::RusselRaoExpanded: + dispatch(std::integral_constant{}); + break; + default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric); + }; +} + +/** + * @brief Convenience wrapper around 'distance' prim to convert runtime metric + * into compile time for the purpose of dispatch + * @tparam Type input/accumulation/output data-type + * @tparam IdxT indexing type + * @param handle raft handle for managing expensive resources + * @param x first set of points + * @param y second set of points + * @param dist output distance matrix + * @param m number of points in x + * @param n number of points in y + * @param k dimensionality + * @param metric distance metric + * @param isRowMajor whether the matrices are row-major or col-major + * @param metric_arg metric argument (used for Minkowski distance) + */ +template +void pairwise_distance(raft::resources const& handle, + const Type* x, + const Type* y, + Type* dist, + IdxT m, + IdxT n, + IdxT k, + cuvs::distance::DistanceType metric, + bool isRowMajor = true, + Type metric_arg = 2.0f) +{ + auto stream = raft::resource::get_cuda_stream(handle); + rmm::device_uvector workspace(0, stream); + pairwise_distance( + handle, x, y, dist, m, n, k, workspace, metric, isRowMajor, metric_arg); +} + +/** @} */ + +/** + * \defgroup distance_mdspan Pairwise distance functions + * @{ + */ + +/** + * @brief Evaluate pairwise distances for the simple use case. + * + * Note: Only contiguous row- or column-major layouts supported currently. + * + * Usage example: + * @code{.cpp} + * #include + * #include + * #include + * #include + * + * raft::raft::resources handle; + * int n_samples = 5000; + * int n_features = 50; + * + * auto input = raft::make_device_matrix(handle, n_samples, n_features); + * auto labels = raft::make_device_vector(handle, n_samples); + * auto output = raft::make_device_matrix(handle, n_samples, n_samples); + * + * raft::random::make_blobs(handle, input.view(), labels.view()); + * auto metric = cuvs::distance::DistanceType::L2SqrtExpanded; + * cuvs::distance::pairwise_distance(handle, input.view(), input.view(), output.view(), metric); + * @endcode + * + * @tparam DistanceType which distance to evaluate + * @tparam DataT input argument type + * @tparam AccT accumulation type + * @tparam OutT output type + * @tparam IdxT Index type + * @param handle raft handle for managing expensive resources + * @param x first set of points (size n*k) + * @param y second set of points (size m*k) + * @param dist output distance matrix (size n*m) + * @param metric_arg metric argument (used for Minkowski distance) + */ +template +void distance(raft::resources const& handle, + raft::device_matrix_view const x, + raft::device_matrix_view const y, + raft::device_matrix_view dist, + DataT metric_arg = 2.0f) +{ + RAFT_EXPECTS(x.extent(1) == y.extent(1), "Number of columns must be equal."); + RAFT_EXPECTS(dist.extent(0) == x.extent(0), + "Number of rows in output must be equal to " + "number of rows in X"); + RAFT_EXPECTS(dist.extent(1) == y.extent(0), + "Number of columns in output must be equal to " + "number of rows in Y"); + + RAFT_EXPECTS(x.is_exhaustive(), "Input x must be contiguous."); + RAFT_EXPECTS(y.is_exhaustive(), "Input y must be contiguous."); + + constexpr auto is_rowmajor = std::is_same_v; + + distance(handle, + x.data_handle(), + y.data_handle(), + dist.data_handle(), + x.extent(0), + y.extent(0), + x.extent(1), + is_rowmajor, + metric_arg); +} + +/** + * @brief Convenience wrapper around 'distance' prim to convert runtime metric + * into compile time for the purpose of dispatch + * @tparam Type input/accumulation/output data-type + * @tparam IdxT indexing type + * @param handle raft handle for managing expensive resources + * @param x first matrix of points (size mxk) + * @param y second matrix of points (size nxk) + * @param dist output distance matrix (size mxn) + * @param metric distance metric + * @param metric_arg metric argument (used for Minkowski distance) + */ +template +void pairwise_distance(raft::resources const& handle, + raft::device_matrix_view const x, + raft::device_matrix_view const y, + raft::device_matrix_view dist, + cuvs::distance::DistanceType metric, + Type metric_arg = 2.0f) +{ + RAFT_EXPECTS(x.extent(1) == y.extent(1), "Number of columns must be equal."); + RAFT_EXPECTS(dist.extent(0) == x.extent(0), + "Number of rows in output must be equal to " + "number of rows in X"); + RAFT_EXPECTS(dist.extent(1) == y.extent(0), + "Number of columns in output must be equal to " + "number of rows in Y"); + + RAFT_EXPECTS(x.is_exhaustive(), "Input x must be contiguous."); + RAFT_EXPECTS(y.is_exhaustive(), "Input y must be contiguous."); + RAFT_EXPECTS(dist.is_exhaustive(), "Output must be contiguous."); + + constexpr auto rowmajor = std::is_same_v; + + auto stream = raft::resource::get_cuda_stream(handle); + rmm::device_uvector workspace(0, stream); + + pairwise_distance(handle, + x.data_handle(), + y.data_handle(), + dist.data_handle(), + x.extent(0), + y.extent(0), + x.extent(1), + metric, + rowmajor, + metric_arg); +} + +/** @} */ + +}; // namespace distance +}; // namespace cuvs diff --git a/cpp/src/distance/distance.cu b/cpp/src/distance/distance.cu new file mode 100644 index 000000000..02c071d13 --- /dev/null +++ b/cpp/src/distance/distance.cu @@ -0,0 +1,934 @@ +/* + * Copyright (c) 2018-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "detail/kernels/rbf_fin_op.cuh" // rbf_fin_op +#include "distance-inl.cuh" + +/* + * Hierarchy of instantiations: + * + * This file defines the template instantiations for the public API of + * cuvs::distance. To improve compile times, the compilation of the distance + * kernels is handled in distance/detail/pairwise_matrix/dispatch_*.cu. + * + */ + +#define instantiate_raft_distance_distance(DT, DataT, AccT, OutT, FinalLambda, IdxT) \ + template void cuvs::distance::distance( \ + raft::resources const& handle, \ + const DataT* x, \ + const DataT* y, \ + OutT* dist, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + void* workspace, \ + size_t worksize, \ + FinalLambda fin_op, \ + bool isRowMajor, \ + DataT metric_arg) + +// The following two instances are used in test/distance/gram.cu. Note the use +// of int64_t for the index type. +instantiate_raft_distance_distance(cuvs::distance::DistanceType::L2Unexpanded, + float, + float, + float, + cuvs::distance::kernels::detail::rbf_fin_op, + int64_t); +instantiate_raft_distance_distance(cuvs::distance::DistanceType::L2Unexpanded, + double, + double, + double, + cuvs::distance::kernels::detail::rbf_fin_op, + int64_t); + +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::Canberra, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::Canberra, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::CorrelationExpanded, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance(cuvs::distance::DistanceType::CorrelationExpanded, + double, + double, + double, + raft::identity_op, + int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::CosineExpanded, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::CosineExpanded, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::HammingUnexpanded, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::HammingUnexpanded, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::HellingerExpanded, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::HellingerExpanded, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::InnerProduct, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::InnerProduct, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::JensenShannon, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::JensenShannon, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::KLDivergence, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::KLDivergence, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::L1, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::L1, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::L2Expanded, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::L2Expanded, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::L2SqrtExpanded, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::L2SqrtExpanded, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::L2SqrtUnexpanded, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::L2SqrtUnexpanded, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::L2Unexpanded, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::L2Unexpanded, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::Linf, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::Linf, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::LpUnexpanded, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::LpUnexpanded, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::RusselRaoExpanded, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::RusselRaoExpanded, double, double, double, raft::identity_op, int); + +#undef instantiate_raft_distance_distance + +// Same, but without raft::identity_op +#define instantiate_raft_distance_distance(DT, DataT, AccT, OutT, IdxT) \ + template void cuvs::distance::distance( \ + raft::resources const& handle, \ + const DataT* x, \ + const DataT* y, \ + OutT* dist, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + void* workspace, \ + size_t worksize, \ + bool isRowMajor, \ + DataT metric_arg) + +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::Canberra, float, float, float, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::Canberra, double, double, double, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::CorrelationExpanded, float, float, float, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::CorrelationExpanded, double, double, double, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::CosineExpanded, float, float, float, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::CosineExpanded, double, double, double, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::HammingUnexpanded, float, float, float, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::HammingUnexpanded, double, double, double, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::HellingerExpanded, float, float, float, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::HellingerExpanded, double, double, double, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::InnerProduct, float, float, float, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::InnerProduct, double, double, double, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::JensenShannon, float, float, float, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::JensenShannon, double, double, double, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::KLDivergence, float, float, float, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::KLDivergence, double, double, double, int); +instantiate_raft_distance_distance(cuvs::distance::DistanceType::L1, float, float, float, int); +instantiate_raft_distance_distance(cuvs::distance::DistanceType::L1, double, double, double, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::L2Expanded, float, float, float, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::L2Expanded, double, double, double, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::L2SqrtExpanded, float, float, float, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::L2SqrtExpanded, double, double, double, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::L2SqrtUnexpanded, float, float, float, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::L2SqrtUnexpanded, double, double, double, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::L2Unexpanded, float, float, float, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::L2Unexpanded, double, double, double, int); +instantiate_raft_distance_distance(cuvs::distance::DistanceType::Linf, float, float, float, int); +instantiate_raft_distance_distance(cuvs::distance::DistanceType::Linf, double, double, double, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::LpUnexpanded, float, float, float, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::LpUnexpanded, double, double, double, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::RusselRaoExpanded, float, float, float, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::RusselRaoExpanded, double, double, double, int); + +#undef instantiate_raft_distance_distance + +// Same, but without workspace +#define instantiate_raft_distance_distance(DT, DataT, AccT, OutT, IdxT) \ + template void cuvs::distance::distance( \ + raft::resources const& handle, \ + const DataT* x, \ + const DataT* y, \ + OutT* dist, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + bool isRowMajor, \ + DataT metric_arg) + +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::Canberra, float, float, float, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::Canberra, double, double, double, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::CorrelationExpanded, float, float, float, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::CorrelationExpanded, double, double, double, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::CosineExpanded, float, float, float, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::CosineExpanded, double, double, double, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::HammingUnexpanded, float, float, float, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::HammingUnexpanded, double, double, double, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::HellingerExpanded, float, float, float, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::HellingerExpanded, double, double, double, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::InnerProduct, float, float, float, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::InnerProduct, double, double, double, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::JensenShannon, float, float, float, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::JensenShannon, double, double, double, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::KLDivergence, float, float, float, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::KLDivergence, double, double, double, int); +instantiate_raft_distance_distance(cuvs::distance::DistanceType::L1, float, float, float, int); +instantiate_raft_distance_distance(cuvs::distance::DistanceType::L1, double, double, double, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::L2Expanded, float, float, float, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::L2Expanded, double, double, double, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::L2SqrtExpanded, float, float, float, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::L2SqrtExpanded, double, double, double, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::L2SqrtUnexpanded, float, float, float, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::L2SqrtUnexpanded, double, double, double, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::L2Unexpanded, float, float, float, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::L2Unexpanded, double, double, double, int); +instantiate_raft_distance_distance(cuvs::distance::DistanceType::Linf, float, float, float, int); +instantiate_raft_distance_distance(cuvs::distance::DistanceType::Linf, double, double, double, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::LpUnexpanded, float, float, float, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::LpUnexpanded, double, double, double, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::RusselRaoExpanded, float, float, float, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::RusselRaoExpanded, double, double, double, int); + +#undef instantiate_raft_distance_distance + +#define instantiate_raft_distance_getWorkspaceSize(DistT, DataT, AccT, OutT, IdxT) \ + template size_t cuvs::distance::getWorkspaceSize( \ + const DataT* x, const DataT* y, IdxT m, IdxT n, IdxT k) + +instantiate_raft_distance_getWorkspaceSize( + cuvs::distance::DistanceType::Canberra, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + cuvs::distance::DistanceType::Canberra, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + cuvs::distance::DistanceType::CorrelationExpanded, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + cuvs::distance::DistanceType::CorrelationExpanded, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + cuvs::distance::DistanceType::CosineExpanded, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + cuvs::distance::DistanceType::CosineExpanded, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + cuvs::distance::DistanceType::HammingUnexpanded, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + cuvs::distance::DistanceType::HammingUnexpanded, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + cuvs::distance::DistanceType::HellingerExpanded, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + cuvs::distance::DistanceType::HellingerExpanded, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + cuvs::distance::DistanceType::InnerProduct, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + cuvs::distance::DistanceType::InnerProduct, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + cuvs::distance::DistanceType::JensenShannon, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + cuvs::distance::DistanceType::JensenShannon, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + cuvs::distance::DistanceType::KLDivergence, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + cuvs::distance::DistanceType::KLDivergence, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + cuvs::distance::DistanceType::L1, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + cuvs::distance::DistanceType::L1, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + cuvs::distance::DistanceType::L2Expanded, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + cuvs::distance::DistanceType::L2Expanded, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + cuvs::distance::DistanceType::L2SqrtExpanded, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + cuvs::distance::DistanceType::L2SqrtExpanded, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + cuvs::distance::DistanceType::L2SqrtUnexpanded, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + cuvs::distance::DistanceType::L2SqrtUnexpanded, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + cuvs::distance::DistanceType::L2Unexpanded, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + cuvs::distance::DistanceType::L2Unexpanded, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + cuvs::distance::DistanceType::Linf, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + cuvs::distance::DistanceType::Linf, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + cuvs::distance::DistanceType::LpUnexpanded, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + cuvs::distance::DistanceType::LpUnexpanded, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + cuvs::distance::DistanceType::RusselRaoExpanded, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + cuvs::distance::DistanceType::RusselRaoExpanded, double, double, double, int); + +#undef instantiate_raft_distance_getWorkspaceSize + +#define instantiate_raft_distance_getWorkspaceSize(DistT, DataT, AccT, OutT, IdxT, layout) \ + template size_t cuvs::distance::getWorkspaceSize( \ + raft::device_matrix_view const& x, \ + raft::device_matrix_view const& y) + +// We could consider not taking template parameters for this function. The +// number of instantiations seems a bit excessive.. +instantiate_raft_distance_getWorkspaceSize( + cuvs::distance::DistanceType::Canberra, float, float, float, int, raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize( + cuvs::distance::DistanceType::Canberra, double, double, double, int, raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize( + cuvs::distance::DistanceType::Canberra, float, float, float, int, raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize( + cuvs::distance::DistanceType::Canberra, double, double, double, int, raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::CorrelationExpanded, + float, + float, + float, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::CorrelationExpanded, + double, + double, + double, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::CorrelationExpanded, + float, + float, + float, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::CorrelationExpanded, + double, + double, + double, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::CosineExpanded, + float, + float, + float, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::CosineExpanded, + double, + double, + double, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::CosineExpanded, + float, + float, + float, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::CosineExpanded, + double, + double, + double, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::HammingUnexpanded, + float, + float, + float, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::HammingUnexpanded, + double, + double, + double, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::HammingUnexpanded, + float, + float, + float, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::HammingUnexpanded, + double, + double, + double, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::HellingerExpanded, + float, + float, + float, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::HellingerExpanded, + double, + double, + double, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::HellingerExpanded, + float, + float, + float, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::HellingerExpanded, + double, + double, + double, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize( + cuvs::distance::DistanceType::InnerProduct, float, float, float, int, raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::InnerProduct, + double, + double, + double, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize( + cuvs::distance::DistanceType::InnerProduct, float, float, float, int, raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::InnerProduct, + double, + double, + double, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize( + cuvs::distance::DistanceType::JensenShannon, float, float, float, int, raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::JensenShannon, + double, + double, + double, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize( + cuvs::distance::DistanceType::JensenShannon, float, float, float, int, raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::JensenShannon, + double, + double, + double, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize( + cuvs::distance::DistanceType::KLDivergence, float, float, float, int, raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::KLDivergence, + double, + double, + double, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize( + cuvs::distance::DistanceType::KLDivergence, float, float, float, int, raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::KLDivergence, + double, + double, + double, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize( + cuvs::distance::DistanceType::L1, float, float, float, int, raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize( + cuvs::distance::DistanceType::L1, double, double, double, int, raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize( + cuvs::distance::DistanceType::L1, float, float, float, int, raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize( + cuvs::distance::DistanceType::L1, double, double, double, int, raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize( + cuvs::distance::DistanceType::L2Expanded, float, float, float, int, raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize( + cuvs::distance::DistanceType::L2Expanded, double, double, double, int, raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize( + cuvs::distance::DistanceType::L2Expanded, float, float, float, int, raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize( + cuvs::distance::DistanceType::L2Expanded, double, double, double, int, raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::L2SqrtExpanded, + float, + float, + float, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::L2SqrtExpanded, + double, + double, + double, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::L2SqrtExpanded, + float, + float, + float, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::L2SqrtExpanded, + double, + double, + double, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::L2SqrtUnexpanded, + float, + float, + float, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::L2SqrtUnexpanded, + double, + double, + double, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::L2SqrtUnexpanded, + float, + float, + float, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::L2SqrtUnexpanded, + double, + double, + double, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize( + cuvs::distance::DistanceType::L2Unexpanded, float, float, float, int, raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(cuvs::distance::DistanceType::L2Unexpanded, + double, + double, + double, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize( + cuvs::distance::DistanceType::L2Unexpanded, float, float, float, int, raft::layout_f_contiguous); + +#undef instantiate_raft_distance_getWorkspaceSize + +#define instantiate_raft_distance_pairwise_distance(DataT, IdxT) \ + template void cuvs::distance::pairwise_distance(raft::resources const& handle, \ + const DataT* x, \ + const DataT* y, \ + DataT* dist, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + rmm::device_uvector& workspace, \ + cuvs::distance::DistanceType metric, \ + bool isRowMajor, \ + DataT metric_arg) + +instantiate_raft_distance_pairwise_distance(float, int); +instantiate_raft_distance_pairwise_distance(double, int); + +#undef instantiate_raft_distance_pairwise_distance + +// Same, but without workspace +#define instantiate_raft_distance_pairwise_distance(DataT, IdxT) \ + template void cuvs::distance::pairwise_distance(raft::resources const& handle, \ + const DataT* x, \ + const DataT* y, \ + DataT* dist, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + cuvs::distance::DistanceType metric, \ + bool isRowMajor, \ + DataT metric_arg) + +instantiate_raft_distance_pairwise_distance(float, int); +instantiate_raft_distance_pairwise_distance(double, int); + +#undef instantiate_raft_distance_pairwise_distance + +// Version with mdspan +#define instantiate_raft_distance_distance(DistT, DataT, AccT, OutT, layout, IdxT) \ + template void cuvs::distance::distance( \ + raft::resources const& handle, \ + raft::device_matrix_view const x, \ + raft::device_matrix_view const y, \ + raft::device_matrix_view dist, \ + DataT metric_arg) + +// Again, we might want to consider reigning in the number of instantiations... +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::Canberra, float, float, float, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::Canberra, double, double, double, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::Canberra, float, float, float, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::Canberra, double, double, double, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance(cuvs::distance::DistanceType::CorrelationExpanded, + float, + float, + float, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(cuvs::distance::DistanceType::CorrelationExpanded, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(cuvs::distance::DistanceType::CorrelationExpanded, + float, + float, + float, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance(cuvs::distance::DistanceType::CorrelationExpanded, + double, + double, + double, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance(cuvs::distance::DistanceType::CosineExpanded, + float, + float, + float, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(cuvs::distance::DistanceType::CosineExpanded, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(cuvs::distance::DistanceType::CosineExpanded, + float, + float, + float, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance(cuvs::distance::DistanceType::CosineExpanded, + double, + double, + double, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance(cuvs::distance::DistanceType::HammingUnexpanded, + float, + float, + float, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(cuvs::distance::DistanceType::HammingUnexpanded, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(cuvs::distance::DistanceType::HammingUnexpanded, + float, + float, + float, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance(cuvs::distance::DistanceType::HammingUnexpanded, + double, + double, + double, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance(cuvs::distance::DistanceType::HellingerExpanded, + float, + float, + float, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(cuvs::distance::DistanceType::HellingerExpanded, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(cuvs::distance::DistanceType::HellingerExpanded, + float, + float, + float, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance(cuvs::distance::DistanceType::HellingerExpanded, + double, + double, + double, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::InnerProduct, float, float, float, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance(cuvs::distance::DistanceType::InnerProduct, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::InnerProduct, float, float, float, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance(cuvs::distance::DistanceType::InnerProduct, + double, + double, + double, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::JensenShannon, float, float, float, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance(cuvs::distance::DistanceType::JensenShannon, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::JensenShannon, float, float, float, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance(cuvs::distance::DistanceType::JensenShannon, + double, + double, + double, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::KLDivergence, float, float, float, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance(cuvs::distance::DistanceType::KLDivergence, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::KLDivergence, float, float, float, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance(cuvs::distance::DistanceType::KLDivergence, + double, + double, + double, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::L1, float, float, float, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::L1, double, double, double, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::L1, float, float, float, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::L1, double, double, double, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::L2Expanded, float, float, float, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::L2Expanded, double, double, double, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::L2Expanded, float, float, float, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::L2Expanded, double, double, double, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance(cuvs::distance::DistanceType::L2SqrtExpanded, + float, + float, + float, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(cuvs::distance::DistanceType::L2SqrtExpanded, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(cuvs::distance::DistanceType::L2SqrtExpanded, + float, + float, + float, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance(cuvs::distance::DistanceType::L2SqrtExpanded, + double, + double, + double, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance(cuvs::distance::DistanceType::L2SqrtUnexpanded, + float, + float, + float, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(cuvs::distance::DistanceType::L2SqrtUnexpanded, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(cuvs::distance::DistanceType::L2SqrtUnexpanded, + float, + float, + float, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance(cuvs::distance::DistanceType::L2SqrtUnexpanded, + double, + double, + double, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::L2Unexpanded, float, float, float, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance(cuvs::distance::DistanceType::L2Unexpanded, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::L2Unexpanded, float, float, float, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance(cuvs::distance::DistanceType::L2Unexpanded, + double, + double, + double, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::Linf, float, float, float, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::Linf, double, double, double, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::Linf, float, float, float, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::Linf, double, double, double, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::LpUnexpanded, float, float, float, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance(cuvs::distance::DistanceType::LpUnexpanded, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance( + cuvs::distance::DistanceType::LpUnexpanded, float, float, float, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance(cuvs::distance::DistanceType::LpUnexpanded, + double, + double, + double, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance(cuvs::distance::DistanceType::RusselRaoExpanded, + float, + float, + float, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(cuvs::distance::DistanceType::RusselRaoExpanded, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(cuvs::distance::DistanceType::RusselRaoExpanded, + float, + float, + float, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance(cuvs::distance::DistanceType::RusselRaoExpanded, + double, + double, + double, + raft::layout_f_contiguous, + int); + +#undef instantiate_raft_distance_distance + +#define instantiate_raft_distance_pairwise_distance(DataT, layout, IdxT) \ + template void cuvs::distance::pairwise_distance( \ + raft::resources const& handle, \ + raft::device_matrix_view const x, \ + raft::device_matrix_view const y, \ + raft::device_matrix_view dist, \ + cuvs::distance::DistanceType metric, \ + DataT metric_arg) + +instantiate_raft_distance_pairwise_distance(float, raft::layout_c_contiguous, int); +instantiate_raft_distance_pairwise_distance(float, raft::layout_f_contiguous, int); +instantiate_raft_distance_pairwise_distance(double, raft::layout_c_contiguous, int); +instantiate_raft_distance_pairwise_distance(double, raft::layout_f_contiguous, int); + +#undef instantiate_raft_distance_pairwise_distance diff --git a/cpp/src/distance/distance.cuh b/cpp/src/distance/distance.cuh new file mode 100644 index 000000000..de70cd469 --- /dev/null +++ b/cpp/src/distance/distance.cuh @@ -0,0 +1,24 @@ +/* + * Copyright (c) 2018-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#ifndef RAFT_EXPLICIT_INSTANTIATE_ONLY +#include "distance-inl.cuh" +#endif + +#ifdef RAFT_COMPILED +#include "distance-ext.cuh" +#endif diff --git a/cpp/src/distance/pairwise_distance.cu b/cpp/src/distance/pairwise_distance.cu new file mode 100644 index 000000000..bf4b21669 --- /dev/null +++ b/cpp/src/distance/pairwise_distance.cu @@ -0,0 +1,103 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "distance.cuh" +#include +#include +#include +#include + +namespace cuvs::distance { + +/** + * @defgroup pairwise_distance_runtime Pairwise Distances Runtime API + * @{ + */ +void pairwise_distance( + raft::resources const& handle, + raft::device_matrix_view const x, + raft::device_matrix_view const y, + raft::device_matrix_view dist, + cuvs::distance::DistanceType metric, + float metric_arg) +{ + auto x_v = raft::make_device_matrix_view( + x.data_handle(), x.extent(0), x.extent(1)); + auto y_v = raft::make_device_matrix_view( + y.data_handle(), y.extent(0), y.extent(1)); + auto d_v = raft::make_device_matrix_view( + dist.data_handle(), dist.extent(0), dist.extent(1)); + pairwise_distance( + handle, x_v, y_v, d_v, metric, metric_arg); +} + +void pairwise_distance( + raft::resources const& handle, + raft::device_matrix_view const x, + raft::device_matrix_view const y, + raft::device_matrix_view dist, + cuvs::distance::DistanceType metric, + double metric_arg) +{ + auto x_v = raft::make_device_matrix_view( + x.data_handle(), x.extent(0), x.extent(1)); + auto y_v = raft::make_device_matrix_view( + y.data_handle(), y.extent(0), y.extent(1)); + auto d_v = raft::make_device_matrix_view( + dist.data_handle(), dist.extent(0), dist.extent(1)); + pairwise_distance( + handle, x_v, y_v, d_v, metric, metric_arg); +} + +void pairwise_distance( + raft::resources const& handle, + raft::device_matrix_view const x, + raft::device_matrix_view const y, + raft::device_matrix_view dist, + cuvs::distance::DistanceType metric, + float metric_arg) +{ + auto x_v = raft::make_device_matrix_view( + x.data_handle(), x.extent(0), x.extent(1)); + auto y_v = raft::make_device_matrix_view( + y.data_handle(), y.extent(0), y.extent(1)); + auto d_v = raft::make_device_matrix_view( + dist.data_handle(), dist.extent(0), dist.extent(1)); + pairwise_distance( + handle, x_v, y_v, d_v, metric, metric_arg); +} + +void pairwise_distance( + raft::resources const& handle, + raft::device_matrix_view const x, + raft::device_matrix_view const y, + raft::device_matrix_view dist, + cuvs::distance::DistanceType metric, + double metric_arg) +{ + auto x_v = raft::make_device_matrix_view( + x.data_handle(), x.extent(0), x.extent(1)); + auto y_v = raft::make_device_matrix_view( + y.data_handle(), y.extent(0), y.extent(1)); + auto d_v = raft::make_device_matrix_view( + dist.data_handle(), dist.extent(0), dist.extent(1)); + pairwise_distance( + handle, x_v, y_v, d_v, metric, metric_arg); +} + +/** @} */ // end group pairwise_distance_runtime + +} // namespace cuvs::distance diff --git a/cpp/src/neighbors/brute_force_c.cpp b/cpp/src/neighbors/brute_force_c.cpp index 88349e089..e988ac2f0 100644 --- a/cpp/src/neighbors/brute_force_c.cpp +++ b/cpp/src/neighbors/brute_force_c.cpp @@ -33,7 +33,7 @@ namespace { template void* _build(cuvsResources_t res, DLManagedTensor* dataset_tensor, - enum DistanceType metric, + cuvsDistanceType metric, T metric_arg) { auto res_ptr = reinterpret_cast(res); @@ -97,7 +97,7 @@ extern "C" cuvsError_t cuvsBruteForceIndexDestroy(cuvsBruteForceIndex_t index_c_ extern "C" cuvsError_t cuvsBruteForceBuild(cuvsResources_t res, DLManagedTensor* dataset_tensor, - enum DistanceType metric, + cuvsDistanceType metric, float metric_arg, cuvsBruteForceIndex_t index) { diff --git a/cpp/src/neighbors/cagra_build_float.cpp b/cpp/src/neighbors/cagra_build_float.cpp index f66ae5cfb..5ff1da7f2 100644 --- a/cpp/src/neighbors/cagra_build_float.cpp +++ b/cpp/src/neighbors/cagra_build_float.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. + * Copyright (c) 2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 841058f60..131737039 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -128,35 +128,47 @@ if(BUILD_TESTS) PERCENT 100 ) + + ConfigureTest( + NAME + DISTANCE_TEST + PATH + test/distance/dist_canberra.cu + test/distance/dist_correlation.cu + test/distance/dist_cos.cu + test/distance/dist_hamming.cu + test/distance/dist_hellinger.cu + test/distance/dist_inner_product.cu + test/distance/dist_jensen_shannon.cu + test/distance/dist_kl_divergence.cu + test/distance/dist_l1.cu + test/distance/dist_l2_exp.cu + test/distance/dist_l2_sqrt_exp.cu + test/distance/dist_l_inf.cu + test/distance/dist_lp_unexp.cu + test/distance/dist_russell_rao.cu + GPUS + 1 + PERCENT + 100 + ) endif() if(BUILD_C_TESTS) ConfigureTest(NAME INTEROP_TEST PATH test/core/interop.cu C_LIB) ConfigureTest( - NAME - BRUTEFORCE_C_TEST - PATH - test/neighbors/run_brute_force_c.c - test/neighbors/brute_force_c.cu + NAME BRUTEFORCE_C_TEST PATH test/neighbors/run_brute_force_c.c test/neighbors/brute_force_c.cu C_LIB ) ConfigureTest( - NAME - IVF_FLAT_C_TEST - PATH - test/neighbors/run_ivf_flat_c.c - test/neighbors/ann_ivf_flat_c.cu + NAME IVF_FLAT_C_TEST PATH test/neighbors/run_ivf_flat_c.c test/neighbors/ann_ivf_flat_c.cu C_LIB ) ConfigureTest( - NAME - IVF_PQ_C_TEST - PATH - test/neighbors/run_ivf_pq_c.c - test/neighbors/ann_ivf_pq_c.cu C_LIB + NAME IVF_PQ_C_TEST PATH test/neighbors/run_ivf_pq_c.c test/neighbors/ann_ivf_pq_c.cu C_LIB ) ConfigureTest(NAME CAGRA_C_TEST PATH test/neighbors/ann_cagra_c.cu C_LIB) diff --git a/cpp/test/distance/dist_canberra.cu b/cpp/test/distance/dist_canberra.cu new file mode 100644 index 000000000..2bf590601 --- /dev/null +++ b/cpp/test/distance/dist_canberra.cu @@ -0,0 +1,70 @@ +/* + * Copyright (c) 2018-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../test_utils.cuh" +#include "distance_base.cuh" + +namespace cuvs { +namespace distance { + +template +class DistanceCanberra : public DistanceTest {}; + +const std::vector> inputsf = { + {0.001f, 1024, 1024, 32, true, 1234ULL}, + {0.001f, 1024, 32, 1024, true, 1234ULL}, + {0.001f, 32, 1024, 1024, true, 1234ULL}, + {0.003f, 1024, 1024, 1024, true, 1234ULL}, + {0.001f, 1024, 1024, 32, false, 1234ULL}, + {0.001f, 1024, 32, 1024, false, 1234ULL}, + {0.001f, 32, 1024, 1024, false, 1234ULL}, + {0.003f, 1024, 1024, 1024, false, 1234ULL}, +}; +typedef DistanceCanberra DistanceCanberraF; +TEST_P(DistanceCanberraF, Result) +{ + int m = params.isRowMajor ? params.m : params.n; + int n = params.isRowMajor ? params.n : params.m; + ASSERT_TRUE(cuvs::devArrMatch( + dist_ref.data(), dist.data(), m, n, cuvs::CompareApprox(params.tolerance), stream)); +} +INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceCanberraF, ::testing::ValuesIn(inputsf)); + +const std::vector> inputsd = { + {0.001, 1024, 1024, 32, true, 1234ULL}, + {0.001, 1024, 32, 1024, true, 1234ULL}, + {0.001, 32, 1024, 1024, true, 1234ULL}, + {0.003, 1024, 1024, 1024, true, 1234ULL}, + {0.001, 1024, 1024, 32, false, 1234ULL}, + {0.001, 1024, 32, 1024, false, 1234ULL}, + {0.001, 32, 1024, 1024, false, 1234ULL}, + {0.003, 1024, 1024, 1024, false, 1234ULL}, +}; +typedef DistanceCanberra DistanceCanberraD; +TEST_P(DistanceCanberraD, Result) +{ + int m = params.isRowMajor ? params.m : params.n; + int n = params.isRowMajor ? params.n : params.m; + ASSERT_TRUE(cuvs::devArrMatch( + dist_ref.data(), dist.data(), m, n, cuvs::CompareApprox(params.tolerance), stream)); +} +INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceCanberraD, ::testing::ValuesIn(inputsd)); + +class BigMatrixCanberra : public BigMatrixDistanceTest {}; +TEST_F(BigMatrixCanberra, Result) {} + +} // end namespace distance +} // namespace cuvs diff --git a/cpp/test/distance/dist_correlation.cu b/cpp/test/distance/dist_correlation.cu new file mode 100644 index 000000000..9e061bebc --- /dev/null +++ b/cpp/test/distance/dist_correlation.cu @@ -0,0 +1,94 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../test_utils.cuh" +#include "distance_base.cuh" + +namespace cuvs { +namespace distance { + +template +class DistanceCorrelation + : public DistanceTest {}; + +template +class DistanceCorrelationXequalY + : public DistanceTestSameBuffer {}; + +const std::vector> inputsf = { + {0.001f, 1024, 1024, 32, true, 1234ULL}, + {0.001f, 1024, 32, 1024, true, 1234ULL}, + {0.001f, 32, 1024, 1024, true, 1234ULL}, + {0.003f, 1024, 1024, 1024, true, 1234ULL}, + {0.001f, 1024, 1024, 32, false, 1234ULL}, + {0.001f, 1024, 32, 1024, false, 1234ULL}, + {0.001f, 32, 1024, 1024, false, 1234ULL}, + {0.003f, 1024, 1024, 1024, false, 1234ULL}, +}; +typedef DistanceCorrelation DistanceCorrelationF; +TEST_P(DistanceCorrelationF, Result) +{ + int m = params.isRowMajor ? params.m : params.n; + int n = params.isRowMajor ? params.n : params.m; + ASSERT_TRUE(cuvs::devArrMatch( + dist_ref.data(), dist.data(), m, n, cuvs::CompareApprox(params.tolerance), stream)); +} +INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceCorrelationF, ::testing::ValuesIn(inputsf)); + +typedef DistanceCorrelationXequalY DistanceCorrelationXequalYF; +TEST_P(DistanceCorrelationXequalYF, Result) +{ + int m = params.m; + ASSERT_TRUE(cuvs::devArrMatch(dist_ref[0].data(), + dist[0].data(), + m, + m, + cuvs::CompareApprox(params.tolerance), + stream)); + ASSERT_TRUE(cuvs::devArrMatch(dist_ref[1].data(), + dist[1].data(), + m / 2, + m, + cuvs::CompareApprox(params.tolerance), + stream)); +} +INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceCorrelationXequalYF, ::testing::ValuesIn(inputsf)); + +const std::vector> inputsd = { + {0.001, 1024, 1024, 32, true, 1234ULL}, + {0.001, 1024, 32, 1024, true, 1234ULL}, + {0.001, 32, 1024, 1024, true, 1234ULL}, + {0.003, 1024, 1024, 1024, true, 1234ULL}, + {0.001, 1024, 1024, 32, false, 1234ULL}, + {0.001, 1024, 32, 1024, false, 1234ULL}, + {0.001, 32, 1024, 1024, false, 1234ULL}, + {0.003, 1024, 1024, 1024, false, 1234ULL}, +}; +typedef DistanceCorrelation DistanceCorrelationD; +TEST_P(DistanceCorrelationD, Result) +{ + int m = params.isRowMajor ? params.m : params.n; + int n = params.isRowMajor ? params.n : params.m; + ASSERT_TRUE(cuvs::devArrMatch( + dist_ref.data(), dist.data(), m, n, cuvs::CompareApprox(params.tolerance), stream)); +} +INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceCorrelationD, ::testing::ValuesIn(inputsd)); + +class BigMatrixCorrelation + : public BigMatrixDistanceTest {}; +TEST_F(BigMatrixCorrelation, Result) {} +} // end namespace distance +} // namespace cuvs diff --git a/cpp/test/distance/dist_cos.cu b/cpp/test/distance/dist_cos.cu new file mode 100644 index 000000000..e134f045f --- /dev/null +++ b/cpp/test/distance/dist_cos.cu @@ -0,0 +1,112 @@ +/* + * Copyright (c) 2018-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../test_utils.cuh" +#include "distance_base.cuh" + +namespace cuvs { +namespace distance { + +template +class DistanceExpCos : public DistanceTest { +}; + +template +class DistanceExpCosXequalY + : public DistanceTestSameBuffer {}; + +const std::vector> inputsf = { + {0.001f, 128, (65536 + 128) * 128, 8, true, 1234ULL}, + {0.001f, 1024, 1024, 32, true, 1234ULL}, + {0.001f, 1024, 32, 1024, true, 1234ULL}, + {0.001f, 32, 1024, 1024, true, 1234ULL}, + {0.003f, 1024, 1024, 1024, true, 1234ULL}, + {0.001f, (65536 + 128) * 128, 128, 8, false, 1234ULL}, + {0.001f, 1024, 1024, 32, false, 1234ULL}, + {0.001f, 1024, 32, 1024, false, 1234ULL}, + {0.001f, 32, 1024, 1024, false, 1234ULL}, + {0.003f, 1024, 1024, 1024, false, 1234ULL}, +}; + +const std::vector> inputsXeqYf = { + {0.01f, 1024, 1024, 32, true, 1234ULL}, + {0.01f, 1024, 32, 1024, true, 1234ULL}, + {0.01f, 32, 1024, 1024, true, 1234ULL}, + {0.03f, 1024, 1024, 1024, true, 1234ULL}, + {0.01f, 1024, 1024, 32, false, 1234ULL}, + {0.01f, 1024, 32, 1024, false, 1234ULL}, + {0.01f, 32, 1024, 1024, false, 1234ULL}, + {0.03f, 1024, 1024, 1024, false, 1234ULL}, +}; + +typedef DistanceExpCos DistanceExpCosF; +TEST_P(DistanceExpCosF, Result) +{ + int m = params.isRowMajor ? params.m : params.n; + int n = params.isRowMajor ? params.n : params.m; + ASSERT_TRUE(devArrMatch( + dist_ref.data(), dist.data(), m, n, cuvs::CompareApprox(params.tolerance), stream)); +} +INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceExpCosF, ::testing::ValuesIn(inputsf)); + +typedef DistanceExpCosXequalY DistanceExpCosXequalYF; +TEST_P(DistanceExpCosXequalYF, Result) +{ + int m = params.m; + int n = params.m; + ASSERT_TRUE(cuvs::devArrMatch(dist_ref[0].data(), + dist[0].data(), + m, + n, + cuvs::CompareApprox(params.tolerance), + stream)); + n = params.isRowMajor ? m : m / 2; + m = params.isRowMajor ? m / 2 : m; + + ASSERT_TRUE(cuvs::devArrMatch(dist_ref[1].data(), + dist[1].data(), + m, + n, + cuvs::CompareApprox(params.tolerance), + stream)); +} +INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceExpCosXequalYF, ::testing::ValuesIn(inputsXeqYf)); + +const std::vector> inputsd = { + {0.001, 1024, 1024, 32, true, 1234ULL}, + {0.001, 1024, 32, 1024, true, 1234ULL}, + {0.001, 32, 1024, 1024, true, 1234ULL}, + {0.003, 1024, 1024, 1024, true, 1234ULL}, + {0.001f, 1024, 1024, 32, false, 1234ULL}, + {0.001f, 1024, 32, 1024, false, 1234ULL}, + {0.001f, 32, 1024, 1024, false, 1234ULL}, + {0.003f, 1024, 1024, 1024, false, 1234ULL}, +}; +typedef DistanceExpCos DistanceExpCosD; +TEST_P(DistanceExpCosD, Result) +{ + int m = params.isRowMajor ? params.m : params.n; + int n = params.isRowMajor ? params.n : params.m; + ASSERT_TRUE(devArrMatch( + dist_ref.data(), dist.data(), m, n, cuvs::CompareApprox(params.tolerance), stream)); +} +INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceExpCosD, ::testing::ValuesIn(inputsd)); + +class BigMatrixCos : public BigMatrixDistanceTest {}; +TEST_F(BigMatrixCos, Result) {} + +} // end namespace distance +} // namespace cuvs diff --git a/cpp/test/distance/dist_hamming.cu b/cpp/test/distance/dist_hamming.cu new file mode 100644 index 000000000..0cf753eca --- /dev/null +++ b/cpp/test/distance/dist_hamming.cu @@ -0,0 +1,71 @@ +/* + * Copyright (c) 2018-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../test_utils.cuh" +#include "distance_base.cuh" + +namespace cuvs { +namespace distance { + +template +class DistanceHamming + : public DistanceTest {}; + +const std::vector> inputsf = { + {0.001f, 1024, 1024, 32, true, 1234ULL}, + {0.001f, 1024, 32, 1024, true, 1234ULL}, + {0.001f, 32, 1024, 1024, true, 1234ULL}, + {0.003f, 1024, 1024, 1024, true, 1234ULL}, + {0.001f, 1024, 1024, 32, false, 1234ULL}, + {0.001f, 1024, 32, 1024, false, 1234ULL}, + {0.001f, 32, 1024, 1024, false, 1234ULL}, + {0.003f, 1024, 1024, 1024, false, 1234ULL}, +}; +typedef DistanceHamming DistanceHammingF; +TEST_P(DistanceHammingF, Result) +{ + int m = params.isRowMajor ? params.m : params.n; + int n = params.isRowMajor ? params.n : params.m; + ASSERT_TRUE(cuvs::devArrMatch( + dist_ref.data(), dist.data(), m, n, cuvs::CompareApprox(params.tolerance), stream)); +} +INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceHammingF, ::testing::ValuesIn(inputsf)); + +const std::vector> inputsd = { + {0.001, 1024, 1024, 32, true, 1234ULL}, + {0.001, 1024, 32, 1024, true, 1234ULL}, + {0.001, 32, 1024, 1024, true, 1234ULL}, + {0.003, 1024, 1024, 1024, true, 1234ULL}, + {0.001, 1024, 1024, 32, false, 1234ULL}, + {0.001, 1024, 32, 1024, false, 1234ULL}, + {0.001, 32, 1024, 1024, false, 1234ULL}, + {0.003, 1024, 1024, 1024, false, 1234ULL}, +}; +typedef DistanceHamming DistanceHammingD; +TEST_P(DistanceHammingD, Result) +{ + int m = params.isRowMajor ? params.m : params.n; + int n = params.isRowMajor ? params.n : params.m; + ASSERT_TRUE(cuvs::devArrMatch( + dist_ref.data(), dist.data(), m, n, cuvs::CompareApprox(params.tolerance), stream)); +} +INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceHammingD, ::testing::ValuesIn(inputsd)); + +class BigMatrixHamming + : public BigMatrixDistanceTest {}; +TEST_F(BigMatrixHamming, Result) {} +} // end namespace distance +} // namespace cuvs diff --git a/cpp/test/distance/dist_hellinger.cu b/cpp/test/distance/dist_hellinger.cu new file mode 100644 index 000000000..3998a60ab --- /dev/null +++ b/cpp/test/distance/dist_hellinger.cu @@ -0,0 +1,71 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../test_utils.cuh" +#include "distance_base.cuh" + +namespace cuvs { +namespace distance { + +template +class DistanceHellingerExp + : public DistanceTest {}; + +const std::vector> inputsf = { + {0.001f, 1024, 1024, 32, true, 1234ULL}, + {0.001f, 1024, 32, 1024, true, 1234ULL}, + {0.001f, 32, 1024, 1024, true, 1234ULL}, + {0.003f, 1024, 1024, 1024, true, 1234ULL}, + {0.001f, 1024, 1024, 32, false, 1234ULL}, + {0.001f, 1024, 32, 1024, false, 1234ULL}, + {0.001f, 32, 1024, 1024, false, 1234ULL}, + {0.003f, 1024, 1024, 1024, false, 1234ULL}, +}; +typedef DistanceHellingerExp DistanceHellingerExpF; +TEST_P(DistanceHellingerExpF, Result) +{ + int m = params.isRowMajor ? params.m : params.n; + int n = params.isRowMajor ? params.n : params.m; + ASSERT_TRUE(cuvs::devArrMatch( + dist_ref.data(), dist.data(), m, n, cuvs::CompareApprox(params.tolerance), stream)); +} +INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceHellingerExpF, ::testing::ValuesIn(inputsf)); + +const std::vector> inputsd = { + {0.001, 1024, 1024, 32, true, 1234ULL}, + {0.001, 1024, 32, 1024, true, 1234ULL}, + {0.001, 32, 1024, 1024, true, 1234ULL}, + {0.003, 1024, 1024, 1024, true, 1234ULL}, + {0.001, 1024, 1024, 32, false, 1234ULL}, + {0.001, 1024, 32, 1024, false, 1234ULL}, + {0.001, 32, 1024, 1024, false, 1234ULL}, + {0.003, 1024, 1024, 1024, false, 1234ULL}, +}; +typedef DistanceHellingerExp DistanceHellingerExpD; +TEST_P(DistanceHellingerExpD, Result) +{ + int m = params.isRowMajor ? params.m : params.n; + int n = params.isRowMajor ? params.n : params.m; + ASSERT_TRUE(cuvs::devArrMatch( + dist_ref.data(), dist.data(), m, n, cuvs::CompareApprox(params.tolerance), stream)); +} +INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceHellingerExpD, ::testing::ValuesIn(inputsd)); + +class BigMatrixHellingerExp + : public BigMatrixDistanceTest {}; +TEST_F(BigMatrixHellingerExp, Result) {} +} // end namespace distance +} // namespace cuvs diff --git a/cpp/test/distance/dist_inner_product.cu b/cpp/test/distance/dist_inner_product.cu new file mode 100644 index 000000000..1d6709d52 --- /dev/null +++ b/cpp/test/distance/dist_inner_product.cu @@ -0,0 +1,74 @@ +/* + * Copyright (c) 2018-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../test_utils.cuh" +#include "distance_base.cuh" + +namespace cuvs { +namespace distance { + +template +class DistanceInnerProduct + : public DistanceTest {}; + +const std::vector> inputsf = { + {0.001f, 10, 5, 32, true, 1234ULL}, + {0.001f, 1024, 32, 1024, true, 1234ULL}, + {0.001f, 32, 1024, 1024, true, 1234ULL}, + {0.003f, 1024, 1024, 1024, true, 1234ULL}, + {0.001f, 1024, 1024, 32, false, 1234ULL}, + {0.001f, 1024, 32, 1024, false, 1234ULL}, + {0.001f, 32, 1024, 1024, false, 1234ULL}, + {0.003f, 1024, 1024, 1024, false, 1234ULL}, +}; +typedef DistanceInnerProduct DistanceInnerProductF; +TEST_P(DistanceInnerProductF, Result) +{ + int m = params.isRowMajor ? params.m : params.n; + int n = params.isRowMajor ? params.n : params.m; + + ASSERT_TRUE(devArrMatch( + dist_ref.data(), dist.data(), m, n, cuvs::CompareApprox(params.tolerance), stream)); +} +INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceInnerProductF, ::testing::ValuesIn(inputsf)); + +const std::vector> inputsd = { + {0.001, 1024, 1024, 32, true, 1234ULL}, + {0.001, 1024, 32, 1024, true, 1234ULL}, + {0.001, 32, 1024, 1024, true, 1234ULL}, + {0.003, 1024, 1024, 1024, true, 1234ULL}, + {0.001f, 1024, 1024, 32, false, 1234ULL}, + {0.001f, 1024, 32, 1024, false, 1234ULL}, + {0.001f, 32, 1024, 1024, false, 1234ULL}, + {0.003f, 1024, 1024, 1024, false, 1234ULL}, +}; +typedef DistanceInnerProduct DistanceInnerProductD; +TEST_P(DistanceInnerProductD, Result) +{ + int m = params.isRowMajor ? params.m : params.n; + int n = params.isRowMajor ? params.n : params.m; + + ASSERT_TRUE(devArrMatch( + dist_ref.data(), dist.data(), m, n, cuvs::CompareApprox(params.tolerance), stream)); +} +INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceInnerProductD, ::testing::ValuesIn(inputsd)); + +class BigMatrixInnerProduct + : public BigMatrixDistanceTest {}; +TEST_F(BigMatrixInnerProduct, Result) {} + +} // end namespace distance +} // namespace cuvs diff --git a/cpp/test/distance/dist_jensen_shannon.cu b/cpp/test/distance/dist_jensen_shannon.cu new file mode 100644 index 000000000..43b7b361d --- /dev/null +++ b/cpp/test/distance/dist_jensen_shannon.cu @@ -0,0 +1,71 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../test_utils.cuh" +#include "distance_base.cuh" + +namespace cuvs { +namespace distance { + +template +class DistanceJensenShannon + : public DistanceTest {}; + +const std::vector> inputsf = { + {0.001f, 1024, 1024, 32, true, 1234ULL}, + {0.001f, 1024, 32, 1024, true, 1234ULL}, + {0.001f, 32, 1024, 1024, true, 1234ULL}, + {0.003f, 1024, 1024, 1024, true, 1234ULL}, + {0.001f, 1024, 1024, 32, false, 1234ULL}, + {0.001f, 1024, 32, 1024, false, 1234ULL}, + {0.001f, 32, 1024, 1024, false, 1234ULL}, + {0.003f, 1024, 1024, 1024, false, 1234ULL}, +}; +typedef DistanceJensenShannon DistanceJensenShannonF; +TEST_P(DistanceJensenShannonF, Result) +{ + int m = params.isRowMajor ? params.m : params.n; + int n = params.isRowMajor ? params.n : params.m; + ASSERT_TRUE(cuvs::devArrMatch( + dist_ref.data(), dist.data(), m, n, cuvs::CompareApprox(params.tolerance), stream)); +} +INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceJensenShannonF, ::testing::ValuesIn(inputsf)); + +const std::vector> inputsd = { + {0.001, 1024, 1024, 32, true, 1234ULL}, + {0.001, 1024, 32, 1024, true, 1234ULL}, + {0.001, 32, 1024, 1024, true, 1234ULL}, + {0.003, 1024, 1024, 1024, true, 1234ULL}, + {0.001, 1024, 1024, 32, false, 1234ULL}, + {0.001, 1024, 32, 1024, false, 1234ULL}, + {0.001, 32, 1024, 1024, false, 1234ULL}, + {0.003, 1024, 1024, 1024, false, 1234ULL}, +}; +typedef DistanceJensenShannon DistanceJensenShannonD; +TEST_P(DistanceJensenShannonD, Result) +{ + int m = params.isRowMajor ? params.m : params.n; + int n = params.isRowMajor ? params.n : params.m; + ASSERT_TRUE(cuvs::devArrMatch( + dist_ref.data(), dist.data(), m, n, cuvs::CompareApprox(params.tolerance), stream)); +} +INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceJensenShannonD, ::testing::ValuesIn(inputsd)); + +class BigMatrixJensenShannon + : public BigMatrixDistanceTest {}; +TEST_F(BigMatrixJensenShannon, Result) {} +} // end namespace distance +} // namespace cuvs diff --git a/cpp/test/distance/dist_kl_divergence.cu b/cpp/test/distance/dist_kl_divergence.cu new file mode 100644 index 000000000..5e5692841 --- /dev/null +++ b/cpp/test/distance/dist_kl_divergence.cu @@ -0,0 +1,71 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../test_utils.cuh" +#include "distance_base.cuh" + +namespace cuvs { +namespace distance { + +template +class DistanceKLDivergence + : public DistanceTest {}; + +const std::vector> inputsf = { + {0.001f, 1024, 1024, 32, true, 1234ULL}, + {0.001f, 1024, 32, 1024, true, 1234ULL}, + {0.001f, 32, 1024, 1024, true, 1234ULL}, + {0.003f, 1024, 1024, 1024, true, 1234ULL}, + {0.001f, 1024, 1024, 32, false, 1234ULL}, + {0.001f, 1024, 32, 1024, false, 1234ULL}, + {0.001f, 32, 1024, 1024, false, 1234ULL}, + {0.003f, 1024, 1024, 1024, false, 1234ULL}, +}; +typedef DistanceKLDivergence DistanceKLDivergenceF; +TEST_P(DistanceKLDivergenceF, Result) +{ + int m = params.isRowMajor ? params.m : params.n; + int n = params.isRowMajor ? params.n : params.m; + ASSERT_TRUE(cuvs::devArrMatch( + dist_ref.data(), dist.data(), m, n, cuvs::CompareApprox(params.tolerance), stream)); +} +INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceKLDivergenceF, ::testing::ValuesIn(inputsf)); + +const std::vector> inputsd = { + {0.001, 1024, 1024, 32, true, 1234ULL}, + {0.001, 1024, 32, 1024, true, 1234ULL}, + {0.001, 32, 1024, 1024, true, 1234ULL}, + {0.003, 1024, 1024, 1024, true, 1234ULL}, + {0.001, 1024, 1024, 32, false, 1234ULL}, + {0.001, 1024, 32, 1024, false, 1234ULL}, + {0.001, 32, 1024, 1024, false, 1234ULL}, + {0.003, 1024, 1024, 1024, false, 1234ULL}, +}; +typedef DistanceKLDivergence DistanceKLDivergenceD; +TEST_P(DistanceKLDivergenceD, Result) +{ + int m = params.isRowMajor ? params.m : params.n; + int n = params.isRowMajor ? params.n : params.m; + ASSERT_TRUE(cuvs::devArrMatch( + dist_ref.data(), dist.data(), m, n, cuvs::CompareApprox(params.tolerance), stream)); +} +INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceKLDivergenceD, ::testing::ValuesIn(inputsd)); + +class BigMatrixKLDivergence + : public BigMatrixDistanceTest {}; +TEST_F(BigMatrixKLDivergence, Result) {} +} // end namespace distance +} // namespace cuvs diff --git a/cpp/test/distance/dist_l1.cu b/cpp/test/distance/dist_l1.cu new file mode 100644 index 000000000..a3ecd21fe --- /dev/null +++ b/cpp/test/distance/dist_l1.cu @@ -0,0 +1,70 @@ +/* + * Copyright (c) 2018-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../test_utils.cuh" +#include "distance_base.cuh" + +namespace cuvs { +namespace distance { + +template +class DistanceUnexpL1 : public DistanceTest {}; + +const std::vector> inputsf = { + {0.001f, 1024, 1024, 32, true, 1234ULL}, + {0.001f, 1024, 32, 1024, true, 1234ULL}, + {0.001f, 32, 1024, 1024, true, 1234ULL}, + {0.003f, 1024, 1024, 1024, true, 1234ULL}, + {0.001f, 1024, 1024, 32, false, 1234ULL}, + {0.001f, 1024, 32, 1024, false, 1234ULL}, + {0.001f, 32, 1024, 1024, false, 1234ULL}, + {0.003f, 1024, 1024, 1024, false, 1234ULL}, +}; +typedef DistanceUnexpL1 DistanceUnexpL1F; +TEST_P(DistanceUnexpL1F, Result) +{ + int m = params.isRowMajor ? params.m : params.n; + int n = params.isRowMajor ? params.n : params.m; + ASSERT_TRUE(cuvs::devArrMatch( + dist_ref.data(), dist.data(), m, n, cuvs::CompareApprox(params.tolerance), stream)); +} +INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceUnexpL1F, ::testing::ValuesIn(inputsf)); + +const std::vector> inputsd = { + {0.001, 1024, 1024, 32, true, 1234ULL}, + {0.001, 1024, 32, 1024, true, 1234ULL}, + {0.001, 32, 1024, 1024, true, 1234ULL}, + {0.003, 1024, 1024, 1024, true, 1234ULL}, + {0.001, 1024, 1024, 32, false, 1234ULL}, + {0.001, 1024, 32, 1024, false, 1234ULL}, + {0.001, 32, 1024, 1024, false, 1234ULL}, + {0.003, 1024, 1024, 1024, false, 1234ULL}, +}; +typedef DistanceUnexpL1 DistanceUnexpL1D; +TEST_P(DistanceUnexpL1D, Result) +{ + int m = params.isRowMajor ? params.m : params.n; + int n = params.isRowMajor ? params.n : params.m; + ASSERT_TRUE(cuvs::devArrMatch( + dist_ref.data(), dist.data(), m, n, cuvs::CompareApprox(params.tolerance), stream)); +} +INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceUnexpL1D, ::testing::ValuesIn(inputsd)); + +class BigMatrixUnexpL1 : public BigMatrixDistanceTest {}; +TEST_F(BigMatrixUnexpL1, Result) {} + +} // end namespace distance +} // namespace cuvs diff --git a/cpp/test/distance/dist_l2_exp.cu b/cpp/test/distance/dist_l2_exp.cu new file mode 100644 index 000000000..f3d038cbc --- /dev/null +++ b/cpp/test/distance/dist_l2_exp.cu @@ -0,0 +1,115 @@ +/* + * Copyright (c) 2018-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../test_utils.cuh" +#include "distance_base.cuh" + +namespace cuvs { +namespace distance { + +template +class DistanceEucExpTest : public DistanceTest { +}; + +template +class DistanceEucExpTestXequalY + : public DistanceTestSameBuffer {}; + +const std::vector> inputsf = { + {0.001f, 128, (65536 + 128) * 128, 8, true, 1234ULL}, + {0.001f, 2048, 4096, 128, true, 1234ULL}, + {0.001f, 1024, 1024, 32, true, 1234ULL}, + {0.001f, 1024, 32, 1024, true, 1234ULL}, + {0.001f, 32, 1024, 1024, true, 1234ULL}, + {0.003f, 1024, 1024, 1024, true, 1234ULL}, + {0.003f, 1021, 1021, 1021, true, 1234ULL}, + {0.001f, (65536 + 128) * 128, 128, 8, false, 1234ULL}, + {0.001f, 1024, 1024, 32, false, 1234ULL}, + {0.001f, 1024, 32, 1024, false, 1234ULL}, + {0.001f, 32, 1024, 1024, false, 1234ULL}, + {0.003f, 1024, 1024, 1024, false, 1234ULL}, + {0.003f, 1021, 1021, 1021, false, 1234ULL}, +}; + +const std::vector> inputsXeqYf = { + {0.01f, 2048, 4096, 128, true, 1234ULL}, + {0.01f, 1024, 1024, 32, true, 1234ULL}, + {0.01f, 1024, 32, 1024, true, 1234ULL}, + {0.01f, 32, 1024, 1024, true, 1234ULL}, + {0.03f, 1024, 1024, 1024, true, 1234ULL}, + {0.03f, 1021, 1021, 1021, true, 1234ULL}, + {0.01f, 1024, 1024, 32, false, 1234ULL}, + {0.01f, 1024, 32, 1024, false, 1234ULL}, + {0.01f, 32, 1024, 1024, false, 1234ULL}, + {0.03f, 1024, 1024, 1024, false, 1234ULL}, + {0.03f, 1021, 1021, 1021, false, 1234ULL}, +}; + +typedef DistanceEucExpTest DistanceEucExpTestF; +TEST_P(DistanceEucExpTestF, Result) +{ + int m = params.isRowMajor ? params.m : params.n; + int n = params.isRowMajor ? params.n : params.m; + ASSERT_TRUE(devArrMatch( + dist_ref.data(), dist.data(), m, n, cuvs::CompareApprox(params.tolerance), stream)); +} +INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceEucExpTestF, ::testing::ValuesIn(inputsf)); + +typedef DistanceEucExpTestXequalY DistanceEucExpTestXequalYF; +TEST_P(DistanceEucExpTestXequalYF, Result) +{ + int m = params.m; + ASSERT_TRUE(cuvs::devArrMatch(dist_ref[0].data(), + dist[0].data(), + m, + m, + cuvs::CompareApprox(params.tolerance), + stream)); + ASSERT_TRUE(cuvs::devArrMatch(dist_ref[1].data(), + dist[1].data(), + m / 2, + m, + cuvs::CompareApprox(params.tolerance), + stream)); +} +INSTANTIATE_TEST_CASE_P(DistanceTests, + DistanceEucExpTestXequalYF, + ::testing::ValuesIn(inputsXeqYf)); + +const std::vector> inputsd = { + {0.001, 1024, 1024, 32, true, 1234ULL}, + {0.001, 1024, 32, 1024, true, 1234ULL}, + {0.001, 32, 1024, 1024, true, 1234ULL}, + {0.003, 1024, 1024, 1024, true, 1234ULL}, + {0.001, 1024, 1024, 32, false, 1234ULL}, + {0.001, 1024, 32, 1024, false, 1234ULL}, + {0.001, 32, 1024, 1024, false, 1234ULL}, + {0.003, 1024, 1024, 1024, false, 1234ULL}, +}; +typedef DistanceEucExpTest DistanceEucExpTestD; +TEST_P(DistanceEucExpTestD, Result) +{ + int m = params.isRowMajor ? params.m : params.n; + int n = params.isRowMajor ? params.n : params.m; + ASSERT_TRUE(devArrMatch( + dist_ref.data(), dist.data(), m, n, cuvs::CompareApprox(params.tolerance), stream)); +} +INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceEucExpTestD, ::testing::ValuesIn(inputsd)); + +class BigMatrixEucExp : public BigMatrixDistanceTest {}; +TEST_F(BigMatrixEucExp, Result) {} +} // end namespace distance +} // namespace cuvs diff --git a/cpp/test/distance/dist_l2_sqrt_exp.cu b/cpp/test/distance/dist_l2_sqrt_exp.cu new file mode 100644 index 000000000..b24384be8 --- /dev/null +++ b/cpp/test/distance/dist_l2_sqrt_exp.cu @@ -0,0 +1,74 @@ +/* + * Copyright (c) 2018-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../test_utils.cuh" +#include "distance_base.cuh" + +namespace cuvs { +namespace distance { + +template +class DistanceEucSqrtExpTest + : public DistanceTest {}; + +const std::vector> inputsf = { + {0.001f, 2048, 4096, 128, true, 1234ULL}, + {0.001f, 1024, 1024, 32, true, 1234ULL}, + {0.001f, 1024, 32, 1024, true, 1234ULL}, + {0.001f, 32, 1024, 1024, true, 1234ULL}, + {0.003f, 1024, 1024, 1024, true, 1234ULL}, + {0.003f, 1021, 1021, 1021, true, 1234ULL}, + {0.001f, 1024, 1024, 32, false, 1234ULL}, + {0.001f, 1024, 32, 1024, false, 1234ULL}, + {0.001f, 32, 1024, 1024, false, 1234ULL}, + {0.003f, 1024, 1024, 1024, false, 1234ULL}, + {0.003f, 1021, 1021, 1021, false, 1234ULL}, +}; +typedef DistanceEucSqrtExpTest DistanceEucSqrtExpTestF; +TEST_P(DistanceEucSqrtExpTestF, Result) +{ + int m = params.isRowMajor ? params.m : params.n; + int n = params.isRowMajor ? params.n : params.m; + ASSERT_TRUE(devArrMatch( + dist_ref.data(), dist.data(), m, n, cuvs::CompareApprox(params.tolerance), stream)); +} +INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceEucSqrtExpTestF, ::testing::ValuesIn(inputsf)); + +const std::vector> inputsd = { + {0.001, 1024, 1024, 32, true, 1234ULL}, + {0.001, 1024, 32, 1024, true, 1234ULL}, + {0.001, 32, 1024, 1024, true, 1234ULL}, + {0.003, 1024, 1024, 1024, true, 1234ULL}, + {0.001, 1024, 1024, 32, false, 1234ULL}, + {0.001, 1024, 32, 1024, false, 1234ULL}, + {0.001, 32, 1024, 1024, false, 1234ULL}, + {0.003, 1024, 1024, 1024, false, 1234ULL}, +}; +typedef DistanceEucSqrtExpTest DistanceEucSqrtExpTestD; +TEST_P(DistanceEucSqrtExpTestD, Result) +{ + int m = params.isRowMajor ? params.m : params.n; + int n = params.isRowMajor ? params.n : params.m; + ASSERT_TRUE(devArrMatch( + dist_ref.data(), dist.data(), m, n, cuvs::CompareApprox(params.tolerance), stream)); +} +INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceEucSqrtExpTestD, ::testing::ValuesIn(inputsd)); + +class BigMatrixEucSqrtExp + : public BigMatrixDistanceTest {}; +TEST_F(BigMatrixEucSqrtExp, Result) {} +} // end namespace distance +} // namespace cuvs diff --git a/cpp/test/distance/dist_l2_unexp.cu b/cpp/test/distance/dist_l2_unexp.cu new file mode 100644 index 000000000..c057434fa --- /dev/null +++ b/cpp/test/distance/dist_l2_unexp.cu @@ -0,0 +1,71 @@ +/* + * Copyright (c) 2018-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../test_utils.cuh" +#include "distance_base.cuh" + +namespace cuvs { +namespace distance { + +template +class DistanceEucUnexpTest + : public DistanceTest {}; + +const std::vector> inputsf = { + {0.001f, 1024, 1024, 32, true, 1234ULL}, + {0.001f, 1024, 32, 1024, true, 1234ULL}, + {0.001f, 32, 1024, 1024, true, 1234ULL}, + {0.003f, 1024, 1024, 1024, true, 1234ULL}, + {0.001f, 1024, 1024, 32, false, 1234ULL}, + {0.001f, 1024, 32, 1024, false, 1234ULL}, + {0.001f, 32, 1024, 1024, false, 1234ULL}, + {0.003f, 1024, 1024, 1024, false, 1234ULL}, +}; +typedef DistanceEucUnexpTest DistanceEucUnexpTestF; +TEST_P(DistanceEucUnexpTestF, Result) +{ + int m = params.isRowMajor ? params.m : params.n; + int n = params.isRowMajor ? params.n : params.m; + ASSERT_TRUE(devArrMatch( + dist_ref.data(), dist.data(), m, n, cuvs::CompareApprox(params.tolerance), stream)); +} +INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceEucUnexpTestF, ::testing::ValuesIn(inputsf)); + +const std::vector> inputsd = { + {0.001, 1024, 1024, 32, true, 1234ULL}, + {0.001, 1024, 32, 1024, true, 1234ULL}, + {0.001, 32, 1024, 1024, true, 1234ULL}, + {0.003, 1024, 1024, 1024, true, 1234ULL}, + {0.001, 1024, 1024, 32, false, 1234ULL}, + {0.001, 1024, 32, 1024, false, 1234ULL}, + {0.001, 32, 1024, 1024, false, 1234ULL}, + {0.003, 1024, 1024, 1024, false, 1234ULL}, +}; +typedef DistanceEucUnexpTest DistanceEucUnexpTestD; +TEST_P(DistanceEucUnexpTestD, Result) +{ + int m = params.isRowMajor ? params.m : params.n; + int n = params.isRowMajor ? params.n : params.m; + ASSERT_TRUE(devArrMatch( + dist_ref.data(), dist.data(), m, n, cuvs::CompareApprox(params.tolerance), stream)); +} +INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceEucUnexpTestD, ::testing::ValuesIn(inputsd)); + +class BigMatrixEucUnexp : public BigMatrixDistanceTest { +}; +TEST_F(BigMatrixEucUnexp, Result) {} +} // end namespace distance +} // namespace cuvs diff --git a/cpp/test/distance/dist_l_inf.cu b/cpp/test/distance/dist_l_inf.cu new file mode 100644 index 000000000..b9ced68f3 --- /dev/null +++ b/cpp/test/distance/dist_l_inf.cu @@ -0,0 +1,70 @@ +/* + * Copyright (c) 2018-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../test_utils.cuh" +#include "distance_base.cuh" + +namespace cuvs { +namespace distance { + +template +class DistanceLinf : public DistanceTest {}; + +const std::vector> inputsf = { + {0.001f, 1024, 1024, 32, true, 1234ULL}, + {0.001f, 1024, 32, 1024, true, 1234ULL}, + {0.001f, 32, 1024, 1024, true, 1234ULL}, + {0.003f, 1024, 1024, 1024, true, 1234ULL}, + {0.001f, 1024, 1024, 32, false, 1234ULL}, + {0.001f, 1024, 32, 1024, false, 1234ULL}, + {0.001f, 32, 1024, 1024, false, 1234ULL}, + {0.003f, 1024, 1024, 1024, false, 1234ULL}, +}; +typedef DistanceLinf DistanceLinfF; +TEST_P(DistanceLinfF, Result) +{ + int m = params.isRowMajor ? params.m : params.n; + int n = params.isRowMajor ? params.n : params.m; + ASSERT_TRUE(cuvs::devArrMatch( + dist_ref.data(), dist.data(), m, n, cuvs::CompareApprox(params.tolerance), stream)); +} +INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceLinfF, ::testing::ValuesIn(inputsf)); + +const std::vector> inputsd = { + {0.001, 1024, 1024, 32, true, 1234ULL}, + {0.001, 1024, 32, 1024, true, 1234ULL}, + {0.001, 32, 1024, 1024, true, 1234ULL}, + {0.003, 1024, 1024, 1024, true, 1234ULL}, + {0.001, 1024, 1024, 32, false, 1234ULL}, + {0.001, 1024, 32, 1024, false, 1234ULL}, + {0.001, 32, 1024, 1024, false, 1234ULL}, + {0.003, 1024, 1024, 1024, false, 1234ULL}, +}; +typedef DistanceLinf DistanceLinfD; +TEST_P(DistanceLinfD, Result) +{ + int m = params.isRowMajor ? params.m : params.n; + int n = params.isRowMajor ? params.n : params.m; + ASSERT_TRUE(cuvs::devArrMatch( + dist_ref.data(), dist.data(), m, n, cuvs::CompareApprox(params.tolerance), stream)); +} +INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceLinfD, ::testing::ValuesIn(inputsd)); + +class BigMatrixLinf : public BigMatrixDistanceTest {}; +TEST_F(BigMatrixLinf, Result) {} + +} // end namespace distance +} // namespace cuvs diff --git a/cpp/test/distance/dist_lp_unexp.cu b/cpp/test/distance/dist_lp_unexp.cu new file mode 100644 index 000000000..26620b44b --- /dev/null +++ b/cpp/test/distance/dist_lp_unexp.cu @@ -0,0 +1,71 @@ +/* + * Copyright (c) 2018-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../test_utils.cuh" +#include "distance_base.cuh" + +namespace cuvs { +namespace distance { + +template +class DistanceLpUnexp : public DistanceTest { +}; + +const std::vector> inputsf = { + {0.001f, 1024, 1024, 32, true, 1234ULL, 4.0f}, + {0.001f, 1024, 32, 1024, true, 1234ULL, 3.0f}, + {0.001f, 32, 1024, 1024, true, 1234ULL, 4.0f}, + {0.003f, 1024, 1024, 1024, true, 1234ULL, 3.0f}, + {0.001f, 1024, 1024, 32, false, 1234ULL, 4.0f}, + {0.001f, 1024, 32, 1024, false, 1234ULL, 3.0f}, + {0.001f, 32, 1024, 1024, false, 1234ULL, 4.0f}, + {0.003f, 1024, 1024, 1024, false, 1234ULL, 3.0f}, +}; +typedef DistanceLpUnexp DistanceLpUnexpF; +TEST_P(DistanceLpUnexpF, Result) +{ + int m = params.isRowMajor ? params.m : params.n; + int n = params.isRowMajor ? params.n : params.m; + ASSERT_TRUE(cuvs::devArrMatch( + dist_ref.data(), dist.data(), m, n, cuvs::CompareApprox(params.tolerance), stream)); +} +INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceLpUnexpF, ::testing::ValuesIn(inputsf)); + +const std::vector> inputsd = { + {0.001, 1024, 1024, 32, true, 1234ULL, 4.0}, + {0.001, 1024, 32, 1024, true, 1234ULL, 3.0}, + {0.001, 32, 1024, 1024, true, 1234ULL, 4.0}, + {0.003, 1024, 1024, 1024, true, 1234ULL, 3.0}, + {0.001, 1024, 1024, 32, false, 1234ULL, 4.0}, + {0.001, 1024, 32, 1024, false, 1234ULL, 3.0}, + {0.001, 32, 1024, 1024, false, 1234ULL, 4.0}, + {0.003, 1024, 1024, 1024, false, 1234ULL, 3.0}, +}; +typedef DistanceLpUnexp DistanceLpUnexpD; +TEST_P(DistanceLpUnexpD, Result) +{ + int m = params.isRowMajor ? params.m : params.n; + int n = params.isRowMajor ? params.n : params.m; + ASSERT_TRUE(cuvs::devArrMatch( + dist_ref.data(), dist.data(), m, n, cuvs::CompareApprox(params.tolerance), stream)); +} +INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceLpUnexpD, ::testing::ValuesIn(inputsd)); + +class BigMatrixLpUnexp : public BigMatrixDistanceTest { +}; +TEST_F(BigMatrixLpUnexp, Result) {} +} // end namespace distance +} // namespace cuvs diff --git a/cpp/test/distance/dist_russell_rao.cu b/cpp/test/distance/dist_russell_rao.cu new file mode 100644 index 000000000..46da7f9cd --- /dev/null +++ b/cpp/test/distance/dist_russell_rao.cu @@ -0,0 +1,71 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../test_utils.cuh" +#include "distance_base.cuh" + +namespace cuvs { +namespace distance { + +template +class DistanceRussellRao + : public DistanceTest {}; + +const std::vector> inputsf = { + {0.001f, 1024, 1024, 32, true, 1234ULL}, + {0.001f, 1024, 32, 1024, true, 1234ULL}, + {0.001f, 32, 1024, 1024, true, 1234ULL}, + {0.003f, 1024, 1024, 1024, true, 1234ULL}, + {0.001f, 1024, 1024, 32, false, 1234ULL}, + {0.001f, 1024, 32, 1024, false, 1234ULL}, + {0.001f, 32, 1024, 1024, false, 1234ULL}, + {0.003f, 1024, 1024, 1024, false, 1234ULL}, +}; +typedef DistanceRussellRao DistanceRussellRaoF; +TEST_P(DistanceRussellRaoF, Result) +{ + int m = params.isRowMajor ? params.m : params.n; + int n = params.isRowMajor ? params.n : params.m; + ASSERT_TRUE(cuvs::devArrMatch( + dist_ref.data(), dist.data(), m, n, cuvs::CompareApprox(params.tolerance), stream)); +} +INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceRussellRaoF, ::testing::ValuesIn(inputsf)); + +const std::vector> inputsd = { + {0.001, 1024, 1024, 32, true, 1234ULL}, + {0.001, 1024, 32, 1024, true, 1234ULL}, + {0.001, 32, 1024, 1024, true, 1234ULL}, + {0.003, 1024, 1024, 1024, true, 1234ULL}, + {0.001, 1024, 1024, 32, false, 1234ULL}, + {0.001, 1024, 32, 1024, false, 1234ULL}, + {0.001, 32, 1024, 1024, false, 1234ULL}, + {0.003, 1024, 1024, 1024, false, 1234ULL}, +}; +typedef DistanceRussellRao DistanceRussellRaoD; +TEST_P(DistanceRussellRaoD, Result) +{ + int m = params.isRowMajor ? params.m : params.n; + int n = params.isRowMajor ? params.n : params.m; + ASSERT_TRUE(cuvs::devArrMatch( + dist_ref.data(), dist.data(), m, n, cuvs::CompareApprox(params.tolerance), stream)); +} +INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceRussellRaoD, ::testing::ValuesIn(inputsd)); + +class BigMatrixRussellRao + : public BigMatrixDistanceTest {}; +TEST_F(BigMatrixRussellRao, Result) {} +} // end namespace distance +} // namespace cuvs diff --git a/cpp/test/distance/distance_base.cuh b/cpp/test/distance/distance_base.cuh new file mode 100644 index 000000000..4a8191ae6 --- /dev/null +++ b/cpp/test/distance/distance_base.cuh @@ -0,0 +1,710 @@ +/* + * Copyright (c) 2018-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../test_utils.cuh" + +#include // cuvs::distance::DistanceType +#include +#include // raft::common::nvtx::range +#include //raft::make_device_matrix_view +#include // raft::sqrt +#include +#include // raft::resources +#include + +#include // rmm::device_uvector + +#include + +namespace cuvs { +namespace distance { + +template +RAFT_KERNEL naiveDistanceKernel(DataType* dist, + const DataType* x, + const DataType* y, + std::int64_t m, + std::int64_t n, + std::int64_t k, + cuvs::distance::DistanceType type, + bool isRowMajor) +{ + std::int64_t midx = threadIdx.x + blockIdx.x * blockDim.x; + std::int64_t nidx = threadIdx.y + blockIdx.y * blockDim.y; + if (midx >= m || nidx >= n) return; + DataType acc = DataType(0); + for (std::int64_t i = 0; i < k; ++i) { + std::int64_t xidx = isRowMajor ? i + midx * k : i * m + midx; + std::int64_t yidx = isRowMajor ? i + nidx * k : i * n + nidx; + auto diff = x[xidx] - y[yidx]; + acc += diff * diff; + } + if (type == cuvs::distance::DistanceType::L2SqrtExpanded || + type == cuvs::distance::DistanceType::L2SqrtUnexpanded) + acc = raft::sqrt(acc); + std::int64_t outidx = isRowMajor ? midx * n + nidx : midx + m * nidx; + dist[outidx] = acc; +} + +template +RAFT_KERNEL naiveL1_Linf_CanberraDistanceKernel(DataType* dist, + const DataType* x, + const DataType* y, + std::int64_t m, + std::int64_t n, + std::int64_t k, + cuvs::distance::DistanceType type, + bool isRowMajor) +{ + std::int64_t midx = threadIdx.x + blockIdx.x * blockDim.x; + std::int64_t nidx = threadIdx.y + blockIdx.y * blockDim.y; + if (midx >= m || nidx >= n) { return; } + + DataType acc = DataType(0); + for (std::int64_t i = 0; i < k; ++i) { + std::int64_t xidx = isRowMajor ? i + midx * k : i * m + midx; + std::int64_t yidx = isRowMajor ? i + nidx * k : i * n + nidx; + auto a = x[xidx]; + auto b = y[yidx]; + auto diff = (a > b) ? (a - b) : (b - a); + if (type == cuvs::distance::DistanceType::Linf) { + acc = raft::max(acc, diff); + } else if (type == cuvs::distance::DistanceType::Canberra) { + const auto add = raft::abs(a) + raft::abs(b); + // deal with potential for 0 in denominator by + // forcing 1/0 instead + acc += ((add != 0) * diff / (add + (add == 0))); + } else { + acc += diff; + } + } + + std::int64_t outidx = isRowMajor ? midx * n + nidx : midx + m * nidx; + dist[outidx] = acc; +} + +template +RAFT_KERNEL naiveCosineDistanceKernel(DataType* dist, + const DataType* x, + const DataType* y, + std::int64_t m, + std::int64_t n, + std::int64_t k, + bool isRowMajor) +{ + std::int64_t midx = threadIdx.x + blockIdx.x * blockDim.x; + std::int64_t nidx = threadIdx.y + blockIdx.y * blockDim.y; + if (midx >= m || nidx >= n) { return; } + + DataType acc_a = DataType(0); + DataType acc_b = DataType(0); + DataType acc_ab = DataType(0); + + for (std::int64_t i = 0; i < k; ++i) { + std::int64_t xidx = isRowMajor ? i + midx * k : i * m + midx; + std::int64_t yidx = isRowMajor ? i + nidx * k : i * n + nidx; + auto a = x[xidx]; + auto b = y[yidx]; + acc_a += a * a; + acc_b += b * b; + acc_ab += a * b; + } + + std::int64_t outidx = isRowMajor ? midx * n + nidx : midx + m * nidx; + + // Use 1.0 - (cosine similarity) to calc the distance + dist[outidx] = (DataType)1.0 - acc_ab / (raft::sqrt(acc_a) * raft::sqrt(acc_b)); +} + +template +RAFT_KERNEL naiveInnerProductKernel(DataType* dist, + const DataType* x, + const DataType* y, + std::int64_t m, + std::int64_t n, + std::int64_t k, + bool isRowMajor) +{ + std::int64_t midx = threadIdx.x + blockIdx.x * blockDim.x; + std::int64_t nidx = threadIdx.y + blockIdx.y * blockDim.y; + if (midx >= m || nidx >= n) { return; } + + DataType acc_ab = DataType(0); + + for (std::int64_t i = 0; i < k; ++i) { + std::int64_t xidx = isRowMajor ? i + midx * k : i * m + midx; + std::int64_t yidx = isRowMajor ? i + nidx * k : i * n + nidx; + auto a = x[xidx]; + auto b = y[yidx]; + acc_ab += a * b; + } + + std::int64_t outidx = isRowMajor ? midx * n + nidx : midx + m * nidx; + dist[outidx] = acc_ab; +} + +template +RAFT_KERNEL naiveHellingerDistanceKernel(DataType* dist, + const DataType* x, + const DataType* y, + std::int64_t m, + std::int64_t n, + std::int64_t k, + bool isRowMajor) +{ + std::int64_t midx = threadIdx.x + blockIdx.x * blockDim.x; + std::int64_t nidx = threadIdx.y + blockIdx.y * blockDim.y; + if (midx >= m || nidx >= n) { return; } + + DataType acc_ab = DataType(0); + + for (std::int64_t i = 0; i < k; ++i) { + std::int64_t xidx = isRowMajor ? i + midx * k : i * m + midx; + std::int64_t yidx = isRowMajor ? i + nidx * k : i * n + nidx; + auto a = x[xidx]; + auto b = y[yidx]; + acc_ab += raft::sqrt(a) * raft::sqrt(b); + } + + std::int64_t outidx = isRowMajor ? midx * n + nidx : midx + m * nidx; + + // Adjust to replace NaN in sqrt with 0 if input to sqrt is negative + acc_ab = 1 - acc_ab; + auto rectifier = (!signbit(acc_ab)); + dist[outidx] = raft::sqrt(rectifier * acc_ab); +} + +template +RAFT_KERNEL naiveLpUnexpDistanceKernel(DataType* dist, + const DataType* x, + const DataType* y, + std::int64_t m, + std::int64_t n, + std::int64_t k, + bool isRowMajor, + DataType p) +{ + std::int64_t midx = threadIdx.x + blockIdx.x * blockDim.x; + std::int64_t nidx = threadIdx.y + blockIdx.y * blockDim.y; + if (midx >= m || nidx >= n) return; + DataType acc = DataType(0); + for (std::int64_t i = 0; i < k; ++i) { + std::int64_t xidx = isRowMajor ? i + midx * k : i * m + midx; + std::int64_t yidx = isRowMajor ? i + nidx * k : i * n + nidx; + auto a = x[xidx]; + auto b = y[yidx]; + auto diff = raft::abs(a - b); + acc += raft::pow(diff, p); + } + auto one_over_p = 1 / p; + acc = raft::pow(acc, one_over_p); + std::int64_t outidx = isRowMajor ? midx * n + nidx : midx + m * nidx; + dist[outidx] = acc; +} + +template +RAFT_KERNEL naiveHammingDistanceKernel(DataType* dist, + const DataType* x, + const DataType* y, + std::int64_t m, + std::int64_t n, + std::int64_t k, + bool isRowMajor) +{ + std::int64_t midx = threadIdx.x + blockIdx.x * blockDim.x; + std::int64_t nidx = threadIdx.y + blockIdx.y * blockDim.y; + if (midx >= m || nidx >= n) return; + DataType acc = DataType(0); + for (std::int64_t i = 0; i < k; ++i) { + std::int64_t xidx = isRowMajor ? i + midx * k : i * m + midx; + std::int64_t yidx = isRowMajor ? i + nidx * k : i * n + nidx; + auto a = x[xidx]; + auto b = y[yidx]; + acc += (a != b); + } + acc = acc / k; + std::int64_t outidx = isRowMajor ? midx * n + nidx : midx + m * nidx; + dist[outidx] = acc; +} + +template +RAFT_KERNEL naiveJensenShannonDistanceKernel(DataType* dist, + const DataType* x, + const DataType* y, + std::int64_t m, + std::int64_t n, + std::int64_t k, + bool isRowMajor) +{ + std::int64_t midx = threadIdx.x + blockIdx.x * blockDim.x; + std::int64_t nidx = threadIdx.y + blockIdx.y * blockDim.y; + if (midx >= m || nidx >= n) return; + DataType acc = DataType(0); + for (std::int64_t i = 0; i < k; ++i) { + std::int64_t xidx = isRowMajor ? i + midx * k : i * m + midx; + std::int64_t yidx = isRowMajor ? i + nidx * k : i * n + nidx; + auto a = x[xidx]; + auto b = y[yidx]; + + DataType m = 0.5f * (a + b); + bool a_zero = a == 0; + bool b_zero = b == 0; + + DataType p = (!a_zero * m) / (a_zero + a); + DataType q = (!b_zero * m) / (b_zero + b); + + bool p_zero = p == 0; + bool q_zero = q == 0; + + acc += (-a * (!p_zero * log(p + p_zero))) + (-b * (!q_zero * log(q + q_zero))); + } + acc = raft::sqrt(0.5f * acc); + std::int64_t outidx = isRowMajor ? midx * n + nidx : midx + m * nidx; + dist[outidx] = acc; +} + +template +RAFT_KERNEL naiveRussellRaoDistanceKernel(OutType* dist, + const DataType* x, + const DataType* y, + std::int64_t m, + std::int64_t n, + std::int64_t k, + bool isRowMajor) +{ + std::int64_t midx = threadIdx.x + blockIdx.x * blockDim.x; + std::int64_t nidx = threadIdx.y + blockIdx.y * blockDim.y; + if (midx >= m || nidx >= n) return; + OutType acc = OutType(0); + for (std::int64_t i = 0; i < k; ++i) { + std::int64_t xidx = isRowMajor ? i + midx * k : i * m + midx; + std::int64_t yidx = isRowMajor ? i + nidx * k : i * n + nidx; + auto a = x[xidx]; + auto b = y[yidx]; + acc += (a * b); + } + acc = (k - acc) / k; + std::int64_t outidx = isRowMajor ? midx * n + nidx : midx + m * nidx; + dist[outidx] = acc; +} + +template +RAFT_KERNEL naiveKLDivergenceDistanceKernel(OutType* dist, + const DataType* x, + const DataType* y, + std::int64_t m, + std::int64_t n, + std::int64_t k, + bool isRowMajor) +{ + std::int64_t midx = threadIdx.x + blockIdx.x * blockDim.x; + std::int64_t nidx = threadIdx.y + blockIdx.y * blockDim.y; + if (midx >= m || nidx >= n) return; + OutType acc = OutType(0); + for (std::int64_t i = 0; i < k; ++i) { + std::int64_t xidx = isRowMajor ? i + midx * k : i * m + midx; + std::int64_t yidx = isRowMajor ? i + nidx * k : i * n + nidx; + auto a = x[xidx]; + auto b = y[yidx]; + bool b_zero = (b == 0); + bool a_zero = (a == 0); + acc += a * (log(a + a_zero) - log(b + b_zero)); + } + acc = 0.5f * acc; + std::int64_t outidx = isRowMajor ? midx * n + nidx : midx + m * nidx; + dist[outidx] = acc; +} + +template +RAFT_KERNEL naiveCorrelationDistanceKernel(OutType* dist, + const DataType* x, + const DataType* y, + std::int64_t m, + std::int64_t n, + std::int64_t k, + bool isRowMajor) +{ + std::int64_t midx = threadIdx.x + blockIdx.x * blockDim.x; + std::int64_t nidx = threadIdx.y + blockIdx.y * blockDim.y; + if (midx >= m || nidx >= n) return; + OutType acc = OutType(0); + auto a_norm = DataType(0); + auto b_norm = DataType(0); + auto a_sq_norm = DataType(0); + auto b_sq_norm = DataType(0); + for (std::int64_t i = 0; i < k; ++i) { + std::int64_t xidx = isRowMajor ? i + midx * k : i * m + midx; + std::int64_t yidx = isRowMajor ? i + nidx * k : i * n + nidx; + auto a = x[xidx]; + auto b = y[yidx]; + a_norm += a; + b_norm += b; + a_sq_norm += (a * a); + b_sq_norm += (b * b); + acc += (a * b); + } + + auto numer = k * acc - (a_norm * b_norm); + auto Q_denom = k * a_sq_norm - (a_norm * a_norm); + auto R_denom = k * b_sq_norm - (b_norm * b_norm); + + acc = 1 - (numer / raft::sqrt(Q_denom * R_denom)); + + std::int64_t outidx = isRowMajor ? midx * n + nidx : midx + m * nidx; + dist[outidx] = acc; +} + +template +void naiveDistance(DataType* dist, + const DataType* x, + const DataType* y, + std::int64_t m, + std::int64_t n, + std::int64_t k, + cuvs::distance::DistanceType type, + bool isRowMajor, + DataType metric_arg = 2.0f, + cudaStream_t stream = 0) +{ + static const dim3 TPB(4, 256, 1); + dim3 nblks(raft::ceildiv(m, (std::int64_t)TPB.x), raft::ceildiv(n, (std::int64_t)TPB.y), 1); + + switch (type) { + case cuvs::distance::DistanceType::Canberra: + case cuvs::distance::DistanceType::Linf: + case cuvs::distance::DistanceType::L1: + naiveL1_Linf_CanberraDistanceKernel + <<>>(dist, x, y, m, n, k, type, isRowMajor); + break; + case cuvs::distance::DistanceType::L2SqrtUnexpanded: + case cuvs::distance::DistanceType::L2Unexpanded: + case cuvs::distance::DistanceType::L2SqrtExpanded: + case cuvs::distance::DistanceType::L2Expanded: + naiveDistanceKernel + <<>>(dist, x, y, m, n, k, type, isRowMajor); + break; + case cuvs::distance::DistanceType::CosineExpanded: + naiveCosineDistanceKernel + <<>>(dist, x, y, m, n, k, isRowMajor); + break; + case cuvs::distance::DistanceType::HellingerExpanded: + naiveHellingerDistanceKernel + <<>>(dist, x, y, m, n, k, isRowMajor); + break; + case cuvs::distance::DistanceType::LpUnexpanded: + naiveLpUnexpDistanceKernel + <<>>(dist, x, y, m, n, k, isRowMajor, metric_arg); + break; + case cuvs::distance::DistanceType::HammingUnexpanded: + naiveHammingDistanceKernel + <<>>(dist, x, y, m, n, k, isRowMajor); + break; + case cuvs::distance::DistanceType::InnerProduct: + naiveInnerProductKernel<<>>(dist, x, y, m, n, k, isRowMajor); + break; + case cuvs::distance::DistanceType::JensenShannon: + naiveJensenShannonDistanceKernel + <<>>(dist, x, y, m, n, k, isRowMajor); + break; + case cuvs::distance::DistanceType::RusselRaoExpanded: + naiveRussellRaoDistanceKernel + <<>>(dist, x, y, m, n, k, isRowMajor); + break; + case cuvs::distance::DistanceType::KLDivergence: + naiveKLDivergenceDistanceKernel + <<>>(dist, x, y, m, n, k, isRowMajor); + break; + case cuvs::distance::DistanceType::CorrelationExpanded: + naiveCorrelationDistanceKernel + <<>>(dist, x, y, m, n, k, isRowMajor); + break; + default: FAIL() << "should be here\n"; + } + RAFT_CUDA_TRY(cudaPeekAtLastError()); +} + +template +struct DistanceInputs { + DataType tolerance; + std::int64_t m, n, k; + bool isRowMajor; + unsigned long long int seed; + DataType metric_arg = 2.0f; +}; + +template +::std::ostream& operator<<(::std::ostream& os, const DistanceInputs& dims) +{ + return os; +} + +// TODO: Remove when mdspan-based raft::runtime::distance::pairwise_distance is +// implemented. +// +// Context: +// https://github.com/rapidsai/raft/issues/1338 +template +constexpr bool layout_to_row_major(); + +template <> +constexpr bool layout_to_row_major() +{ + return true; +} +template <> +constexpr bool layout_to_row_major() +{ + return false; +} + +template +void distanceLauncher(raft::resources const& handle, + DataType* x, + DataType* y, + DataType* dist, + DataType* dist2, + std::int64_t m, + std::int64_t n, + std::int64_t k, + DistanceInputs& params, + DataType threshold, + DataType metric_arg = 2.0f) +{ + auto x_v = raft::make_device_matrix_view(x, m, k); + auto y_v = raft::make_device_matrix_view(y, n, k); + auto dist_v = raft::make_device_matrix_view(dist, m, n); + + cuvs::distance::pairwise_distance(handle, x_v, y_v, dist_v, distanceType, metric_arg); +} + +template +class DistanceTest : public ::testing::TestWithParam> { + public: + DistanceTest() + : params(::testing::TestWithParam>::GetParam()), + stream(raft::resource::get_cuda_stream(handle)), + x(params.m * params.k, stream), + y(params.n * params.k, stream), + dist_ref(params.m * params.n, stream), + dist(params.m * params.n, stream), + dist2(params.m * params.n, stream) + { + } + + void SetUp() override + { + auto testInfo = testing::UnitTest::GetInstance()->current_test_info(); + raft::common::nvtx::range fun_scope( + "test::%s/%s", testInfo->test_suite_name(), testInfo->name()); + + raft::random::RngState r(params.seed); + std::int64_t m = params.m; + std::int64_t n = params.n; + std::int64_t k = params.k; + DataType metric_arg = params.metric_arg; + bool isRowMajor = params.isRowMajor; + if (distanceType == cuvs::distance::DistanceType::HellingerExpanded || + distanceType == cuvs::distance::DistanceType::JensenShannon || + distanceType == cuvs::distance::DistanceType::KLDivergence) { + // Hellinger works only on positive numbers + uniform(handle, r, x.data(), m * k, DataType(0.0), DataType(1.0)); + uniform(handle, r, y.data(), n * k, DataType(0.0), DataType(1.0)); + } else if (distanceType == cuvs::distance::DistanceType::RusselRaoExpanded) { + uniform(handle, r, x.data(), m * k, DataType(0.0), DataType(1.0)); + uniform(handle, r, y.data(), n * k, DataType(0.0), DataType(1.0)); + // Russel rao works on boolean values. + bernoulli(handle, r, x.data(), m * k, 0.5f); + bernoulli(handle, r, y.data(), n * k, 0.5f); + } else { + uniform(handle, r, x.data(), m * k, DataType(-1.0), DataType(1.0)); + uniform(handle, r, y.data(), n * k, DataType(-1.0), DataType(1.0)); + } + naiveDistance( + dist_ref.data(), x.data(), y.data(), m, n, k, distanceType, isRowMajor, metric_arg, stream); + + DataType threshold = -10000.f; + + if (isRowMajor) { + distanceLauncher(handle, + x.data(), + y.data(), + dist.data(), + dist2.data(), + m, + n, + k, + params, + threshold, + metric_arg); + + } else { + distanceLauncher(handle, + x.data(), + y.data(), + dist.data(), + dist2.data(), + m, + n, + k, + params, + threshold, + metric_arg); + } + raft::resource::sync_stream(handle, stream); + } + + protected: + raft::resources handle; + cudaStream_t stream; + + DistanceInputs params; + rmm::device_uvector x, y, dist_ref, dist, dist2; +}; + +/* + * This test suite verifies the path when X and Y are same buffer, + * distance metrics which requires norms like L2 expanded/cosine/correlation + * takes a more optimal path in such case to skip norm calculation for Y buffer. + * It may happen that though both X and Y are same buffer but user passes + * different dimensions for them like in case of tiled_brute_force_knn. + */ +template +class DistanceTestSameBuffer : public ::testing::TestWithParam> { + public: + using dev_vector = rmm::device_uvector; + DistanceTestSameBuffer() + : params(::testing::TestWithParam>::GetParam()), + stream(raft::resource::get_cuda_stream(handle)), + x(params.m * params.k, stream), + dist_ref({dev_vector(params.m * params.m, stream), dev_vector(params.m * params.m, stream)}), + dist({dev_vector(params.m * params.m, stream), dev_vector(params.m * params.m, stream)}), + dist2({dev_vector(params.m * params.m, stream), dev_vector(params.m * params.m, stream)}) + { + } + + void SetUp() override + { + auto testInfo = testing::UnitTest::GetInstance()->current_test_info(); + raft::common::nvtx::range fun_scope( + "test::%s/%s", testInfo->test_suite_name(), testInfo->name()); + + raft::random::RngState r(params.seed); + std::int64_t m = params.m; + std::int64_t n = params.m; + std::int64_t k = params.k; + DataType metric_arg = params.metric_arg; + bool isRowMajor = params.isRowMajor; + if (distanceType == cuvs::distance::DistanceType::HellingerExpanded || + distanceType == cuvs::distance::DistanceType::JensenShannon || + distanceType == cuvs::distance::DistanceType::KLDivergence) { + // Hellinger works only on positive numbers + uniform(handle, r, x.data(), m * k, DataType(0.0), DataType(1.0)); + } else if (distanceType == cuvs::distance::DistanceType::RusselRaoExpanded) { + uniform(handle, r, x.data(), m * k, DataType(0.0), DataType(1.0)); + // Russel rao works on boolean values. + bernoulli(handle, r, x.data(), m * k, 0.5f); + } else { + uniform(handle, r, x.data(), m * k, DataType(-1.0), DataType(1.0)); + } + + for (std::int64_t i = 0; i < 2; i++) { + // both X and Y are same buffer but when i = 1 + // different dimensions for x & y is passed. + m = m / (i + 1); + naiveDistance(dist_ref[i].data(), + x.data(), + x.data(), + m, + n, + k, + distanceType, + isRowMajor, + metric_arg, + stream); + + DataType threshold = -10000.f; + + if (isRowMajor) { + distanceLauncher(handle, + x.data(), + x.data(), + dist[i].data(), + dist2[i].data(), + m, + n, + k, + params, + threshold, + metric_arg); + + } else { + distanceLauncher(handle, + x.data(), + x.data(), + dist[i].data(), + dist2[i].data(), + m, + n, + k, + params, + threshold, + metric_arg); + } + } + raft::resource::sync_stream(handle, stream); + } + + protected: + raft::resources handle; + cudaStream_t stream; + + DistanceInputs params; + dev_vector x; + static const std::int64_t N = 2; + std::array dist_ref, dist, dist2; +}; + +template +class BigMatrixDistanceTest : public ::testing::Test { + public: + BigMatrixDistanceTest() + : x(m * k, raft::resource::get_cuda_stream(handle)), + dist(std::size_t(m) * m, raft::resource::get_cuda_stream(handle)){}; + void SetUp() override + { + auto testInfo = testing::UnitTest::GetInstance()->current_test_info(); + raft::common::nvtx::range fun_scope( + "test::%s/%s", testInfo->test_suite_name(), testInfo->name()); + + constexpr float metric_arg = 0.0f; + auto x_v = + raft::make_device_matrix_view(x.data(), m, k); + auto dist_v = raft::make_device_matrix_view( + dist.data(), m, n); + + cuvs::distance::pairwise_distance(handle, x_v, x_v, dist_v, distanceType, metric_arg); + raft::resource::sync_stream(handle); + } + + protected: + raft::resources handle; + std::int64_t m = 48000; + std::int64_t n = 48000; + std::int64_t k = 1; + rmm::device_uvector x, dist; +}; +} // end namespace distance +} // namespace cuvs diff --git a/cpp/test/neighbors/ann_ivf_flat_c.cu b/cpp/test/neighbors/ann_ivf_flat_c.cu index e85450494..784418860 100644 --- a/cpp/test/neighbors/ann_ivf_flat_c.cu +++ b/cpp/test/neighbors/ann_ivf_flat_c.cu @@ -30,7 +30,7 @@ extern "C" void run_ivf_flat(int64_t n_rows, float* query_data, float* distances_data, int64_t* neighbors_data, - enum DistanceType metric, + cuvsDistanceType metric, size_t n_probes, size_t n_lists); @@ -51,7 +51,7 @@ void recall_eval(T* query_data, size_t n_rows, size_t n_dim, size_t n_neighbors, - DistanceType metric, + cuvsDistanceType metric, size_t n_probes, size_t n_lists) { @@ -101,9 +101,9 @@ TEST(IvfFlatC, BuildSearch) int64_t n_dim = 32; uint32_t n_neighbors = 8; - enum DistanceType metric = L2Expanded; - size_t n_probes = 20; - size_t n_lists = 1024; + cuvsDistanceType metric = L2Expanded; + size_t n_probes = 20; + size_t n_lists = 1024; float *index_data, *query_data, *distances_data; int64_t* neighbors_data; diff --git a/cpp/test/neighbors/ann_ivf_pq_c.cu b/cpp/test/neighbors/ann_ivf_pq_c.cu index 94d121ce2..88cd1bd93 100644 --- a/cpp/test/neighbors/ann_ivf_pq_c.cu +++ b/cpp/test/neighbors/ann_ivf_pq_c.cu @@ -30,7 +30,7 @@ extern "C" void run_ivf_pq(int64_t n_rows, float* query_data, float* distances_data, int64_t* neighbors_data, - enum DistanceType metric, + cuvsDistanceType metric, size_t n_probes, size_t n_lists); @@ -51,7 +51,7 @@ void recall_eval(T* query_data, size_t n_rows, size_t n_dim, size_t n_neighbors, - DistanceType metric, + cuvsDistanceType metric, size_t n_probes, size_t n_lists) { @@ -101,9 +101,9 @@ TEST(IvfPqC, BuildSearch) int64_t n_dim = 32; uint32_t n_neighbors = 8; - enum DistanceType metric = L2Expanded; - size_t n_probes = 20; - size_t n_lists = 1024; + cuvsDistanceType metric = L2Expanded; + size_t n_probes = 20; + size_t n_lists = 1024; float *index_data, *query_data, *distances_data; int64_t* neighbors_data; diff --git a/cpp/test/neighbors/ann_utils.cuh b/cpp/test/neighbors/ann_utils.cuh index 079740945..27a0fff7f 100644 --- a/cpp/test/neighbors/ann_utils.cuh +++ b/cpp/test/neighbors/ann_utils.cuh @@ -83,27 +83,35 @@ struct print_metric { inline auto operator<<(std::ostream& os, const print_metric& p) -> std::ostream& { switch (p.value) { - case cuvs::distance::L2Expanded: os << "distance::L2Expanded"; break; - case cuvs::distance::L2SqrtExpanded: os << "distance::L2SqrtExpanded"; break; - case cuvs::distance::CosineExpanded: os << "distance::CosineExpanded"; break; - case cuvs::distance::L1: os << "distance::L1"; break; - case cuvs::distance::L2Unexpanded: os << "distance::L2Unexpanded"; break; - case cuvs::distance::L2SqrtUnexpanded: os << "distance::L2SqrtUnexpanded"; break; - case cuvs::distance::InnerProduct: os << "distance::InnerProduct"; break; - case cuvs::distance::Linf: os << "distance::Linf"; break; - case cuvs::distance::Canberra: os << "distance::Canberra"; break; - case cuvs::distance::LpUnexpanded: os << "distance::LpUnexpanded"; break; - case cuvs::distance::CorrelationExpanded: os << "distance::CorrelationExpanded"; break; - case cuvs::distance::JaccardExpanded: os << "distance::JaccardExpanded"; break; - case cuvs::distance::HellingerExpanded: os << "distance::HellingerExpanded"; break; - case cuvs::distance::Haversine: os << "distance::Haversine"; break; - case cuvs::distance::BrayCurtis: os << "distance::BrayCurtis"; break; - case cuvs::distance::JensenShannon: os << "distance::JensenShannon"; break; - case cuvs::distance::HammingUnexpanded: os << "distance::HammingUnexpanded"; break; - case cuvs::distance::KLDivergence: os << "distance::KLDivergence"; break; - case cuvs::distance::RusselRaoExpanded: os << "distance::RusselRaoExpanded"; break; - case cuvs::distance::DiceExpanded: os << "distance::DiceExpanded"; break; - case cuvs::distance::Precomputed: os << "distance::Precomputed"; break; + case cuvs::distance::DistanceType::L2Expanded: os << "distance::L2Expanded"; break; + case cuvs::distance::DistanceType::L2SqrtExpanded: os << "distance::L2SqrtExpanded"; break; + case cuvs::distance::DistanceType::CosineExpanded: os << "distance::CosineExpanded"; break; + case cuvs::distance::DistanceType::L1: os << "distance::L1"; break; + case cuvs::distance::DistanceType::L2Unexpanded: os << "distance::L2Unexpanded"; break; + case cuvs::distance::DistanceType::L2SqrtUnexpanded: os << "distance::L2SqrtUnexpanded"; break; + case cuvs::distance::DistanceType::InnerProduct: os << "distance::InnerProduct"; break; + case cuvs::distance::DistanceType::Linf: os << "distance::Linf"; break; + case cuvs::distance::DistanceType::Canberra: os << "distance::Canberra"; break; + case cuvs::distance::DistanceType::LpUnexpanded: os << "distance::LpUnexpanded"; break; + case cuvs::distance::DistanceType::CorrelationExpanded: + os << "distance::CorrelationExpanded"; + break; + case cuvs::distance::DistanceType::JaccardExpanded: os << "distance::JaccardExpanded"; break; + case cuvs::distance::DistanceType::HellingerExpanded: + os << "distance::HellingerExpanded"; + break; + case cuvs::distance::DistanceType::Haversine: os << "distance::Haversine"; break; + case cuvs::distance::DistanceType::BrayCurtis: os << "distance::BrayCurtis"; break; + case cuvs::distance::DistanceType::JensenShannon: os << "distance::JensenShannon"; break; + case cuvs::distance::DistanceType::HammingUnexpanded: + os << "distance::HammingUnexpanded"; + break; + case cuvs::distance::DistanceType::KLDivergence: os << "distance::KLDivergence"; break; + case cuvs::distance::DistanceType::RusselRaoExpanded: + os << "distance::RusselRaoExpanded"; + break; + case cuvs::distance::DistanceType::DiceExpanded: os << "distance::DiceExpanded"; break; + case cuvs::distance::DistanceType::Precomputed: os << "distance::Precomputed"; break; default: RAFT_FAIL("unreachable code"); } return os; diff --git a/cpp/test/neighbors/brute_force_c.cu b/cpp/test/neighbors/brute_force_c.cu index 7730a98c6..8caf1c9d1 100644 --- a/cpp/test/neighbors/brute_force_c.cu +++ b/cpp/test/neighbors/brute_force_c.cu @@ -30,7 +30,7 @@ extern "C" void run_brute_force(int64_t n_rows, float* query_data, float* distances_data, int64_t* neighbors_data, - enum DistanceType metric); + cuvsDistanceType metric); template void generate_random_data(T* devPtr, size_t size) @@ -49,7 +49,7 @@ void recall_eval(T* query_data, size_t n_rows, size_t n_dim, size_t n_neighbors, - DistanceType metric) + cuvsDistanceType metric) { raft::handle_t handle; auto distances_ref = raft::make_device_matrix(handle, n_queries, n_neighbors); @@ -97,7 +97,7 @@ TEST(BruteForceC, BuildSearch) int64_t n_dim = 32; uint32_t n_neighbors = 8; - enum DistanceType metric = L2Expanded; + cuvsDistanceType metric = L2Expanded; float *index_data, *query_data, *distances_data; int64_t* neighbors_data; diff --git a/cpp/test/neighbors/run_brute_force_c.c b/cpp/test/neighbors/run_brute_force_c.c index ed775a2d6..ed9e99970 100644 --- a/cpp/test/neighbors/run_brute_force_c.c +++ b/cpp/test/neighbors/run_brute_force_c.c @@ -24,7 +24,7 @@ void run_brute_force(int64_t n_rows, float* query_data, float* distances_data, int64_t* neighbors_data, - enum DistanceType metric) + cuvsDistanceType metric) { // create cuvsResources_t cuvsResources_t res; diff --git a/cpp/test/neighbors/run_ivf_flat_c.c b/cpp/test/neighbors/run_ivf_flat_c.c index badb507a5..9ecbd18eb 100644 --- a/cpp/test/neighbors/run_ivf_flat_c.c +++ b/cpp/test/neighbors/run_ivf_flat_c.c @@ -24,7 +24,7 @@ void run_ivf_flat(int64_t n_rows, float* query_data, float* distances_data, int64_t* neighbors_data, - enum DistanceType metric, + cuvsDistanceType metric, size_t n_probes, size_t n_lists) { diff --git a/cpp/test/neighbors/run_ivf_pq_c.c b/cpp/test/neighbors/run_ivf_pq_c.c index fece4a644..332c8a0f8 100644 --- a/cpp/test/neighbors/run_ivf_pq_c.c +++ b/cpp/test/neighbors/run_ivf_pq_c.c @@ -24,7 +24,7 @@ void run_ivf_pq(int64_t n_rows, float* query_data, float* distances_data, int64_t* neighbors_data, - enum DistanceType metric, + cuvsDistanceType metric, size_t n_probes, size_t n_lists) { diff --git a/dependencies.yaml b/dependencies.yaml index eb1f99d4a..0c72b9b9a 100644 --- a/dependencies.yaml +++ b/dependencies.yaml @@ -220,6 +220,7 @@ dependencies: - cuda-nvtx-dev - cuda-cudart-dev - cuda-profiler-api + - libnvjitlink - libcublas-dev - libcurand-dev - libcusolver-dev diff --git a/docs/source/cpp_api/distance.rst b/docs/source/cpp_api/distance.rst index 21f4558ec..c1b8c619d 100644 --- a/docs/source/cpp_api/distance.rst +++ b/docs/source/cpp_api/distance.rst @@ -17,3 +17,16 @@ namespace *cuvs::distance* .. doxygenenum:: cuvs::distance::DistanceType :project: cuvs + + +Pairwise Distances +------------------ + +``include `` + +namespace *cuvs::distance* + +.. doxygengroup:: pairwise_distance + :project: cuvs + :members: + :content-only: \ No newline at end of file diff --git a/python/cuvs/cuvs/distance_type.pxd b/python/cuvs/cuvs/distance_type.pxd index a1f0366a5..b85ee4b36 100644 --- a/python/cuvs/cuvs/distance_type.pxd +++ b/python/cuvs/cuvs/distance_type.pxd @@ -17,7 +17,7 @@ cdef extern from "cuvs/distance/distance_types.h" nogil: - ctypedef enum DistanceType: + ctypedef enum cuvsDistanceType: L2Expanded L2SqrtExpanded CosineExpanded diff --git a/python/cuvs/cuvs/neighbors/brute_force/brute_force.pxd b/python/cuvs/cuvs/neighbors/brute_force/brute_force.pxd index 77e484fba..c57fa9e8d 100644 --- a/python/cuvs/cuvs/neighbors/brute_force/brute_force.pxd +++ b/python/cuvs/cuvs/neighbors/brute_force/brute_force.pxd @@ -19,7 +19,7 @@ from libc.stdint cimport uintptr_t from cuvs.common.c_api cimport cuvsError_t, cuvsResources_t from cuvs.common.cydlpack cimport DLDataType, DLManagedTensor -from cuvs.distance_type cimport DistanceType +from cuvs.distance_type cimport cuvsDistanceType cdef extern from "cuvs/neighbors/brute_force.h" nogil: @@ -36,7 +36,7 @@ cdef extern from "cuvs/neighbors/brute_force.h" nogil: cuvsError_t cuvsBruteForceBuild(cuvsResources_t res, DLManagedTensor* dataset, - DistanceType metric, + cuvsDistanceType metric, float metric_arg, cuvsBruteForceIndex_t index) except + diff --git a/python/cuvs/cuvs/neighbors/brute_force/brute_force.pyx b/python/cuvs/cuvs/neighbors/brute_force/brute_force.pyx index ccb10e305..6af3c920c 100644 --- a/python/cuvs/cuvs/neighbors/brute_force/brute_force.pyx +++ b/python/cuvs/cuvs/neighbors/brute_force/brute_force.pyx @@ -26,7 +26,7 @@ from libc.stdint cimport uint32_t from libcpp cimport bool from cuvs.common cimport cydlpack -from cuvs.distance_type cimport DistanceType +from cuvs.distance_type cimport cuvsDistanceType from pylibraft.common import auto_convert_output, cai_wrapper, device_ndarray from pylibraft.common.cai_wrapper import wrap_array @@ -105,7 +105,7 @@ def build(dataset, metric="sqeuclidean", metric_arg=2.0, resources=None): cdef cuvsResources_t res = resources.get_c_obj() - cdef DistanceType c_metric = DISTANCE_TYPES[metric] + cdef cuvsDistanceType c_metric = DISTANCE_TYPES[metric] cdef Index idx = Index() cdef cydlpack.DLManagedTensor* dataset_dlpack = \ cydlpack.dlpack_c(dataset_ai) diff --git a/rust/cuvs/src/distance_type.rs b/rust/cuvs/src/distance_type.rs index 4ac3e9164..a0cbcc86e 100644 --- a/rust/cuvs/src/distance_type.rs +++ b/rust/cuvs/src/distance_type.rs @@ -14,4 +14,4 @@ * limitations under the License. */ -pub type DistanceType = ffi::DistanceType; +pub type DistanceType = ffi::cuvsDistanceType;