From 0b2ecbeb632f40e34d61f2ef6935f0811c44e1d1 Mon Sep 17 00:00:00 2001 From: LHT129 Date: Mon, 23 Dec 2024 08:33:26 +0000 Subject: [PATCH] add "concurrent add test" for hnsw/hgraph - rename recall to expected_recall - hgraph not enable SUPPORT_ADD_CONCURRENT feature, so the concurrent add test is skip actually Signed-off-by: LHT129 --- include/vsag/index.h | 2 +- src/algorithm/hgraph.cpp | 9 ++---- src/algorithm/hgraph.h | 2 +- src/index/hgraph_index.h | 4 +-- src/index_feature_list.cpp | 2 +- tests/test_hgraph.cpp | 36 +++++++++++++++++++++++ tests/test_hnsw_new.cpp | 23 +++++++++++++++ tests/test_index.cpp | 58 ++++++++++++++++++++++++++++++++------ tests/test_index.h | 13 ++++++--- 9 files changed, 125 insertions(+), 24 deletions(-) diff --git a/include/vsag/index.h b/include/vsag/index.h index 762ccc8a..92e087b4 100644 --- a/include/vsag/index.h +++ b/include/vsag/index.h @@ -224,7 +224,7 @@ class Index { throw std::runtime_error("Index doesn't support get distance by id"); }; - virtual tl::expected + [[nodiscard]] virtual bool CheckFeature(IndexFeature feature) const { throw std::runtime_error("Index doesn't support check feature"); } diff --git a/src/algorithm/hgraph.cpp b/src/algorithm/hgraph.cpp index 6c728dc9..356c1b05 100644 --- a/src/algorithm/hgraph.cpp +++ b/src/algorithm/hgraph.cpp @@ -863,14 +863,9 @@ HGraph::init_features() { } } -tl::expected +bool HGraph::CheckFeature(IndexFeature feature) const { - try { - return this->feature_list_.CheckFeature(feature); - } catch (const std::invalid_argument& e) { - LOG_ERROR_AND_RETURNS( - ErrorType::INVALID_ARGUMENT, "[HGraph] failed to CheckFeature: ", e.what()); - } + return this->feature_list_.CheckFeature(feature); } } // namespace vsag diff --git a/src/algorithm/hgraph.h b/src/algorithm/hgraph.h index 4a9c112c..aa900d2f 100644 --- a/src/algorithm/hgraph.h +++ b/src/algorithm/hgraph.h @@ -106,7 +106,7 @@ class HGraph { tl::expected CalculateDistanceById(const float* vector, int64_t id) const; - tl::expected + bool CheckFeature(IndexFeature feature) const; inline void diff --git a/src/index/hgraph_index.h b/src/index/hgraph_index.h index 210bb1d1..5766b3b8 100644 --- a/src/index/hgraph_index.h +++ b/src/index/hgraph_index.h @@ -136,9 +136,9 @@ class HGraphIndex : public Index { return this->hgraph_->GetMemoryUsage(); } - tl::expected + bool CheckFeature(IndexFeature feature) const override { - SAFE_CALL(return this->hgraph_->CheckFeature(feature)); + return this->hgraph_->CheckFeature(feature); } private: diff --git a/src/index_feature_list.cpp b/src/index_feature_list.cpp index b40366c6..8563cf50 100644 --- a/src/index_feature_list.cpp +++ b/src/index_feature_list.cpp @@ -29,7 +29,7 @@ get_pos(const uint32_t val) { IndexFeatureList::IndexFeatureList() : feature_count_(static_cast(IndexFeature::INDEX_FEATURE_COUNT)) { uint32_t size = (feature_count_ + 7) / 8; - data_.resize(size); + data_.resize(size, 0); } bool diff --git a/tests/test_hgraph.cpp b/tests/test_hgraph.cpp index 9c614065..b1174732 100644 --- a/tests/test_hgraph.cpp +++ b/tests/test_hgraph.cpp @@ -288,6 +288,42 @@ TEST_CASE_PERSISTENT_FIXTURE(fixtures::HgraphTestIndex, "HGraph Add", "[ft][hgra } } +TEST_CASE_PERSISTENT_FIXTURE(fixtures::HgraphTestIndex, "HGraph Concurrent Add", "[ft][hgraph]") { + auto origin_size = vsag::Options::Instance().block_size_limit(); + auto size = GENERATE(1024 * 1024 * 2); + auto metric_type = GENERATE("l2", "ip", "cosine"); + std::vector> test_cases = { + {"sq8", 0.97}, {"fp32", 0.99}, {"sq8_uniform", 0.95}}; + const std::string name = "hgraph"; + auto search_param = fmt::format(search_param_tmp, 200); + for (auto& dim : dims) { + for (auto& [base_quantization_str, recall] : test_cases) { + vsag::Options::Instance().set_block_size_limit(size); + auto param = + GenerateHGraphBuildParametersString(metric_type, dim, base_quantization_str); + auto index = TestFactory(name, param, true); + if (index->CheckFeature(vsag::SUPPORT_ADD_CONCURRENT)) { + auto dataset = pool.GetDatasetAndCreate(dim, base_count, metric_type); + TestConcurrentAdd(index, dataset, true); + if (index->CheckFeature(vsag::SUPPORT_KNN_SEARCH)) { + TestKnnSearch(index, dataset, search_param, recall, true); + if (index->CheckFeature(vsag::SUPPORT_SEARCH_CONCURRENT)) { + TestConcurrentKnnSearch(index, dataset, search_param, recall, true); + } + } + if (index->CheckFeature(vsag::SUPPORT_RANGE_SEARCH)) { + TestRangeSearch(index, dataset, search_param, recall, 10, true); + TestRangeSearch(index, dataset, search_param, recall / 2.0, 5, true); + } + if (index->CheckFeature(vsag::SUPPORT_KNN_SEARCH_WITH_ID_FILTER)) { + TestFilterSearch(index, dataset, search_param, recall, true); + } + } + vsag::Options::Instance().set_block_size_limit(origin_size); + } + } +} + TEST_CASE_PERSISTENT_FIXTURE(fixtures::HgraphTestIndex, "HGraph Serialize File", "[ft][hgraph]") { auto origin_size = vsag::Options::Instance().block_size_limit(); auto size = GENERATE(1024 * 1024 * 2); diff --git a/tests/test_hnsw_new.cpp b/tests/test_hnsw_new.cpp index e15aa2ca..e72e7f9d 100644 --- a/tests/test_hnsw_new.cpp +++ b/tests/test_hnsw_new.cpp @@ -264,6 +264,29 @@ TEST_CASE_PERSISTENT_FIXTURE(fixtures::HNSWTestIndex, "HNSW Add", "[ft][hnsw]") } } +TEST_CASE_PERSISTENT_FIXTURE(fixtures::HNSWTestIndex, "HNSW Concurrent Add", "[ft][hnsw]") { + auto origin_size = vsag::Options::Instance().block_size_limit(); + auto size = GENERATE(1024 * 1024 * 2); + auto metric_type = GENERATE("l2", "ip", "cosine"); + const std::string name = "hnsw"; + auto search_param = fmt::format(search_param_tmp, 100); + for (auto& dim : dims) { + vsag::Options::Instance().set_block_size_limit(size); + auto param = GenerateHNSWBuildParametersString(metric_type, dim); + auto index = TestFactory(name, param, true); + + auto dataset = pool.GetDatasetAndCreate(dim, base_count, metric_type); + TestConcurrentAdd(index, dataset, true); + TestKnnSearch(index, dataset, search_param, 0.99, true); + TestConcurrentKnnSearch(index, dataset, search_param, 0.99, true); + TestRangeSearch(index, dataset, search_param, 0.99, 10, true); + TestRangeSearch(index, dataset, search_param, 0.49, 5, true); + TestFilterSearch(index, dataset, search_param, 0.99, true); + + vsag::Options::Instance().set_block_size_limit(origin_size); + } +} + TEST_CASE_PERSISTENT_FIXTURE(fixtures::HNSWTestIndex, "HNSW Serialize File", "[ft][hnsw]") { auto origin_size = vsag::Options::Instance().block_size_limit(); auto size = GENERATE(1024 * 1024 * 2); diff --git a/tests/test_index.cpp b/tests/test_index.cpp index 96f7d60e..1f1ec37e 100644 --- a/tests/test_index.cpp +++ b/tests/test_index.cpp @@ -97,7 +97,7 @@ void TestIndex::TestKnnSearch(const IndexPtr& index, const TestDatasetPtr& dataset, const std::string& search_param, - float recall, + float expected_recall, bool expected_success) { auto queries = dataset->query_; auto query_count = queries->GetNumElements(); @@ -123,14 +123,14 @@ TestIndex::TestKnnSearch(const IndexPtr& index, auto val = Intersection(gt, gt_topK, result, topk); cur_recall += static_cast(val) / static_cast(gt_topK); } - REQUIRE(cur_recall > recall * query_count); + REQUIRE(cur_recall > expected_recall * query_count); } void TestIndex::TestRangeSearch(const IndexPtr& index, const TestDatasetPtr& dataset, const std::string& search_param, - float recall, + float expected_recall, int64_t limited_size, bool expected_success) { auto queries = dataset->range_query_; @@ -159,13 +159,13 @@ TestIndex::TestRangeSearch(const IndexPtr& index, auto val = Intersection(gt, gt_topK, result, res.value()->GetDim()); cur_recall += static_cast(val) / static_cast(gt_topK); } - REQUIRE(cur_recall > recall * query_count); + REQUIRE(cur_recall > expected_recall * query_count); } void TestIndex::TestFilterSearch(const TestIndex::IndexPtr& index, const TestDatasetPtr& dataset, const std::string& search_param, - float recall, + float expected_recall, bool expected_success) { auto queries = dataset->filter_query_; auto query_count = queries->GetNumElements(); @@ -191,7 +191,7 @@ TestIndex::TestFilterSearch(const TestIndex::IndexPtr& index, auto val = Intersection(gt, gt_topK, result, topk); cur_recall += static_cast(val) / static_cast(gt_topK); } - REQUIRE(cur_recall > recall * query_count); + REQUIRE(cur_recall > expected_recall * query_count); } void @@ -254,11 +254,53 @@ TestIndex::TestSerializeFile(const IndexPtr& index_from, } } } + +void +TestIndex::TestConcurrentAdd(const TestIndex::IndexPtr& index, + const TestDatasetPtr& dataset, + bool expected_success) { + auto base_count = dataset->base_->GetNumElements(); + int64_t temp_count = base_count / 2; + auto dim = dataset->base_->GetDim(); + auto temp_dataset = vsag::Dataset::Make(); + temp_dataset->Dim(dim) + ->Ids(dataset->base_->GetIds()) + ->NumElements(temp_count) + ->Float32Vectors(dataset->base_->GetFloat32Vectors()) + ->Owner(false); + index->Build(temp_dataset); + auto rest_count = base_count - temp_count; + fixtures::ThreadPool pool(5); + using RetType = tl::expected, vsag::Error>; + std::vector> futures; + + auto func = [&](uint64_t i) -> RetType { + auto data_one = vsag::Dataset::Make(); + data_one->Dim(dim) + ->Ids(dataset->base_->GetIds() + i) + ->NumElements(1) + ->Float32Vectors(dataset->base_->GetFloat32Vectors() + i * dim) + ->Owner(false); + auto add_index = index->Add(data_one); + return add_index; + }; + + for (uint64_t j = rest_count; j < base_count; ++j) { + futures.emplace_back(pool.enqueue(func, j)); + } + + for (auto& res : futures) { + auto val = res.get(); + REQUIRE(val.has_value() == expected_success); + } + REQUIRE(index->GetNumElements() == base_count); +} + void TestIndex::TestConcurrentKnnSearch(const TestIndex::IndexPtr& index, const TestDatasetPtr& dataset, const std::string& search_param, - float recall, + float expected_recall, bool expected_success) { auto queries = dataset->query_; auto query_count = queries->GetNumElements(); @@ -299,7 +341,7 @@ TestIndex::TestConcurrentKnnSearch(const TestIndex::IndexPtr& index, } auto cur_recall = std::accumulate(search_results.begin(), search_results.end(), 0.0f); - REQUIRE(cur_recall > recall * query_count); + REQUIRE(cur_recall > expected_recall * query_count); } } // namespace fixtures \ No newline at end of file diff --git a/tests/test_index.h b/tests/test_index.h index 7f19c9cc..d8e3e85d 100644 --- a/tests/test_index.h +++ b/tests/test_index.h @@ -73,14 +73,14 @@ class TestIndex { TestKnnSearch(const IndexPtr& index, const TestDatasetPtr& dataset, const std::string& search_param, - float recall = 0.99, + float expected_recall = 0.99, bool expected_success = true); static void TestRangeSearch(const IndexPtr& index, const TestDatasetPtr& dataset, const std::string& search_param, - float recall = 0.99, + float expected_recall = 0.99, int64_t limited_size = -1, bool expected_success = true); @@ -88,7 +88,7 @@ class TestIndex { TestFilterSearch(const IndexPtr& index, const TestDatasetPtr& dataset, const std::string& search_param, - float recall = 0.99, + float expected_recall = 0.99, bool expected_success = true); static void @@ -111,8 +111,13 @@ class TestIndex { TestConcurrentKnnSearch(const IndexPtr& index, const TestDatasetPtr& dataset, const std::string& search_param, - float recall = 0.99, + float expected_recall = 0.99, bool expected_success = true); + + static void + TestConcurrentAdd(const IndexPtr& index, + const TestDatasetPtr& dataset, + bool expected_success = true); }; } // namespace fixtures