Skip to content

Commit

Permalink
Refactor cached_beam_search to reduce duplicate code
Browse files Browse the repository at this point in the history
Signed-off-by: Buqian Zheng <[email protected]>
  • Loading branch information
zhengbuqian committed Nov 20, 2023
1 parent 66b1656 commit 3686c4d
Showing 1 changed file with 27 additions and 81 deletions.
108 changes: 27 additions & 81 deletions thirdparty/DiskANN/src/pq_flash_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1211,16 +1211,14 @@ namespace diskann {
}
}

// process cached nhoods
for (auto &cached_nhood : cached_nhoods) {
auto global_cache_iter = coord_cache.find(cached_nhood.first);
T *node_fp_coords_copy = global_cache_iter->second;
if (bitset_view.empty() || !bitset_view.test(cached_nhood.first)) {
auto process_node = [&](T *node_fp_coords_copy, auto node_id, auto n_nbr,
auto *nbrs) {
if (bitset_view.empty() || !bitset_view.test(node_id)) {
float cur_expanded_dist;
if (!use_disk_index_pq) {
cur_expanded_dist =
dist_cmp_wrap(query, node_fp_coords_copy, (size_t) aligned_dim,
cached_nhood.first);
node_id);
} else {
if (metric == diskann::Metric::INNER_PRODUCT ||
metric == diskann::Metric::COSINE)
Expand All @@ -1231,17 +1229,16 @@ namespace diskann {
query_float, (_u8 *) node_fp_coords_copy);
}
full_retset.push_back(
Neighbor((unsigned) cached_nhood.first, cur_expanded_dist, true));
Neighbor((unsigned) node_id, cur_expanded_dist, true));

// add top candidate info into feder result
if (feder != nullptr) {
feder->visit_info_.AddTopCandidateInfo(cached_nhood.first,
feder->visit_info_.AddTopCandidateInfo(node_id,
cur_expanded_dist);
feder->id_set_.insert(cached_nhood.first);
feder->id_set_.insert(node_id);
}
}
auto [nnbrs, node_nbrs] =
filter_nbrs(cached_nhood.second.first, cached_nhood.second.second);
auto [nnbrs, node_nbrs] = filter_nbrs(n_nbr, nbrs);

// compute node_nbrs <-> query dists in PQ space
cpu_timer.reset();
Expand All @@ -1251,18 +1248,22 @@ namespace diskann {
stats->cpu_us += (double) cpu_timer.elapsed();
}

cpu_timer.reset();
// process prefetched nhood
for (_u64 m = 0; m < nnbrs; ++m) {
unsigned id = node_nbrs[m];

// add neighbor info into feder result
if (feder != nullptr) {
feder->visit_info_.AddTopCandidateNeighbor(cached_nhood.first, id,
feder->visit_info_.AddTopCandidateNeighbor(node_id, id,
dist_scratch[m]);
feder->id_set_.insert(id);
}

float dist = dist_scratch[m];
if (stats != nullptr) {
stats->n_cmps++;
}
if (cur_list_size > 0 &&
dist >= retset[cur_list_size - 1].distance &&
(cur_list_size == l_search))
Expand All @@ -1277,7 +1278,19 @@ namespace diskann {
// updated due to neighbors of n.
nk = r;
}
if (stats != nullptr) {
stats->cpu_us += (double) cpu_timer.elapsed();
}
};

// process cached nhoods
for (auto &cached_nhood : cached_nhoods) {
auto global_cache_iter = coord_cache.find(cached_nhood.first);
T *node_fp_coords_copy = global_cache_iter->second;
process_node(node_fp_coords_copy, cached_nhood.first,
cached_nhood.second.first, cached_nhood.second.second);
}

#ifdef USE_BING_INFRA
// process each frontier nhood - compute distances to unvisited nodes
int completedIndex = -1;
Expand All @@ -1298,77 +1311,10 @@ namespace diskann {
get_offset_to_node(frontier_nhood.second, frontier_nhood.first);
unsigned *node_buf = OFFSET_TO_NODE_NHOOD(node_disk_buf);
T *node_fp_coords = OFFSET_TO_NODE_COORDS(node_disk_buf);

T *node_fp_coords_copy = data_buf;
memcpy(node_fp_coords_copy, node_fp_coords, disk_bytes_per_point);
if (bitset_view.empty() || !bitset_view.test(frontier_nhood.first)) {
float cur_expanded_dist;
if (!use_disk_index_pq) {
cur_expanded_dist =
dist_cmp_wrap(query, node_fp_coords_copy, (size_t) aligned_dim,
frontier_nhood.first);
} else {
if (metric == diskann::Metric::INNER_PRODUCT ||
metric == diskann::Metric::COSINE)
cur_expanded_dist = disk_pq_table.inner_product(
query_float, (_u8 *) node_fp_coords_copy);
else
cur_expanded_dist = disk_pq_table.l2_distance(
query_float, (_u8 *) node_fp_coords_copy);
}
full_retset.push_back(
Neighbor(frontier_nhood.first, cur_expanded_dist, true));

// add top candidate info into feder result
if (feder != nullptr) {
feder->visit_info_.AddTopCandidateInfo(frontier_nhood.first,
cur_expanded_dist);
feder->id_set_.insert(frontier_nhood.first);
}
}
auto [nnbrs, node_nbrs] = filter_nbrs(*node_buf, (node_buf + 1));
// compute node_nbrs <-> query dist in PQ space
cpu_timer.reset();
compute_dists(node_nbrs, nnbrs, dist_scratch);
if (stats != nullptr) {
stats->n_cmps += (double) nnbrs;
stats->cpu_us += (double) cpu_timer.elapsed();
}

cpu_timer.reset();
// process prefetch-ed nhood
for (_u64 m = 0; m < nnbrs; ++m) {
unsigned id = node_nbrs[m];

// add neighbor info into feder result
if (feder != nullptr) {
feder->visit_info_.AddTopCandidateNeighbor(frontier_nhood.first, id,
dist_scratch[m]);
feder->id_set_.insert(frontier_nhood.first);
}

float dist = dist_scratch[m];
if (stats != nullptr) {
stats->n_cmps++;
}
if (cur_list_size > 0 &&
dist >= retset[cur_list_size - 1].distance &&
(cur_list_size == l_search))
continue;
Neighbor nn(id, dist, true);
auto r = InsertIntoPool(
retset.data(), cur_list_size,
nn); // Return position in sorted list where nn inserted.
if (cur_list_size < l_search)
++cur_list_size;
if (r < nk)
nk = r; // nk logs the best position in the retset that was
// updated due to neighbors of n.
}

if (stats != nullptr) {
stats->cpu_us += (double) cpu_timer.elapsed();
}
process_node(node_fp_coords_copy, frontier_nhood.first, *node_buf,
node_buf + 1);
}

// update best inserted position
Expand Down

0 comments on commit 3686c4d

Please sign in to comment.