Skip to content

Commit

Permalink
Set scoped omp to fix IVF index build degrade (#863)
Browse files Browse the repository at this point in the history
Signed-off-by: Cai Yudong <[email protected]>
  • Loading branch information
cydrain authored Sep 24, 2024
1 parent bf220f4 commit 49882ec
Show file tree
Hide file tree
Showing 16 changed files with 108 additions and 66 deletions.
2 changes: 2 additions & 0 deletions benchmark/hdf5/benchmark_binary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ class Benchmark_binary : public Benchmark_knowhere, public ::testing::Test {

cfg_[knowhere::meta::METRIC_TYPE] = metric_type_;
knowhere::KnowhereConfig::SetSimdType(knowhere::KnowhereConfig::SimdType::AVX2);
knowhere::KnowhereConfig::SetBuildThreadPoolSize(default_build_thread_num);
knowhere::KnowhereConfig::SetSearchThreadPoolSize(default_search_thread_num);
printf("faiss::distance_compute_blas_threshold: %ld\n", knowhere::KnowhereConfig::GetBlasThreshold());
}

Expand Down
2 changes: 2 additions & 0 deletions benchmark/hdf5/benchmark_binary_range.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@ class Benchmark_binary_range : public Benchmark_knowhere, public ::testing::Test
cfg_[knowhere::meta::METRIC_TYPE] = metric_type_;
cfg_[knowhere::meta::RADIUS] = *gt_radius_;
knowhere::KnowhereConfig::SetSimdType(knowhere::KnowhereConfig::SimdType::AVX2);
knowhere::KnowhereConfig::SetBuildThreadPoolSize(default_build_thread_num);
knowhere::KnowhereConfig::SetSearchThreadPoolSize(default_search_thread_num);
printf("faiss::distance_compute_blas_threshold: %ld\n", knowhere::KnowhereConfig::GetBlasThreshold());
}

Expand Down
2 changes: 2 additions & 0 deletions benchmark/hdf5/benchmark_float.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,8 @@ class Benchmark_float : public Benchmark_knowhere, public ::testing::Test {

cfg_[knowhere::meta::METRIC_TYPE] = metric_type_;
knowhere::KnowhereConfig::SetSimdType(knowhere::KnowhereConfig::SimdType::AVX2);
knowhere::KnowhereConfig::SetBuildThreadPoolSize(default_build_thread_num);
knowhere::KnowhereConfig::SetSearchThreadPoolSize(default_search_thread_num);
printf("faiss::distance_compute_blas_threshold: %ld\n", knowhere::KnowhereConfig::GetBlasThreshold());
}

Expand Down
2 changes: 2 additions & 0 deletions benchmark/hdf5/benchmark_float_bitset.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,8 @@ class Benchmark_float_bitset : public Benchmark_knowhere, public ::testing::Test

cfg_[knowhere::meta::METRIC_TYPE] = metric_type_;
knowhere::KnowhereConfig::SetSimdType(knowhere::KnowhereConfig::SimdType::AVX2);
knowhere::KnowhereConfig::SetBuildThreadPoolSize(default_build_thread_num);
knowhere::KnowhereConfig::SetSearchThreadPoolSize(default_search_thread_num);
printf("faiss::distance_compute_blas_threshold: %ld\n", knowhere::KnowhereConfig::GetBlasThreshold());

create_golden_index(cfg_);
Expand Down
5 changes: 4 additions & 1 deletion benchmark/hdf5/benchmark_float_qps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,10 @@ class Benchmark_float_qps : public Benchmark_knowhere, public ::testing::Test {
load_hdf5_data<false>();

cfg_[knowhere::meta::METRIC_TYPE] = metric_type_;
knowhere::KnowhereConfig::SetSimdType(knowhere::KnowhereConfig::SimdType::AUTO);
knowhere::KnowhereConfig::SetSimdType(knowhere::KnowhereConfig::SimdType::AVX2);
knowhere::KnowhereConfig::SetBuildThreadPoolSize(default_build_thread_num);
knowhere::KnowhereConfig::SetSearchThreadPoolSize(default_search_thread_num);
printf("faiss::distance_compute_blas_threshold: %ld\n", knowhere::KnowhereConfig::GetBlasThreshold());
#ifdef KNOWHERE_WITH_GPU
knowhere::KnowhereConfig::InitGPUResource(GPU_DEVICE_ID, 2);
cfg_[knowhere::meta::DEVICE_ID] = GPU_DEVICE_ID;
Expand Down
2 changes: 2 additions & 0 deletions benchmark/hdf5/benchmark_float_range.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,8 @@ class Benchmark_float_range : public Benchmark_knowhere, public ::testing::Test
cfg_[knowhere::meta::METRIC_TYPE] = metric_type_;
cfg_[knowhere::meta::RADIUS] = *gt_radius_;
knowhere::KnowhereConfig::SetSimdType(knowhere::KnowhereConfig::SimdType::AVX2);
knowhere::KnowhereConfig::SetBuildThreadPoolSize(default_build_thread_num);
knowhere::KnowhereConfig::SetSearchThreadPoolSize(default_search_thread_num);
printf("faiss::distance_compute_blas_threshold: %ld\n", knowhere::KnowhereConfig::GetBlasThreshold());
}

