Skip to content

Commit

Permalink
async generate diskann cache after diskann deserialize
Browse files Browse the repository at this point in the history
Signed-off-by: [email protected] <cqy123456>
  • Loading branch information
[email protected] committed Nov 27, 2023
1 parent aeb6931 commit 0fa0627
Show file tree
Hide file tree
Showing 4 changed files with 145 additions and 71 deletions.
37 changes: 21 additions & 16 deletions src/index/diskann/diskann.cc
Original file line number Diff line number Diff line change
Expand Up @@ -411,8 +411,9 @@ DiskANNIndexNode<T>::Deserialize(const BinarySet& binset, const Config& cfg) {
std::string warmup_query_file = diskann::get_sample_data_filename(index_prefix_);
// load cache
auto cached_nodes_file = diskann::get_cached_nodes_file(index_prefix_);
std::vector<uint32_t> node_list;
if (file_exists(cached_nodes_file)) {
// get cached nodes id from file.
std::vector<uint32_t> node_list;
LOG_KNOWHERE_INFO_ << "Reading cached nodes from file.";
size_t num_nodes, nodes_id_dim;
uint32_t* cached_nodes_ids = nullptr;
Expand All @@ -421,7 +422,14 @@ DiskANNIndexNode<T>::Deserialize(const BinarySet& binset, const Config& cfg) {
if (cached_nodes_ids != nullptr) {
delete[] cached_nodes_ids;
}
if (node_list.size() > 0) {
if (TryDiskANNCall([&]() { pq_flash_index_->load_cache_list(node_list); }) != Status::success) {
LOG_KNOWHERE_ERROR_ << "Failed to load cache for DiskANN.";
return Status::diskann_inner_error;
}
}
} else {
// TODO: generate cache in Deserialize function will be remove later.
auto num_nodes_to_cache = GetCachedNodeNum(prep_conf.search_cache_budget_gb.value(),
pq_flash_index_->get_data_dim(), pq_flash_index_->get_max_degree());
if (num_nodes_to_cache > pq_flash_index_->get_num_points() / 3) {
Expand All @@ -433,32 +441,29 @@ DiskANNIndexNode<T>::Deserialize(const BinarySet& binset, const Config& cfg) {
LOG_KNOWHERE_INFO_ << "Caching " << num_nodes_to_cache << " sample nodes around medoid(s).";
if (prep_conf.use_bfs_cache.value()) {
LOG_KNOWHERE_INFO_ << "Use bfs to generate cache list";
std::vector<uint32_t> node_list;
if (TryDiskANNCall([&]() { pq_flash_index_->cache_bfs_levels(num_nodes_to_cache, node_list); }) !=
Status::success) {
LOG_KNOWHERE_ERROR_ << "Failed to generate bfs cache for DiskANN.";
return Status::diskann_inner_error;
}
} else {
LOG_KNOWHERE_INFO_ << "Use sample_queries to generate cache list";
if (TryDiskANNCall([&]() {
pq_flash_index_->generate_cache_list_from_sample_queries(warmup_query_file, 15, 6,
num_nodes_to_cache, node_list);
}) != Status::success) {
LOG_KNOWHERE_ERROR_ << "Failed to generate cache from sample queries for DiskANN.";
return Status::diskann_inner_error;
if (node_list.size() > 0) {
if (TryDiskANNCall([&]() { pq_flash_index_->load_cache_list(node_list); }) != Status::success) {
LOG_KNOWHERE_ERROR_ << "Failed to load cache for DiskANN.";
return Status::diskann_inner_error;
}
}
} else {
LOG_KNOWHERE_INFO_ << "Use sample_queries and input queries to generate cache asynchronously.";
pq_flash_index_->set_async_cache_flag(true);
search_pool_->push([&, sample_file = warmup_query_file, cache_num = num_nodes_to_cache]() {
pq_flash_index_->async_generate_cache_list_from_sample_queries(sample_file, 15, 6, cache_num);
});
}
}
LOG_KNOWHERE_INFO_ << "End of preparing diskann index.";
}

if (node_list.size() > 0) {
if (TryDiskANNCall([&]() { pq_flash_index_->load_cache_list(node_list); }) != Status::success) {
LOG_KNOWHERE_ERROR_ << "Failed to load cache for DiskANN.";
return Status::diskann_inner_error;
}
}

// warmup
if (prep_conf.warm_up.value()) {
LOG_KNOWHERE_INFO_ << "Warming up.";
Expand Down
15 changes: 11 additions & 4 deletions thirdparty/DiskANN/include/diskann/pq_flash_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "percentile_stats.h"
#include "pq_table.h"
#include "utils.h"
#include "semaphore.h"
#include "diskann/distance.h"
#include "knowhere/comp/thread_pool.h"

Expand Down Expand Up @@ -73,9 +74,12 @@ namespace diskann {

void load_cache_list(std::vector<uint32_t> &node_list);

void generate_cache_list_from_sample_queries(
// set async cache flag before calling async_generate_cache_list_from_sample_queries,
void set_async_cache_flag(const bool flag);
// asynchronously collect the access frequency of each node in the graph
void async_generate_cache_list_from_sample_queries(
std::string sample_bin, _u64 l_search, _u64 beamwidth,
_u64 num_nodes_to_cache, std::vector<uint32_t> &node_list);
_u64 num_nodes_to_cache);

void cache_bfs_levels(_u64 num_nodes_to_cache,
std::vector<uint32_t> &node_list);
Expand Down Expand Up @@ -192,7 +196,10 @@ namespace diskann {

std::string disk_index_file;
std::vector<std::pair<_u32, _u32>> node_visit_counter;

std::atomic<_u32> search_counter = 0;
// used for async generate cache.
Semaphore semaph;
std::atomic<bool> async_generate_cache = false;
// PQ data
// n_chunks = # of chunks ndims is split into
// data: _u8 * n_chunks
Expand Down Expand Up @@ -253,7 +260,7 @@ namespace diskann {
ConcurrentQueue<ThreadData<T>> thread_data;
_u64 max_nthreads;
bool load_flag = false;
bool count_visited_nodes = false;
std::atomic<bool> count_visited_nodes = false;
bool reorder_data_exists = false;
_u64 reoreder_data_offset = 0;

Expand Down
38 changes: 38 additions & 0 deletions thirdparty/DiskANN/include/diskann/semaphore.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.

#pragma once
#include <mutex>
#include <condition_variable>

namespace diskann {
class Semaphore {
public:
Semaphore(long count = 0) : count(count) {}
void Signal()
{
std::unique_lock<std::mutex> unique(mt);
++count;
if (count <= 0) {
cond.notify_one();
}
}
void Wait()
{
std::unique_lock<std::mutex> unique(mt);
--count;
if (count < 0) {
cond.wait(unique);
}
}
bool IsWaitting() {
std::unique_lock<std::mutex> unique(mt);
return count < 0;
}

private:
std::mutex mt;
std::condition_variable cond;
long count;
};
} // namespace diskann
126 changes: 75 additions & 51 deletions thirdparty/DiskANN/src/pq_flash_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ namespace diskann {
template<typename T>
PQFlashIndex<T>::PQFlashIndex(std::shared_ptr<AlignedFileReader> fileReader,
diskann::Metric m)
: reader(fileReader), metric(m) {
: reader(fileReader), metric(m), semaph(0) {
if (m == diskann::Metric::INNER_PRODUCT || m == diskann::Metric::COSINE) {
if (!std::is_floating_point<T>::value) {
LOG(WARNING) << "Cannot normalize integral data types."
Expand All @@ -80,6 +80,9 @@ namespace diskann {

template<typename T>
PQFlashIndex<T>::~PQFlashIndex() {
if (this->async_generate_cache) {
this->semaph.Wait();
}
if (data != nullptr) {
delete[] data;
}
Expand Down Expand Up @@ -184,6 +187,8 @@ namespace diskann {

template<typename T>
void PQFlashIndex<T>::load_cache_list(std::vector<uint32_t> &node_list) {
assert(this->nhood_cache_buf == nullptr && "nhoodc_cache_buf is not null");
assert(this->coord_cache_buf == nullptr && "coord_cache_buf is not null");
_u64 num_cached_nodes = node_list.size();
LOG_KNOWHERE_DEBUG_ << "Loading the cache list(" << num_cached_nodes
<< " points) into memory...";
Expand Down Expand Up @@ -247,62 +252,78 @@ namespace diskann {
}

template<typename T>
void PQFlashIndex<T>::generate_cache_list_from_sample_queries(
std::string sample_bin, _u64 l_search, _u64 beamwidth,
_u64 num_nodes_to_cache, std::vector<uint32_t> &node_list) {
this->count_visited_nodes = true;
this->node_visit_counter.clear();
this->node_visit_counter.resize(this->num_points);
for (_u32 i = 0; i < node_visit_counter.size(); i++) {
this->node_visit_counter[i].first = i;
this->node_visit_counter[i].second = 0;
}
void PQFlashIndex<T>::set_async_cache_flag(const bool flag) {
this->async_generate_cache.exchange(flag);
}

_u64 sample_num, sample_dim, sample_aligned_dim;
T *samples;
template<typename T>
void PQFlashIndex<T>::async_generate_cache_list_from_sample_queries(
std::string sample_bin, _u64 l_search, _u64 beamwidth,
_u64 num_nodes_to_cache) {
try {
auto s = std::chrono::high_resolution_clock::now();
this->search_counter.store(0);
this->node_visit_counter.clear();
this->node_visit_counter.resize(this->num_points);
for (_u32 i = 0; i < node_visit_counter.size(); i++) {
this->node_visit_counter[i].first = i;
this->node_visit_counter[i].second = 0;
}
this->count_visited_nodes.store(true);

if (file_exists(sample_bin)) {
diskann::load_aligned_bin<T>(sample_bin, samples, sample_num, sample_dim,
sample_aligned_dim);
}
else {
diskann::cerr << "Sample bin file not found. Not generating cache."
<< std::endl;
return;
}
_u64 sample_num, sample_dim, sample_aligned_dim;
T *samples;

std::vector<int64_t> tmp_result_ids_64(sample_num, 0);
std::vector<float> tmp_result_dists(sample_num, 0);

auto thread_pool = knowhere::ThreadPool::GetGlobalSearchThreadPool();
std::vector<folly::Future<folly::Unit>> futures;
futures.reserve(sample_num);
for (_s64 i = 0; i < (int64_t) sample_num; i++) {
futures.emplace_back(thread_pool->push([&, index = i]() {
cached_beam_search(samples + (index * sample_aligned_dim), 1, l_search,
tmp_result_ids_64.data() + (index * 1),
tmp_result_dists.data() + (index * 1), beamwidth);
}));
}
if (file_exists(sample_bin)) {
diskann::load_aligned_bin<T>(sample_bin, samples, sample_num, sample_dim,
sample_aligned_dim);
} else {
std::stringstream stream;
stream << "Sample bin file not found. Not generating cache."
<< std::endl;
throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__,
__LINE__);
}

for (auto &future : futures) {
future.wait();
}
int64_t tmp_result_ids_64 = 0;
float tmp_result_dists = 0.0;

std::sort(this->node_visit_counter.begin(), node_visit_counter.end(),
[](std::pair<_u32, _u32> &left, std::pair<_u32, _u32> &right) {
return left.second > right.second;
});
node_list.clear();
node_list.shrink_to_fit();
node_list.reserve(num_nodes_to_cache);
for (_u64 i = 0; i < num_nodes_to_cache; i++) {
node_list.push_back(this->node_visit_counter[i].first);
_u64 id = 0;
while (this->search_counter.load() < sample_num && id < sample_num &&
!this->semaph.IsWaitting()) {
cached_beam_search(samples + (id * sample_aligned_dim), 1, l_search,
&tmp_result_ids_64, &tmp_result_dists, beamwidth);
id++;
}
if (this->semaph.IsWaitting()) {
diskann::aligned_free(samples);
this->semaph.Signal();
return;
}
this->count_visited_nodes.store(false);
std::sort(this->node_visit_counter.begin(), node_visit_counter.end(),
[](std::pair<_u32, _u32> &left, std::pair<_u32, _u32> &right) {
return left.second > right.second;
});
std::vector<uint32_t> node_list;
node_list.reserve(num_nodes_to_cache);
for (_u64 i = 0; i < num_nodes_to_cache; i++) {
node_list.push_back(this->node_visit_counter[i].first);
}
std::vector<std::pair<_u32, _u32>>().swap(this->node_visit_counter);

diskann::aligned_free(samples);
this->load_cache_list(node_list);
this->semaph.Signal();
auto e = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> diff = e - s;
LOG(INFO) << "Using sample queries to generate cache, cost: " << diff.count() << "s";
return;
} catch (const std::exception& e) {
LOG(ERROR)<< "DiskANN Other Exception: " << e.what();
this->semaph.Signal();
return;
}
this->count_visited_nodes = false;
std::vector<std::pair<_u32, _u32>>().swap(this->node_visit_counter);

diskann::aligned_free(samples);
}

template<typename T>
Expand Down Expand Up @@ -861,6 +882,9 @@ namespace diskann {
if (stats != nullptr) {
stats->total_us = (double) query_timer.elapsed();
}
if (this->count_visited_nodes) {
this->search_counter.fetch_add(1);
}
return;
}

Expand Down

0 comments on commit 0fa0627

Please sign in to comment.