-
Notifications
You must be signed in to change notification settings - Fork 83
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Generate diskann cache asynchronously. #191
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
#pragma once | ||
#include <mutex> | ||
#include <condition_variable> | ||
|
||
namespace diskann { | ||
class Semaphore { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd use the syntax from https://en.cppreference.com/w/cpp/thread/counting_semaphore, so that we could transition the code to C++20 more easily |
||
public: | ||
Semaphore(long count = 0) : count(count) {} | ||
void Signal() | ||
{ | ||
std::unique_lock<std::mutex> unique(mt); | ||
++count; | ||
if (count <= 0) { | ||
cond.notify_one(); | ||
} | ||
} | ||
void Wait() | ||
{ | ||
std::unique_lock<std::mutex> unique(mt); | ||
--count; | ||
if (count < 0) { | ||
cond.wait(unique); | ||
} | ||
} | ||
bool IsWaitting() { | ||
std::unique_lock<std::mutex> unique(mt); | ||
return count < 0; | ||
} | ||
|
||
private: | ||
std::mutex mt; | ||
std::condition_variable cond; | ||
long count; | ||
}; | ||
} // namespace diskann |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,6 +14,7 @@ | |
#include <iterator> | ||
#include <random> | ||
#include <thread> | ||
#include <mutex> | ||
#include "distance.h" | ||
#include "exceptions.h" | ||
#include "parameters.h" | ||
|
@@ -98,7 +99,7 @@ namespace diskann { | |
template<typename T> | ||
PQFlashIndex<T>::PQFlashIndex( | ||
std::shared_ptr<AlignedFileReader> fileReader, diskann::Metric m) | ||
: reader(fileReader), metric(m) { | ||
: reader(fileReader), metric(m), semaph(0) { | ||
if (m == diskann::Metric::COSINE || m == diskann::Metric::INNER_PRODUCT) { | ||
if (std::is_floating_point<T>::value) { | ||
LOG(INFO) << "Cosine metric chosen for (normalized) float data." | ||
|
@@ -119,6 +120,9 @@ namespace diskann { | |
|
||
template<typename T> | ||
PQFlashIndex<T>::~PQFlashIndex() { | ||
if (this->async_generate_cache) { | ||
this->semaph.Wait(); | ||
} | ||
#ifndef EXEC_ENV_OLS | ||
if (data != nullptr) { | ||
delete[] data; | ||
|
@@ -216,6 +220,8 @@ namespace diskann { | |
template<typename T> | ||
void PQFlashIndex<T>::load_cache_list(std::vector<uint32_t> &node_list) { | ||
LOG(DEBUG) << "Loading the cache list into memory..."; | ||
assert(this->nhood_cache_buf == nullptr && "nhoodc_cache_buf is not null"); | ||
assert(this->coord_cache_buf == nullptr && "coord_cache_buf is not null"); | ||
_u64 num_cached_nodes = node_list.size(); | ||
|
||
// borrow thread data | ||
|
@@ -293,20 +299,20 @@ namespace diskann { | |
#ifdef EXEC_ENV_OLS | ||
template<typename T> | ||
void PQFlashIndex<T>::generate_cache_list_from_sample_queries( | ||
MemoryMappedFiles &files, std::string sample_bin, _u64 l_search, | ||
_u64 beamwidth, _u64 num_nodes_to_cache, uint32_t nthreads, | ||
std::vector<uint32_t> &node_list) { | ||
MemoryMappedFiles files, std::string sample_bin, _u64 l_search, | ||
_u64 beamwidth, _u64 num_nodes_to_cache) { | ||
#else | ||
template<typename T> | ||
void PQFlashIndex<T>::generate_cache_list_from_sample_queries( | ||
std::string sample_bin, _u64 l_search, _u64 beamwidth, | ||
_u64 num_nodes_to_cache, uint32_t nthreads, | ||
std::vector<uint32_t> &node_list) { | ||
_u64 num_nodes_to_cache) { | ||
#endif | ||
auto s = std::chrono::high_resolution_clock::now(); | ||
this->count_visited_nodes = true; | ||
this->search_counter.store(0); | ||
this->node_visit_counter.clear(); | ||
this->node_visit_counter.resize(this->num_points); | ||
this->count_visited_nodes.store(true); | ||
|
||
for (_u32 i = 0; i < node_visit_counter.size(); i++) { | ||
this->node_visit_counter[i].first = i; | ||
this->node_visit_counter[i].second = 0; | ||
|
@@ -332,32 +338,47 @@ namespace diskann { | |
return; | ||
} | ||
|
||
std::vector<int64_t> tmp_result_ids_64(sample_num, 0); | ||
std::vector<float> tmp_result_dists(sample_num, 0); | ||
int64_t tmp_result_ids_64; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Any race condition risks? |
||
float tmp_result_dists; | ||
|
||
#pragma omp parallel for schedule(dynamic, 1) num_threads(nthreads) | ||
for (_s64 i = 0; i < (int64_t) sample_num; i++) { | ||
cached_beam_search(samples + (i * sample_aligned_dim), 1, l_search, | ||
tmp_result_ids_64.data() + (i * 1), | ||
tmp_result_dists.data() + (i * 1), beamwidth); | ||
auto id = 0; | ||
while (this->search_counter.load() < sample_num && id < sample_num && | ||
!this->semaph.IsWaitting()) { | ||
cached_beam_search(samples + (id * sample_aligned_dim), 1, l_search, | ||
&tmp_result_ids_64, &tmp_result_dists, beamwidth); | ||
id++; | ||
} | ||
|
||
if (this->semaph.IsWaitting()) { | ||
this->semaph.Signal(); | ||
return; | ||
} | ||
|
||
this->count_visited_nodes.store(false); | ||
std::sort(this->node_visit_counter.begin(), node_visit_counter.end(), | ||
[](std::pair<_u32, _u32> &left, std::pair<_u32, _u32> &right) { | ||
return left.second > right.second; | ||
}); | ||
|
||
std::vector<uint32_t> node_list; | ||
node_list.clear(); | ||
node_list.shrink_to_fit(); | ||
node_list.reserve(num_nodes_to_cache); | ||
for (_u64 i = 0; i < num_nodes_to_cache; i++) { | ||
node_list.push_back(this->node_visit_counter[i].first); | ||
} | ||
this->count_visited_nodes = false; | ||
this->node_visit_counter.clear(); | ||
liliu-z marked this conversation as resolved.
Show resolved
Hide resolved
|
||
this->node_visit_counter.shrink_to_fit(); | ||
this->search_counter.store(0); | ||
|
||
diskann::aligned_free(samples); | ||
this->load_cache_list(node_list); | ||
auto e = std::chrono::high_resolution_clock::now(); | ||
std::chrono::duration<double> diff = e - s; | ||
LOG(INFO) << "Using sample queries to generate cache, cost: " << diff.count() << "s"; | ||
|
||
this->semaph.Signal(); | ||
return; | ||
} | ||
|
||
template<typename T> | ||
|
@@ -1464,6 +1485,9 @@ namespace diskann { | |
if (stats != nullptr) { | ||
stats->total_us = (double) query_timer.elapsed(); | ||
} | ||
if (this->count_visited_nodes) { | ||
this->search_counter.fetch_add(1); | ||
} | ||
} | ||
|
||
// range search returns results of all neighbors within distance of range. | ||
|
@@ -1548,6 +1572,11 @@ namespace diskann { | |
diskann::Metric PQFlashIndex<T>::get_metric() const noexcept { | ||
return metric; | ||
} | ||
|
||
template<typename T> | ||
void PQFlashIndex<T>::set_async_cache_flag(const bool flag) { | ||
this->async_generate_cache.exchange(flag); | ||
} | ||
|
||
#ifdef EXEC_ENV_OLS | ||
template<typename T> | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: incorrect indentation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated