Skip to content

Commit

Permalink
Switch to brute force for hnsw when topk is super large
Browse files Browse the repository at this point in the history
Signed-off-by: chasingegg <[email protected]>
  • Loading branch information
chasingegg committed Nov 27, 2023
1 parent 2191316 commit 1251461
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 14 deletions.
4 changes: 2 additions & 2 deletions tests/ut/test_iterator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ TEST_CASE("Test Iterator Mem Index With Float Vector", "[float metrics]") {
SECTION("Test Search with Bitset using iterator") {
using std::make_tuple;
auto [name, gen, threshold] = GENERATE_REF(table<std::string, std::function<knowhere::Json()>, float>({
make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen, hnswlib::kHnswSearchKnnBFThreshold),
make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen, hnswlib::kHnswSearchKnnBFFilterThreshold),
}));
auto idx = knowhere::IndexFactory::Instance().Create(name, version);
auto cfg_json = gen().dump();
Expand Down Expand Up @@ -142,7 +142,7 @@ TEST_CASE("Test Iterator Mem Index With Float Vector", "[float metrics]") {
SECTION("Test Search with Bitset using iterator insufficient results") {
using std::make_tuple;
auto [name, gen, threshold] = GENERATE_REF(table<std::string, std::function<knowhere::Json()>, float>({
make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen, hnswlib::kHnswSearchKnnBFThreshold),
make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen, hnswlib::kHnswSearchKnnBFFilterThreshold),
}));
auto idx = knowhere::IndexFactory::Instance().Create(name, version);
auto cfg_json = gen().dump();
Expand Down
2 changes: 1 addition & 1 deletion tests/ut/test_mmap.cc
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ TEST_CASE("Search mmap", "[float metrics]") {
SECTION("Test Search with Bitset") {
using std::make_tuple;
auto [name, gen, threshold] = GENERATE_REF(table<std::string, std::function<knowhere::Json()>, float>({
make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen, hnswlib::kHnswSearchKnnBFThreshold),
make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen, hnswlib::kHnswSearchKnnBFFilterThreshold),
}));
auto idx = knowhere::IndexFactory::Instance().Create(name, version);
auto cfg_json = gen().dump();
Expand Down
32 changes: 29 additions & 3 deletions tests/ut/test_search.cc
Original file line number Diff line number Diff line change
Expand Up @@ -187,10 +187,35 @@ TEST_CASE("Test Mem Index With Float Vector", "[float metrics]") {
}
}

SECTION("Test Search with super large topk") {
using std::make_tuple;
auto hnsw_gen_ = [base_gen]() {
knowhere::Json json = base_gen();
json[knowhere::indexparam::HNSW_M] = 12;
json[knowhere::indexparam::EFCONSTRUCTION] = 30;
json[knowhere::meta::TOPK] = GENERATE(as<int64_t>{}, 600);
return json;
};
auto [name, gen] = GENERATE_REF(table<std::string, std::function<knowhere::Json()>>({
make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen_),
}));
auto idx = knowhere::IndexFactory::Instance().Create(name, version);
auto cfg_json = gen().dump();
CAPTURE(name, cfg_json);
knowhere::Json json = knowhere::Json::parse(cfg_json);
REQUIRE(idx.Type() == name);
REQUIRE(idx.Build(*train_ds, json) == knowhere::Status::success);

auto results = idx.Search(*query_ds, json, nullptr);
auto gt = knowhere::BruteForce::Search(train_ds, query_ds, json, nullptr);
float recall = GetKNNRecall(*gt.value(), *results.value());
REQUIRE(recall > kBruteForceRecallThreshold);
}

