diff --git a/cpp/include/cuvs/neighbors/nn_descent.hpp b/cpp/include/cuvs/neighbors/nn_descent.hpp index 347ccf889..64900ea11 100644 --- a/cpp/include/cuvs/neighbors/nn_descent.hpp +++ b/cpp/include/cuvs/neighbors/nn_descent.hpp @@ -55,6 +55,8 @@ struct index_params : cuvs::neighbors::index_params { size_t intermediate_graph_degree = 128; // Degree of input graph for pruning. size_t max_iterations = 20; // Number of nn-descent iterations. float termination_threshold = 0.0001; // Termination threshold of nn-descent. + bool return_distances = false; // return distances if true + size_t n_clusters = 1; // defaults to not using any batching /** @brief Construct NN descent parameters for a specific kNN graph degree * @@ -100,14 +102,20 @@ struct index : cuvs::neighbors::index { * @param res raft::resources is an object mangaging resources * @param n_rows number of rows in knn-graph * @param n_cols number of cols in knn-graph + * @param return_distances whether to return distances */ - index(raft::resources const& res, int64_t n_rows, int64_t n_cols) + index(raft::resources const& res, int64_t n_rows, int64_t n_cols, bool return_distances = false) : cuvs::neighbors::index(), res_{res}, metric_{cuvs::distance::DistanceType::L2Expanded}, graph_{raft::make_host_matrix(n_rows, n_cols)}, - graph_view_{graph_.view()} + graph_view_{graph_.view()}, + return_distances_{return_distances} { + if (return_distances) { + distances_ = raft::make_device_matrix(res_, n_rows, n_cols); + distances_view_ = distances_.value().view(); + } } /** @@ -119,14 +127,22 @@ struct index : cuvs::neighbors::index { * * @param res raft::resources is an object mangaging resources * @param graph_view raft::host_matrix_view for storing knn-graph + * @param distances_view optional raft::device_matrix_view for storing + * distances + * @param return_distances whether to return distances */ index(raft::resources const& res, - raft::host_matrix_view graph_view) + raft::host_matrix_view graph_view, + std::optional> distances_view = + std::nullopt, + bool return_distances = false) : cuvs::neighbors::index(), res_{res}, metric_{cuvs::distance::DistanceType::L2Expanded}, graph_{raft::make_host_matrix(0, 0)}, - graph_view_{graph_view} + graph_view_{graph_view}, + distances_view_{distances_view}, + return_distances_{return_distances} { } @@ -155,6 +171,13 @@ struct index : cuvs::neighbors::index { return graph_view_; } + /** neighborhood graph distances [size, graph-degree] */ + [[nodiscard]] inline auto distances() noexcept + -> std::optional> + { + return distances_view_; + } + // Don't allow copying the index for performance reasons (try avoiding copying data) index(const index&) = delete; index(index&&) = default; @@ -166,8 +189,11 @@ struct index : cuvs::neighbors::index { raft::resources const& res_; cuvs::distance::DistanceType metric_; raft::host_matrix graph_; // graph to return for non-int IdxT + std::optional> distances_; raft::host_matrix_view graph_view_; // view of graph for user provided matrix + std::optional> distances_view_; + bool return_distances_; }; /** @} */ @@ -393,8 +419,6 @@ auto build(raft::resources const& res, raft::device_matrix_view dataset) -> cuvs::neighbors::nn_descent::index; -/** @} */ - /** * @brief Build nn-descent Index with dataset in host memory * @@ -427,6 +451,270 @@ auto build(raft::resources const& res, raft::host_matrix_view dataset) -> cuvs::neighbors::nn_descent::index; +/** + * @brief Build nn-descent Index with dataset in device memory + * + * The following distance metrics are supported: + * - L2 + * + * Usage example: + * @code{.cpp} + * using namespace cuvs::neighbors; + * // use default index parameters + * nn_descent::index_params index_params; + * // create and fill the index from a [N, D] raft::device_matrix_view dataset + * auto index = nn_descent::index(res, index_params, N, D); + * // index.graph() provides a raft::host_matrix_view of an + * // all-neighbors knn graph of dimensions [N, k] of the input + * // dataset + * nn_descent::build(res, index_params, dataset, index); + * @endcode + * + * @param[in] res raft::resources is an object mangaging resources + * @param[in] params an instance of nn_descent::index_params that are parameters + * to run the nn-descent algorithm + * @param[in] dataset raft::device_matrix_view input dataset expected to be located + * in device memory + * @param[out] index index containing all-neighbors knn graph in host memory + */ +void build(raft::resources const& res, + index_params const& params, + raft::device_matrix_view dataset, + index& index); + +/** + * @brief Build nn-descent Index with dataset in host memory + * + * The following distance metrics are supported: + * - L2 + * + * Usage example: + * @code{.cpp} + * using namespace cuvs::neighbors; + * // use default index parameters + * nn_descent::index_params index_params; + * // create and fill the index from a [N, D] raft::device_matrix_view dataset + * auto index = nn_descent::index(res, index_params, N, D); + * // index.graph() provides a raft::host_matrix_view of an + * // all-neighbors knn graph of dimensions [N, k] of the input + * // dataset + * nn_descent::build(res, index_params, dataset, index); + * @endcode + * + * @tparam T data-type of the input dataset + * @tparam IdxT data-type for the output index + * @param res raft::resources is an object mangaging resources + * @param[in] params an instance of nn_descent::index_params that are parameters + * to run the nn-descent algorithm + * @param[in] dataset raft::host_matrix_view input dataset expected to be located + * in host memory + * @param[out] index index containing all-neighbors knn graph in host memory + */ +void build(raft::resources const& res, + index_params const& params, + raft::host_matrix_view dataset, + index& index); + +/** + * @brief Build nn-descent Index with dataset in host memory + * + * The following distance metrics are supported: + * - L2 + * + * Usage example: + * @code{.cpp} + * using namespace cuvs::neighbors; + * // use default index parameters + * nn_descent::index_params index_params; + * // create and fill the index from a [N, D] raft::device_matrix_view dataset + * auto index = nn_descent::index(res, index_params, N, D); + * // index.graph() provides a raft::host_matrix_view of an + * // all-neighbors knn graph of dimensions [N, k] of the input + * // dataset + * nn_descent::build(res, index_params, dataset, index); + * @endcode + * + * @tparam T data-type of the input dataset + * @tparam IdxT data-type for the output index + * @param res raft::resources is an object mangaging resources + * @param[in] params an instance of nn_descent::index_params that are parameters + * to run the nn-descent algorithm + * @param[in] dataset raft::host_matrix_view input dataset expected to be located + * in host memory + * @param[out] index index containing all-neighbors knn graph in host memory + */ +void build(raft::resources const& res, + index_params const& params, + raft::device_matrix_view dataset, + index& index); + +/** + * @brief Build nn-descent Index with dataset in host memory + * + * The following distance metrics are supported: + * - L2 + * + * Usage example: + * @code{.cpp} + * using namespace cuvs::neighbors; + * // use default index parameters + * nn_descent::index_params index_params; + * // create and fill the index from a [N, D] raft::device_matrix_view dataset + * auto index = nn_descent::index(res, index_params, N, D); + * // index.graph() provides a raft::host_matrix_view of an + * // all-neighbors knn graph of dimensions [N, k] of the input + * // dataset + * nn_descent::build(res, index_params, dataset, index); + * @endcode + * + * @tparam T data-type of the input dataset + * @tparam IdxT data-type for the output index + * @param res raft::resources is an object mangaging resources + * @param[in] params an instance of nn_descent::index_params that are parameters + * to run the nn-descent algorithm + * @param[in] dataset raft::host_matrix_view input dataset expected to be located + * in host memory + * @param[out] index index containing all-neighbors knn graph in host memory + */ +void build(raft::resources const& res, + index_params const& params, + raft::host_matrix_view dataset, + index& index); + +/** + * @brief Build nn-descent Index with dataset in host memory + * + * The following distance metrics are supported: + * - L2 + * + * Usage example: + * @code{.cpp} + * using namespace cuvs::neighbors; + * // use default index parameters + * nn_descent::index_params index_params; + * // create and fill the index from a [N, D] raft::device_matrix_view dataset + * auto index = nn_descent::index(res, index_params, N, D); + * // index.graph() provides a raft::host_matrix_view of an + * // all-neighbors knn graph of dimensions [N, k] of the input + * // dataset + * nn_descent::build(res, index_params, dataset, index); + * @endcode + * + * @tparam T data-type of the input dataset + * @tparam IdxT data-type for the output index + * @param res raft::resources is an object mangaging resources + * @param[in] params an instance of nn_descent::index_params that are parameters + * to run the nn-descent algorithm + * @param[in] dataset raft::host_matrix_view input dataset expected to be located + * in host memory + * @param[out] index index containing all-neighbors knn graph in host memory + */ +void build(raft::resources const& res, + index_params const& params, + raft::device_matrix_view dataset, + index& index); + +/** + * @brief Build nn-descent Index with dataset in host memory + * + * The following distance metrics are supported: + * - L2 + * + * Usage example: + * @code{.cpp} + * using namespace cuvs::neighbors; + * // use default index parameters + * nn_descent::index_params index_params; + * // create and fill the index from a [N, D] raft::device_matrix_view dataset + * auto index = nn_descent::index(res, index_params, N, D); + * // index.graph() provides a raft::host_matrix_view of an + * // all-neighbors knn graph of dimensions [N, k] of the input + * // dataset + * nn_descent::build(res, index_params, dataset, index); + * @endcode + * + * @tparam T data-type of the input dataset + * @tparam IdxT data-type for the output index + * @param res raft::resources is an object mangaging resources + * @param[in] params an instance of nn_descent::index_params that are parameters + * to run the nn-descent algorithm + * @param[in] dataset raft::host_matrix_view input dataset expected to be located + * in host memory + * @param[out] index index containing all-neighbors knn graph in host memory + */ +void build(raft::resources const& res, + index_params const& params, + raft::host_matrix_view dataset, + index& index); + +/** + * @brief Build nn-descent Index with dataset in host memory + * + * The following distance metrics are supported: + * - L2 + * + * Usage example: + * @code{.cpp} + * using namespace cuvs::neighbors; + * // use default index parameters + * nn_descent::index_params index_params; + * // create and fill the index from a [N, D] raft::device_matrix_view dataset + * auto index = nn_descent::index(res, index_params, N, D); + * // index.graph() provides a raft::host_matrix_view of an + * // all-neighbors knn graph of dimensions [N, k] of the input + * // dataset + * nn_descent::build(res, index_params, dataset, index); + * @endcode + * + * @tparam T data-type of the input dataset + * @tparam IdxT data-type for the output index + * @param res raft::resources is an object mangaging resources + * @param[in] params an instance of nn_descent::index_params that are parameters + * to run the nn-descent algorithm + * @param[in] dataset raft::host_matrix_view input dataset expected to be located + * in host memory + * @param[out] index index containing all-neighbors knn graph in host memory + */ +void build(raft::resources const& res, + index_params const& params, + raft::device_matrix_view dataset, + index& index); + +/** + * @brief Build nn-descent Index with dataset in host memory + * + * The following distance metrics are supported: + * - L2 + * + * Usage example: + * @code{.cpp} + * using namespace cuvs::neighbors; + * // use default index parameters + * nn_descent::index_params index_params; + * // create and fill the index from a [N, D] raft::device_matrix_view dataset + * auto index = nn_descent::index(res, index_params, N, D); + * // index.graph() provides a raft::host_matrix_view of an + * // all-neighbors knn graph of dimensions [N, k] of the input + * // dataset + * nn_descent::build(res, index_params, dataset, index); + * @endcode + * + * @tparam T data-type of the input dataset + * @tparam IdxT data-type for the output index + * @param res raft::resources is an object mangaging resources + * @param[in] params an instance of nn_descent::index_params that are parameters + * to run the nn-descent algorithm + * @param[in] dataset raft::host_matrix_view input dataset expected to be located + * in host memory + * @param[out] index index containing all-neighbors knn graph in host memory + */ +void build(raft::resources const& res, + index_params const& params, + raft::host_matrix_view dataset, + index& index); + +/** @} */ + /** * @brief Test if we have enough GPU memory to run NN descent algorithm. * diff --git a/cpp/src/neighbors/detail/cagra/cagra_build.cuh b/cpp/src/neighbors/detail/cagra/cagra_build.cuh index 9e4d453e3..e234450d2 100644 --- a/cpp/src/neighbors/detail/cagra/cagra_build.cuh +++ b/cpp/src/neighbors/detail/cagra/cagra_build.cuh @@ -33,8 +33,7 @@ #include #include -// TODO: Fixme- this needs to be migrated -#include "../../nn_descent.cuh" +#include // TODO: This shouldn't be calling spatial/knn APIs #include "../ann_utils.cuh" @@ -357,7 +356,7 @@ void build_knn_graph( cuvs::neighbors::nn_descent::index_params build_params) { auto nn_descent_idx = cuvs::neighbors::nn_descent::index(res, knn_graph); - cuvs::neighbors::nn_descent::build(res, build_params, dataset, nn_descent_idx); + cuvs::neighbors::nn_descent::build(res, build_params, dataset, nn_descent_idx); using internal_IdxT = typename std::make_unsigned::type; using g_accessor = typename decltype(nn_descent_idx.graph())::accessor_type; diff --git a/cpp/src/neighbors/detail/nn_descent.cuh b/cpp/src/neighbors/detail/nn_descent.cuh index 8c5767c50..5a11515c5 100644 --- a/cpp/src/neighbors/detail/nn_descent.cuh +++ b/cpp/src/neighbors/detail/nn_descent.cuh @@ -18,20 +18,26 @@ #include -#include "ann_utils.cuh" -#include "cagra/device_common.hpp" #include +#include #include #include +#include +#include #include #include - +#include +#include +#include +#include #include // raft::util::arch::SM_* #include #include #include #include +#include + #include #include #include @@ -44,12 +50,13 @@ #include #include +#include #include #include namespace cuvs::neighbors::nn_descent::detail { -static const std::string RAFT_NAME = "raft"; -using pinned_memory_resource = thrust::universal_host_pinned_memory_resource; + +using pinned_memory_resource = thrust::universal_host_pinned_memory_resource; template using pinned_memory_allocator = thrust::mr::stateless_resource_allocator; @@ -146,7 +153,7 @@ using align32 = raft::Pow2<32>; template int get_batch_size(const int it_now, const T nrow, const int batch_size) { - int it_total = raft::ceildiv(nrow, batch_size); + int it_total = ceildiv(nrow, batch_size); return (it_now == it_total - 1) ? nrow - it_now * batch_size : batch_size; } @@ -156,7 +163,7 @@ constexpr __host__ __device__ __forceinline__ int skew_dim(int ndim) { // all "4"s are for alignment if constexpr (std::is_same::value) { - ndim = raft::ceildiv(ndim, 4) * 4; + ndim = ceildiv(ndim, 4) * 4; return ndim + (ndim % 32 == 0) * 4; } } @@ -216,6 +223,7 @@ struct BuildConfig { // If internal_node_degree == 0, the value of node_degree will be assigned to it size_t max_iterations{50}; float termination_threshold{0.0001}; + size_t output_graph_degree{32}; }; template @@ -344,9 +352,14 @@ class GNND { GNND(const GNND&) = delete; GNND& operator=(const GNND&) = delete; - void build(Data_t* data, const Index_t nrow, Index_t* output_graph); + void build(Data_t* data, + const Index_t nrow, + Index_t* output_graph, + bool return_distances, + DistData_t* output_distances); ~GNND() = default; using ID_t = InternalID_t; + void reset(raft::resources const& res); private: void add_reverse_edges(Index_t* graph_ptr, @@ -409,7 +422,7 @@ __device__ __forceinline__ void load_vec(Data_t* vec_buffer, if constexpr (std::is_same_v or std::is_same_v or std::is_same_v) { constexpr int num_load_elems_per_warp = raft::warp_size(); - for (int step = 0; step < raft::ceildiv(padding_dims, num_load_elems_per_warp); step++) { + for (int step = 0; step < ceildiv(padding_dims, num_load_elems_per_warp); step++) { int idx = step * num_load_elems_per_warp + lane_id; if (idx < load_dims) { vec_buffer[idx] = d_vec[idx]; @@ -423,7 +436,7 @@ __device__ __forceinline__ void load_vec(Data_t* vec_buffer, load_dims % 4 == 0 && padding_dims % 4 == 0) { constexpr int num_load_elems_per_warp = raft::warp_size() * 4; #pragma unroll - for (int step = 0; step < raft::ceildiv(padding_dims, num_load_elems_per_warp); step++) { + for (int step = 0; step < ceildiv(padding_dims, num_load_elems_per_warp); step++) { int idx_in_vec = step * num_load_elems_per_warp + lane_id * 4; if (idx_in_vec + 4 <= load_dims) { *(float2*)(vec_buffer + idx_in_vec) = *(float2*)(d_vec + idx_in_vec); @@ -433,7 +446,7 @@ __device__ __forceinline__ void load_vec(Data_t* vec_buffer, } } else { constexpr int num_load_elems_per_warp = raft::warp_size(); - for (int step = 0; step < raft::ceildiv(padding_dims, num_load_elems_per_warp); step++) { + for (int step = 0; step < ceildiv(padding_dims, num_load_elems_per_warp); step++) { int idx = step * num_load_elems_per_warp + lane_id; if (idx < load_dims) { vec_buffer[idx] = d_vec[idx]; @@ -463,7 +476,7 @@ RAFT_KERNEL preprocess_data_kernel(const Data_t* input_data, if (threadIdx.x == 0) { l2_norm = 0; } __syncthreads(); int lane_id = threadIdx.x % raft::warp_size(); - for (int step = 0; step < raft::ceildiv(dim, raft::warp_size()); step++) { + for (int step = 0; step < ceildiv(dim, raft::warp_size()); step++) { int idx = step * raft::warp_size() + lane_id; float part_dist = 0; if (idx < dim) { @@ -478,7 +491,7 @@ RAFT_KERNEL preprocess_data_kernel(const Data_t* input_data, __syncwarp(); } - for (int step = 0; step < raft::ceildiv(dim, raft::warp_size()); step++) { + for (int step = 0; step < ceildiv(dim, raft::warp_size()); step++) { int idx = step * raft::warp_size() + threadIdx.x; if (idx < dim) { if (l2_norms == nullptr) { @@ -527,7 +540,7 @@ __device__ void insert_to_global_graph(ResultItem elem, size_t global_idx_base = list_id * node_degree; if (elem.id() == list_id) return; - const int num_segments = raft::ceildiv(node_degree, raft::warp_size()); + const int num_segments = ceildiv(node_degree, raft::warp_size()); int loop_flag = 0; do { @@ -1020,28 +1033,49 @@ void GnndGraph::sample_graph_new(InternalID_t* new_neighbors, template void GnndGraph::init_random_graph() { - for (size_t seg_idx = 0; seg_idx < static_cast(num_segments); seg_idx++) { - // random sequence (range: 0~nrow) - // segment_x stores neighbors which id % num_segments == x - std::vector rand_seq(nrow / num_segments); - std::iota(rand_seq.begin(), rand_seq.end(), 0); - auto gen = std::default_random_engine{seg_idx}; - std::shuffle(rand_seq.begin(), rand_seq.end(), gen); + // for (size_t seg_idx = 0; seg_idx < static_cast(num_segments); seg_idx++) { + // // random sequence (range: 0~nrow) + // // segment_x stores neighbors which id % num_segments == x + // std::vector rand_seq(nrow / num_segments); + // std::iota(rand_seq.begin(), rand_seq.end(), 0); + // auto gen = std::default_random_engine{seg_idx}; + // std::shuffle(rand_seq.begin(), rand_seq.end(), gen); + + // #pragma omp parallel for + // for (size_t i = 0; i < nrow; i++) { + // size_t base_idx = i * node_degree + seg_idx * segment_size; + // auto h_neighbor_list = h_graph + base_idx; + // auto h_dist_list = h_dists.data_handle() + base_idx; + // for (size_t j = 0; j < static_cast(segment_size); j++) { + // size_t idx = base_idx + j; + // Index_t id = rand_seq[idx % rand_seq.size()] * num_segments + seg_idx; + // if ((size_t)id == i) { + // id = rand_seq[(idx + segment_size) % rand_seq.size()] * num_segments + seg_idx; + // } + // h_neighbor_list[j].id_with_flag() = id; + // h_dist_list[j] = std::numeric_limits::max(); + // } + // } + // } + // random sequence (range: 0~nrow) + std::vector rand_seq(nrow); + std::iota(rand_seq.begin(), rand_seq.end(), 0); + std::random_shuffle(rand_seq.begin(), rand_seq.end()); + auto h_dists_ptr = h_dists.data_handle(); #pragma omp parallel for - for (size_t i = 0; i < nrow; i++) { - size_t base_idx = i * node_degree + seg_idx * segment_size; - auto h_neighbor_list = h_graph + base_idx; - auto h_dist_list = h_dists.data_handle() + base_idx; - for (size_t j = 0; j < static_cast(segment_size); j++) { - size_t idx = base_idx + j; - Index_t id = rand_seq[idx % rand_seq.size()] * num_segments + seg_idx; - if ((size_t)id == i) { - id = rand_seq[(idx + segment_size) % rand_seq.size()] * num_segments + seg_idx; - } - h_neighbor_list[j].id_with_flag() = id; - h_dist_list[j] = std::numeric_limits::max(); - } + for (size_t i = 0; i < nrow; i++) { + for (size_t j = 0; j < NUM_SAMPLES; j++) { + size_t idx = i * NUM_SAMPLES + j; + Index_t id = rand_seq[idx % nrow]; + if ((size_t)id == i) { id = rand_seq[(idx + NUM_SAMPLES) % nrow]; } + h_graph[i * node_degree + j].id_with_flag() = id; + } + for (size_t j = NUM_SAMPLES; j < node_degree; j++) { + h_graph[i * node_degree + j].id_with_flag() = std::numeric_limits::max(); + } + for (size_t j = 0; j < node_degree; j++) { + h_dists_ptr[i * node_degree + j] = std::numeric_limits::max(); } } } @@ -1161,18 +1195,23 @@ GNND::GNND(raft::resources const& res, const BuildConfig& build d_list_sizes_old_{raft::make_device_vector(res, nrow_)} { static_assert(NUM_SAMPLES <= 32); - - thrust::fill(thrust::device, - dists_buffer_.data_handle(), - dists_buffer_.data_handle() + dists_buffer_.size(), - std::numeric_limits::max()); - thrust::fill(thrust::device, - reinterpret_cast(graph_buffer_.data_handle()), - reinterpret_cast(graph_buffer_.data_handle()) + graph_buffer_.size(), - std::numeric_limits::max()); - thrust::fill(thrust::device, d_locks_.data_handle(), d_locks_.data_handle() + d_locks_.size(), 0); + raft::matrix::fill(res, dists_buffer_.view(), std::numeric_limits::max()); + auto graph_buffer_view = raft::make_device_matrix_view( + reinterpret_cast(graph_buffer_.data_handle()), nrow_, DEGREE_ON_DEVICE); + raft::matrix::fill(res, graph_buffer_view, std::numeric_limits::max()); + raft::matrix::fill(res, d_locks_.view(), 0); }; +template +void GNND::reset(raft::resources const& res) +{ + raft::matrix::fill(res, dists_buffer_.view(), std::numeric_limits::max()); + auto graph_buffer_view = raft::make_device_matrix_view( + reinterpret_cast(graph_buffer_.data_handle()), nrow_, DEGREE_ON_DEVICE); + raft::matrix::fill(res, graph_buffer_view, std::numeric_limits::max()); + raft::matrix::fill(res, d_locks_.view(), 0); +} + template void GNND::add_reverse_edges(Index_t* graph_ptr, Index_t* h_rev_graph_ptr, @@ -1211,32 +1250,36 @@ void GNND::local_join(cudaStream_t stream) } template -void GNND::build(Data_t* data, const Index_t nrow, Index_t* output_graph) +void GNND::build(Data_t* data, + const Index_t nrow, + Index_t* output_graph, + bool return_distances, + DistData_t* output_distances) { using input_t = typename std::remove_const::type; cudaStream_t stream = raft::resource::get_cuda_stream(res); nrow_ = nrow; + graph_.nrow = nrow; graph_.h_graph = (InternalID_t*)output_graph; cudaPointerAttributes data_ptr_attr; RAFT_CUDA_TRY(cudaPointerGetAttributes(&data_ptr_attr, data)); size_t batch_size = (data_ptr_attr.devicePointer == nullptr) ? 100000 : nrow_; - cuvs::spatial::knn::detail::utils::batch_load_iterator vec_batches{ + raft::spatial::knn::detail::utils::batch_load_iterator vec_batches{ data, static_cast(nrow_), build_config_.dataset_dim, batch_size, stream}; for (auto const& batch : vec_batches) { - preprocess_data_kernel<<(raft::warp_size())) * - raft::warp_size(), - stream>>>(batch.data(), - d_data_.data_handle(), - build_config_.dataset_dim, - l2_norms_.data_handle(), - batch.offset()); + preprocess_data_kernel<<< + batch.size(), + raft::warp_size(), + sizeof(Data_t) * ceildiv(build_config_.dataset_dim, static_cast(raft::warp_size())) * + raft::warp_size(), + stream>>>(batch.data(), + d_data_.data_handle(), + build_config_.dataset_dim, + l2_norms_.data_handle(), + batch.offset()); } thrust::fill(thrust::device.on(stream), @@ -1338,6 +1381,27 @@ void GNND::build(Data_t* data, const Index_t nrow, Index_t* out // Reuse graph_.h_dists as the buffer for shrink the lists in graph static_assert(sizeof(decltype(*(graph_.h_dists.data_handle()))) >= sizeof(Index_t)); + + if (return_distances) { + auto graph_d_dists = raft::make_device_matrix( + res, nrow_, build_config_.node_degree); + raft::copy(graph_d_dists.data_handle(), + graph_.h_dists.data_handle(), + nrow_ * build_config_.node_degree, + raft::resource::get_cuda_stream(res)); + + auto output_dist_view = raft::make_device_matrix_view( + output_distances, nrow_, build_config_.output_graph_degree); + + raft::matrix::slice_coordinates coords{static_cast(0), + static_cast(0), + static_cast(nrow_), + static_cast(build_config_.output_graph_degree)}; + raft::matrix::slice( + res, raft::make_const_mdspan(graph_d_dists.view()), output_dist_view, coords); + raft::resource::sync_stream(res); + } + Index_t* graph_shrink_buffer = (Index_t*)graph_.h_dists.data_handle(); #pragma omp parallel for @@ -1349,7 +1413,7 @@ void GNND::build(Data_t* data, const Index_t nrow, Index_t* out graph_shrink_buffer[i * build_config_.node_degree + j] = id; } else { graph_shrink_buffer[i * build_config_.node_degree + j] = - cuvs::neighbors::cagra::detail::device::xorshift64(idx) % nrow_; + raft::neighbors::cagra::detail::device::xorshift64(idx) % nrow_; } } } @@ -1365,12 +1429,12 @@ void GNND::build(Data_t* data, const Index_t nrow, Index_t* out } template , - raft::memory_type::host>> + typename IdxT = uint32_t, + typename Accessor = + host_device_accessor, memory_type::host>> void build(raft::resources const& res, const index_params& params, - raft::mdspan, raft::row_major, Accessor> dataset, + mdspan, row_major, Accessor> dataset, index& idx) { RAFT_EXPECTS(dataset.extent(0) < std::numeric_limits::max() - 1, @@ -1402,7 +1466,7 @@ void build(raft::resources const& res, size_t extended_intermediate_degree = align32::roundUp( static_cast(intermediate_degree * (intermediate_degree <= 32 ? 1.0 : 1.3))); - auto int_graph = raft::make_host_matrix( + auto int_graph = raft::make_host_matrix( dataset.extent(0), static_cast(extended_graph_degree)); BuildConfig build_config{.max_dataset_size = static_cast(dataset.extent(0)), @@ -1410,10 +1474,24 @@ void build(raft::resources const& res, .node_degree = extended_graph_degree, .internal_node_degree = extended_intermediate_degree, .max_iterations = params.max_iterations, - .termination_threshold = params.termination_threshold}; + .termination_threshold = params.termination_threshold, + .output_graph_degree = params.graph_degree}; GNND nnd(res, build_config); - nnd.build(dataset.data_handle(), dataset.extent(0), int_graph.data_handle()); + + if (idx.distances().has_value() || !params.return_distances) { + nnd.build(dataset.data_handle(), + dataset.extent(0), + int_graph.data_handle(), + params.return_distances, + idx.distances() + .value_or(raft::make_device_matrix(res, 0, 0).view()) + .data_handle()); + } else { + RAFT_EXPECTS(!params.return_distances, + "Distance view not allocated. Using return_distances set to true requires " + "distance view to be allocated."); + } #pragma omp parallel for for (size_t i = 0; i < static_cast(dataset.extent(0)); i++) { @@ -1425,13 +1503,12 @@ void build(raft::resources const& res, } template , - raft::memory_type::host>> -index build( - raft::resources const& res, - const index_params& params, - raft::mdspan, raft::row_major, Accessor> dataset) + typename IdxT = uint32_t, + typename Accessor = + host_device_accessor, memory_type::host>> +index build(raft::resources const& res, + const index_params& params, + mdspan, row_major, Accessor> dataset) { size_t intermediate_degree = params.intermediate_graph_degree; size_t graph_degree = params.graph_degree; @@ -1445,11 +1522,12 @@ index build( graph_degree = intermediate_degree; } - index idx{res, dataset.extent(0), static_cast(graph_degree)}; + index idx{ + res, dataset.extent(0), static_cast(graph_degree), params.return_distances}; build(res, params, dataset, idx); return idx; } -} // namespace cuvs::neighbors::nn_descent::detail +} // namespace cuvs::neighbors::nn_descent::detail diff --git a/cpp/src/neighbors/detail/nn_descent_batch.cuh b/cpp/src/neighbors/detail/nn_descent_batch.cuh new file mode 100644 index 000000000..dc0ecd703 --- /dev/null +++ b/cpp/src/neighbors/detail/nn_descent_batch.cuh @@ -0,0 +1,740 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#undef RAFT_EXPLICIT_INSTANTIATE_ONLY + +#include "nn_descent.cuh" +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include + +#include +#include +#include +#include +#include +#include + +namespace cuvs::neighbors::nn_descent::detail::experimental { + +// +// Run balanced kmeans on a subsample of the dataset to get centroids +// +template , memory_type::host>> +void get_balanced_kmeans_centroids( + raft::resources const& res, + cuvs::distance::DistanceType metric, + mdspan, row_major, Accessor> dataset, + raft::device_matrix_view centroids) +{ + size_t num_rows = static_cast(dataset.extent(0)); + size_t num_cols = static_cast(dataset.extent(1)); + size_t n_clusters = centroids.extent(0); + size_t num_subsamples = + std::min(static_cast(num_rows / n_clusters), static_cast(num_rows * 0.1)); + + auto d_subsample_dataset = + raft::make_device_matrix(res, num_subsamples, num_cols); + raft::matrix::sample_rows( + res, raft::random::RngState{0}, dataset, d_subsample_dataset.view()); + + cuvs::cluster::kmeans::balanced_params kmeans_params; + kmeans_params.metric = metric; + + auto d_subsample_dataset_const_view = + raft::make_device_matrix_view( + d_subsample_dataset.data_handle(), num_subsamples, num_cols); + auto centroids_view = raft::make_device_matrix_view( + centroids.data_handle(), n_clusters, num_cols); + cuvs::cluster::kmeans::fit(res, kmeans_params, d_subsample_dataset_const_view, centroids_view); +} + +// +// Get the top k closest centroid indices for each data point +// Loads the data in batches onto device if data is on host for memory efficiency +// +template +void get_global_nearest_k( + raft::resources const& res, + size_t k, + size_t num_rows, + size_t n_clusters, + const T* dataset, + raft::host_matrix_view global_nearest_cluster, + raft::device_matrix_view centroids, + cuvs::distance::DistanceType metric) +{ + size_t num_cols = centroids.extent(1); + auto centroids_view = raft::make_device_matrix_view( + centroids.data_handle(), n_clusters, num_cols); + + cudaPointerAttributes attr; + RAFT_CUDA_TRY(cudaPointerGetAttributes(&attr, dataset)); + float* ptr = reinterpret_cast(attr.devicePointer); + + size_t num_batches = n_clusters; + size_t batch_size = (num_rows + n_clusters) / n_clusters; + if (ptr == nullptr) { // data on host + + auto d_dataset_batch = + raft::make_device_matrix(res, batch_size, num_cols); + + auto nearest_clusters_idx = + raft::make_device_matrix(res, batch_size, k); + auto nearest_clusters_idxt = + raft::make_device_matrix(res, batch_size, k); + auto nearest_clusters_dist = + raft::make_device_matrix(res, batch_size, k); + + for (size_t i = 0; i < num_batches; i++) { + size_t batch_size_ = batch_size; + + if (i == num_batches - 1) { batch_size_ = num_rows - batch_size * i; } + raft::copy(d_dataset_batch.data_handle(), + dataset + i * batch_size * num_cols, + batch_size_ * num_cols, + resource::get_cuda_stream(res)); + + std::optional> norms_view; + cuvs::neighbors::brute_force::index brute_force_index( + res, centroids_view, norms_view, metric); + cuvs::neighbors::brute_force::search(res, + brute_force_index, + raft::make_const_mdspan(d_dataset_batch.view()), + nearest_clusters_idx.view(), + nearest_clusters_dist.view()); + + thrust::copy(raft::resource::get_thrust_policy(res), + nearest_clusters_idx.data_handle(), + nearest_clusters_idx.data_handle() + nearest_clusters_idx.size(), + nearest_clusters_idxt.data_handle()); + raft::copy(global_nearest_cluster.data_handle() + i * batch_size * k, + nearest_clusters_idxt.data_handle(), + batch_size_ * k, + resource::get_cuda_stream(res)); + } + } else { // data on device + auto nearest_clusters_idx = + raft::make_device_matrix(res, num_rows, k); + auto nearest_clusters_dist = + raft::make_device_matrix(res, num_rows, k); + + std::optional> norms_view; + cuvs::neighbors::brute_force::index brute_force_index( + res, centroids_view, norms_view, metric); + auto dataset_view = + raft::make_device_matrix_view(dataset, num_rows, num_cols); + cuvs::neighbors::brute_force::search(res, + brute_force_index, + dataset_view, + nearest_clusters_idx.view(), + nearest_clusters_dist.view()); + + auto nearest_clusters_idxt = + raft::make_device_matrix(res, batch_size, k); + for (size_t i = 0; i < num_batches; i++) { + size_t batch_size_ = batch_size; + + if (i == num_batches - 1) { batch_size_ = num_rows - batch_size * i; } + thrust::copy(raft::resource::get_thrust_policy(res), + nearest_clusters_idx.data_handle() + i * batch_size_ * k, + nearest_clusters_idx.data_handle() + (i + 1) * batch_size_ * k, + nearest_clusters_idxt.data_handle()); + raft::copy(global_nearest_cluster.data_handle() + i * batch_size_ * k, + nearest_clusters_idxt.data_handle(), + batch_size_ * k, + resource::get_cuda_stream(res)); + } + } +} + +// +// global_nearest_cluster [num_rows X k=2] : top 2 closest clusters for each data point +// inverted_indices [num_rows x k vector] : sparse vector for data indices for each cluster +// cluster_size [n_cluster] : cluster size for each cluster +// offset [n_cluster] : offset in inverted_indices for each cluster +// Loads the data in batches onto device if data is on host for memory efficiency +// +template +void get_inverted_indices(raft::resources const& res, + size_t n_clusters, + size_t& max_cluster_size, + size_t& min_cluster_size, + raft::host_matrix_view global_nearest_cluster, + raft::host_vector_view inverted_indices, + raft::host_vector_view cluster_size, + raft::host_vector_view offset) +{ + // build sparse inverted indices and get number of data points for each cluster + size_t num_rows = global_nearest_cluster.extent(0); + size_t k = global_nearest_cluster.extent(1); + + auto local_offset = raft::make_host_vector(n_clusters); + + max_cluster_size = 0; + min_cluster_size = std::numeric_limits::max(); + + thrust::fill( + thrust::host, cluster_size.data_handle(), cluster_size.data_handle() + n_clusters, 0); + thrust::fill( + thrust::host, local_offset.data_handle(), local_offset.data_handle() + n_clusters, 0); + + // TODO: this part isn't really a bottleneck but maybe worth trying omp parallel + // for with atomic add + for (size_t i = 0; i < num_rows; i++) { + for (size_t j = 0; j < k; j++) { + IdxT cluster_id = global_nearest_cluster(i, j); + cluster_size(cluster_id) += 1; + } + } + + offset(0) = 0; + for (size_t i = 1; i < n_clusters; i++) { + offset(i) = offset(i - 1) + cluster_size(i - 1); + } + for (size_t i = 0; i < num_rows; i++) { + for (size_t j = 0; j < k; j++) { + IdxT cluster_id = global_nearest_cluster(i, j); + inverted_indices(offset(cluster_id) + local_offset(cluster_id)) = i; + local_offset(cluster_id) += 1; + } + } + + max_cluster_size = static_cast( + *std::max_element(cluster_size.data_handle(), cluster_size.data_handle() + n_clusters)); + min_cluster_size = static_cast( + *std::min_element(cluster_size.data_handle(), cluster_size.data_handle() + n_clusters)); +} + +template +struct KeyValuePair { + KeyType key; + ValueType value; +}; + +template +struct CustomKeyComparator { + __device__ bool operator()(const KeyValuePair& a, + const KeyValuePair& b) const + { + if (a.key == b.key) { return a.value < b.value; } + return a.key < b.key; + } +}; + +template +RAFT_KERNEL merge_subgraphs(IdxT* cluster_data_indices, + size_t graph_degree, + size_t num_cluster_in_batch, + float* global_distances, + float* batch_distances, + IdxT* global_indices, + IdxT* batch_indices) +{ + size_t batch_row = blockIdx.x; + typedef cub::BlockMergeSort, BLOCK_SIZE, ITEMS_PER_THREAD> + BlockMergeSortType; + __shared__ typename cub::BlockMergeSort, BLOCK_SIZE, ITEMS_PER_THREAD>:: + TempStorage tmpSmem; + + extern __shared__ char sharedMem[]; + float* blockKeys = reinterpret_cast(sharedMem); + IdxT* blockValues = reinterpret_cast(&sharedMem[graph_degree * 2 * sizeof(float)]); + int16_t* uniqueMask = + reinterpret_cast(&sharedMem[graph_degree * 2 * (sizeof(float) + sizeof(IdxT))]); + + if (batch_row < num_cluster_in_batch) { + // load batch or global depending on threadIdx + size_t global_row = cluster_data_indices[batch_row]; + + KeyValuePair threadKeyValuePair[ITEMS_PER_THREAD]; + + size_t halfway = BLOCK_SIZE / 2; + size_t do_global = threadIdx.x < halfway; + + float* distances; + IdxT* indices; + + if (do_global) { + distances = global_distances; + indices = global_indices; + } else { + distances = batch_distances; + indices = batch_indices; + } + + size_t idxBase = (threadIdx.x * do_global + (threadIdx.x - halfway) * (1lu - do_global)) * + static_cast(ITEMS_PER_THREAD); + size_t arrIdxBase = (global_row * do_global + batch_row * (1lu - do_global)) * graph_degree; + for (int i = 0; i < ITEMS_PER_THREAD; i++) { + size_t colId = idxBase + i; + if (colId < graph_degree) { + threadKeyValuePair[i].key = distances[arrIdxBase + colId]; + threadKeyValuePair[i].value = indices[arrIdxBase + colId]; + } else { + threadKeyValuePair[i].key = std::numeric_limits::max(); + threadKeyValuePair[i].value = std::numeric_limits::max(); + } + } + + __syncthreads(); + + BlockMergeSortType(tmpSmem).Sort(threadKeyValuePair, CustomKeyComparator{}); + + // load sorted result into shared memory to get unique values + idxBase = threadIdx.x * ITEMS_PER_THREAD; + for (int i = 0; i < ITEMS_PER_THREAD; i++) { + size_t colId = idxBase + i; + if (colId < 2 * graph_degree) { + blockKeys[colId] = threadKeyValuePair[i].key; + blockValues[colId] = threadKeyValuePair[i].value; + } + } + + __syncthreads(); + + // get unique mask + if (threadIdx.x == 0) { uniqueMask[0] = 1; } + for (int i = 0; i < ITEMS_PER_THREAD; i++) { + size_t colId = idxBase + i; + if (colId > 0 && colId < 2 * graph_degree) { + uniqueMask[colId] = static_cast(blockValues[colId] != blockValues[colId - 1]); + } + } + + __syncthreads(); + + // prefix sum + if (threadIdx.x == 0) { + for (int i = 1; i < 2 * graph_degree; i++) { + uniqueMask[i] += uniqueMask[i - 1]; + } + } + + __syncthreads(); + // load unique values to global memory + if (threadIdx.x == 0) { + global_distances[global_row * graph_degree] = blockKeys[0]; + global_indices[global_row * graph_degree] = blockValues[0]; + } + + for (int i = 0; i < ITEMS_PER_THREAD; i++) { + size_t colId = idxBase + i; + if (colId > 0 && colId < 2 * graph_degree) { + bool is_unique = uniqueMask[colId] != uniqueMask[colId - 1]; + int16_t global_colId = uniqueMask[colId] - 1; + if (is_unique && static_cast(global_colId) < graph_degree) { + global_distances[global_row * graph_degree + global_colId] = blockKeys[colId]; + global_indices[global_row * graph_degree + global_colId] = blockValues[colId]; + } + } + } + } +} + +// +// builds knn graph using NN Descent and merge with global graph +// +template , memory_type::host>> +void build_and_merge(raft::resources const& res, + const index_params& params, + size_t num_data_in_cluster, + size_t graph_degree, + size_t int_graph_node_degree, + T* cluster_data, + IdxT* cluster_data_indices, + int* int_graph, + IdxT* inverted_indices, + IdxT* global_indices_d, + float* global_distances_d, + IdxT* batch_indices_h, + IdxT* batch_indices_d, + float* batch_distances_d, + GNND& nnd) +{ + nnd.build(cluster_data, num_data_in_cluster, int_graph, true, batch_distances_d); + + // remap indices +#pragma omp parallel for + for (size_t i = 0; i < num_data_in_cluster; i++) { + for (size_t j = 0; j < graph_degree; j++) { + size_t local_idx = int_graph[i * int_graph_node_degree + j]; + batch_indices_h[i * graph_degree + j] = inverted_indices[local_idx]; + } + } + + raft::copy(batch_indices_d, + batch_indices_h, + num_data_in_cluster * graph_degree, + raft::resource::get_cuda_stream(res)); + + size_t num_elems = graph_degree * 2; + size_t sharedMemSize = num_elems * (sizeof(float) + sizeof(IdxT) + sizeof(int16_t)); + + if (num_elems <= 128) { + merge_subgraphs + <<>>( + cluster_data_indices, + graph_degree, + num_data_in_cluster, + global_distances_d, + batch_distances_d, + global_indices_d, + batch_indices_d); + } else if (num_elems <= 512) { + merge_subgraphs + <<>>( + cluster_data_indices, + graph_degree, + num_data_in_cluster, + global_distances_d, + batch_distances_d, + global_indices_d, + batch_indices_d); + } else if (num_elems <= 1024) { + merge_subgraphs + <<>>( + cluster_data_indices, + graph_degree, + num_data_in_cluster, + global_distances_d, + batch_distances_d, + global_indices_d, + batch_indices_d); + } else if (num_elems <= 2048) { + merge_subgraphs + <<>>( + cluster_data_indices, + graph_degree, + num_data_in_cluster, + global_distances_d, + batch_distances_d, + global_indices_d, + batch_indices_d); + } else { + // this is as far as we can get due to the shared mem usage of cub::BlockMergeSort + RAFT_FAIL("The degree of knn is too large (%lu). It must be smaller than 1024", graph_degree); + } + raft::resource::sync_stream(res); +} + +// +// For each cluster, gather the data samples that belong to that cluster, and +// call build_and_merge +// +template +void cluster_nnd(raft::resources const& res, + const index_params& params, + size_t graph_degree, + size_t extended_graph_degree, + size_t max_cluster_size, + raft::host_matrix_view dataset, + IdxT* offsets, + IdxT* cluster_size, + IdxT* cluster_data_indices, + int* int_graph, + IdxT* inverted_indices, + IdxT* global_indices_h, + float* global_distances_h, + IdxT* batch_indices_h, + IdxT* batch_indices_d, + float* batch_distances_d, + const BuildConfig& build_config) +{ + size_t num_rows = dataset.extent(0); + size_t num_cols = dataset.extent(1); + + GNND nnd(res, build_config); + + auto cluster_data_matrix = + raft::make_host_matrix(max_cluster_size, num_cols); + + for (size_t cluster_id = 0; cluster_id < params.n_clusters; cluster_id++) { + RAFT_LOG_DEBUG( + "# Data on host. Running clusters: %lu / %lu", cluster_id + 1, params.n_clusters); + size_t num_data_in_cluster = cluster_size[cluster_id]; + size_t offset = offsets[cluster_id]; + +#pragma omp parallel for + for (size_t i = 0; i < num_data_in_cluster; i++) { + for (size_t j = 0; j < num_cols; j++) { + size_t global_row = (inverted_indices + offset)[i]; + cluster_data_matrix(i, j) = dataset(global_row, j); + } + } + + build_and_merge(res, + params, + num_data_in_cluster, + graph_degree, + extended_graph_degree, + cluster_data_matrix.data_handle(), + cluster_data_indices + offset, + int_graph, + inverted_indices + offset, + global_indices_h, + global_distances_h, + batch_indices_h, + batch_indices_d, + batch_distances_d, + nnd); + nnd.reset(res); + } +} + +template +void cluster_nnd(raft::resources const& res, + const index_params& params, + size_t graph_degree, + size_t extended_graph_degree, + size_t max_cluster_size, + raft::device_matrix_view dataset, + IdxT* offsets, + IdxT* cluster_size, + IdxT* cluster_data_indices, + int* int_graph, + IdxT* inverted_indices, + IdxT* global_indices_h, + float* global_distances_h, + IdxT* batch_indices_h, + IdxT* batch_indices_d, + float* batch_distances_d, + const BuildConfig& build_config) +{ + size_t num_rows = dataset.extent(0); + size_t num_cols = dataset.extent(1); + + GNND nnd(res, build_config); + + auto cluster_data_matrix = + raft::make_device_matrix(res, max_cluster_size, num_cols); + + for (size_t cluster_id = 0; cluster_id < params.n_clusters; cluster_id++) { + RAFT_LOG_DEBUG( + "# Data on device. Running clusters: %lu / %lu", cluster_id + 1, params.n_clusters); + size_t num_data_in_cluster = cluster_size[cluster_id]; + size_t offset = offsets[cluster_id]; + + auto cluster_data_view = raft::make_device_matrix_view( + cluster_data_matrix.data_handle(), num_data_in_cluster, num_cols); + auto cluster_data_indices_view = raft::make_device_vector_view( + cluster_data_indices + offset, num_data_in_cluster); + + auto dataset_IdxT = + raft::make_device_matrix_view(dataset.data_handle(), num_rows, num_cols); + raft::matrix::gather(res, dataset_IdxT, cluster_data_indices_view, cluster_data_view); + + build_and_merge(res, + params, + num_data_in_cluster, + graph_degree, + extended_graph_degree, + cluster_data_view.data_handle(), + cluster_data_indices + offset, + int_graph, + inverted_indices + offset, + global_indices_h, + global_distances_h, + batch_indices_h, + batch_indices_d, + batch_distances_d, + nnd); + nnd.reset(res); + } +} + +template , memory_type::host>> +void batch_build(raft::resources const& res, + const index_params& params, + mdspan, row_major, Accessor> dataset, + index& global_idx) +{ + size_t graph_degree = params.graph_degree; + size_t intermediate_degree = params.intermediate_graph_degree; + + size_t num_rows = static_cast(dataset.extent(0)); + size_t num_cols = static_cast(dataset.extent(1)); + + auto centroids = + raft::make_device_matrix(res, params.n_clusters, num_cols); + get_balanced_kmeans_centroids(res, params.metric, dataset, centroids.view()); + + size_t k = 2; + auto global_nearest_cluster = raft::make_host_matrix(num_rows, k); + get_global_nearest_k(res, + k, + num_rows, + params.n_clusters, + dataset.data_handle(), + global_nearest_cluster.view(), + centroids.view(), + params.metric); + + auto inverted_indices = raft::make_host_vector(num_rows * k); + auto cluster_size = raft::make_host_vector(params.n_clusters); + auto offset = raft::make_host_vector(params.n_clusters); + + size_t max_cluster_size, min_cluster_size; + get_inverted_indices(res, + params.n_clusters, + max_cluster_size, + min_cluster_size, + global_nearest_cluster.view(), + inverted_indices.view(), + cluster_size.view(), + offset.view()); + + if (intermediate_degree >= min_cluster_size) { + RAFT_LOG_WARN( + "Intermediate graph degree cannot be larger than minimum cluster size, reducing it to %lu", + dataset.extent(0)); + intermediate_degree = min_cluster_size - 1; + } + if (intermediate_degree < graph_degree) { + RAFT_LOG_WARN( + "Graph degree (%lu) cannot be larger than intermediate graph degree (%lu), reducing " + "graph_degree.", + graph_degree, + intermediate_degree); + graph_degree = intermediate_degree; + } + + size_t extended_graph_degree = + align32::roundUp(static_cast(graph_degree * (graph_degree <= 32 ? 1.0 : 1.3))); + size_t extended_intermediate_degree = align32::roundUp( + static_cast(intermediate_degree * (intermediate_degree <= 32 ? 1.0 : 1.3))); + + auto int_graph = raft::make_host_matrix( + max_cluster_size, static_cast(extended_graph_degree)); + + BuildConfig build_config{.max_dataset_size = max_cluster_size, + .dataset_dim = num_cols, + .node_degree = extended_graph_degree, + .internal_node_degree = extended_intermediate_degree, + .max_iterations = params.max_iterations, + .termination_threshold = params.termination_threshold, + .output_graph_degree = graph_degree}; + + auto global_indices_h = raft::make_managed_matrix(res, num_rows, graph_degree); + auto global_distances_h = raft::make_managed_matrix(res, num_rows, graph_degree); + + thrust::fill(thrust::host, + global_indices_h.data_handle(), + global_indices_h.data_handle() + num_rows * graph_degree, + std::numeric_limits::max()); + thrust::fill(thrust::host, + global_distances_h.data_handle(), + global_distances_h.data_handle() + num_rows * graph_degree, + std::numeric_limits::max()); + + auto batch_indices_h = + raft::make_host_matrix(max_cluster_size, graph_degree); + auto batch_indices_d = + raft::make_device_matrix(res, max_cluster_size, graph_degree); + auto batch_distances_d = + raft::make_device_matrix(res, max_cluster_size, graph_degree); + + auto cluster_data_indices = raft::make_device_vector(res, num_rows * k); + raft::copy(cluster_data_indices.data_handle(), + inverted_indices.data_handle(), + num_rows * k, + resource::get_cuda_stream(res)); + + cluster_nnd(res, + params, + graph_degree, + extended_graph_degree, + max_cluster_size, + dataset, + offset.data_handle(), + cluster_size.data_handle(), + cluster_data_indices.data_handle(), + int_graph.data_handle(), + inverted_indices.data_handle(), + global_indices_h.data_handle(), + global_distances_h.data_handle(), + batch_indices_h.data_handle(), + batch_indices_d.data_handle(), + batch_distances_d.data_handle(), + build_config); + + raft::copy(global_idx.graph().data_handle(), + global_indices_h.data_handle(), + num_rows * graph_degree, + raft::resource::get_cuda_stream(res)); + if (params.return_distances && global_idx.distances().has_value()) { + raft::copy(global_idx.distances().value().data_handle(), + global_distances_h.data_handle(), + num_rows * graph_degree, + raft::resource::get_cuda_stream(res)); + } +} + +template , memory_type::host>> +index batch_build(raft::resources const& res, + const index_params& params, + mdspan, row_major, Accessor> dataset) +{ + size_t intermediate_degree = params.intermediate_graph_degree; + size_t graph_degree = params.graph_degree; + + if (intermediate_degree < graph_degree) { + RAFT_LOG_WARN( + "Graph degree (%lu) cannot be larger than intermediate graph degree (%lu), reducing " + "graph_degree.", + graph_degree, + intermediate_degree); + graph_degree = intermediate_degree; + } + + index idx{ + res, dataset.extent(0), static_cast(graph_degree), params.return_distances}; + + batch_build(res, params, dataset, idx); + + return idx; +} + +} // namespace cuvs::neighbors::nn_descent::detail::experimental diff --git a/cpp/src/neighbors/nn_descent.cuh b/cpp/src/neighbors/nn_descent.cuh index 582da72c1..ed91dac91 100644 --- a/cpp/src/neighbors/nn_descent.cuh +++ b/cpp/src/neighbors/nn_descent.cuh @@ -17,9 +17,14 @@ #pragma once #include "detail/nn_descent.cuh" +#include "detail/nn_descent_batch.cuh" + +#include +#include #include #include +#include #include namespace cuvs::neighbors::nn_descent { @@ -61,7 +66,15 @@ auto build(raft::resources const& res, index_params const& params, raft::device_matrix_view dataset) -> index { - return detail::build(res, params, dataset); + if (params.n_clusters > 1) { + if constexpr (std::is_same_v) { + return detail::experimental::batch_build(res, params, dataset); + } else { + RAFT_FAIL("Batched nn-descent is only supported for float precision"); + } + } else { + return detail::build(res, params, dataset); + } } /** @@ -100,7 +113,15 @@ void build(raft::resources const& res, raft::device_matrix_view dataset, index& idx) { - detail::build(res, params, dataset, idx); + if (params.n_clusters > 1) { + if constexpr (std::is_same_v) { + detail::experimental::batch_build(res, params, dataset, idx); + } else { + RAFT_FAIL("Batched nn-descent is only supported for float precision"); + } + } else { + detail::build(res, params, dataset, idx); + } } /** @@ -135,7 +156,15 @@ auto build(raft::resources const& res, index_params const& params, raft::host_matrix_view dataset) -> index { - return detail::build(res, params, dataset); + if (params.n_clusters > 1) { + if constexpr (std::is_same_v) { + return detail::experimental::batch_build(res, params, dataset); + } else { + RAFT_FAIL("Batched nn-descent is only supported for float precision"); + } + } else { + return detail::build(res, params, dataset); + } } /** @@ -174,7 +203,15 @@ void build(raft::resources const& res, raft::host_matrix_view dataset, index& idx) { - detail::build(res, params, dataset, idx); + if (params.n_clusters > 1) { + if constexpr (std::is_same_v) { + detail::experimental::batch_build(res, params, dataset, idx); + } else { + RAFT_FAIL("Batched nn-descent is only supported for float precision"); + } + } else { + detail::build(res, params, dataset, idx); + } } /** @} */ // end group nn-descent diff --git a/cpp/src/neighbors/nn_descent_float.cu b/cpp/src/neighbors/nn_descent_float.cu index c6d356671..be3182673 100644 --- a/cpp/src/neighbors/nn_descent_float.cu +++ b/cpp/src/neighbors/nn_descent_float.cu @@ -34,6 +34,21 @@ namespace cuvs::neighbors::nn_descent { ->cuvs::neighbors::nn_descent::index \ { \ return cuvs::neighbors::nn_descent::build(handle, params, dataset); \ + }; \ + \ + void build(raft::resources const& handle, \ + const cuvs::neighbors::nn_descent::index_params& params, \ + raft::device_matrix_view dataset, \ + cuvs::neighbors::nn_descent::index& idx) \ + { \ + cuvs::neighbors::nn_descent::build(handle, params, dataset, idx); \ + }; \ + void build(raft::resources const& handle, \ + const cuvs::neighbors::nn_descent::index_params& params, \ + raft::host_matrix_view dataset, \ + cuvs::neighbors::nn_descent::index& idx) \ + { \ + cuvs::neighbors::nn_descent::build(handle, params, dataset, idx); \ }; CUVS_INST_NN_DESCENT_BUILD(float, uint32_t); diff --git a/cpp/src/neighbors/nn_descent_half.cu b/cpp/src/neighbors/nn_descent_half.cu index 587993031..f11542db5 100644 --- a/cpp/src/neighbors/nn_descent_half.cu +++ b/cpp/src/neighbors/nn_descent_half.cu @@ -34,6 +34,21 @@ namespace cuvs::neighbors::nn_descent { ->cuvs::neighbors::nn_descent::index \ { \ return cuvs::neighbors::nn_descent::build(handle, params, dataset); \ + }; \ + \ + void build(raft::resources const& handle, \ + const cuvs::neighbors::nn_descent::index_params& params, \ + raft::device_matrix_view dataset, \ + cuvs::neighbors::nn_descent::index& idx) \ + { \ + cuvs::neighbors::nn_descent::build(handle, params, dataset, idx); \ + }; \ + void build(raft::resources const& handle, \ + const cuvs::neighbors::nn_descent::index_params& params, \ + raft::host_matrix_view dataset, \ + cuvs::neighbors::nn_descent::index& idx) \ + { \ + cuvs::neighbors::nn_descent::build(handle, params, dataset, idx); \ }; CUVS_INST_NN_DESCENT_BUILD(half, uint32_t); diff --git a/cpp/src/neighbors/nn_descent_int8.cu b/cpp/src/neighbors/nn_descent_int8.cu index 813a01746..a52e068d6 100644 --- a/cpp/src/neighbors/nn_descent_int8.cu +++ b/cpp/src/neighbors/nn_descent_int8.cu @@ -34,6 +34,21 @@ namespace cuvs::neighbors::nn_descent { ->cuvs::neighbors::nn_descent::index \ { \ return cuvs::neighbors::nn_descent::build(handle, params, dataset); \ + }; \ + \ + void build(raft::resources const& handle, \ + const cuvs::neighbors::nn_descent::index_params& params, \ + raft::device_matrix_view dataset, \ + cuvs::neighbors::nn_descent::index& idx) \ + { \ + cuvs::neighbors::nn_descent::build(handle, params, dataset, idx); \ + }; \ + void build(raft::resources const& handle, \ + const cuvs::neighbors::nn_descent::index_params& params, \ + raft::host_matrix_view dataset, \ + cuvs::neighbors::nn_descent::index& idx) \ + { \ + cuvs::neighbors::nn_descent::build(handle, params, dataset, idx); \ }; CUVS_INST_NN_DESCENT_BUILD(int8_t, uint32_t); diff --git a/cpp/src/neighbors/nn_descent_uint8.cu b/cpp/src/neighbors/nn_descent_uint8.cu index 9d73dd90f..8fb38a870 100644 --- a/cpp/src/neighbors/nn_descent_uint8.cu +++ b/cpp/src/neighbors/nn_descent_uint8.cu @@ -34,6 +34,21 @@ namespace cuvs::neighbors::nn_descent { ->cuvs::neighbors::nn_descent::index \ { \ return cuvs::neighbors::nn_descent::build(handle, params, dataset); \ + }; \ + \ + void build(raft::resources const& handle, \ + const cuvs::neighbors::nn_descent::index_params& params, \ + raft::device_matrix_view dataset, \ + cuvs::neighbors::nn_descent::index& idx) \ + { \ + cuvs::neighbors::nn_descent::build(handle, params, dataset, idx); \ + }; \ + void build(raft::resources const& handle, \ + const cuvs::neighbors::nn_descent::index_params& params, \ + raft::host_matrix_view dataset, \ + cuvs::neighbors::nn_descent::index& idx) \ + { \ + cuvs::neighbors::nn_descent::build(handle, params, dataset, idx); \ }; CUVS_INST_NN_DESCENT_BUILD(uint8_t, uint32_t); diff --git a/cpp/test/neighbors/ann_nn_descent.cuh b/cpp/test/neighbors/ann_nn_descent.cuh index bce0f9899..7d2575c2b 100644 --- a/cpp/test/neighbors/ann_nn_descent.cuh +++ b/cpp/test/neighbors/ann_nn_descent.cuh @@ -18,9 +18,13 @@ #include "../test_utils.cuh" #include "ann_utils.cuh" +#include #include + #include +#include #include +#include #include "naive_knn.cuh" @@ -42,6 +46,15 @@ struct AnnNNDescentInputs { double min_recall; }; +struct AnnNNDescentBatchInputs { + std::pair recall_cluster; + int n_rows; + int dim; + int graph_degree; + cuvs::distance::DistanceType metric; + bool host_dataset; +}; + inline ::std::ostream& operator<<(::std::ostream& os, const AnnNNDescentInputs& p) { os << "dataset shape=" << p.n_rows << "x" << p.dim << ", graph_degree=" << p.graph_degree @@ -50,6 +63,14 @@ inline ::std::ostream& operator<<(::std::ostream& os, const AnnNNDescentInputs& return os; } +inline ::std::ostream& operator<<(::std::ostream& os, const AnnNNDescentBatchInputs& p) +{ + os << "dataset shape=" << p.n_rows << "x" << p.dim << ", graph_degree=" << p.graph_degree + << ", metric=" << static_cast(p.metric) << (p.host_dataset ? ", host" : ", device") + << ", clusters=" << p.recall_cluster.second << std::endl; + return os; +} + template class AnnNNDescentTest : public ::testing::TestWithParam { public: @@ -65,7 +86,9 @@ class AnnNNDescentTest : public ::testing::TestWithParam { { size_t queries_size = ps.n_rows * ps.graph_degree; std::vector indices_NNDescent(queries_size); + std::vector distances_NNDescent(queries_size); std::vector indices_naive(queries_size); + std::vector distances_naive(queries_size); { rmm::device_uvector distances_naive_dev(queries_size, stream_); @@ -81,16 +104,18 @@ class AnnNNDescentTest : public ::testing::TestWithParam { ps.graph_degree, ps.metric); raft::update_host(indices_naive.data(), indices_naive_dev.data(), queries_size, stream_); + raft::update_host(distances_naive.data(), distances_naive_dev.data(), queries_size, stream_); raft::resource::sync_stream(handle_); } { { - cuvs::neighbors::nn_descent::index_params index_params; + nn_descent::index_params index_params; index_params.metric = ps.metric; index_params.graph_degree = ps.graph_degree; index_params.intermediate_graph_degree = 2 * ps.graph_degree; index_params.max_iterations = 100; + index_params.return_distances = true; auto database_view = raft::make_device_matrix_view( (const DataT*)database.data(), ps.n_rows, ps.dim); @@ -101,22 +126,40 @@ class AnnNNDescentTest : public ::testing::TestWithParam { raft::copy(database_host.data_handle(), database.data(), database.size(), stream_); auto database_host_view = raft::make_host_matrix_view( (const DataT*)database_host.data_handle(), ps.n_rows, ps.dim); - auto index = - cuvs::neighbors::nn_descent::build(handle_, index_params, database_host_view); - raft::update_host( + auto index = nn_descent::build(handle_, index_params, database_host_view); + raft::copy( indices_NNDescent.data(), index.graph().data_handle(), queries_size, stream_); + if (index.distances().has_value()) { + raft::copy(distances_NNDescent.data(), + index.distances().value().data_handle(), + queries_size, + stream_); + } + } else { - auto index = cuvs::neighbors::nn_descent::build(handle_, index_params, database_view); - raft::update_host( + auto index = nn_descent::build(handle_, index_params, database_view); + raft::copy( indices_NNDescent.data(), index.graph().data_handle(), queries_size, stream_); + if (index.distances().has_value()) { + raft::copy(distances_NNDescent.data(), + index.distances().value().data_handle(), + queries_size, + stream_); + } }; } raft::resource::sync_stream(handle_); } double min_recall = ps.min_recall; - EXPECT_TRUE(eval_recall( - indices_naive, indices_NNDescent, ps.n_rows, ps.graph_degree, 0.001, min_recall)); + EXPECT_TRUE(eval_neighbours(indices_naive, + indices_NNDescent, + distances_naive, + distances_NNDescent, + ps.n_rows, + ps.graph_degree, + 0.001, + min_recall)); } } @@ -146,6 +189,125 @@ class AnnNNDescentTest : public ::testing::TestWithParam { rmm::device_uvector database; }; +template +class AnnNNDescentBatchTest : public ::testing::TestWithParam { + public: + AnnNNDescentBatchTest() + : stream_(raft::resource::get_cuda_stream(handle_)), + ps(::testing::TestWithParam::GetParam()), + database(0, stream_) + { + } + + void testNNDescentBatch() + { + size_t queries_size = ps.n_rows * ps.graph_degree; + std::vector indices_NNDescent(queries_size); + std::vector distances_NNDescent(queries_size); + std::vector indices_naive(queries_size); + std::vector distances_naive(queries_size); + + { + rmm::device_uvector distances_naive_dev(queries_size, stream_); + rmm::device_uvector indices_naive_dev(queries_size, stream_); + naive_knn(handle_, + distances_naive_dev.data(), + indices_naive_dev.data(), + database.data(), + database.data(), + ps.n_rows, + ps.n_rows, + ps.dim, + ps.graph_degree, + ps.metric); + raft::update_host(indices_naive.data(), indices_naive_dev.data(), queries_size, stream_); + raft::update_host(distances_naive.data(), distances_naive_dev.data(), queries_size, stream_); + raft::resource::sync_stream(handle_); + } + + { + { + nn_descent::index_params index_params; + index_params.metric = ps.metric; + index_params.graph_degree = ps.graph_degree; + index_params.intermediate_graph_degree = 2 * ps.graph_degree; + index_params.max_iterations = 10; + index_params.return_distances = true; + index_params.n_clusters = ps.recall_cluster.second; + + auto database_view = raft::make_device_matrix_view( + (const DataT*)database.data(), ps.n_rows, ps.dim); + + { + if (ps.host_dataset) { + auto database_host = raft::make_host_matrix(ps.n_rows, ps.dim); + raft::copy(database_host.data_handle(), database.data(), database.size(), stream_); + auto database_host_view = raft::make_host_matrix_view( + (const DataT*)database_host.data_handle(), ps.n_rows, ps.dim); + auto index = nn_descent::build(handle_, index_params, database_host_view); + raft::copy( + indices_NNDescent.data(), index.graph().data_handle(), queries_size, stream_); + if (index.distances().has_value()) { + raft::copy(distances_NNDescent.data(), + index.distances().value().data_handle(), + queries_size, + stream_); + } + + } else { + auto index = nn_descent::build(handle_, index_params, database_view); + raft::copy( + indices_NNDescent.data(), index.graph().data_handle(), queries_size, stream_); + if (index.distances().has_value()) { + raft::copy(distances_NNDescent.data(), + index.distances().value().data_handle(), + queries_size, + stream_); + } + }; + } + raft::resource::sync_stream(handle_); + } + double min_recall = ps.recall_cluster.first; + EXPECT_TRUE(eval_neighbours(indices_naive, + indices_NNDescent, + distances_naive, + distances_NNDescent, + ps.n_rows, + ps.graph_degree, + 0.01, + min_recall, + true, + static_cast(ps.graph_degree * 0.1))); + } + } + + void SetUp() override + { + database.resize(((size_t)ps.n_rows) * ps.dim, stream_); + raft::random::RngState r(1234ULL); + if constexpr (std::is_same{}) { + raft::random::normal(handle_, r, database.data(), ps.n_rows * ps.dim, DataT(0.1), DataT(2.0)); + } else { + raft::random::uniformInt( + handle_, r, database.data(), ps.n_rows * ps.dim, DataT(1), DataT(20)); + } + raft::resource::sync_stream(handle_); + } + + void TearDown() override + { + raft::resource::sync_stream(handle_); + database.resize(0, stream_); + } + + private: + raft::resources handle_; + rmm::cuda_stream_view stream_; + AnnNNDescentBatchInputs ps; + rmm::device_uvector database; +}; + const std::vector inputs = raft::util::itertools::product( {1000, 2000}, // n_rows {3, 5, 7, 8, 17, 64, 128, 137, 192, 256, 512, 619, 1024}, // dim @@ -154,4 +316,15 @@ const std::vector inputs = raft::util::itertools::product inputsBatch = + raft::util::itertools::product( + {std::make_pair(0.9, 3lu), std::make_pair(0.9, 2lu)}, // min_recall, n_clusters + {4000, 5000}, // n_rows + {192, 512}, // dim + {32, 64}, // graph_degree + {cuvs::distance::DistanceType::L2Expanded}, + {false, true}); + +} // namespace cuvs::neighbors::nn_descent diff --git a/cpp/test/neighbors/ann_nn_descent/test_float_uint32_t.cu b/cpp/test/neighbors/ann_nn_descent/test_float_uint32_t.cu index 64c0e0291..ab59b06a3 100644 --- a/cpp/test/neighbors/ann_nn_descent/test_float_uint32_t.cu +++ b/cpp/test/neighbors/ann_nn_descent/test_float_uint32_t.cu @@ -23,6 +23,12 @@ namespace cuvs::neighbors::nn_descent { typedef AnnNNDescentTest AnnNNDescentTestF_U32; TEST_P(AnnNNDescentTestF_U32, AnnNNDescent) { this->testNNDescent(); } +typedef AnnNNDescentBatchTest AnnNNDescentBatchTestF_U32; +TEST_P(AnnNNDescentBatchTestF_U32, AnnNNDescentBatch) { this->testNNDescentBatch(); } + INSTANTIATE_TEST_CASE_P(AnnNNDescentTest, AnnNNDescentTestF_U32, ::testing::ValuesIn(inputs)); +INSTANTIATE_TEST_CASE_P(AnnNNDescentBatchTest, + AnnNNDescentBatchTestF_U32, + ::testing::ValuesIn(inputsBatch)); } // namespace cuvs::neighbors::nn_descent diff --git a/cpp/test/neighbors/ann_utils.cuh b/cpp/test/neighbors/ann_utils.cuh index b08e1d725..94bccade2 100644 --- a/cpp/test/neighbors/ann_utils.cuh +++ b/cpp/test/neighbors/ann_utils.cuh @@ -16,6 +16,7 @@ #pragma once +#include #include #include // raft::make_device_matrix #include @@ -165,9 +166,14 @@ auto calc_recall(const std::vector& expected_idx, /** check uniqueness of indices */ template -auto check_unique_indices(const std::vector& actual_idx, size_t rows, size_t cols) +auto check_unique_indices(const std::vector& actual_idx, + size_t rows, + size_t cols, + size_t max_duplicates = 0) { size_t max_count; + size_t dup_count = 0lu; + std::set unique_indices; for (size_t i = 0; i < rows; ++i) { unique_indices.clear(); @@ -180,8 +186,11 @@ auto check_unique_indices(const std::vector& actual_idx, size_t rows, size_t } else if (unique_indices.find(act_idx) == unique_indices.end()) { unique_indices.insert(act_idx); } else { - return testing::AssertionFailure() - << "Duplicated index " << act_idx << " at k " << k << " for query " << i << "! "; + dup_count++; + if (dup_count > max_duplicates) { + return testing::AssertionFailure() + << "Duplicated index " << act_idx << " at k " << k << " for query " << i << "! "; + } } } } @@ -264,7 +273,8 @@ auto eval_neighbours(const std::vector& expected_idx, size_t cols, double eps, double min_recall, - bool test_unique = true) -> testing::AssertionResult + bool test_unique = true, + size_t max_duplicates = 0) -> testing::AssertionResult { auto [actual_recall, match_count, total_count] = calc_recall(expected_idx, actual_idx, expected_dist, actual_dist, rows, cols, eps); @@ -284,7 +294,7 @@ auto eval_neighbours(const std::vector& expected_idx, << min_recall << "); eps = " << eps << ". "; } if (test_unique) - return check_unique_indices(actual_idx, rows, cols); + return check_unique_indices(actual_idx, rows, cols, max_duplicates); else return testing::AssertionSuccess(); }