Skip to content

Commit

Permalink
Improve DiskANN Hybrid Search via Alpha-Strategy (#143)
Browse files Browse the repository at this point in the history
Signed-off-by: Patrick Weizhi Xu <[email protected]>
  • Loading branch information
PwzXxm authored Oct 12, 2023
1 parent 915ba68 commit 6ed164e
Showing 1 changed file with 60 additions and 52 deletions.
112 changes: 60 additions & 52 deletions thirdparty/DiskANN/src/pq_flash_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,10 @@
((((_u64) (id)) % nvecs_per_sector) * data_dim * sizeof(float))

namespace {
constexpr size_t kReadBatchSize = 32;
constexpr _u64 kRefineBeamWidthFactor = 2;
constexpr _u64 kBruteForceTopkRefineExpansionFactor = 2;
auto calcFilterThreshold = [](const auto topk) -> const float {
return std::max(-0.04570166137874405f * log2(topk + 58.96422392240403) +
1.1982775974217197,
0.5);
};
constexpr float kFilterThreshold = 0.93f;
constexpr float kAlpha = 0.15f;
} // namespace

namespace diskann {
Expand Down Expand Up @@ -1016,7 +1012,7 @@ namespace diskann {

if (!bitset_view.empty()) {
const auto filter_threshold =
filter_ratio_in < 0 ? calcFilterThreshold(k_search) : filter_ratio_in;
filter_ratio_in < 0 ? kFilterThreshold : filter_ratio_in;
const auto bv_cnt = bitset_view.count();
if (bitset_view.size() == bv_cnt) {
for (_u64 i = 0; i < k_search; i++) {
Expand Down Expand Up @@ -1116,6 +1112,31 @@ namespace diskann {
unsigned num_ios = 0;
unsigned k = 0;

float accumulative_alpha = 0;
std::vector<unsigned> filtered_nbrs;
filtered_nbrs.reserve(this->max_degree);
auto filter_nbrs = [&](_u64 nnbrs,
unsigned *node_nbrs) -> std::pair<_u64, unsigned *> {
filtered_nbrs.clear();
for (_u64 m = 0; m < nnbrs; ++m) {
unsigned id = node_nbrs[m];
if (visited.find(id) != visited.end()) {
continue;
}
visited.insert(id);
if (!bitset_view.empty() && bitset_view.test(id)) {
accumulative_alpha += kAlpha;
if (accumulative_alpha < 1.0f) {
continue;
}
accumulative_alpha -= 1.0f;
}
cmps++;
filtered_nbrs.push_back(id);
}
return {filtered_nbrs.size(), filtered_nbrs.data()};
};

while (k < cur_list_size) {
auto nk = cur_list_size;
// clear iteration state
Expand Down Expand Up @@ -1219,8 +1240,8 @@ namespace diskann {
feder->id_set_.insert(cached_nhood.first);
}
}
_u64 nnbrs = cached_nhood.second.first;
unsigned *node_nbrs = cached_nhood.second.second;
auto [nnbrs, node_nbrs] =
filter_nbrs(cached_nhood.second.first, cached_nhood.second.second);

// compute node_nbrs <-> query dists in PQ space
cpu_timer.reset();
Expand All @@ -1241,26 +1262,20 @@ namespace diskann {
feder->id_set_.insert(id);
}

if (visited.find(id) != visited.end()) {
float dist = dist_scratch[m];
if (cur_list_size > 0 &&
dist >= retset[cur_list_size - 1].distance &&
(cur_list_size == l_search))
continue;
} else {
visited.insert(id);
cmps++;
float dist = dist_scratch[m];
if (cur_list_size > 0 &&
dist >= retset[cur_list_size - 1].distance &&
(cur_list_size == l_search))
continue;
Neighbor nn(id, dist, true);
// Return position in sorted list where nn inserted.
auto r = InsertIntoPool(retset.data(), cur_list_size, nn);
if (cur_list_size < l_search)
++cur_list_size;
if (r < nk)
// nk logs the best position in the retset that was
// updated due to neighbors of n.
nk = r;
}
Neighbor nn(id, dist, true);
// Return position in sorted list where nn inserted.
auto r = InsertIntoPool(retset.data(), cur_list_size, nn);
if (cur_list_size < l_search)
++cur_list_size;
if (r < nk)
// nk logs the best position in the retset that was
// updated due to neighbors of n.
nk = r;
}
}
#ifdef USE_BING_INFRA
Expand All @@ -1282,7 +1297,6 @@ namespace diskann {
char *node_disk_buf =
get_offset_to_node(frontier_nhood.second, frontier_nhood.first);
unsigned *node_buf = OFFSET_TO_NODE_NHOOD(node_disk_buf);
_u64 nnbrs = (_u64) (*node_buf);
T *node_fp_coords = OFFSET_TO_NODE_COORDS(node_disk_buf);

T *node_fp_coords_copy = data_buf;
Expand Down Expand Up @@ -1312,7 +1326,7 @@ namespace diskann {
feder->id_set_.insert(frontier_nhood.first);
}
}
unsigned *node_nbrs = (node_buf + 1);
auto [nnbrs, node_nbrs] = filter_nbrs(*node_buf, (node_buf + 1));
// compute node_nbrs <-> query dist in PQ space
cpu_timer.reset();
compute_dists(node_nbrs, nnbrs, dist_scratch);
Expand All @@ -1333,29 +1347,23 @@ namespace diskann {
feder->id_set_.insert(frontier_nhood.first);
}

if (visited.find(id) != visited.end()) {
continue;
} else {
visited.insert(id);
cmps++;
float dist = dist_scratch[m];
if (stats != nullptr) {
stats->n_cmps++;
}
if (cur_list_size > 0 &&
dist >= retset[cur_list_size - 1].distance &&
(cur_list_size == l_search))
continue;
Neighbor nn(id, dist, true);
auto r = InsertIntoPool(
retset.data(), cur_list_size,
nn); // Return position in sorted list where nn inserted.
if (cur_list_size < l_search)
++cur_list_size;
if (r < nk)
nk = r; // nk logs the best position in the retset that was
// updated due to neighbors of n.
float dist = dist_scratch[m];
if (stats != nullptr) {
stats->n_cmps++;
}
if (cur_list_size > 0 &&
dist >= retset[cur_list_size - 1].distance &&
(cur_list_size == l_search))
continue;
Neighbor nn(id, dist, true);
auto r = InsertIntoPool(
retset.data(), cur_list_size,
nn); // Return position in sorted list where nn inserted.
if (cur_list_size < l_search)
++cur_list_size;
if (r < nk)
nk = r; // nk logs the best position in the retset that was
// updated due to neighbors of n.
}

if (stats != nullptr) {
Expand Down

0 comments on commit 6ed164e

Please sign in to comment.