SECTION("Test Search with Bitset") {
using std::make_tuple;
auto [name, gen, threshold] = GENERATE_REF(table<std::string, std::function<knowhere::Json()>, float>({
make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen, hnswlib::kHnswSearchKnnBFThreshold),
make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen, hnswlib::kHnswSearchKnnBFFilterThreshold),
}));
auto idx = knowhere::IndexFactory::Instance().Create(name, version);
auto cfg_json = gen().dump();
Expand All @@ -201,15 +226,16 @@ TEST_CASE("Test Mem Index With Float Vector", "[float metrics]") {

std::vector<std::function<std::vector<uint8_t>(size_t, size_t)>> gen_bitset_funcs = {
GenerateBitsetWithFirstTbitsSet, GenerateBitsetWithRandomTbitsSet};
const auto bitset_percentages = {0.4f, 0.98f};
const auto bitset_percentages = {0.4f, 0.8f, 0.98f};
for (const float percentage : bitset_percentages) {
for (const auto& gen_func : gen_bitset_funcs) {
auto bitset_data = gen_func(nb, percentage * nb);
knowhere::BitsetView bitset(bitset_data.data(), nb);
auto results = idx.Search(*query_ds, json, bitset);
auto gt = knowhere::BruteForce::Search(train_ds, query_ds, json, bitset);
float recall = GetKNNRecall(*gt.value(), *results.value());
if (percentage > threshold) {
if (percentage > threshold ||
json[knowhere::meta::TOPK] > (1 - percentage) * nb * hnswlib::kHnswSearchBFTopkThreshold) {
REQUIRE(recall > kBruteForceRecallThreshold);
} else {
REQUIRE(recall > kKnnRecallThreshold);
Expand Down
29 changes: 21 additions & 8 deletions thirdparty/hnswlib/hnswlib/hnswalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,9 @@
namespace hnswlib {
typedef unsigned int tableint;
typedef unsigned int linklistsizeint;
constexpr float kHnswSearchKnnBFThreshold = 0.93f;
constexpr float kHnswSearchRangeBFThreshold = 0.97f;
constexpr float kHnswSearchKnnBFFilterThreshold = 0.93f;
constexpr float kHnswSearchRangeBFFilterThreshold = 0.97f;
constexpr float kHnswSearchBFTopkThreshold = 0.5f;
constexpr float kAlpha = 0.15f;

enum Metric {
Expand Down Expand Up @@ -1155,7 +1156,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
searchKnnBF(const void* query_data, size_t k, const knowhere::BitsetView bitset) const {
knowhere::ResultMaxHeap<dist_t, labeltype> max_heap(k);
for (labeltype id = 0; id < cur_element_count; ++id) {
if (!bitset.test(id)) {
if (bitset.empty() || !bitset.test(id)) {
dist_t dist = calcDistance(query_data, id);
max_heap.Push(dist, id);
}
Expand Down Expand Up @@ -1238,9 +1239,15 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
query_data = query_data_norm.get();
}

// do bruteforce search when topk is super large
if (k >= (cur_element_count * kHnswSearchBFTopkThreshold)) {
return searchKnnBF(query_data, k, bitset);
}

// do bruteforce search when delete rate high
if (!bitset.empty()) {
if (bitset.count() >= (cur_element_count * kHnswSearchKnnBFThreshold)) {
const size_t filtered_out_num = bitset.count();
if (filtered_out_num >= (cur_element_count * kHnswSearchKnnBFFilterThreshold) || k >= (cur_element_count - filtered_out_num) * kHnswSearchBFTopkThreshold) {
return searchKnnBF(query_data, k, bitset);
}
}
Expand Down Expand Up @@ -1270,7 +1277,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
std::unique_ptr<IteratorWorkspace>
getIteratorWorkspace(const void* query_data, const size_t seed_ef, const bool for_tuning,
const knowhere::BitsetView& bitset) const {
auto accumulative_alpha = (bitset.count() >= (cur_element_count * kHnswSearchKnnBFThreshold))
auto accumulative_alpha = (bitset.count() >= (cur_element_count * kHnswSearchKnnBFFilterThreshold))
? std::numeric_limits<float>::max()
: 0.0f;
if (metric_type_ == Metric::COSINE) {
Expand Down Expand Up @@ -1344,7 +1351,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
searchRangeBF(const void* query_data, float radius, const knowhere::BitsetView bitset) const {
std::vector<std::pair<dist_t, labeltype>> result;
for (labeltype id = 0; id < cur_element_count; ++id) {
if (!bitset.test(id)) {
if (bitset.empty() || !bitset.test(id)) {
dist_t dist = calcDistance(query_data, id);
if (dist < radius) {
result.emplace_back(dist, id);
Expand All @@ -1369,16 +1376,22 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
query_data = query_data_norm.get();
}

// do bruteforce range search when ef is super large
size_t ef = param ? param->ef_ : this->ef_;
if (ef >= (cur_element_count * kHnswSearchBFTopkThreshold)) {
return searchRangeBF(query_data, radius, bitset);
}

// do bruteforce range search when delete rate high
if (!bitset.empty()) {
if (bitset.count() >= (cur_element_count * kHnswSearchRangeBFThreshold)) {
const size_t filtered_out_num = bitset.count();
if (filtered_out_num >= (cur_element_count * kHnswSearchRangeBFFilterThreshold) || ef >= (cur_element_count - filtered_out_num) * kHnswSearchBFTopkThreshold) {
return searchRangeBF(query_data, radius, bitset);
}
}

auto [currObj, vec_hash] = searchTopLayers(query_data, param, feder_result);
NeighborSet retset;
size_t ef = param ? param->ef_ : this->ef_;
auto visited = visited_list_pool_->getFreeVisitedList();
if (!bitset.empty()) {
retset = searchBaseLayerST<true, true>(currObj, query_data, ef, visited, bitset, feder_result);
Expand Down

0 comments on commit 1251461

Please sign in to comment.