From 0fa062767cd2f21ea9d605a4e18a48c3c2913813 Mon Sep 17 00:00:00 2001 From: "qianya.cheng@zilliz.com" Date: Mon, 27 Nov 2023 08:35:06 -0500 Subject: [PATCH] async generate diskann cache after diskann deserialize Signed-off-by: qianya.cheng@zilliz.com --- src/index/diskann/diskann.cc | 37 ++--- .../DiskANN/include/diskann/pq_flash_index.h | 15 ++- .../DiskANN/include/diskann/semaphore.h | 38 ++++++ thirdparty/DiskANN/src/pq_flash_index.cpp | 126 +++++++++++------- 4 files changed, 145 insertions(+), 71 deletions(-) create mode 100644 thirdparty/DiskANN/include/diskann/semaphore.h diff --git a/src/index/diskann/diskann.cc b/src/index/diskann/diskann.cc index c1c9f11cc..3f4ee5854 100644 --- a/src/index/diskann/diskann.cc +++ b/src/index/diskann/diskann.cc @@ -411,8 +411,9 @@ DiskANNIndexNode::Deserialize(const BinarySet& binset, const Config& cfg) { std::string warmup_query_file = diskann::get_sample_data_filename(index_prefix_); // load cache auto cached_nodes_file = diskann::get_cached_nodes_file(index_prefix_); - std::vector node_list; if (file_exists(cached_nodes_file)) { + // get cached nodes id from file. + std::vector node_list; LOG_KNOWHERE_INFO_ << "Reading cached nodes from file."; size_t num_nodes, nodes_id_dim; uint32_t* cached_nodes_ids = nullptr; @@ -421,7 +422,14 @@ DiskANNIndexNode::Deserialize(const BinarySet& binset, const Config& cfg) { if (cached_nodes_ids != nullptr) { delete[] cached_nodes_ids; } + if (node_list.size() > 0) { + if (TryDiskANNCall([&]() { pq_flash_index_->load_cache_list(node_list); }) != Status::success) { + LOG_KNOWHERE_ERROR_ << "Failed to load cache for DiskANN."; + return Status::diskann_inner_error; + } + } } else { + // TODO: generate cache in Deserialize function will be remove later. auto num_nodes_to_cache = GetCachedNodeNum(prep_conf.search_cache_budget_gb.value(), pq_flash_index_->get_data_dim(), pq_flash_index_->get_max_degree()); if (num_nodes_to_cache > pq_flash_index_->get_num_points() / 3) { @@ -433,32 +441,29 @@ DiskANNIndexNode::Deserialize(const BinarySet& binset, const Config& cfg) { LOG_KNOWHERE_INFO_ << "Caching " << num_nodes_to_cache << " sample nodes around medoid(s)."; if (prep_conf.use_bfs_cache.value()) { LOG_KNOWHERE_INFO_ << "Use bfs to generate cache list"; + std::vector node_list; if (TryDiskANNCall([&]() { pq_flash_index_->cache_bfs_levels(num_nodes_to_cache, node_list); }) != Status::success) { LOG_KNOWHERE_ERROR_ << "Failed to generate bfs cache for DiskANN."; return Status::diskann_inner_error; } - } else { - LOG_KNOWHERE_INFO_ << "Use sample_queries to generate cache list"; - if (TryDiskANNCall([&]() { - pq_flash_index_->generate_cache_list_from_sample_queries(warmup_query_file, 15, 6, - num_nodes_to_cache, node_list); - }) != Status::success) { - LOG_KNOWHERE_ERROR_ << "Failed to generate cache from sample queries for DiskANN."; - return Status::diskann_inner_error; + if (node_list.size() > 0) { + if (TryDiskANNCall([&]() { pq_flash_index_->load_cache_list(node_list); }) != Status::success) { + LOG_KNOWHERE_ERROR_ << "Failed to load cache for DiskANN."; + return Status::diskann_inner_error; + } } + } else { + LOG_KNOWHERE_INFO_ << "Use sample_queries and input queries to generate cache asynchronously."; + pq_flash_index_->set_async_cache_flag(true); + search_pool_->push([&, sample_file = warmup_query_file, cache_num = num_nodes_to_cache]() { + pq_flash_index_->async_generate_cache_list_from_sample_queries(sample_file, 15, 6, cache_num); + }); } } LOG_KNOWHERE_INFO_ << "End of preparing diskann index."; } - if (node_list.size() > 0) { - if (TryDiskANNCall([&]() { pq_flash_index_->load_cache_list(node_list); }) != Status::success) { - LOG_KNOWHERE_ERROR_ << "Failed to load cache for DiskANN."; - return Status::diskann_inner_error; - } - } - // warmup if (prep_conf.warm_up.value()) { LOG_KNOWHERE_INFO_ << "Warming up."; diff --git a/thirdparty/DiskANN/include/diskann/pq_flash_index.h b/thirdparty/DiskANN/include/diskann/pq_flash_index.h index 3082442f5..9b4f12dc5 100644 --- a/thirdparty/DiskANN/include/diskann/pq_flash_index.h +++ b/thirdparty/DiskANN/include/diskann/pq_flash_index.h @@ -22,6 +22,7 @@ #include "percentile_stats.h" #include "pq_table.h" #include "utils.h" +#include "semaphore.h" #include "diskann/distance.h" #include "knowhere/comp/thread_pool.h" @@ -73,9 +74,12 @@ namespace diskann { void load_cache_list(std::vector &node_list); - void generate_cache_list_from_sample_queries( + // set async cache flag before calling async_generate_cache_list_from_sample_queries, + void set_async_cache_flag(const bool flag); + // asynchronously collect the access frequency of each node in the graph + void async_generate_cache_list_from_sample_queries( std::string sample_bin, _u64 l_search, _u64 beamwidth, - _u64 num_nodes_to_cache, std::vector &node_list); + _u64 num_nodes_to_cache); void cache_bfs_levels(_u64 num_nodes_to_cache, std::vector &node_list); @@ -192,7 +196,10 @@ namespace diskann { std::string disk_index_file; std::vector> node_visit_counter; - + std::atomic<_u32> search_counter = 0; + // used for async generate cache. + Semaphore semaph; + std::atomic async_generate_cache = false; // PQ data // n_chunks = # of chunks ndims is split into // data: _u8 * n_chunks @@ -253,7 +260,7 @@ namespace diskann { ConcurrentQueue> thread_data; _u64 max_nthreads; bool load_flag = false; - bool count_visited_nodes = false; + std::atomic count_visited_nodes = false; bool reorder_data_exists = false; _u64 reoreder_data_offset = 0; diff --git a/thirdparty/DiskANN/include/diskann/semaphore.h b/thirdparty/DiskANN/include/diskann/semaphore.h new file mode 100644 index 000000000..0a434d110 --- /dev/null +++ b/thirdparty/DiskANN/include/diskann/semaphore.h @@ -0,0 +1,38 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#pragma once +#include +#include + +namespace diskann { +class Semaphore { +public: + Semaphore(long count = 0) : count(count) {} + void Signal() + { + std::unique_lock unique(mt); + ++count; + if (count <= 0) { + cond.notify_one(); + } + } + void Wait() + { + std::unique_lock unique(mt); + --count; + if (count < 0) { + cond.wait(unique); + } + } + bool IsWaitting() { + std::unique_lock unique(mt); + return count < 0; + } + +private: + std::mutex mt; + std::condition_variable cond; + long count; +}; +} // namespace diskann \ No newline at end of file diff --git a/thirdparty/DiskANN/src/pq_flash_index.cpp b/thirdparty/DiskANN/src/pq_flash_index.cpp index 357fc4ce5..de1e64283 100644 --- a/thirdparty/DiskANN/src/pq_flash_index.cpp +++ b/thirdparty/DiskANN/src/pq_flash_index.cpp @@ -60,7 +60,7 @@ namespace diskann { template PQFlashIndex::PQFlashIndex(std::shared_ptr fileReader, diskann::Metric m) - : reader(fileReader), metric(m) { + : reader(fileReader), metric(m), semaph(0) { if (m == diskann::Metric::INNER_PRODUCT || m == diskann::Metric::COSINE) { if (!std::is_floating_point::value) { LOG(WARNING) << "Cannot normalize integral data types." @@ -80,6 +80,9 @@ namespace diskann { template PQFlashIndex::~PQFlashIndex() { + if (this->async_generate_cache) { + this->semaph.Wait(); + } if (data != nullptr) { delete[] data; } @@ -184,6 +187,8 @@ namespace diskann { template void PQFlashIndex::load_cache_list(std::vector &node_list) { + assert(this->nhood_cache_buf == nullptr && "nhoodc_cache_buf is not null"); + assert(this->coord_cache_buf == nullptr && "coord_cache_buf is not null"); _u64 num_cached_nodes = node_list.size(); LOG_KNOWHERE_DEBUG_ << "Loading the cache list(" << num_cached_nodes << " points) into memory..."; @@ -247,62 +252,78 @@ namespace diskann { } template - void PQFlashIndex::generate_cache_list_from_sample_queries( - std::string sample_bin, _u64 l_search, _u64 beamwidth, - _u64 num_nodes_to_cache, std::vector &node_list) { - this->count_visited_nodes = true; - this->node_visit_counter.clear(); - this->node_visit_counter.resize(this->num_points); - for (_u32 i = 0; i < node_visit_counter.size(); i++) { - this->node_visit_counter[i].first = i; - this->node_visit_counter[i].second = 0; - } + void PQFlashIndex::set_async_cache_flag(const bool flag) { + this->async_generate_cache.exchange(flag); + } - _u64 sample_num, sample_dim, sample_aligned_dim; - T *samples; + template + void PQFlashIndex::async_generate_cache_list_from_sample_queries( + std::string sample_bin, _u64 l_search, _u64 beamwidth, + _u64 num_nodes_to_cache) { + try { + auto s = std::chrono::high_resolution_clock::now(); + this->search_counter.store(0); + this->node_visit_counter.clear(); + this->node_visit_counter.resize(this->num_points); + for (_u32 i = 0; i < node_visit_counter.size(); i++) { + this->node_visit_counter[i].first = i; + this->node_visit_counter[i].second = 0; + } + this->count_visited_nodes.store(true); - if (file_exists(sample_bin)) { - diskann::load_aligned_bin(sample_bin, samples, sample_num, sample_dim, - sample_aligned_dim); - } - else { - diskann::cerr << "Sample bin file not found. Not generating cache." - << std::endl; - return; - } + _u64 sample_num, sample_dim, sample_aligned_dim; + T *samples; - std::vector tmp_result_ids_64(sample_num, 0); - std::vector tmp_result_dists(sample_num, 0); - - auto thread_pool = knowhere::ThreadPool::GetGlobalSearchThreadPool(); - std::vector> futures; - futures.reserve(sample_num); - for (_s64 i = 0; i < (int64_t) sample_num; i++) { - futures.emplace_back(thread_pool->push([&, index = i]() { - cached_beam_search(samples + (index * sample_aligned_dim), 1, l_search, - tmp_result_ids_64.data() + (index * 1), - tmp_result_dists.data() + (index * 1), beamwidth); - })); - } + if (file_exists(sample_bin)) { + diskann::load_aligned_bin(sample_bin, samples, sample_num, sample_dim, + sample_aligned_dim); + } else { + std::stringstream stream; + stream << "Sample bin file not found. Not generating cache." + << std::endl; + throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, + __LINE__); + } - for (auto &future : futures) { - future.wait(); - } + int64_t tmp_result_ids_64 = 0; + float tmp_result_dists = 0.0; - std::sort(this->node_visit_counter.begin(), node_visit_counter.end(), - [](std::pair<_u32, _u32> &left, std::pair<_u32, _u32> &right) { - return left.second > right.second; - }); - node_list.clear(); - node_list.shrink_to_fit(); - node_list.reserve(num_nodes_to_cache); - for (_u64 i = 0; i < num_nodes_to_cache; i++) { - node_list.push_back(this->node_visit_counter[i].first); + _u64 id = 0; + while (this->search_counter.load() < sample_num && id < sample_num && + !this->semaph.IsWaitting()) { + cached_beam_search(samples + (id * sample_aligned_dim), 1, l_search, + &tmp_result_ids_64, &tmp_result_dists, beamwidth); + id++; + } + if (this->semaph.IsWaitting()) { + diskann::aligned_free(samples); + this->semaph.Signal(); + return; + } + this->count_visited_nodes.store(false); + std::sort(this->node_visit_counter.begin(), node_visit_counter.end(), + [](std::pair<_u32, _u32> &left, std::pair<_u32, _u32> &right) { + return left.second > right.second; + }); + std::vector node_list; + node_list.reserve(num_nodes_to_cache); + for (_u64 i = 0; i < num_nodes_to_cache; i++) { + node_list.push_back(this->node_visit_counter[i].first); + } + std::vector>().swap(this->node_visit_counter); + + diskann::aligned_free(samples); + this->load_cache_list(node_list); + this->semaph.Signal(); + auto e = std::chrono::high_resolution_clock::now(); + std::chrono::duration diff = e - s; + LOG(INFO) << "Using sample queries to generate cache, cost: " << diff.count() << "s"; + return; + } catch (const std::exception& e) { + LOG(ERROR)<< "DiskANN Other Exception: " << e.what(); + this->semaph.Signal(); + return; } - this->count_visited_nodes = false; - std::vector>().swap(this->node_visit_counter); - - diskann::aligned_free(samples); } template @@ -861,6 +882,9 @@ namespace diskann { if (stats != nullptr) { stats->total_us = (double) query_timer.elapsed(); } + if (this->count_visited_nodes) { + this->search_counter.fetch_add(1); + } return; }