Skip to content

Commit

Permalink
support update in conjugate graph (#317)
Browse files Browse the repository at this point in the history
Signed-off-by: zhongxiaoyao.zxy <[email protected]>
  • Loading branch information
ShawnShawnYou authored Jan 17, 2025
1 parent 0e7044d commit 34ecc8a
Show file tree
Hide file tree
Showing 7 changed files with 148 additions and 44 deletions.
34 changes: 34 additions & 0 deletions src/impl/conjugate_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -217,4 +217,38 @@ ConjugateGraph::is_empty() const {
return (this->memory_usage_ == sizeof(this->memory_usage_) + FOOTER_SIZE);
}

tl::expected<bool, Error>
ConjugateGraph::UpdateId(int64_t old_tag_id, int64_t new_tag_id) {
if (old_tag_id == new_tag_id) {
return true;
}

// 1. update key
bool updated = false;
auto it_old_key = conjugate_graph_.find(old_tag_id);
if (it_old_key != conjugate_graph_.end()) {
auto it_new_key = conjugate_graph_.find(new_tag_id);
if (it_new_key != conjugate_graph_.end()) {
// both two id exists in graph, note that this situation should be filtered out before use this function.
return false;
} else {
conjugate_graph_[new_tag_id] = std::move(it_old_key->second);
}
conjugate_graph_.erase(it_old_key);
updated = true;
}

// 2. update neighbors
for (auto& [key, neighbors] : conjugate_graph_) {
auto it_old_neighbor = neighbors.find(old_tag_id);
if (it_old_neighbor != neighbors.end()) {
neighbors.erase(it_old_neighbor);
neighbors.insert(new_tag_id);
updated = true;
}
}

return updated;
}

} // namespace vsag
3 changes: 3 additions & 0 deletions src/impl/conjugate_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ class ConjugateGraph {
EnhanceResult(std::priority_queue<std::pair<float, LabelType>>& results,
const std::function<float(int64_t)>& distance_of_tag) const;

tl::expected<bool, Error>
UpdateId(int64_t old_id, int64_t new_id);

public:
tl::expected<Binary, Error>
Serialize() const;
Expand Down
30 changes: 29 additions & 1 deletion src/impl/conjugate_graph_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -279,4 +279,32 @@ TEST_CASE("serialize and deserialize with stream", "[ut][conjugate_graph]") {
REQUIRE(conjugate_graph->GetMemoryUsage() == 4 + vsag::FOOTER_SIZE);
re_in_stream.close();
}
}
}

TEST_CASE("update id", "[ut][conjugate_graph]") {
std::shared_ptr<vsag::ConjugateGraph> conjugate_graph =
std::make_shared<vsag::ConjugateGraph>();

REQUIRE(conjugate_graph->AddNeighbor(0, 1) == true);
REQUIRE(conjugate_graph->AddNeighbor(0, 2) == true);
REQUIRE(conjugate_graph->AddNeighbor(1, 0) == true);
REQUIRE(conjugate_graph->AddNeighbor(4, 0) == true);

// update key
REQUIRE(conjugate_graph->UpdateId(1, 1) == true); // succ case: 1 -> 1
REQUIRE(conjugate_graph->UpdateId(5, 4) == false); // old id don't exist
REQUIRE(conjugate_graph->UpdateId(0, 4) == false); // old id and new id both exists
REQUIRE(conjugate_graph->UpdateId(4, 5) == true); // succ case: 4 -> 5
REQUIRE(conjugate_graph->AddNeighbor(5, 0) == false); // valid of succ case

// update value
REQUIRE(conjugate_graph->UpdateId(2, 3) == true); // succ case: 2 -> 3
REQUIRE(conjugate_graph->AddNeighbor(0, 3) == false); // neighbor exists

// update both key and value
REQUIRE(conjugate_graph->UpdateId(0, -1) == true); // succ case: 0 -> -1
REQUIRE(conjugate_graph->AddNeighbor(-1, 1) == false);
REQUIRE(conjugate_graph->AddNeighbor(-1, 3) == false);
REQUIRE(conjugate_graph->AddNeighbor(1, -1) == false);
REQUIRE(conjugate_graph->AddNeighbor(5, -1) == false);
}
8 changes: 8 additions & 0 deletions src/index/hnsw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -617,9 +617,17 @@ HNSW::update_id(int64_t old_id, int64_t new_id) {
}

