From 125146187b347cc5dd55930c21008046686513be Mon Sep 17 00:00:00 2001 From: chasingegg Date: Mon, 27 Nov 2023 15:25:05 +0800 Subject: [PATCH] Switch to brute force for hnsw when topk is super large Signed-off-by: chasingegg --- tests/ut/test_iterator.cc | 4 ++-- tests/ut/test_mmap.cc | 2 +- tests/ut/test_search.cc | 32 +++++++++++++++++++++++++--- thirdparty/hnswlib/hnswlib/hnswalg.h | 29 ++++++++++++++++++------- 4 files changed, 53 insertions(+), 14 deletions(-) diff --git a/tests/ut/test_iterator.cc b/tests/ut/test_iterator.cc index 1f440f886..7a92e7e0a 100644 --- a/tests/ut/test_iterator.cc +++ b/tests/ut/test_iterator.cc @@ -112,7 +112,7 @@ TEST_CASE("Test Iterator Mem Index With Float Vector", "[float metrics]") { SECTION("Test Search with Bitset using iterator") { using std::make_tuple; auto [name, gen, threshold] = GENERATE_REF(table, float>({ - make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen, hnswlib::kHnswSearchKnnBFThreshold), + make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen, hnswlib::kHnswSearchKnnBFFilterThreshold), })); auto idx = knowhere::IndexFactory::Instance().Create(name, version); auto cfg_json = gen().dump(); @@ -142,7 +142,7 @@ TEST_CASE("Test Iterator Mem Index With Float Vector", "[float metrics]") { SECTION("Test Search with Bitset using iterator insufficient results") { using std::make_tuple; auto [name, gen, threshold] = GENERATE_REF(table, float>({ - make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen, hnswlib::kHnswSearchKnnBFThreshold), + make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen, hnswlib::kHnswSearchKnnBFFilterThreshold), })); auto idx = knowhere::IndexFactory::Instance().Create(name, version); auto cfg_json = gen().dump(); diff --git a/tests/ut/test_mmap.cc b/tests/ut/test_mmap.cc index 5b635bc3b..8e9d5bb08 100644 --- a/tests/ut/test_mmap.cc +++ b/tests/ut/test_mmap.cc @@ -176,7 +176,7 @@ TEST_CASE("Search mmap", "[float metrics]") { SECTION("Test Search with Bitset") { using std::make_tuple; auto [name, gen, threshold] = GENERATE_REF(table, float>({ - make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen, hnswlib::kHnswSearchKnnBFThreshold), + make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen, hnswlib::kHnswSearchKnnBFFilterThreshold), })); auto idx = knowhere::IndexFactory::Instance().Create(name, version); auto cfg_json = gen().dump(); diff --git a/tests/ut/test_search.cc b/tests/ut/test_search.cc index f5e7e2ab1..4d177376d 100644 --- a/tests/ut/test_search.cc +++ b/tests/ut/test_search.cc @@ -187,10 +187,35 @@ TEST_CASE("Test Mem Index With Float Vector", "[float metrics]") { } } + SECTION("Test Search with super large topk") { + using std::make_tuple; + auto hnsw_gen_ = [base_gen]() { + knowhere::Json json = base_gen(); + json[knowhere::indexparam::HNSW_M] = 12; + json[knowhere::indexparam::EFCONSTRUCTION] = 30; + json[knowhere::meta::TOPK] = GENERATE(as{}, 600); + return json; + }; + auto [name, gen] = GENERATE_REF(table>({ + make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen_), + })); + auto idx = knowhere::IndexFactory::Instance().Create(name, version); + auto cfg_json = gen().dump(); + CAPTURE(name, cfg_json); + knowhere::Json json = knowhere::Json::parse(cfg_json); + REQUIRE(idx.Type() == name); + REQUIRE(idx.Build(*train_ds, json) == knowhere::Status::success); + + auto results = idx.Search(*query_ds, json, nullptr); + auto gt = knowhere::BruteForce::Search(train_ds, query_ds, json, nullptr); + float recall = GetKNNRecall(*gt.value(), *results.value()); + REQUIRE(recall > kBruteForceRecallThreshold); + } + SECTION("Test Search with Bitset") { using std::make_tuple; auto [name, gen, threshold] = GENERATE_REF(table, float>({ - make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen, hnswlib::kHnswSearchKnnBFThreshold), + make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen, hnswlib::kHnswSearchKnnBFFilterThreshold), })); auto idx = knowhere::IndexFactory::Instance().Create(name, version); auto cfg_json = gen().dump(); @@ -201,7 +226,7 @@ TEST_CASE("Test Mem Index With Float Vector", "[float metrics]") { std::vector(size_t, size_t)>> gen_bitset_funcs = { GenerateBitsetWithFirstTbitsSet, GenerateBitsetWithRandomTbitsSet}; - const auto bitset_percentages = {0.4f, 0.98f}; + const auto bitset_percentages = {0.4f, 0.8f, 0.98f}; for (const float percentage : bitset_percentages) { for (const auto& gen_func : gen_bitset_funcs) { auto bitset_data = gen_func(nb, percentage * nb); @@ -209,7 +234,8 @@ TEST_CASE("Test Mem Index With Float Vector", "[float metrics]") { auto results = idx.Search(*query_ds, json, bitset); auto gt = knowhere::BruteForce::Search(train_ds, query_ds, json, bitset); float recall = GetKNNRecall(*gt.value(), *results.value()); - if (percentage > threshold) { + if (percentage > threshold || + json[knowhere::meta::TOPK] > (1 - percentage) * nb * hnswlib::kHnswSearchBFTopkThreshold) { REQUIRE(recall > kBruteForceRecallThreshold); } else { REQUIRE(recall > kKnnRecallThreshold); diff --git a/thirdparty/hnswlib/hnswlib/hnswalg.h b/thirdparty/hnswlib/hnswlib/hnswalg.h index c71ed51bd..203a341c1 100644 --- a/thirdparty/hnswlib/hnswlib/hnswalg.h +++ b/thirdparty/hnswlib/hnswlib/hnswalg.h @@ -40,8 +40,9 @@ namespace hnswlib { typedef unsigned int tableint; typedef unsigned int linklistsizeint; -constexpr float kHnswSearchKnnBFThreshold = 0.93f; -constexpr float kHnswSearchRangeBFThreshold = 0.97f; +constexpr float kHnswSearchKnnBFFilterThreshold = 0.93f; +constexpr float kHnswSearchRangeBFFilterThreshold = 0.97f; +constexpr float kHnswSearchBFTopkThreshold = 0.5f; constexpr float kAlpha = 0.15f; enum Metric { @@ -1155,7 +1156,7 @@ class HierarchicalNSW : public AlgorithmInterface { searchKnnBF(const void* query_data, size_t k, const knowhere::BitsetView bitset) const { knowhere::ResultMaxHeap max_heap(k); for (labeltype id = 0; id < cur_element_count; ++id) { - if (!bitset.test(id)) { + if (bitset.empty() || !bitset.test(id)) { dist_t dist = calcDistance(query_data, id); max_heap.Push(dist, id); } @@ -1238,9 +1239,15 @@ class HierarchicalNSW : public AlgorithmInterface { query_data = query_data_norm.get(); } + // do bruteforce search when topk is super large + if (k >= (cur_element_count * kHnswSearchBFTopkThreshold)) { + return searchKnnBF(query_data, k, bitset); + } + // do bruteforce search when delete rate high if (!bitset.empty()) { - if (bitset.count() >= (cur_element_count * kHnswSearchKnnBFThreshold)) { + const size_t filtered_out_num = bitset.count(); + if (filtered_out_num >= (cur_element_count * kHnswSearchKnnBFFilterThreshold) || k >= (cur_element_count - filtered_out_num) * kHnswSearchBFTopkThreshold) { return searchKnnBF(query_data, k, bitset); } } @@ -1270,7 +1277,7 @@ class HierarchicalNSW : public AlgorithmInterface { std::unique_ptr getIteratorWorkspace(const void* query_data, const size_t seed_ef, const bool for_tuning, const knowhere::BitsetView& bitset) const { - auto accumulative_alpha = (bitset.count() >= (cur_element_count * kHnswSearchKnnBFThreshold)) + auto accumulative_alpha = (bitset.count() >= (cur_element_count * kHnswSearchKnnBFFilterThreshold)) ? std::numeric_limits::max() : 0.0f; if (metric_type_ == Metric::COSINE) { @@ -1344,7 +1351,7 @@ class HierarchicalNSW : public AlgorithmInterface { searchRangeBF(const void* query_data, float radius, const knowhere::BitsetView bitset) const { std::vector> result; for (labeltype id = 0; id < cur_element_count; ++id) { - if (!bitset.test(id)) { + if (bitset.empty() || !bitset.test(id)) { dist_t dist = calcDistance(query_data, id); if (dist < radius) { result.emplace_back(dist, id); @@ -1369,16 +1376,22 @@ class HierarchicalNSW : public AlgorithmInterface { query_data = query_data_norm.get(); } + // do bruteforce range search when ef is super large + size_t ef = param ? param->ef_ : this->ef_; + if (ef >= (cur_element_count * kHnswSearchBFTopkThreshold)) { + return searchRangeBF(query_data, radius, bitset); + } + // do bruteforce range search when delete rate high if (!bitset.empty()) { - if (bitset.count() >= (cur_element_count * kHnswSearchRangeBFThreshold)) { + const size_t filtered_out_num = bitset.count(); + if (filtered_out_num >= (cur_element_count * kHnswSearchRangeBFFilterThreshold) || ef >= (cur_element_count - filtered_out_num) * kHnswSearchBFTopkThreshold) { return searchRangeBF(query_data, radius, bitset); } } auto [currObj, vec_hash] = searchTopLayers(query_data, param, feder_result); NeighborSet retset; - size_t ef = param ? param->ef_ : this->ef_; auto visited = visited_list_pool_->getFreeVisitedList(); if (!bitset.empty()) { retset = searchBaseLayerST(currObj, query_data, ef, visited, bitset, feder_result);