Skip to content

Commit

Permalink
Add async thread pool for generating diskann cache
Browse files Browse the repository at this point in the history
Signed-off-by: cqy123456 <[email protected]>
  • Loading branch information
cqy123456 committed Nov 28, 2023
1 parent d63c403 commit 90c9ada
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 76 deletions.
33 changes: 20 additions & 13 deletions knowhere/index/vector_index/IndexDiskANN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

#include "index/vector_index/IndexDiskANN.h"

#include <math.h>
#include <omp.h>

#include <limits>
Expand Down Expand Up @@ -56,6 +57,7 @@ IndexDiskANN<T>::IndexDiskANN(std::string index_prefix, MetricType metric_type,
namespace {
static constexpr float kCacheExpansionRate = 1.2;
static constexpr uint32_t kLinuxAioMaxnrLimit = 65536;
static std::shared_ptr<ThreadPool> async_pool;
void
CheckPreparation(bool is_prepared) {
if (!is_prepared) {
Expand Down Expand Up @@ -153,6 +155,16 @@ TryDiskANNCallAndThrow(std::function<T()>&& diskann_call) {
KNOWHERE_THROW_MSG("DiskANN Other Exception: " + std::string(e.what()));
}
}

static std::shared_ptr<ThreadPool>
GetGlobalAsyncThreadPool() {
auto glb_pool = ThreadPool::GetGlobalThreadPool();
auto glb_pool_size = glb_pool->size();
uint32_t async_thread_pool_size = int(std::ceil(glb_pool_size / 2.0));
LOG_KNOWHERE_WARNING_ << "async thread pool size with thread number:" << async_thread_pool_size;
static auto async_pool = std::make_shared<ThreadPool>(async_thread_pool_size);
return async_pool;
}
} // namespace

template <typename T>
Expand Down Expand Up @@ -312,15 +324,11 @@ IndexDiskANN<T>::Prepare(const Config& config) {
return false;
}
} else {
pq_flash_index_->set_async_cache_flag(true);
pool_->push([&, cache_num = num_nodes_to_cache,
sample_nodes_file = warmup_query_file]() {
try {
pq_flash_index_->generate_cache_list_from_sample_queries(
sample_nodes_file, 15, 6, cache_num);
} catch (const std::exception& e) {
LOG_KNOWHERE_ERROR_ << "DiskANN Exception: " << e.what();
}
auto aysnc_pool_ = GetGlobalAsyncThreadPool();

pq_flash_index_->setup_cache_sync_task();
aysnc_pool_->push([&, cache_num = num_nodes_to_cache, sample_nodes_file = warmup_query_file]() {
pq_flash_index_->generate_cache_list_from_sample_queries(sample_nodes_file, 15, 6, cache_num);
});
}
}
Expand Down Expand Up @@ -445,10 +453,9 @@ IndexDiskANN<T>::QueryByRange(const DatasetPtr& dataset_ptr, const Config& confi
std::vector<int64_t> indices;
std::vector<float> distances;

auto res_count = pq_flash_index_->range_search(query + (index * dim), radius, query_conf.min_k,
query_conf.max_k, result_id_array[index],
result_dist_array[index], query_conf.beamwidth,
query_conf.search_list_and_k_ratio, bitset);
auto res_count = pq_flash_index_->range_search(
query + (index * dim), radius, query_conf.min_k, query_conf.max_k, result_id_array[index],
result_dist_array[index], query_conf.beamwidth, query_conf.search_list_and_k_ratio, bitset);

// filter range search result
if (range_filter_exist) {
Expand Down
5 changes: 2 additions & 3 deletions knowhere/index/vector_index/IndexDiskANN.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,16 @@

#pragma once

#include <atomic>
#include <memory>
#include <string>
#include <atomic>

#include "DiskANN/include/pq_flash_index.h"
#include "knowhere/common/FileManager.h"
#include "knowhere/index/VecIndex.h"
#include "knowhere/common/ThreadPool.h"
#include "knowhere/index/VecIndex.h"

namespace knowhere {

template <typename T>
class IndexDiskANN : public VecIndex {
static_assert(std::is_same_v<T, float>, "DiskANN only support float");
Expand Down
2 changes: 1 addition & 1 deletion thirdparty/DiskANN/include/pq_flash_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ namespace diskann {

DISKANN_DLLEXPORT diskann::Metric get_metric() const noexcept;

DISKANN_DLLEXPORT void set_async_cache_flag(const bool flag);
DISKANN_DLLEXPORT void setup_cache_sync_task();

protected:
DISKANN_DLLEXPORT void use_medoids_data_as_centroids();
Expand Down
129 changes: 70 additions & 59 deletions thirdparty/DiskANN/src/pq_flash_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -307,76 +307,78 @@ namespace diskann {
std::string sample_bin, _u64 l_search, _u64 beamwidth,
_u64 num_nodes_to_cache) {
#endif
auto s = std::chrono::high_resolution_clock::now();
this->search_counter.store(0);
this->node_visit_counter.clear();
this->node_visit_counter.resize(this->num_points);
this->count_visited_nodes.store(true);

for (_u32 i = 0; i < node_visit_counter.size(); i++) {
this->node_visit_counter[i].first = i;
this->node_visit_counter[i].second = 0;
}

_u64 sample_num, sample_dim, sample_aligned_dim;
T * samples;
try {
auto s = std::chrono::high_resolution_clock::now();

_u64 sample_num, sample_dim, sample_aligned_dim;
std::stringstream stream;

#ifdef EXEC_ENV_OLS
if (files.fileExists(sample_bin)) {
diskann::load_aligned_bin<T>(files, sample_bin, samples, sample_num,
sample_dim, sample_aligned_dim);
}
if (files.fileExists(sample_bin)) {
diskann::load_aligned_bin<T>(files, sample_bin, samples, sample_num,
sample_dim, sample_aligned_dim);
}
#else
if (file_exists(sample_bin)) {
diskann::load_aligned_bin<T>(sample_bin, samples, sample_num, sample_dim,
sample_aligned_dim);
}
if (file_exists(sample_bin)) {
diskann::load_aligned_bin<T>(sample_bin, samples, sample_num, sample_dim,
sample_aligned_dim);
}
#endif
else {
diskann::cerr << "Sample bin file not found. Not generating cache."
<< std::endl;
return;
}
else {
stream << "Sample bin file not found. Not generating cache."
<< std::endl;
throw diskann::ANNException(stream.str(), -1);
}

int64_t tmp_result_ids_64;
float tmp_result_dists;
int64_t tmp_result_ids_64;
float tmp_result_dists;

auto id = 0;
while (this->search_counter.load() < sample_num && id < sample_num &&
!this->semaph.IsWaitting()) {
cached_beam_search(samples + (id * sample_aligned_dim), 1, l_search,
&tmp_result_ids_64, &tmp_result_dists, beamwidth);
id++;
}
auto id = 0;
while (this->search_counter.load() < sample_num && id < sample_num &&
!this->semaph.IsWaitting()) {
cached_beam_search(samples + (id * sample_aligned_dim), 1, l_search,
&tmp_result_ids_64, &tmp_result_dists, beamwidth);
id++;
}

if (this->semaph.IsWaitting()) {
this->semaph.Signal();
return;
}
if (this->semaph.IsWaitting()) {
stream << "pq_flash_index is destoried, async thread should be exit."
<< std::endl;
throw diskann::ANNException(stream.str(), -1);
}

this->count_visited_nodes.store(false);
std::sort(this->node_visit_counter.begin(), node_visit_counter.end(),
[](std::pair<_u32, _u32> &left, std::pair<_u32, _u32> &right) {
return left.second > right.second;
});

std::vector<uint32_t> node_list;
node_list.clear();
node_list.shrink_to_fit();
node_list.reserve(num_nodes_to_cache);
for (_u64 i = 0; i < num_nodes_to_cache; i++) {
node_list.push_back(this->node_visit_counter[i].first);
this->count_visited_nodes.store(false);
std::sort(this->node_visit_counter.begin(), node_visit_counter.end(),
[](std::pair<_u32, _u32> &left, std::pair<_u32, _u32> &right) {
return left.second > right.second;
});

std::vector<uint32_t> node_list;
node_list.clear();
node_list.reserve(num_nodes_to_cache);
for (_u64 i = 0; i < num_nodes_to_cache; i++) {
node_list.push_back(this->node_visit_counter[i].first);
}

this->load_cache_list(node_list);
auto e = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> diff = e - s;
LOG(INFO) << "Using sample queries to generate cache, cost: " << diff.count() << "s";
} catch (const std::exception& e) {
LOG(ERROR) << "DiskANN Exception: " << e.what();
}
// clear up
if (this->count_visited_nodes.load() == true) {
this->count_visited_nodes.store(false);
}
this->node_visit_counter.clear();
this->node_visit_counter.shrink_to_fit();
this->search_counter.store(0);

diskann::aligned_free(samples);
this->load_cache_list(node_list);
auto e = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> diff = e - s;
LOG(INFO) << "Using sample queries to generate cache, cost: " << diff.count() << "s";

// free samples
if (samples != nullptr) {
diskann::aligned_free(samples);
}
this->semaph.Signal();
return;
}
Expand Down Expand Up @@ -1574,8 +1576,17 @@ namespace diskann {
}

template<typename T>
void PQFlashIndex<T>::set_async_cache_flag(const bool flag) {
this->async_generate_cache.exchange(flag);
void PQFlashIndex<T>::setup_cache_sync_task() {
this->async_generate_cache.exchange(true);
this->search_counter.store(0);
this->node_visit_counter.clear();
this->node_visit_counter.resize(this->num_points);
this->count_visited_nodes.store(true);

for (_u32 i = 0; i < node_visit_counter.size(); i++) {
this->node_visit_counter[i].first = i;
this->node_visit_counter[i].second = 0;
}
}

#ifdef EXEC_ENV_OLS
Expand Down

0 comments on commit 90c9ada

Please sign in to comment.