try {
if (old_id == new_id) {
return true;
}

// note that the validation of old_id is handled within updateLabel.
std::reinterpret_pointer_cast<hnswlib::HierarchicalNSW>(alg_hnsw_)->updateLabel(old_id,
new_id);
if (use_conjugate_graph_) {
std::unique_lock lock(rw_mutex_);
conjugate_graph_->UpdateId(old_id, new_id);
}
} catch (const std::runtime_error& e) {
#ifndef ENABLE_TESTS
logger::warn(
Expand Down
18 changes: 14 additions & 4 deletions tests/test_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,16 +116,26 @@ TestIndex::TestUpdateId(const IndexPtr& index,
REQUIRE(failed_old_res.has_value());
REQUIRE(not failed_old_res.value());

// new id is used
auto failed_new_res = index->UpdateId(update_id_map[ids[i]], update_id_map[ids[i]]);
REQUIRE(failed_new_res.has_value());
REQUIRE(not failed_new_res.value());
// same id
auto succ_same_res = index->UpdateId(update_id_map[ids[i]], update_id_map[ids[i]]);
REQUIRE(succ_same_res.has_value());
REQUIRE(succ_same_res.value());
} else {
if (result.value()->GetIds()[0] == update_id_map[ids[i]]) {
correct_num[round] += 1;
}
}
}

for (int i = 0; i < num_vectors; i++) {
if (round == 0) {
// new id is used
auto failed_new_res =
index->UpdateId(update_id_map[ids[i]], update_id_map[ids[num_vectors - i - 1]]);
REQUIRE(failed_new_res.has_value());
REQUIRE(not failed_new_res.value());
}
}
}

REQUIRE(correct_num[0] == correct_num[1]);
Expand Down
19 changes: 14 additions & 5 deletions tests/test_index_old.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1051,7 +1051,7 @@ TEST_CASE("build index with generated_build_parameters", "[ft][index]") {
REQUIRE(recall > 0.95);
}

TEST_CASE("int8 + freshhnsw + feedback", "[ft][index][hnsw]") {
TEST_CASE("int8 + freshhnsw + feedback + update", "[ft][index][hnsw]") {
auto logger = vsag::Options::Instance().logger();
logger->SetLevel(vsag::Logger::Level::kDEBUG);

Expand All @@ -1068,7 +1068,7 @@ TEST_CASE("int8 + freshhnsw + feedback", "[ft][index][hnsw]") {
"dim": {},
"hnsw": {{
"max_degree": 16,
"ef_construction": 200,
"ef_construction": 100,
"use_conjugate_graph": true
}}
}}
Expand All @@ -1082,8 +1082,10 @@ TEST_CASE("int8 + freshhnsw + feedback", "[ft][index][hnsw]") {

// generate dataset
std::vector<int64_t> base_ids(num_base);
std::vector<int64_t> update_ids(num_base);
for (int64_t i = 0; i < num_base; ++i) {
base_ids[i] = i;
update_ids[i] = i + 2 * num_base;
}
auto base_vectors = fixtures::generate_int8_codes(num_base, dim);
auto base = vsag::Dataset::Make();
Expand Down Expand Up @@ -1121,7 +1123,7 @@ TEST_CASE("int8 + freshhnsw + feedback", "[ft][index][hnsw]") {
constexpr auto search_parameters_json = R"(
{{
"hnsw": {{
"ef_search": 50,
"ef_search": 10,
"use_conjugate_graph_search": {}
}}
}}
Expand All @@ -1146,17 +1148,24 @@ TEST_CASE("int8 + freshhnsw + feedback", "[ft][index][hnsw]") {
REQUIRE(*index->Feedback(query, k, search_parameters) == 0);
}

if (local_optimum == global_optimum) {
if (local_optimum == global_optimum or local_optimum == update_ids[global_optimum]) {
correct++;
}
}

if (round == 0) {
for (int i = 0; i < num_base; i++) {
REQUIRE(*index->UpdateId(base_ids[i], update_ids[i]) == true);
}
}
recall[round] = correct / (1.0 * num_query);
logger->Debug(fmt::format(R"(Recall: {:.4f})", recall[round]));
}

logger->Debug("====summary====");
logger->Debug(fmt::format(R"(Error fix: {})", error_fix));

REQUIRE(recall[0] < recall[1]);
REQUIRE(fixtures::time_t(recall[1]) == fixtures::time_t(1.0f));
}

Expand Down Expand Up @@ -1230,7 +1239,7 @@ TEST_CASE("hnsw + feedback with global optimum id", "[ft][index][hnsw]") {
constexpr auto search_parameters_json = R"(
{{
"hnsw": {{
"ef_search": 100,
"ef_search": 10,
"use_conjugate_graph_search": {}
}}
}}
Expand Down
80 changes: 46 additions & 34 deletions tests/test_multi_thread.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -245,17 +245,10 @@ TEST_CASE("multi-threading read-write test", "[ft][hnsw]") {
}));

// update id
update_id_results.push_back(
pool.enqueue([&ids, &data, &index, dim, i, max_elements]() -> bool {
auto dataset = vsag::Dataset::Make();
dataset->Dim(dim)
->NumElements(1)
->Ids(ids.get() + i)
->Float32Vectors(data.get() + i * dim)
->Owner(false);
auto res = index->UpdateId(ids[i], ids[i] + 2 * max_elements);
return res.has_value();
}));
update_id_results.push_back(pool.enqueue([&ids, &index, i, max_elements]() -> bool {
auto res = index->UpdateId(ids[i], ids[i] + 2 * max_elements);
return res.has_value();
}));

