Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add "concurrent add test" for hnsw/hgraph #243

Merged
merged 1 commit into from
Dec 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/vsag/index.h
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ class Index {
throw std::runtime_error("Index doesn't support get distance by id");
};

virtual tl::expected<bool, Error>
[[nodiscard]] virtual bool
CheckFeature(IndexFeature feature) const {
throw std::runtime_error("Index doesn't support check feature");
}
Expand Down
9 changes: 2 additions & 7 deletions src/algorithm/hgraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -863,14 +863,9 @@ HGraph::init_features() {
}
}

tl::expected<bool, Error>
bool
HGraph::CheckFeature(IndexFeature feature) const {
try {
return this->feature_list_.CheckFeature(feature);
} catch (const std::invalid_argument& e) {
LOG_ERROR_AND_RETURNS(
ErrorType::INVALID_ARGUMENT, "[HGraph] failed to CheckFeature: ", e.what());
}
return this->feature_list_.CheckFeature(feature);
}

} // namespace vsag
2 changes: 1 addition & 1 deletion src/algorithm/hgraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ class HGraph {
tl::expected<float, Error>
CalculateDistanceById(const float* vector, int64_t id) const;

tl::expected<bool, Error>
bool
CheckFeature(IndexFeature feature) const;

inline void
Expand Down
4 changes: 2 additions & 2 deletions src/index/hgraph_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,9 @@ class HGraphIndex : public Index {
return this->hgraph_->GetMemoryUsage();
}

tl::expected<bool, Error>
bool
CheckFeature(IndexFeature feature) const override {
SAFE_CALL(return this->hgraph_->CheckFeature(feature));
return this->hgraph_->CheckFeature(feature);
}

private:
Expand Down
2 changes: 1 addition & 1 deletion src/index_feature_list.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ get_pos(const uint32_t val) {
IndexFeatureList::IndexFeatureList()
: feature_count_(static_cast<uint32_t>(IndexFeature::INDEX_FEATURE_COUNT)) {
uint32_t size = (feature_count_ + 7) / 8;
data_.resize(size);
data_.resize(size, 0);
}

bool
Expand Down
36 changes: 36 additions & 0 deletions tests/test_hgraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,42 @@ TEST_CASE_PERSISTENT_FIXTURE(fixtures::HgraphTestIndex, "HGraph Add", "[ft][hgra
}
}

TEST_CASE_PERSISTENT_FIXTURE(fixtures::HgraphTestIndex, "HGraph Concurrent Add", "[ft][hgraph]") {
auto origin_size = vsag::Options::Instance().block_size_limit();
auto size = GENERATE(1024 * 1024 * 2);
auto metric_type = GENERATE("l2", "ip", "cosine");
std::vector<std::pair<std::string, float>> test_cases = {
{"sq8", 0.97}, {"fp32", 0.99}, {"sq8_uniform", 0.95}};
const std::string name = "hgraph";
auto search_param = fmt::format(search_param_tmp, 200);
for (auto& dim : dims) {
for (auto& [base_quantization_str, recall] : test_cases) {
vsag::Options::Instance().set_block_size_limit(size);
auto param =
GenerateHGraphBuildParametersString(metric_type, dim, base_quantization_str);
auto index = TestFactory(name, param, true);
if (index->CheckFeature(vsag::SUPPORT_ADD_CONCURRENT)) {
auto dataset = pool.GetDatasetAndCreate(dim, base_count, metric_type);
TestConcurrentAdd(index, dataset, true);
if (index->CheckFeature(vsag::SUPPORT_KNN_SEARCH)) {
TestKnnSearch(index, dataset, search_param, recall, true);
if (index->CheckFeature(vsag::SUPPORT_SEARCH_CONCURRENT)) {
TestConcurrentKnnSearch(index, dataset, search_param, recall, true);
}
}
if (index->CheckFeature(vsag::SUPPORT_RANGE_SEARCH)) {
TestRangeSearch(index, dataset, search_param, recall, 10, true);
TestRangeSearch(index, dataset, search_param, recall / 2.0, 5, true);
}
if (index->CheckFeature(vsag::SUPPORT_KNN_SEARCH_WITH_ID_FILTER)) {
TestFilterSearch(index, dataset, search_param, recall, true);
}
}
vsag::Options::Instance().set_block_size_limit(origin_size);
}
}
}

TEST_CASE_PERSISTENT_FIXTURE(fixtures::HgraphTestIndex, "HGraph Serialize File", "[ft][hgraph]") {
auto origin_size = vsag::Options::Instance().block_size_limit();
auto size = GENERATE(1024 * 1024 * 2);
Expand Down
23 changes: 23 additions & 0 deletions tests/test_hnsw_new.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,29 @@ TEST_CASE_PERSISTENT_FIXTURE(fixtures::HNSWTestIndex, "HNSW Add", "[ft][hnsw]")
}
}

