From e82d6a53235d4f335c2f0104f722187e02648e44 Mon Sep 17 00:00:00 2001 From: chasingegg Date: Mon, 27 Nov 2023 15:25:05 +0800 Subject: [PATCH] Switch to brute force for hnsw when topk is super large Signed-off-by: chasingegg --- tests/ut/test_iterator.cc | 4 ++-- tests/ut/test_mmap.cc | 2 +- tests/ut/test_search.cc | 27 ++++++++++++++++++++++++++- thirdparty/hnswlib/hnswlib/hnswalg.h | 27 +++++++++++++++++++-------- 4 files changed, 48 insertions(+), 12 deletions(-) diff --git a/tests/ut/test_iterator.cc b/tests/ut/test_iterator.cc index 1f440f886..7a92e7e0a 100644 --- a/tests/ut/test_iterator.cc +++ b/tests/ut/test_iterator.cc @@ -112,7 +112,7 @@ TEST_CASE("Test Iterator Mem Index With Float Vector", "[float metrics]") { SECTION("Test Search with Bitset using iterator") { using std::make_tuple; auto [name, gen, threshold] = GENERATE_REF(table, float>({ - make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen, hnswlib::kHnswSearchKnnBFThreshold), + make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen, hnswlib::kHnswSearchKnnBFFilterThreshold), })); auto idx = knowhere::IndexFactory::Instance().Create(name, version); auto cfg_json = gen().dump(); @@ -142,7 +142,7 @@ TEST_CASE("Test Iterator Mem Index With Float Vector", "[float metrics]") { SECTION("Test Search with Bitset using iterator insufficient results") { using std::make_tuple; auto [name, gen, threshold] = GENERATE_REF(table, float>({ - make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen, hnswlib::kHnswSearchKnnBFThreshold), + make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen, hnswlib::kHnswSearchKnnBFFilterThreshold), })); auto idx = knowhere::IndexFactory::Instance().Create(name, version); auto cfg_json = gen().dump(); diff --git a/tests/ut/test_mmap.cc b/tests/ut/test_mmap.cc index 5b635bc3b..8e9d5bb08 100644 --- a/tests/ut/test_mmap.cc +++ b/tests/ut/test_mmap.cc @@ -176,7 +176,7 @@ TEST_CASE("Search mmap", "[float metrics]") { SECTION("Test Search with Bitset") { using std::make_tuple; auto [name, gen, threshold] = GENERATE_REF(table, float>({ - make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen, hnswlib::kHnswSearchKnnBFThreshold), + make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen, hnswlib::kHnswSearchKnnBFFilterThreshold), })); auto idx = knowhere::IndexFactory::Instance().Create(name, version); auto cfg_json = gen().dump(); diff --git a/tests/ut/test_search.cc b/tests/ut/test_search.cc index f5e7e2ab1..cc3665a77 100644 --- a/tests/ut/test_search.cc +++ b/tests/ut/test_search.cc @@ -187,10 +187,35 @@ TEST_CASE("Test Mem Index With Float Vector", "[float metrics]") { } } + SECTION("Test Search with super large topk") { + using std::make_tuple; + auto hnsw_gen_ = [base_gen]() { + knowhere::Json json = base_gen(); + json[knowhere::indexparam::HNSW_M] = 12; + json[knowhere::indexparam::EFCONSTRUCTION] = 30; + json[knowhere::meta::TOPK] = GENERATE(as{}, 600); + return json; + }; + auto [name, gen] = GENERATE_REF(table>({ + make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen_), + })); + auto idx = knowhere::IndexFactory::Instance().Create(name, version); + auto cfg_json = gen().dump(); + CAPTURE(name, cfg_json); + knowhere::Json json = knowhere::Json::parse(cfg_json); + REQUIRE(idx.Type() == name); + REQUIRE(idx.Build(*train_ds, json) == knowhere::Status::success); + + auto results = idx.Search(*query_ds, json, nullptr); + auto gt = knowhere::BruteForce::Search(train_ds, query_ds, json, nullptr); + float recall = GetKNNRecall(*gt.value(), *results.value()); + REQUIRE(recall > kBruteForceRecallThreshold); + } + SECTION("Test Search with Bitset") { using std::make_tuple; auto [name, gen, threshold] = GENERATE_REF(table, float>({ - make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen, hnswlib::kHnswSearchKnnBFThreshold), + make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen, hnswlib::kHnswSearchKnnBFFilterThreshold), })); auto idx = knowhere::IndexFactory::Instance().Create(name, version); auto cfg_json = gen().dump(); diff --git a/thirdparty/hnswlib/hnswlib/hnswalg.h b/thirdparty/hnswlib/hnswlib/hnswalg.h index c71ed51bd..2f6891afe 100644 --- a/thirdparty/hnswlib/hnswlib/hnswalg.h +++ b/thirdparty/hnswlib/hnswlib/hnswalg.h @@ -40,8 +40,9 @@ namespace hnswlib { typedef unsigned int tableint; typedef unsigned int linklistsizeint; -constexpr float kHnswSearchKnnBFThreshold = 0.93f; -constexpr float kHnswSearchRangeBFThreshold = 0.97f; +constexpr float kHnswSearchKnnBFFilterThreshold = 0.93f; +constexpr float kHnswSearchRangeBFFilterThreshold = 0.97f; +constexpr float kHnswSearchBFTopkThreshold = 0.5f; constexpr float kAlpha = 0.15f; enum Metric { @@ -1155,7 +1156,7 @@ class HierarchicalNSW : public AlgorithmInterface { searchKnnBF(const void* query_data, size_t k, const knowhere::BitsetView bitset) const { knowhere::ResultMaxHeap max_heap(k); for (labeltype id = 0; id < cur_element_count; ++id) { - if (!bitset.test(id)) { + if (bitset.empty() || !bitset.test(id)) { dist_t dist = calcDistance(query_data, id); max_heap.Push(dist, id); } @@ -1240,11 +1241,16 @@ class HierarchicalNSW : public AlgorithmInterface { // do bruteforce search when delete rate high if (!bitset.empty()) { - if (bitset.count() >= (cur_element_count * kHnswSearchKnnBFThreshold)) { + if (bitset.count() >= (cur_element_count * kHnswSearchKnnBFFilterThreshold)) { return searchKnnBF(query_data, k, bitset); } } + // do bruteforce search when topk is super large + if (k >= (cur_element_count * kHnswSearchBFTopkThreshold)) { + return searchKnnBF(query_data, k, bitset); + } + auto [currObj, vec_hash] = searchTopLayers(query_data, param, feder_result); NeighborSet retset; size_t ef = param ? param->ef_ : this->ef_; @@ -1270,7 +1276,7 @@ class HierarchicalNSW : public AlgorithmInterface { std::unique_ptr getIteratorWorkspace(const void* query_data, const size_t seed_ef, const bool for_tuning, const knowhere::BitsetView& bitset) const { - auto accumulative_alpha = (bitset.count() >= (cur_element_count * kHnswSearchKnnBFThreshold)) + auto accumulative_alpha = (bitset.count() >= (cur_element_count * kHnswSearchKnnBFFilterThreshold)) ? std::numeric_limits::max() : 0.0f; if (metric_type_ == Metric::COSINE) { @@ -1344,7 +1350,7 @@ class HierarchicalNSW : public AlgorithmInterface { searchRangeBF(const void* query_data, float radius, const knowhere::BitsetView bitset) const { std::vector> result; for (labeltype id = 0; id < cur_element_count; ++id) { - if (!bitset.test(id)) { + if (bitset.empty() || !bitset.test(id)) { dist_t dist = calcDistance(query_data, id); if (dist < radius) { result.emplace_back(dist, id); @@ -1371,14 +1377,19 @@ class HierarchicalNSW : public AlgorithmInterface { // do bruteforce range search when delete rate high if (!bitset.empty()) { - if (bitset.count() >= (cur_element_count * kHnswSearchRangeBFThreshold)) { + if (bitset.count() >= (cur_element_count * kHnswSearchRangeBFFilterThreshold)) { return searchRangeBF(query_data, radius, bitset); } } + // do bruteforce range search when ef is super large + size_t ef = param ? param->ef_ : this->ef_; + if (ef >= (cur_element_count * kHnswSearchBFTopkThreshold)) { + return searchRangeBF(query_data, radius, bitset); + } + auto [currObj, vec_hash] = searchTopLayers(query_data, param, feder_result); NeighborSet retset; - size_t ef = param ? param->ef_ : this->ef_; auto visited = visited_list_pool_->getFreeVisitedList(); if (!bitset.empty()) { retset = searchBaseLayerST(currObj, query_data, ef, visited, bitset, feder_result);