// update vector
update_vec_results.push_back(pool.enqueue([&ids, &data, &index, dim, i]() -> bool {
Expand All @@ -270,13 +263,12 @@ TEST_CASE("multi-threading read-write test", "[ft][hnsw]") {
}));

// search
search_results.push_back(
pool.enqueue([&index, &ids, dim, &data, i, &str_parameters]() -> bool {
auto query = vsag::Dataset::Make();
query->NumElements(1)->Dim(dim)->Float32Vectors(data.get() + i * dim)->Owner(false);
auto result = index->KnnSearch(query, 2, str_parameters);
return result.has_value();
}));
search_results.push_back(pool.enqueue([&index, dim, &data, i, &str_parameters]() -> bool {
auto query = vsag::Dataset::Make();
query->NumElements(1)->Dim(dim)->Float32Vectors(data.get() + i * dim)->Owner(false);
auto result = index->KnnSearch(query, 2, str_parameters);
return result.has_value();
}));
}

for (int i = 0; i < max_elements; ++i) {
Expand All @@ -293,12 +285,12 @@ TEST_CASE("multi-threading read-write with feedback and pretrain test", "[ft][hn
// avoid too much slow task logs
vsag::Options::Instance().logger()->SetLevel(vsag::Logger::Level::kWARN);

int thread_num = 32;
int dim = 256;
int max_elements = 10000;
int max_degree = 32;
int ef_construction = 200;
int ef_search = 100;
int thread_num = 16;
int dim = 32;
int max_elements = 1000;
int max_degree = 16;
int ef_construction = 50;
int ef_search = 10;
int k = 10;
nlohmann::json hnsw_parameters{{"max_degree", max_degree},
{"ef_construction", ef_construction},
Expand All @@ -324,8 +316,9 @@ TEST_CASE("multi-threading read-write with feedback and pretrain test", "[ft][hn
std::string str_parameters = parameters.dump();

std::vector<std::future<int64_t>> insert_results;
std::vector<std::future<uint64_t>> feedback_results;
std::vector<std::future<uint32_t>> pretrain_results;
std::vector<std::future<bool>> feedback_results;
std::vector<std::future<bool>> pretrain_results;
std::vector<std::future<bool>> update_id_results;
std::vector<std::future<bool>> search_results;

for (int64_t i = 0; i < max_elements / 2; ++i) {
Expand Down Expand Up @@ -356,20 +349,30 @@ TEST_CASE("multi-threading read-write with feedback and pretrain test", "[ft][hn
return add_res.value().size();
}));

// update id
update_id_results.push_back(pool.enqueue([&ids, &index, i, max_elements]() -> bool {
auto res = index->UpdateId(ids[i], ids[i] + 2 * max_elements);
return res.has_value();
}));

// feedback
feedback_results.push_back(
pool.enqueue([&index, &data, i, dim, k, str_parameters]() -> uint64_t {
pool.enqueue([&index, &data, i, dim, k, str_parameters]() -> bool {
auto query = vsag::Dataset::Make();
query->Dim(dim)->NumElements(1)->Int8Vectors(data.get() + i * dim)->Owner(false);
auto feedback_res = index->Feedback(query, k, str_parameters);
return feedback_res.value();
return feedback_res.has_value();
}));

// pretrain
pretrain_results.push_back(pool.enqueue([&index, &ids, i, k, str_parameters]() -> uint32_t {
auto pretrain_res = index->Pretrain({ids[i]}, k, str_parameters);
return pretrain_res.value();
}));
pretrain_results.push_back(
pool.enqueue([&index, &ids, i, k, str_parameters, max_elements]() -> bool {
auto pretrain_res = index->Pretrain({ids[i]}, k, str_parameters);
if (not pretrain_res.has_value()) {
pretrain_res = index->Pretrain({ids[i] + 2 * max_elements}, k, str_parameters);
}
return pretrain_res.has_value();
}));

// search
search_results.push_back(pool.enqueue([&index, &data, i, dim, k, str_parameters]() -> bool {
Expand All @@ -380,12 +383,21 @@ TEST_CASE("multi-threading read-write with feedback and pretrain test", "[ft][hn
}));
}

uint32_t succ_feedback = 0, succ_pretrain = 0;
for (int64_t i = 0; i < max_elements; ++i) {
REQUIRE(insert_results[i].get() == 0);
if (i < max_elements / 2) {
REQUIRE(pretrain_results[i].get() >= 0);
REQUIRE(feedback_results[i].get() >= 0);
if (feedback_results[i].get()) {
succ_feedback++;
}
if (pretrain_results[i].get()) {
succ_pretrain++;
}
REQUIRE(update_id_results[i].get() == true);
REQUIRE(search_results[i].get() >= 0);
}
}

REQUIRE(succ_feedback > 0);
REQUIRE(succ_pretrain > 0);
}

0 comments on commit 34ecc8a

Please sign in to comment.