TEST_CASE_PERSISTENT_FIXTURE(fixtures::HNSWTestIndex, "HNSW Concurrent Add", "[ft][hnsw]") {
auto origin_size = vsag::Options::Instance().block_size_limit();
auto size = GENERATE(1024 * 1024 * 2);
auto metric_type = GENERATE("l2", "ip", "cosine");
const std::string name = "hnsw";
auto search_param = fmt::format(search_param_tmp, 100);
for (auto& dim : dims) {
vsag::Options::Instance().set_block_size_limit(size);
auto param = GenerateHNSWBuildParametersString(metric_type, dim);
auto index = TestFactory(name, param, true);

auto dataset = pool.GetDatasetAndCreate(dim, base_count, metric_type);
TestConcurrentAdd(index, dataset, true);
TestKnnSearch(index, dataset, search_param, 0.99, true);
TestConcurrentKnnSearch(index, dataset, search_param, 0.99, true);
TestRangeSearch(index, dataset, search_param, 0.99, 10, true);
TestRangeSearch(index, dataset, search_param, 0.49, 5, true);
TestFilterSearch(index, dataset, search_param, 0.99, true);

vsag::Options::Instance().set_block_size_limit(origin_size);
}
}

TEST_CASE_PERSISTENT_FIXTURE(fixtures::HNSWTestIndex, "HNSW Serialize File", "[ft][hnsw]") {
auto origin_size = vsag::Options::Instance().block_size_limit();
auto size = GENERATE(1024 * 1024 * 2);
Expand Down
58 changes: 50 additions & 8 deletions tests/test_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ void
TestIndex::TestKnnSearch(const IndexPtr& index,
const TestDatasetPtr& dataset,
const std::string& search_param,
float recall,
float expected_recall,
bool expected_success) {
auto queries = dataset->query_;
auto query_count = queries->GetNumElements();
Expand All @@ -123,14 +123,14 @@ TestIndex::TestKnnSearch(const IndexPtr& index,
auto val = Intersection(gt, gt_topK, result, topk);
cur_recall += static_cast<float>(val) / static_cast<float>(gt_topK);
}
REQUIRE(cur_recall > recall * query_count);
REQUIRE(cur_recall > expected_recall * query_count);
}

void
TestIndex::TestRangeSearch(const IndexPtr& index,
const TestDatasetPtr& dataset,
const std::string& search_param,
float recall,
float expected_recall,
int64_t limited_size,
bool expected_success) {
auto queries = dataset->range_query_;
Expand Down Expand Up @@ -159,13 +159,13 @@ TestIndex::TestRangeSearch(const IndexPtr& index,
auto val = Intersection(gt, gt_topK, result, res.value()->GetDim());
cur_recall += static_cast<float>(val) / static_cast<float>(gt_topK);
}
REQUIRE(cur_recall > recall * query_count);
REQUIRE(cur_recall > expected_recall * query_count);
}
void
TestIndex::TestFilterSearch(const TestIndex::IndexPtr& index,
const TestDatasetPtr& dataset,
const std::string& search_param,
float recall,
float expected_recall,
bool expected_success) {
auto queries = dataset->filter_query_;
auto query_count = queries->GetNumElements();
Expand All @@ -191,7 +191,7 @@ TestIndex::TestFilterSearch(const TestIndex::IndexPtr& index,
auto val = Intersection(gt, gt_topK, result, topk);
cur_recall += static_cast<float>(val) / static_cast<float>(gt_topK);
}
REQUIRE(cur_recall > recall * query_count);
REQUIRE(cur_recall > expected_recall * query_count);
}

void
Expand Down Expand Up @@ -254,11 +254,53 @@ TestIndex::TestSerializeFile(const IndexPtr& index_from,
}
}
}

