Skip to content

Commit

Permalink
Switch to brute force for hnsw when topk is super large
Browse files Browse the repository at this point in the history
Signed-off-by: chasingegg <[email protected]>
  • Loading branch information
chasingegg committed Nov 27, 2023
1 parent 2191316 commit e82d6a5
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 12 deletions.
4 changes: 2 additions & 2 deletions tests/ut/test_iterator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ TEST_CASE("Test Iterator Mem Index With Float Vector", "[float metrics]") {
SECTION("Test Search with Bitset using iterator") {
using std::make_tuple;
auto [name, gen, threshold] = GENERATE_REF(table<std::string, std::function<knowhere::Json()>, float>({
make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen, hnswlib::kHnswSearchKnnBFThreshold),
make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen, hnswlib::kHnswSearchKnnBFFilterThreshold),
}));
auto idx = knowhere::IndexFactory::Instance().Create(name, version);
auto cfg_json = gen().dump();
Expand Down Expand Up @@ -142,7 +142,7 @@ TEST_CASE("Test Iterator Mem Index With Float Vector", "[float metrics]") {
SECTION("Test Search with Bitset using iterator insufficient results") {
using std::make_tuple;
auto [name, gen, threshold] = GENERATE_REF(table<std::string, std::function<knowhere::Json()>, float>({
make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen, hnswlib::kHnswSearchKnnBFThreshold),
make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen, hnswlib::kHnswSearchKnnBFFilterThreshold),
}));
auto idx = knowhere::IndexFactory::Instance().Create(name, version);
auto cfg_json = gen().dump();
Expand Down
2 changes: 1 addition & 1 deletion tests/ut/test_mmap.cc
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ TEST_CASE("Search mmap", "[float metrics]") {
SECTION("Test Search with Bitset") {
using std::make_tuple;
auto [name, gen, threshold] = GENERATE_REF(table<std::string, std::function<knowhere::Json()>, float>({
make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen, hnswlib::kHnswSearchKnnBFThreshold),
make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen, hnswlib::kHnswSearchKnnBFFilterThreshold),
}));
auto idx = knowhere::IndexFactory::Instance().Create(name, version);
auto cfg_json = gen().dump();
Expand Down
27 changes: 26 additions & 1 deletion tests/ut/test_search.cc
Original file line number Diff line number Diff line change
Expand Up @@ -187,10 +187,35 @@ TEST_CASE("Test Mem Index With Float Vector", "[float metrics]") {
}
}

SECTION("Test Search with super large topk") {
using std::make_tuple;
auto hnsw_gen_ = [base_gen]() {
knowhere::Json json = base_gen();
json[knowhere::indexparam::HNSW_M] = 12;
json[knowhere::indexparam::EFCONSTRUCTION] = 30;
json[knowhere::meta::TOPK] = GENERATE(as<int64_t>{}, 600);
return json;
};
auto [name, gen] = GENERATE_REF(table<std::string, std::function<knowhere::Json()>>({
make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen_),
}));
auto idx = knowhere::IndexFactory::Instance().Create(name, version);
auto cfg_json = gen().dump();
CAPTURE(name, cfg_json);
knowhere::Json json = knowhere::Json::parse(cfg_json);
REQUIRE(idx.Type() == name);
REQUIRE(idx.Build(*train_ds, json) == knowhere::Status::success);

auto results = idx.Search(*query_ds, json, nullptr);
auto gt = knowhere::BruteForce::Search(train_ds, query_ds, json, nullptr);
float recall = GetKNNRecall(*gt.value(), *results.value());
REQUIRE(recall > kBruteForceRecallThreshold);
}

