Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generate diskann cache asynchronously. #191

Merged
merged 1 commit into from
Nov 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 16 additions & 17 deletions knowhere/index/vector_index/IndexDiskANN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -290,9 +290,9 @@ IndexDiskANN<T>::Prepare(const Config& config) {
KNOWHERE_THROW_MSG("Failed to generate cache, num_nodes_to_cache is larger than 1/3 of the total data number.");
}
if (num_nodes_to_cache > 0) {
std::vector<uint32_t> node_list;
LOG_KNOWHERE_INFO_ << "Caching " << num_nodes_to_cache << " sample nodes around medoid(s).";
if (prep_conf.use_bfs_cache) {
std::vector<uint32_t> node_list;
auto gen_cache_successful = TryDiskANNCall<bool>([&]() -> bool {
pq_flash_index_->cache_bfs_levels(num_nodes_to_cache, node_list);
return true;
Expand All @@ -302,29 +302,28 @@ IndexDiskANN<T>::Prepare(const Config& config) {
LOG_KNOWHERE_ERROR_ << "Failed to generate bfs cache for DiskANN.";
return false;
}
} else {
auto gen_cache_successful = TryDiskANNCall<bool>([&]() -> bool {
pq_flash_index_->generate_cache_list_from_sample_queries(warmup_query_file, 15, 6, num_nodes_to_cache,
prep_conf.num_threads, node_list);
auto load_cache_successful = TryDiskANNCall<bool>([&]() -> bool {
pq_flash_index_->load_cache_list(node_list);
return true;
});

if (!gen_cache_successful.has_value()) {
LOG_KNOWHERE_ERROR_ << "Failed to generate cache from sample queries for DiskANN.";
if (!load_cache_successful.has_value()) {
LOG_KNOWHERE_ERROR_ << "Failed to load cache for DiskANN.";
return false;
}
}
auto load_cache_successful = TryDiskANNCall<bool>([&]() -> bool {
pq_flash_index_->load_cache_list(node_list);
return true;
});

if (!load_cache_successful.has_value()) {
LOG_KNOWHERE_ERROR_ << "Failed to load cache for DiskANN.";
return false;
} else {
pq_flash_index_->set_async_cache_flag(true);
pool_->push([&, cache_num = num_nodes_to_cache,
sample_nodes_file = warmup_query_file]() {
try {
pq_flash_index_->generate_cache_list_from_sample_queries(
sample_nodes_file, 15, 6, cache_num);
} catch (const std::exception& e) {
LOG_KNOWHERE_ERROR_ << "DiskANN Exception: " << e.what();
}
});
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: incorrect indentation

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated

}
}

// warmup
if (prep_conf.warm_up) {
LOG_KNOWHERE_DEBUG_ << "Warming up.";
Expand Down
17 changes: 11 additions & 6 deletions thirdparty/DiskANN/include/pq_flash_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "percentile_stats.h"
#include "pq_table.h"
#include "utils.h"
#include "semaphore.h"
#include "windows_customizations.h"

#define MAX_GRAPH_DEGREE 512
Expand Down Expand Up @@ -80,16 +81,15 @@ namespace diskann {

DISKANN_DLLEXPORT void load_cache_list(std::vector<uint32_t> &node_list);

// asynchronously collect the access frequency of each node in the graph
#ifdef EXEC_ENV_OLS
DISKANN_DLLEXPORT void generate_cache_list_from_sample_queries(
MemoryMappedFiles &files, std::string sample_bin, _u64 l_search,
_u64 beamwidth, _u64 num_nodes_to_cache, uint32_t nthreads,
std::vector<uint32_t> &node_list);
MemoryMappedFiles files, std::string sample_bin, _u64 l_search,
_u64 beamwidth, _u64 num_nodes_to_cache);
#else
DISKANN_DLLEXPORT void generate_cache_list_from_sample_queries(
std::string sample_bin, _u64 l_search, _u64 beamwidth,
_u64 num_nodes_to_cache, uint32_t num_threads,
std::vector<uint32_t> &node_list);
_u64 num_nodes_to_cache);
#endif

DISKANN_DLLEXPORT void cache_bfs_levels(_u64 num_nodes_to_cache,
Expand Down Expand Up @@ -128,6 +128,8 @@ namespace diskann {

DISKANN_DLLEXPORT diskann::Metric get_metric() const noexcept;

DISKANN_DLLEXPORT void set_async_cache_flag(const bool flag);

protected:
DISKANN_DLLEXPORT void use_medoids_data_as_centroids();
DISKANN_DLLEXPORT void setup_thread_data(_u64 nthreads);
Expand Down Expand Up @@ -195,6 +197,7 @@ namespace diskann {

std::string disk_index_file;
std::vector<std::pair<_u32, _u32>> node_visit_counter;
std::atomic<_u32> search_counter = 0;

// PQ data
// n_chunks = # of chunks ndims is split into
Expand Down Expand Up @@ -233,12 +236,14 @@ namespace diskann {
// coord_cache
T * coord_cache_buf = nullptr;
tsl::robin_map<_u32, T *> coord_cache;
Semaphore semaph;
std::atomic<bool> async_generate_cache = false;

// thread-specific scratch
ConcurrentQueue<ThreadData<T>> thread_data;
_u64 max_nthreads;
bool load_flag = false;
bool count_visited_nodes = false;
std::atomic<bool> count_visited_nodes = false;
bool reorder_data_exists = false;
_u64 reoreder_data_offset = 0;

Expand Down
35 changes: 35 additions & 0 deletions thirdparty/DiskANN/include/semaphore.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#pragma once
#include <mutex>
#include <condition_variable>

namespace diskann {
class Semaphore {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd use the syntax from https://en.cppreference.com/w/cpp/thread/counting_semaphore, so that we could transition the code to C++20 more easily

public:
Semaphore(long count = 0) : count(count) {}
void Signal()
{
std::unique_lock<std::mutex> unique(mt);
++count;
if (count <= 0) {
cond.notify_one();
}
}
void Wait()
{
std::unique_lock<std::mutex> unique(mt);
--count;
if (count < 0) {
cond.wait(unique);
}
}
bool IsWaitting() {
std::unique_lock<std::mutex> unique(mt);
return count < 0;
}

private:
std::mutex mt;
std::condition_variable cond;
long count;
};
} // namespace diskann
59 changes: 44 additions & 15 deletions thirdparty/DiskANN/src/pq_flash_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include <iterator>
#include <random>
#include <thread>
#include <mutex>
#include "distance.h"
#include "exceptions.h"
#include "parameters.h"
Expand Down Expand Up @@ -98,7 +99,7 @@ namespace diskann {
template<typename T>
PQFlashIndex<T>::PQFlashIndex(
std::shared_ptr<AlignedFileReader> fileReader, diskann::Metric m)
: reader(fileReader), metric(m) {
: reader(fileReader), metric(m), semaph(0) {
if (m == diskann::Metric::COSINE || m == diskann::Metric::INNER_PRODUCT) {
if (std::is_floating_point<T>::value) {
LOG(INFO) << "Cosine metric chosen for (normalized) float data."
Expand All @@ -119,6 +120,9 @@ namespace diskann {

template<typename T>
PQFlashIndex<T>::~PQFlashIndex() {
if (this->async_generate_cache) {
this->semaph.Wait();
}
#ifndef EXEC_ENV_OLS
if (data != nullptr) {
delete[] data;
Expand Down Expand Up @@ -216,6 +220,8 @@ namespace diskann {
template<typename T>
void PQFlashIndex<T>::load_cache_list(std::vector<uint32_t> &node_list) {
LOG(DEBUG) << "Loading the cache list into memory...";
assert(this->nhood_cache_buf == nullptr && "nhoodc_cache_buf is not null");
assert(this->coord_cache_buf == nullptr && "coord_cache_buf is not null");
_u64 num_cached_nodes = node_list.size();

// borrow thread data
Expand Down Expand Up @@ -293,20 +299,20 @@ namespace diskann {
#ifdef EXEC_ENV_OLS
template<typename T>
void PQFlashIndex<T>::generate_cache_list_from_sample_queries(
MemoryMappedFiles &files, std::string sample_bin, _u64 l_search,
_u64 beamwidth, _u64 num_nodes_to_cache, uint32_t nthreads,
std::vector<uint32_t> &node_list) {
MemoryMappedFiles files, std::string sample_bin, _u64 l_search,
_u64 beamwidth, _u64 num_nodes_to_cache) {
#else
template<typename T>
void PQFlashIndex<T>::generate_cache_list_from_sample_queries(
std::string sample_bin, _u64 l_search, _u64 beamwidth,
_u64 num_nodes_to_cache, uint32_t nthreads,
std::vector<uint32_t> &node_list) {
_u64 num_nodes_to_cache) {
#endif
auto s = std::chrono::high_resolution_clock::now();
this->count_visited_nodes = true;
this->search_counter.store(0);
this->node_visit_counter.clear();
this->node_visit_counter.resize(this->num_points);
this->count_visited_nodes.store(true);

for (_u32 i = 0; i < node_visit_counter.size(); i++) {
this->node_visit_counter[i].first = i;
this->node_visit_counter[i].second = 0;
Expand All @@ -332,32 +338,47 @@ namespace diskann {
return;
}

std::vector<int64_t> tmp_result_ids_64(sample_num, 0);
std::vector<float> tmp_result_dists(sample_num, 0);
int64_t tmp_result_ids_64;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any race condition risks?

float tmp_result_dists;

#pragma omp parallel for schedule(dynamic, 1) num_threads(nthreads)
for (_s64 i = 0; i < (int64_t) sample_num; i++) {
cached_beam_search(samples + (i * sample_aligned_dim), 1, l_search,
tmp_result_ids_64.data() + (i * 1),
tmp_result_dists.data() + (i * 1), beamwidth);
auto id = 0;
while (this->search_counter.load() < sample_num && id < sample_num &&
!this->semaph.IsWaitting()) {
cached_beam_search(samples + (id * sample_aligned_dim), 1, l_search,
&tmp_result_ids_64, &tmp_result_dists, beamwidth);
id++;
}

if (this->semaph.IsWaitting()) {
this->semaph.Signal();
return;
}

this->count_visited_nodes.store(false);
std::sort(this->node_visit_counter.begin(), node_visit_counter.end(),
[](std::pair<_u32, _u32> &left, std::pair<_u32, _u32> &right) {
return left.second > right.second;
});

std::vector<uint32_t> node_list;
node_list.clear();
node_list.shrink_to_fit();
node_list.reserve(num_nodes_to_cache);
for (_u64 i = 0; i < num_nodes_to_cache; i++) {
node_list.push_back(this->node_visit_counter[i].first);
}
this->count_visited_nodes = false;
this->node_visit_counter.clear();
liliu-z marked this conversation as resolved.
Show resolved Hide resolved
this->node_visit_counter.shrink_to_fit();
this->search_counter.store(0);

diskann::aligned_free(samples);
this->load_cache_list(node_list);
auto e = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> diff = e - s;
LOG(INFO) << "Using sample queries to generate cache, cost: " << diff.count() << "s";

this->semaph.Signal();
return;
}

template<typename T>
Expand Down Expand Up @@ -1464,6 +1485,9 @@ namespace diskann {
if (stats != nullptr) {
stats->total_us = (double) query_timer.elapsed();
}
if (this->count_visited_nodes) {
this->search_counter.fetch_add(1);
}
}

// range search returns results of all neighbors within distance of range.
Expand Down Expand Up @@ -1548,6 +1572,11 @@ namespace diskann {
diskann::Metric PQFlashIndex<T>::get_metric() const noexcept {
return metric;
}

template<typename T>
void PQFlashIndex<T>::set_async_cache_flag(const bool flag) {
this->async_generate_cache.exchange(flag);
}

#ifdef EXEC_ENV_OLS
template<typename T>
Expand Down
Loading