Expand Down
2 changes: 2 additions & 0 deletions benchmark/hdf5/benchmark_float_range_bitset.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,8 @@ class Benchmark_float_range_bitset : public Benchmark_knowhere, public ::testing
cfg_[knowhere::meta::METRIC_TYPE] = metric_type_;
cfg_[knowhere::meta::RADIUS] = *gt_radius_;
knowhere::KnowhereConfig::SetSimdType(knowhere::KnowhereConfig::SimdType::AVX2);
knowhere::KnowhereConfig::SetBuildThreadPoolSize(default_build_thread_num);
knowhere::KnowhereConfig::SetSearchThreadPoolSize(default_search_thread_num);
printf("faiss::distance_compute_blas_threshold: %ld\n", knowhere::KnowhereConfig::GetBlasThreshold());

create_golden_index(cfg_);
Expand Down
3 changes: 3 additions & 0 deletions benchmark/hdf5/benchmark_knowhere.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@
#include "knowhere/index/index_factory.h"
#include "knowhere/version.h"

static const size_t default_build_thread_num = 8;
static const size_t default_search_thread_num = 8;

namespace fs = std::filesystem;
std::string kDir = fs::current_path().string() + "/diskann_test";
std::string kRawDataPath = kDir + "/raw_data";
Expand Down
1 change: 1 addition & 0 deletions benchmark/hdf5/ref_logs/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ test_binary_range_hnsw:
# Test Knowhere float index
test_float: test_float_idmap test_float_ivf_flat test_float_ivf_sq8 test_float_ivf_pq test_float_hnsw test_float_diskann
test_float_gpu: test_float_ivf_flat test_float_ivf_pq
test_float_ivf: test_float_ivf_flat test_float_ivf_pq

test_float_idmap:
./benchmark_float --gtest_filter="Benchmark_float.TEST_IDMAP" | tee test_float_idmap.log
Expand Down
29 changes: 19 additions & 10 deletions include/knowhere/comp/thread_pool.h
Original file line number Diff line number Diff line change
Expand Up @@ -245,33 +245,42 @@ class ThreadPool {
return search_pool_;
}

