From fcff38163182882eb5110c0705b628abdec82886 Mon Sep 17 00:00:00 2001 From: cqy123456 Date: Wed, 22 Nov 2023 04:04:40 -0500 Subject: [PATCH] async generate diskann cache Signed-off-by: cqy123456 --- knowhere/index/vector_index/IndexDiskANN.cpp | 33 ++++++----- thirdparty/DiskANN/include/pq_flash_index.h | 17 ++++-- thirdparty/DiskANN/include/semaphore.h | 35 ++++++++++++ thirdparty/DiskANN/src/pq_flash_index.cpp | 59 +++++++++++++++----- 4 files changed, 106 insertions(+), 38 deletions(-) create mode 100644 thirdparty/DiskANN/include/semaphore.h diff --git a/knowhere/index/vector_index/IndexDiskANN.cpp b/knowhere/index/vector_index/IndexDiskANN.cpp index 654a0214d..fb19dcb5f 100644 --- a/knowhere/index/vector_index/IndexDiskANN.cpp +++ b/knowhere/index/vector_index/IndexDiskANN.cpp @@ -290,9 +290,9 @@ IndexDiskANN::Prepare(const Config& config) { KNOWHERE_THROW_MSG("Failed to generate cache, num_nodes_to_cache is larger than 1/3 of the total data number."); } if (num_nodes_to_cache > 0) { - std::vector node_list; LOG_KNOWHERE_INFO_ << "Caching " << num_nodes_to_cache << " sample nodes around medoid(s)."; if (prep_conf.use_bfs_cache) { + std::vector node_list; auto gen_cache_successful = TryDiskANNCall([&]() -> bool { pq_flash_index_->cache_bfs_levels(num_nodes_to_cache, node_list); return true; @@ -302,29 +302,28 @@ IndexDiskANN::Prepare(const Config& config) { LOG_KNOWHERE_ERROR_ << "Failed to generate bfs cache for DiskANN."; return false; } - } else { - auto gen_cache_successful = TryDiskANNCall([&]() -> bool { - pq_flash_index_->generate_cache_list_from_sample_queries(warmup_query_file, 15, 6, num_nodes_to_cache, - prep_conf.num_threads, node_list); + auto load_cache_successful = TryDiskANNCall([&]() -> bool { + pq_flash_index_->load_cache_list(node_list); return true; }); - if (!gen_cache_successful.has_value()) { - LOG_KNOWHERE_ERROR_ << "Failed to generate cache from sample queries for DiskANN."; + if (!load_cache_successful.has_value()) { + LOG_KNOWHERE_ERROR_ << "Failed to load cache for DiskANN."; return false; } - } - auto load_cache_successful = TryDiskANNCall([&]() -> bool { - pq_flash_index_->load_cache_list(node_list); - return true; - }); - - if (!load_cache_successful.has_value()) { - LOG_KNOWHERE_ERROR_ << "Failed to load cache for DiskANN."; - return false; + } else { + pq_flash_index_->set_async_cache_flag(true); + pool_->push([&, cache_num = num_nodes_to_cache, + sample_nodes_file = warmup_query_file]() { + try { + pq_flash_index_->generate_cache_list_from_sample_queries( + sample_nodes_file, 15, 6, cache_num); + } catch (const std::exception& e) { + LOG_KNOWHERE_ERROR_ << "DiskANN Exception: " << e.what(); + } + }); } } - // warmup if (prep_conf.warm_up) { LOG_KNOWHERE_DEBUG_ << "Warming up."; diff --git a/thirdparty/DiskANN/include/pq_flash_index.h b/thirdparty/DiskANN/include/pq_flash_index.h index a31c19f26..7278fcc58 100644 --- a/thirdparty/DiskANN/include/pq_flash_index.h +++ b/thirdparty/DiskANN/include/pq_flash_index.h @@ -22,6 +22,7 @@ #include "percentile_stats.h" #include "pq_table.h" #include "utils.h" +#include "semaphore.h" #include "windows_customizations.h" #define MAX_GRAPH_DEGREE 512 @@ -80,16 +81,15 @@ namespace diskann { DISKANN_DLLEXPORT void load_cache_list(std::vector &node_list); + // asynchronously collect the access frequency of each node in the graph #ifdef EXEC_ENV_OLS DISKANN_DLLEXPORT void generate_cache_list_from_sample_queries( - MemoryMappedFiles &files, std::string sample_bin, _u64 l_search, - _u64 beamwidth, _u64 num_nodes_to_cache, uint32_t nthreads, - std::vector &node_list); + MemoryMappedFiles files, std::string sample_bin, _u64 l_search, + _u64 beamwidth, _u64 num_nodes_to_cache); #else DISKANN_DLLEXPORT void generate_cache_list_from_sample_queries( std::string sample_bin, _u64 l_search, _u64 beamwidth, - _u64 num_nodes_to_cache, uint32_t num_threads, - std::vector &node_list); + _u64 num_nodes_to_cache); #endif DISKANN_DLLEXPORT void cache_bfs_levels(_u64 num_nodes_to_cache, @@ -128,6 +128,8 @@ namespace diskann { DISKANN_DLLEXPORT diskann::Metric get_metric() const noexcept; + DISKANN_DLLEXPORT void set_async_cache_flag(const bool flag); + protected: DISKANN_DLLEXPORT void use_medoids_data_as_centroids(); DISKANN_DLLEXPORT void setup_thread_data(_u64 nthreads); @@ -195,6 +197,7 @@ namespace diskann { std::string disk_index_file; std::vector> node_visit_counter; + std::atomic<_u32> search_counter = 0; // PQ data // n_chunks = # of chunks ndims is split into @@ -233,12 +236,14 @@ namespace diskann { // coord_cache T * coord_cache_buf = nullptr; tsl::robin_map<_u32, T *> coord_cache; + Semaphore semaph; + std::atomic async_generate_cache = false; // thread-specific scratch ConcurrentQueue> thread_data; _u64 max_nthreads; bool load_flag = false; - bool count_visited_nodes = false; + std::atomic count_visited_nodes = false; bool reorder_data_exists = false; _u64 reoreder_data_offset = 0; diff --git a/thirdparty/DiskANN/include/semaphore.h b/thirdparty/DiskANN/include/semaphore.h new file mode 100644 index 000000000..a86258be3 --- /dev/null +++ b/thirdparty/DiskANN/include/semaphore.h @@ -0,0 +1,35 @@ +#pragma once +#include +#include + +namespace diskann { +class Semaphore { +public: + Semaphore(long count = 0) : count(count) {} + void Signal() + { + std::unique_lock unique(mt); + ++count; + if (count <= 0) { + cond.notify_one(); + } + } + void Wait() + { + std::unique_lock unique(mt); + --count; + if (count < 0) { + cond.wait(unique); + } + } + bool IsWaitting() { + std::unique_lock unique(mt); + return count < 0; + } + +private: + std::mutex mt; + std::condition_variable cond; + long count; +}; +} // namespace diskann \ No newline at end of file diff --git a/thirdparty/DiskANN/src/pq_flash_index.cpp b/thirdparty/DiskANN/src/pq_flash_index.cpp index ee0c9c670..213fc303c 100644 --- a/thirdparty/DiskANN/src/pq_flash_index.cpp +++ b/thirdparty/DiskANN/src/pq_flash_index.cpp @@ -14,6 +14,7 @@ #include #include #include +#include #include "distance.h" #include "exceptions.h" #include "parameters.h" @@ -98,7 +99,7 @@ namespace diskann { template PQFlashIndex::PQFlashIndex( std::shared_ptr fileReader, diskann::Metric m) - : reader(fileReader), metric(m) { + : reader(fileReader), metric(m), semaph(0) { if (m == diskann::Metric::COSINE || m == diskann::Metric::INNER_PRODUCT) { if (std::is_floating_point::value) { LOG(INFO) << "Cosine metric chosen for (normalized) float data." @@ -119,6 +120,9 @@ namespace diskann { template PQFlashIndex::~PQFlashIndex() { + if (this->async_generate_cache) { + this->semaph.Wait(); + } #ifndef EXEC_ENV_OLS if (data != nullptr) { delete[] data; @@ -216,6 +220,8 @@ namespace diskann { template void PQFlashIndex::load_cache_list(std::vector &node_list) { LOG(DEBUG) << "Loading the cache list into memory..."; + assert(this->nhood_cache_buf == nullptr && "nhoodc_cache_buf is not null"); + assert(this->coord_cache_buf == nullptr && "coord_cache_buf is not null"); _u64 num_cached_nodes = node_list.size(); // borrow thread data @@ -293,20 +299,20 @@ namespace diskann { #ifdef EXEC_ENV_OLS template void PQFlashIndex::generate_cache_list_from_sample_queries( - MemoryMappedFiles &files, std::string sample_bin, _u64 l_search, - _u64 beamwidth, _u64 num_nodes_to_cache, uint32_t nthreads, - std::vector &node_list) { + MemoryMappedFiles files, std::string sample_bin, _u64 l_search, + _u64 beamwidth, _u64 num_nodes_to_cache) { #else template void PQFlashIndex::generate_cache_list_from_sample_queries( std::string sample_bin, _u64 l_search, _u64 beamwidth, - _u64 num_nodes_to_cache, uint32_t nthreads, - std::vector &node_list) { + _u64 num_nodes_to_cache) { #endif auto s = std::chrono::high_resolution_clock::now(); - this->count_visited_nodes = true; + this->search_counter.store(0); this->node_visit_counter.clear(); this->node_visit_counter.resize(this->num_points); + this->count_visited_nodes.store(true); + for (_u32 i = 0; i < node_visit_counter.size(); i++) { this->node_visit_counter[i].first = i; this->node_visit_counter[i].second = 0; @@ -332,32 +338,47 @@ namespace diskann { return; } - std::vector tmp_result_ids_64(sample_num, 0); - std::vector tmp_result_dists(sample_num, 0); + int64_t tmp_result_ids_64; + float tmp_result_dists; -#pragma omp parallel for schedule(dynamic, 1) num_threads(nthreads) - for (_s64 i = 0; i < (int64_t) sample_num; i++) { - cached_beam_search(samples + (i * sample_aligned_dim), 1, l_search, - tmp_result_ids_64.data() + (i * 1), - tmp_result_dists.data() + (i * 1), beamwidth); + auto id = 0; + while (this->search_counter.load() < sample_num && id < sample_num && + !this->semaph.IsWaitting()) { + cached_beam_search(samples + (id * sample_aligned_dim), 1, l_search, + &tmp_result_ids_64, &tmp_result_dists, beamwidth); + id++; } + if (this->semaph.IsWaitting()) { + this->semaph.Signal(); + return; + } + + this->count_visited_nodes.store(false); std::sort(this->node_visit_counter.begin(), node_visit_counter.end(), [](std::pair<_u32, _u32> &left, std::pair<_u32, _u32> &right) { return left.second > right.second; }); + + std::vector node_list; node_list.clear(); node_list.shrink_to_fit(); node_list.reserve(num_nodes_to_cache); for (_u64 i = 0; i < num_nodes_to_cache; i++) { node_list.push_back(this->node_visit_counter[i].first); } - this->count_visited_nodes = false; + this->node_visit_counter.clear(); + this->node_visit_counter.shrink_to_fit(); + this->search_counter.store(0); diskann::aligned_free(samples); + this->load_cache_list(node_list); auto e = std::chrono::high_resolution_clock::now(); std::chrono::duration diff = e - s; LOG(INFO) << "Using sample queries to generate cache, cost: " << diff.count() << "s"; + + this->semaph.Signal(); + return; } template @@ -1464,6 +1485,9 @@ namespace diskann { if (stats != nullptr) { stats->total_us = (double) query_timer.elapsed(); } + if (this->count_visited_nodes) { + this->search_counter.fetch_add(1); + } } // range search returns results of all neighbors within distance of range. @@ -1548,6 +1572,11 @@ namespace diskann { diskann::Metric PQFlashIndex::get_metric() const noexcept { return metric; } + + template + void PQFlashIndex::set_async_cache_flag(const bool flag) { + this->async_generate_cache.exchange(flag); + } #ifdef EXEC_ENV_OLS template