void
TestIndex::TestConcurrentAdd(const TestIndex::IndexPtr& index,
const TestDatasetPtr& dataset,
bool expected_success) {
auto base_count = dataset->base_->GetNumElements();
int64_t temp_count = base_count / 2;
auto dim = dataset->base_->GetDim();
auto temp_dataset = vsag::Dataset::Make();
temp_dataset->Dim(dim)
->Ids(dataset->base_->GetIds())
->NumElements(temp_count)
->Float32Vectors(dataset->base_->GetFloat32Vectors())
->Owner(false);
index->Build(temp_dataset);
auto rest_count = base_count - temp_count;
fixtures::ThreadPool pool(5);
using RetType = tl::expected<std::vector<int64_t>, vsag::Error>;
std::vector<std::future<RetType>> futures;

auto func = [&](uint64_t i) -> RetType {
auto data_one = vsag::Dataset::Make();
data_one->Dim(dim)
->Ids(dataset->base_->GetIds() + i)
->NumElements(1)
->Float32Vectors(dataset->base_->GetFloat32Vectors() + i * dim)
->Owner(false);
auto add_index = index->Add(data_one);
return add_index;
};

for (uint64_t j = rest_count; j < base_count; ++j) {
futures.emplace_back(pool.enqueue(func, j));
jiaweizone marked this conversation as resolved.
Show resolved Hide resolved
}

for (auto& res : futures) {
auto val = res.get();
REQUIRE(val.has_value() == expected_success);
}
REQUIRE(index->GetNumElements() == base_count);
}

void
TestIndex::TestConcurrentKnnSearch(const TestIndex::IndexPtr& index,
const TestDatasetPtr& dataset,
const std::string& search_param,
float recall,
float expected_recall,
bool expected_success) {
auto queries = dataset->query_;
auto query_count = queries->GetNumElements();
Expand Down Expand Up @@ -299,7 +341,7 @@ TestIndex::TestConcurrentKnnSearch(const TestIndex::IndexPtr& index,
}

auto cur_recall = std::accumulate(search_results.begin(), search_results.end(), 0.0f);
REQUIRE(cur_recall > recall * query_count);
REQUIRE(cur_recall > expected_recall * query_count);
}

} // namespace fixtures
13 changes: 9 additions & 4 deletions tests/test_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,22 +73,22 @@ class TestIndex {
TestKnnSearch(const IndexPtr& index,
const TestDatasetPtr& dataset,
const std::string& search_param,
float recall = 0.99,
float expected_recall = 0.99,
bool expected_success = true);

static void
TestRangeSearch(const IndexPtr& index,
const TestDatasetPtr& dataset,
const std::string& search_param,
float recall = 0.99,
float expected_recall = 0.99,
int64_t limited_size = -1,
bool expected_success = true);

static void
TestFilterSearch(const IndexPtr& index,
const TestDatasetPtr& dataset,
const std::string& search_param,
float recall = 0.99,
float expected_recall = 0.99,
bool expected_success = true);

static void
Expand All @@ -111,8 +111,13 @@ class TestIndex {
TestConcurrentKnnSearch(const IndexPtr& index,
const TestDatasetPtr& dataset,
const std::string& search_param,
float recall = 0.99,
float expected_recall = 0.99,
bool expected_success = true);

static void
TestConcurrentAdd(const IndexPtr& index,
const TestDatasetPtr& dataset,
bool expected_success = true);
};

} // namespace fixtures
Loading