Skip to content

Commit

Permalink
Add async thread pool for generating diskann cache
Browse files Browse the repository at this point in the history
Signed-off-by: cqy123456 <[email protected]>
  • Loading branch information
cqy123456 committed Nov 28, 2023
1 parent d63c403 commit 1c4a180
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 71 deletions.
15 changes: 15 additions & 0 deletions knowhere/common/ThreadPool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,4 +60,19 @@ ThreadPool::GetGlobalThreadPool() {
static auto pool = std::make_shared<ThreadPool>(global_thread_pool_size_);
return pool;
}

std::shared_ptr<ThreadPool>
ThreadPool::GetGlobalAsyncThreadPool() {
if (global_thread_pool_size_ == 0) {
std::lock_guard<std::mutex> lock(global_thread_pool_mutex_);
if (global_thread_pool_size_ == 0) {
global_thread_pool_size_ = std::thread::hardware_concurrency();
}
}
uint32_t async_thread_pool_size = int(std::ceil(global_thread_pool_size_ / 2.0));
LOG_KNOWHERE_WARNING_ << "async thread pool size init with thread number:"
<< async_thread_pool_size;
static auto async_pool = std::make_shared<ThreadPool>(async_thread_pool_size);
return async_pool;
}
} // namespace knowhere
9 changes: 9 additions & 0 deletions knowhere/common/ThreadPool.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

#pragma once

#include <math.h>
#include <memory>
#include <utility>