class ScopedOmpSetter {
class ScopedBuildOmpSetter {
int omp_before;
#ifdef OPENBLAS_OS_LINUX
int blas_thread_before;
#endif
public:
explicit ScopedOmpSetter(int num_threads = 1) {
if (num_threads <= 0) {
return;
}

explicit ScopedBuildOmpSetter(int num_threads = 0) {
omp_before = (build_pool_ ? build_pool_->size() : omp_get_max_threads());
#ifdef OPENBLAS_OS_LINUX
// to avoid thread spawn when IVF_PQ build
blas_thread_before = openblas_get_num_threads();
openblas_set_num_threads(num_threads);
openblas_set_num_threads(1);
#endif

omp_set_num_threads(num_threads);
omp_set_num_threads(num_threads <= 0 ? omp_before : num_threads);
}
~ScopedOmpSetter() {
~ScopedBuildOmpSetter() {
omp_set_num_threads(omp_before);
#ifdef OPENBLAS_OS_LINUX
openblas_set_num_threads(blas_thread_before);
#endif
}
};

class ScopedSearchOmpSetter {
int omp_before;

public:
explicit ScopedSearchOmpSetter(int num_threads = 1) {
omp_before = (search_pool_ ? search_pool_->size() : omp_get_max_threads());
omp_set_num_threads(num_threads <= 0 ? omp_before : num_threads);
}
~ScopedSearchOmpSetter() {
omp_set_num_threads(omp_before);
}
};

private:
folly::CPUThreadPoolExecutor pool_;

Expand Down
8 changes: 4 additions & 4 deletions src/common/comp/brute_force.cc
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ BruteForce::Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset
futs.reserve(nq);
for (int i = 0; i < nq; ++i) {
futs.emplace_back(pool->push([&, index = i, labels_ptr = labels.get(), distances_ptr = distances.get()] {
ThreadPool::ScopedOmpSetter setter(1);
ThreadPool::ScopedSearchOmpSetter setter(1);
auto cur_labels = labels_ptr + topk * index;
auto cur_distances = distances_ptr + topk * index;

Expand Down Expand Up @@ -244,7 +244,7 @@ BruteForce::SearchWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_
futs.reserve(nq);
for (int i = 0; i < nq; ++i) {
futs.emplace_back(pool->push([&, index = i] {
ThreadPool::ScopedOmpSetter setter(1);
ThreadPool::ScopedSearchOmpSetter setter(1);
auto cur_labels = labels + topk * index;
auto cur_distances = distances + topk * index;

Expand Down Expand Up @@ -420,7 +420,7 @@ BruteForce::RangeSearch(const DataSetPtr base_dataset, const DataSetPtr query_da
return Status::success;
}
// else not sparse:
ThreadPool::ScopedOmpSetter setter(1);
ThreadPool::ScopedSearchOmpSetter setter(1);
faiss::RangeSearchResult res(1);

BitsetViewIDSelector bw_idselector(bitset);
Expand Down Expand Up @@ -667,7 +667,7 @@ BruteForce::AnnIterator(const DataSetPtr base_dataset, const DataSetPtr query_da

for (int i = 0; i < nq; ++i) {
futs.emplace_back(pool->push([&, index = i] {
ThreadPool::ScopedOmpSetter setter(1);
ThreadPool::ScopedSearchOmpSetter setter(1);

BitsetViewIDSelector bw_idselector(bitset);
faiss::IDSelector* id_selector = (bitset.empty()) ? nullptr : &bw_idselector;
Expand Down
15 changes: 10 additions & 5 deletions src/common/thread/thread.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ ExecOverSearchThreadPool(std::vector<std::function<void()>>& tasks) {
futures.reserve(tasks.size());
for (auto&& t : tasks) {
futures.emplace_back(pool->push([&t]() {
ThreadPool::ScopedOmpSetter setter(1);
ThreadPool::ScopedSearchOmpSetter setter(1);
t();
}));
}
Expand All @@ -44,7 +44,7 @@ ExecOverBuildThreadPool(std::vector<std::function<void()>>& tasks) {
futures.reserve(tasks.size());
for (auto&& t : tasks) {
futures.emplace_back(pool->push([&t]() {
ThreadPool::ScopedOmpSetter setter(1);
ThreadPool::ScopedBuildOmpSetter setter(1);
t();
}));
}
Expand Down Expand Up @@ -72,9 +72,14 @@ GetBuildThreadPoolSize() {
return ThreadPool::GetGlobalBuildThreadPool()->size();
}

std::unique_ptr<ThreadPool::ScopedOmpSetter>
CreateScopeOmpSetter(int num_threads) {
return std::make_unique<ThreadPool::ScopedOmpSetter>(num_threads);
std::unique_ptr<ThreadPool::ScopedBuildOmpSetter>
CreateScopeBuildOmpSetter(int num_threads) {
return std::make_unique<ThreadPool::ScopedBuildOmpSetter>(num_threads);
}

std::unique_ptr<ThreadPool::ScopedSearchOmpSetter>
CreateScopeSearchOmpSetter(int num_threads) {
return std::make_unique<ThreadPool::ScopedSearchOmpSetter>(num_threads);
}

} // namespace knowhere
4 changes: 2 additions & 2 deletions src/index/flat/flat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ class FlatIndexNode : public IndexNode {
futs.reserve(nq);
for (int i = 0; i < nq; ++i) {
futs.emplace_back(search_pool_->push([&, index = i] {
ThreadPool::ScopedOmpSetter setter(1);
ThreadPool::ScopedSearchOmpSetter setter(1);
auto cur_ids = ids + k * index;
auto cur_dis = distances + k * index;

Expand Down Expand Up @@ -167,7 +167,7 @@ class FlatIndexNode : public IndexNode {
futs.reserve(nq);
for (int i = 0; i < nq; ++i) {
futs.emplace_back(search_pool_->push([&, index = i] {
ThreadPool::ScopedOmpSetter setter(1);
ThreadPool::ScopedSearchOmpSetter setter(1);
faiss::RangeSearchResult res(1);

BitsetViewIDSelector bw_idselector(bitset);
Expand Down
52 changes: 25 additions & 27 deletions src/index/hnsw/faiss_hnsw.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,19 +80,18 @@ class BaseFaissIndexNode : public IndexNode {

// use build_pool_ to make sure the OMP threads spawned by index_->train etc
// can inherit the low nice value of threads in build_pool_.
auto tryObj = build_pool
->push([&] {
std::unique_ptr<ThreadPool::ScopedOmpSetter> setter;
if (base_cfg.num_build_thread.has_value()) {
setter =
std::make_unique<ThreadPool::ScopedOmpSetter>(base_cfg.num_build_thread.value());
} else {
setter = std::make_unique<ThreadPool::ScopedOmpSetter>();
}

return TrainInternal(dataset, *cfg);
})
.getTry();
auto tryObj =
build_pool
->push([&] {
std::unique_ptr<ThreadPool::ScopedBuildOmpSetter> setter;
if (base_cfg.num_build_thread.has_value()) {
setter = std::make_unique<ThreadPool::ScopedBuildOmpSetter>(base_cfg.num_build_thread.value());
} else {
setter = std::make_unique<ThreadPool::ScopedBuildOmpSetter>();
}
return TrainInternal(dataset, *cfg);
})
.getTry();

if (!tryObj.hasValue()) {
LOG_KNOWHERE_WARNING_ << "faiss internal error: " << tryObj.exception().what();
Expand All @@ -108,19 +107,18 @@ class BaseFaissIndexNode : public IndexNode {

// use build_pool_ to make sure the OMP threads spawned by index_->train etc
// can inherit the low nice value of threads in build_pool_.
auto tryObj = build_pool
->push([&] {
std::unique_ptr<ThreadPool::ScopedOmpSetter> setter;
if (base_cfg.num_build_thread.has_value()) {
setter =
std::make_unique<ThreadPool::ScopedOmpSetter>(base_cfg.num_build_thread.value());
} else {
setter = std::make_unique<ThreadPool::ScopedOmpSetter>();
}

return AddInternal(dataset, *cfg);
})
.getTry();
auto tryObj =
build_pool
->push([&] {
std::unique_ptr<ThreadPool::ScopedBuildOmpSetter> setter;
if (base_cfg.num_build_thread.has_value()) {
setter = std::make_unique<ThreadPool::ScopedBuildOmpSetter>(base_cfg.num_build_thread.value());
} else {
setter = std::make_unique<ThreadPool::ScopedBuildOmpSetter>();
}
return AddInternal(dataset, *cfg);
})
.getTry();

if (!tryObj.hasValue()) {
LOG_KNOWHERE_WARNING_ << "faiss internal error: " << tryObj.exception().what();
Expand Down Expand Up @@ -869,7 +867,7 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode {
for (int64_t i = 0; i < rows; ++i) {
futs.emplace_back(search_pool->push([&, idx = i] {
// 1 thread per element
ThreadPool::ScopedOmpSetter setter(1);
ThreadPool::ScopedSearchOmpSetter setter(1);

// set up a query
// const float* cur_query = (const float*)data + idx * dim;
Expand Down
17 changes: 9 additions & 8 deletions src/index/ivf/ivf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -414,11 +414,11 @@ template <typename DataType, typename IndexType>
Status
IvfIndexNode<DataType, IndexType>::TrainInternal(const DataSetPtr dataset, std::shared_ptr<Config> cfg) {
const BaseConfig& base_cfg = static_cast<const IvfConfig&>(*cfg);
std::unique_ptr<ThreadPool::ScopedOmpSetter> setter;
std::unique_ptr<ThreadPool::ScopedBuildOmpSetter> setter;
if (base_cfg.num_build_thread.has_value()) {
setter = std::make_unique<ThreadPool::ScopedOmpSetter>(base_cfg.num_build_thread.value());
setter = std::make_unique<ThreadPool::ScopedBuildOmpSetter>(base_cfg.num_build_thread.value());
} else {
setter = std::make_unique<ThreadPool::ScopedOmpSetter>();
setter = std::make_unique<ThreadPool::ScopedBuildOmpSetter>();
}

bool is_cosine = IsMetricType(base_cfg.metric_type.value(), knowhere::metric::COSINE);
Expand Down Expand Up @@ -627,11 +627,12 @@ IvfIndexNode<DataType, IndexType>::Add(const DataSetPtr dataset, std::shared_ptr
// can inherit the low nice value of threads in build_pool_.
auto tryObj = build_pool_
->push([&] {
std::unique_ptr<ThreadPool::ScopedOmpSetter> setter;
std::unique_ptr<ThreadPool::ScopedBuildOmpSetter> setter;
if (base_cfg.num_build_thread.has_value()) {
setter = std::make_unique<ThreadPool::ScopedOmpSetter>(base_cfg.num_build_thread.value());
setter =
std::make_unique<ThreadPool::ScopedBuildOmpSetter>(base_cfg.num_build_thread.value());
} else {
setter = std::make_unique<ThreadPool::ScopedOmpSetter>();
setter = std::make_unique<ThreadPool::ScopedBuildOmpSetter>();
}
if constexpr (std::is_same<faiss::IndexBinaryIVF, IndexType>::value) {
index_->add(rows, (const uint8_t*)data);
Expand Down Expand Up @@ -677,7 +678,7 @@ IvfIndexNode<DataType, IndexType>::Search(const DataSetPtr dataset, std::unique_
futs.reserve(rows);
for (int i = 0; i < rows; ++i) {
futs.emplace_back(search_pool_->push([&, index = i] {
ThreadPool::ScopedOmpSetter setter(1);
ThreadPool::ScopedSearchOmpSetter setter(1);
auto offset = k * index;
std::unique_ptr<float[]> copied_query = nullptr;

Expand Down Expand Up @@ -802,7 +803,7 @@ IvfIndexNode<DataType, IndexType>::RangeSearch(const DataSetPtr dataset, std::un
futs.reserve(nq);
for (int i = 0; i < nq; ++i) {
futs.emplace_back(search_pool_->push([&, index = i] {
ThreadPool::ScopedOmpSetter setter(1);
ThreadPool::ScopedSearchOmpSetter setter(1);
faiss::RangeSearchResult res(1);
std::unique_ptr<float[]> copied_query = nullptr;

Expand Down
28 changes: 19 additions & 9 deletions tests/ut/test_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -204,18 +204,28 @@ TEST_CASE("Test ThreadPool") {
}
}

SECTION("ScopedOmpSetter") {
int prev_num_threads = omp_get_max_threads();
SECTION("ScopedBuildOmpSetter") {
int prev_num_threads = knowhere::ThreadPool::GetGlobalBuildThreadPoolSize();
{
int target_num_threads = (prev_num_threads / 2) > 0 ? (prev_num_threads / 2) : 1;
knowhere::ThreadPool::ScopedOmpSetter setter(target_num_threads);
auto thread_num = omp_get_max_threads();
REQUIRE(thread_num == target_num_threads);
#ifdef OPENBLAS_OS_LINUX
auto openblas_thread_num = openblas_get_num_threads();
REQUIRE(openblas_thread_num == target_num_threads);
#endif
knowhere::ThreadPool::ScopedBuildOmpSetter setter(target_num_threads);
auto thread_num_1 = omp_get_max_threads();
REQUIRE(thread_num_1 == target_num_threads);
}
auto thread_num_2 = omp_get_max_threads();
REQUIRE(thread_num_2 == prev_num_threads);
}

SECTION("ScopedSearchOmpSetter") {
int prev_num_threads = knowhere::ThreadPool::GetGlobalSearchThreadPoolSize();
{
int target_num_threads = (prev_num_threads / 2) > 0 ? (prev_num_threads / 2) : 1;
knowhere::ThreadPool::ScopedSearchOmpSetter setter(target_num_threads);
auto thread_num_1 = omp_get_max_threads();
REQUIRE(thread_num_1 == target_num_threads);
}
auto thread_num_2 = omp_get_max_threads();
REQUIRE(thread_num_2 == prev_num_threads);
}
}

Expand Down

0 comments on commit 49882ec

Please sign in to comment.