SECTION("Test Search with Bitset") {
using std::make_tuple;
auto [name, gen, threshold] = GENERATE_REF(table<std::string, std::function<knowhere::Json()>, float>({
make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen, hnswlib::kHnswSearchKnnBFThreshold),
make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen, hnswlib::kHnswSearchKnnBFFilterThreshold),
}));
auto idx = knowhere::IndexFactory::Instance().Create(name, version);
auto cfg_json = gen().dump();
Expand Down
27 changes: 19 additions & 8 deletions thirdparty/hnswlib/hnswlib/hnswalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,9 @@
namespace hnswlib {
typedef unsigned int tableint;
typedef unsigned int linklistsizeint;
constexpr float kHnswSearchKnnBFThreshold = 0.93f;
constexpr float kHnswSearchRangeBFThreshold = 0.97f;
constexpr float kHnswSearchKnnBFFilterThreshold = 0.93f;
constexpr float kHnswSearchRangeBFFilterThreshold = 0.97f;
constexpr float kHnswSearchBFTopkThreshold = 0.5f;
constexpr float kAlpha = 0.15f;

enum Metric {
Expand Down Expand Up @@ -1155,7 +1156,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
searchKnnBF(const void* query_data, size_t k, const knowhere::BitsetView bitset) const {
knowhere::ResultMaxHeap<dist_t, labeltype> max_heap(k);
for (labeltype id = 0; id < cur_element_count; ++id) {
if (!bitset.test(id)) {
if (bitset.empty() || !bitset.test(id)) {
dist_t dist = calcDistance(query_data, id);
max_heap.Push(dist, id);
}
Expand Down Expand Up @@ -1240,11 +1241,16 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {

// do bruteforce search when delete rate high
if (!bitset.empty()) {
if (bitset.count() >= (cur_element_count * kHnswSearchKnnBFThreshold)) {
if (bitset.count() >= (cur_element_count * kHnswSearchKnnBFFilterThreshold)) {
return searchKnnBF(query_data, k, bitset);
}
}

// do bruteforce search when topk is super large
if (k >= (cur_element_count * kHnswSearchBFTopkThreshold)) {
return searchKnnBF(query_data, k, bitset);
}

auto [currObj, vec_hash] = searchTopLayers(query_data, param, feder_result);
NeighborSet retset;
size_t ef = param ? param->ef_ : this->ef_;
Expand All @@ -1270,7 +1276,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
std::unique_ptr<IteratorWorkspace>
getIteratorWorkspace(const void* query_data, const size_t seed_ef, const bool for_tuning,
const knowhere::BitsetView& bitset) const {
auto accumulative_alpha = (bitset.count() >= (cur_element_count * kHnswSearchKnnBFThreshold))
auto accumulative_alpha = (bitset.count() >= (cur_element_count * kHnswSearchKnnBFFilterThreshold))
? std::numeric_limits<float>::max()
: 0.0f;
if (metric_type_ == Metric::COSINE) {
Expand Down Expand Up @@ -1344,7 +1350,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
searchRangeBF(const void* query_data, float radius, const knowhere::BitsetView bitset) const {
std::vector<std::pair<dist_t, labeltype>> result;
for (labeltype id = 0; id < cur_element_count; ++id) {
if (!bitset.test(id)) {
if (bitset.empty() || !bitset.test(id)) {
dist_t dist = calcDistance(query_data, id);
if (dist < radius) {
result.emplace_back(dist, id);
Expand All @@ -1371,14 +1377,19 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {

// do bruteforce range search when delete rate high
if (!bitset.empty()) {
if (bitset.count() >= (cur_element_count * kHnswSearchRangeBFThreshold)) {
if (bitset.count() >= (cur_element_count * kHnswSearchRangeBFFilterThreshold)) {
return searchRangeBF(query_data, radius, bitset);
}
}

// do bruteforce range search when ef is super large
size_t ef = param ? param->ef_ : this->ef_;
if (ef >= (cur_element_count * kHnswSearchBFTopkThreshold)) {
return searchRangeBF(query_data, radius, bitset);
}

auto [currObj, vec_hash] = searchTopLayers(query_data, param, feder_result);
NeighborSet retset;
size_t ef = param ? param->ef_ : this->ef_;
auto visited = visited_list_pool_->getFreeVisitedList();
if (!bitset.empty()) {
retset = searchBaseLayerST<true, true>(currObj, query_data, ef, visited, bitset, feder_result);
Expand Down

0 comments on commit e82d6a5

Please sign in to comment.