Expand Down Expand Up @@ -58,6 +59,14 @@ class ThreadPool {
static std::shared_ptr<ThreadPool>
GetGlobalThreadPool();

/**
* @brief Get the global async thread pool of knowhere.
*
* @return ThreadPool&
*/
static std::shared_ptr<ThreadPool>
GetGlobalAsyncThreadPool();

class ScopedOmpSetter {
int omp_before;
public:
Expand Down
14 changes: 6 additions & 8 deletions knowhere/index/vector_index/IndexDiskANN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -312,15 +312,13 @@ IndexDiskANN<T>::Prepare(const Config& config) {
return false;
}
} else {
pq_flash_index_->set_async_cache_flag(true);
pool_->push([&, cache_num = num_nodes_to_cache,
auto aysnc_pool_ = ThreadPool::GetGlobalAsyncThreadPool();

pq_flash_index_->setup_cache_sync_task();
aysnc_pool_->push([&, cache_num = num_nodes_to_cache,
sample_nodes_file = warmup_query_file]() {
try {
pq_flash_index_->generate_cache_list_from_sample_queries(
sample_nodes_file, 15, 6, cache_num);
} catch (const std::exception& e) {
LOG_KNOWHERE_ERROR_ << "DiskANN Exception: " << e.what();
}
pq_flash_index_->generate_cache_list_from_sample_queries(
sample_nodes_file, 15, 6, cache_num);
});
}
}
Expand Down
2 changes: 1 addition & 1 deletion thirdparty/DiskANN/include/pq_flash_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ namespace diskann {

DISKANN_DLLEXPORT diskann::Metric get_metric() const noexcept;

DISKANN_DLLEXPORT void set_async_cache_flag(const bool flag);
DISKANN_DLLEXPORT void setup_cache_sync_task();

protected:
DISKANN_DLLEXPORT void use_medoids_data_as_centroids();
Expand Down
134 changes: 72 additions & 62 deletions thirdparty/DiskANN/src/pq_flash_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -307,77 +307,78 @@ namespace diskann {
std::string sample_bin, _u64 l_search, _u64 beamwidth,
_u64 num_nodes_to_cache) {
#endif
auto s = std::chrono::high_resolution_clock::now();
this->search_counter.store(0);
this->node_visit_counter.clear();
this->node_visit_counter.resize(this->num_points);
this->count_visited_nodes.store(true);

for (_u32 i = 0; i < node_visit_counter.size(); i++) {
this->node_visit_counter[i].first = i;
this->node_visit_counter[i].second = 0;
}

_u64 sample_num, sample_dim, sample_aligned_dim;
T * samples;
try {
auto s = std::chrono::high_resolution_clock::now();

_u64 sample_num, sample_dim, sample_aligned_dim;
std::stringstream stream;

#ifdef EXEC_ENV_OLS
if (files.fileExists(sample_bin)) {
diskann::load_aligned_bin<T>(files, sample_bin, samples, sample_num,
sample_dim, sample_aligned_dim);
}
if (files.fileExists(sample_bin)) {
diskann::load_aligned_bin<T>(files, sample_bin, samples, sample_num,
sample_dim, sample_aligned_dim);
}
#else
if (file_exists(sample_bin)) {
diskann::load_aligned_bin<T>(sample_bin, samples, sample_num, sample_dim,
sample_aligned_dim);
}
if (file_exists(sample_bin)) {
diskann::load_aligned_bin<T>(sample_bin, samples, sample_num, sample_dim,
sample_aligned_dim);
}
#endif
else {
diskann::cerr << "Sample bin file not found. Not generating cache."
<< std::endl;
return;
}

int64_t tmp_result_ids_64;
float tmp_result_dists;
else {
stream << "Sample bin file not found. Not generating cache."
<< std::endl;
throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__,
__LINE__);
}

auto id = 0;
while (this->search_counter.load() < sample_num && id < sample_num &&
!this->semaph.IsWaitting()) {
cached_beam_search(samples + (id * sample_aligned_dim), 1, l_search,
&tmp_result_ids_64, &tmp_result_dists, beamwidth);
id++;
}
int64_t tmp_result_ids_64;
float tmp_result_dists;

if (this->semaph.IsWaitting()) {
this->semaph.Signal();
return;
}
auto id = 0;
while (this->search_counter.load() < sample_num && id < sample_num &&
!this->semaph.IsWaitting()) {
cached_beam_search(samples + (id * sample_aligned_dim), 1, l_search,
&tmp_result_ids_64, &tmp_result_dists, beamwidth);
id++;
}

this->count_visited_nodes.store(false);
std::sort(this->node_visit_counter.begin(), node_visit_counter.end(),
[](std::pair<_u32, _u32> &left, std::pair<_u32, _u32> &right) {
return left.second > right.second;
});

std::vector<uint32_t> node_list;
node_list.clear();
node_list.shrink_to_fit();
node_list.reserve(num_nodes_to_cache);
for (_u64 i = 0; i < num_nodes_to_cache; i++) {
node_list.push_back(this->node_visit_counter[i].first);
}
this->node_visit_counter.clear();
this->node_visit_counter.shrink_to_fit();
this->search_counter.store(0);
if (this->semaph.IsWaitting()) {
stream << "pq_flash_index is destoried, async thread should be exit."
<< std::endl;
throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__,
__LINE__);
}

diskann::aligned_free(samples);
this->load_cache_list(node_list);
auto e = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> diff = e - s;
LOG(INFO) << "Using sample queries to generate cache, cost: " << diff.count() << "s";
this->count_visited_nodes.store(false);
std::sort(this->node_visit_counter.begin(), node_visit_counter.end(),
[](std::pair<_u32, _u32> &left, std::pair<_u32, _u32> &right) {
return left.second > right.second;
});

std::vector<uint32_t> node_list;
node_list.clear();
node_list.shrink_to_fit();
node_list.reserve(num_nodes_to_cache);
for (_u64 i = 0; i < num_nodes_to_cache; i++) {
node_list.push_back(this->node_visit_counter[i].first);
}
this->node_visit_counter.clear();
this->node_visit_counter.shrink_to_fit();
this->search_counter.store(0);

this->load_cache_list(node_list);
auto e = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> diff = e - s;
LOG(INFO) << "Using sample queries to generate cache, cost: " << diff.count() << "s";
} catch (const std::exception& e) {
LOG(ERROR) << "DiskANN Exception: " << e.what();
}
this->semaph.Signal();
// free samples
if (samples != nullptr) {
diskann::aligned_free(samples);
}
return;
}

Expand Down Expand Up @@ -1574,8 +1575,17 @@ namespace diskann {
}

template<typename T>
void PQFlashIndex<T>::set_async_cache_flag(const bool flag) {
this->async_generate_cache.exchange(flag);
void PQFlashIndex<T>::setup_cache_sync_task() {
this->async_generate_cache.exchange(true);
this->search_counter.store(0);
this->node_visit_counter.clear();
this->node_visit_counter.resize(this->num_points);
this->count_visited_nodes.store(true);

for (_u32 i = 0; i < node_visit_counter.size(); i++) {
this->node_visit_counter[i].first = i;
this->node_visit_counter[i].second = 0;
}
}

#ifdef EXEC_ENV_OLS
Expand Down

0 comments on commit 1c4a180

Please sign in to comment.