Skip to content

Commit

Permalink
Use MaxMinHeap to track scores to avoid negation; Change Cursor.lut_ …
Browse files Browse the repository at this point in the history
…to const ref; update multiple InvertedIndex member functions to accept reference type for container types; addressed other review comments

Signed-off-by: Buqian Zheng <[email protected]>
  • Loading branch information
zhengbuqian committed Nov 21, 2023
1 parent 936826a commit 460800d
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 69 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ut.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ jobs:
ut:
name: ut on ubuntu-20.04
runs-on: ubuntu-20.04
timeout-minutes: 90
timeout-minutes: 120
strategy:
fail-fast: false
steps:
Expand Down
22 changes: 12 additions & 10 deletions include/knowhere/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -179,25 +179,27 @@ struct Neighbor {
}
};

// when pushing new elements into a MinMaxHeap, only the `capacity` smallest elements are kept.
// pop()/top() returns the largest element out of those `capacity` smallest elements.
// When pushing new elements into a MaxMinHeap, only the `capacity` largest elements are kept.
// pop()/top() returns the smallest element out of those `capacity` largest elements.
template <typename T = float>
class MinMaxHeap {
class MaxMinHeap {
public:
explicit MinMaxHeap(int capacity) : capacity_(capacity), pool_(capacity) {
explicit MaxMinHeap(int capacity) : capacity_(capacity), pool_(capacity) {
}
void
push(table_t id, T dist) {
if (size_ < capacity_) {
pool_[size_] = {id, dist};
std::push_heap(pool_.begin(), pool_.begin() + ++size_);
} else if (dist < pool_[0].distance) {
size_ += 1;
std::push_heap(pool_.begin(), pool_.begin() + size_, std::greater<Neighbor<T>>());
} else if (dist > pool_[0].distance) {
sift_down(id, dist);
}
}
table_t
pop() {
std::pop_heap(pool_.begin(), pool_.begin() + size_--);
std::pop_heap(pool_.begin(), pool_.begin() + size_, std::greater<Neighbor<T>>());
size_ -= 1;
return pool_[size_].id;
}
[[nodiscard]] size_t
Expand All @@ -224,10 +226,10 @@ class MinMaxHeap {
for (; 2 * i + 1 < size_;) {
size_t j = i;
size_t l = 2 * i + 1, r = 2 * i + 2;
if (pool_[l].distance > dist) {
if (pool_[l].distance < dist) {
j = l;
}
if (r < size_ && pool_[r].distance > std::max(pool_[l].distance, dist)) {
if (r < size_ && pool_[r].distance < std::min(pool_[l].distance, dist)) {
j = r;
}
if (i == j) {
Expand All @@ -241,7 +243,7 @@ class MinMaxHeap {

size_t size_ = 0, capacity_;
std::vector<Neighbor<T>> pool_;
}; // class MinMaxHeap
}; // class MaxMinHeap

} // namespace sparse

Expand Down
7 changes: 3 additions & 4 deletions src/common/comp/brute_force.cc
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,6 @@ BruteForce::SearchSparse(const DataSetPtr base_dataset, const DataSetPtr query_d
futs.reserve(nq);
for (int i = 0; i < nq; ++i) {
futs.emplace_back(pool->push([&, index = i] {
ThreadPool::ScopedOmpSetter setter(1);
auto cur_labels = labels + topk * index;
auto cur_distances = distances + topk * index;
std::fill(cur_labels, cur_labels + topk, -1);
Expand All @@ -400,7 +399,7 @@ BruteForce::SearchSparse(const DataSetPtr base_dataset, const DataSetPtr query_d
for (size_t j = 0; j < len; ++j) {
query[cur_indices[j]] = cur_data[j];
}
sparse::MinMaxHeap<float> heap(topk);
sparse::MaxMinHeap<float> heap(topk);
for (size_t j = 0; j < rows; ++j) {
if (!bitset.empty() && bitset.test(j)) {
continue;
Expand All @@ -413,13 +412,13 @@ BruteForce::SearchSparse(const DataSetPtr base_dataset, const DataSetPtr query_d
}
}
if (dist > 0) {
heap.push(j, -dist);
heap.push(j, dist);
}
}
int result_size = heap.size();
for (int64_t j = result_size - 1; j >= 0; --j) {
cur_labels[j] = heap.top().id;
cur_distances[j] = -heap.top().distance;
cur_distances[j] = heap.top().distance;
heap.pop();
}
return Status::success;
Expand Down
89 changes: 35 additions & 54 deletions src/index/sparse/sparse_inverted_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,19 +108,20 @@ class InvertedIndex {
}

Status
Add(const void* csr_matrix, bool sealed = true) {
Add(const void* csr_matrix) {
size_t rows, cols, nnz;
const IndPtrT* indptr;
const IndicesT* indices;
const T* data;
parse_csr_matrix(csr_matrix, rows, cols, nnz, indptr, indices, data);

const int32_t scale_factor = sealed ? 1 : 2;

for (size_t i = 0; i < rows + 1; ++i) {
indptr_.push_back(nnz_ + indptr[i]);
}

// TODO: benchmark performance: for growing segments with lots of small
// csr_matrix to add, it may be better to rely on the vector's internal
// memory management to avoid frequent reallocations caused by reserve.
indices_.reserve(nnz_ + nnz);
indices_.insert(indices_.end(), indices, indices + nnz);
data_.reserve(nnz_ + nnz);
Expand Down Expand Up @@ -174,11 +175,11 @@ class InvertedIndex {
}
std::sort(q_vec.begin(), q_vec.end(),
[](const auto& lhs, const auto& rhs) { return std::abs(lhs.second) > std::abs(rhs.second); });
while (q_vec.size() && q_vec[0].second * drop_ratio_search > q_vec.back().second) {
while (!q_vec.empty() && q_vec[0].second * drop_ratio_search > q_vec.back().second) {
q_vec.pop_back();
}

MinMaxHeap<T> heap(k * refine_factor);
MaxMinHeap<T> heap(k * refine_factor);
if (!use_wand_) {
search_brute_force(q_vec, heap, bitset);
} else {
Expand All @@ -189,7 +190,8 @@ class InvertedIndex {
if (refine_factor == 1) {
collect_result(heap, distances, labels);
} else {
std::unordered_map<IndicesT, T> q_map;
// TODO tweak the map buckets number for best performance
std::unordered_map<IndicesT, T> q_map(4 * len);
for (size_t i = 0; i < len; ++i) {
q_map[indices[i]] = data[i];
}
Expand Down Expand Up @@ -226,7 +228,7 @@ class InvertedIndex {

private:
[[nodiscard]] float
dot_product(std::unordered_map<IndicesT, T> q_map, table_t u) const {
dot_product(const std::unordered_map<IndicesT, T>& q_map, table_t u) const {
float res = 0.0f;
for (IndPtrT i = indptr_[u]; i < indptr_[u + 1]; ++i) {
auto idx = indices_[i];
Expand All @@ -241,7 +243,7 @@ class InvertedIndex {

// find the top-k candidates using brute force search, k as specified by the capacity of the heap.
void
search_brute_force(std::vector<std::pair<IndicesT, T>>& q_vec, MinMaxHeap<T>& heap,
search_brute_force(const std::vector<std::pair<IndicesT, T>>& q_vec, MaxMinHeap<T>& heap,
const BitsetView& bitset) const {
std::vector<float> scores(n_rows_, 0.0f);
for (auto [i, v] : q_vec) {
Expand All @@ -252,44 +254,33 @@ class InvertedIndex {
}
for (size_t i = 0; i < n_rows_; ++i) {
if ((bitset.empty() || !bitset.test(i)) && scores[i] != 0) {
heap.push(i, -scores[i]);
heap.push(i, scores[i]);
}
}
}

class Cursor {
public:
Cursor() = default;
Cursor(const std::vector<Neighbor<T>>* lut, size_t num_vec, float max_score, float q_value,
const BitsetView* bitset)
Cursor(const std::vector<Neighbor<T>>& lut, size_t num_vec, float max_score, float q_value,
const BitsetView bitset)
: lut_(lut), num_vec_(num_vec), max_score_(max_score), q_value_(q_value), bitset_(bitset) {
while (loc_ < lut_->size() && !bitset_->empty() && bitset_->test(cur_vec_id())) {
while (loc_ < lut_.size() && !bitset_.empty() && bitset_.test(cur_vec_id())) {
loc_++;
}
}
Cursor(const Cursor& rhs) = delete;
Cursor(Cursor&& rhs) noexcept {
swap(rhs, *this);
}
Cursor&
operator=(const Cursor& rhs) = delete;
Cursor&
operator=(Cursor&& rhs) noexcept {
Cursor tmp(std::move(rhs));
swap(*this, tmp);
return *this;
}

void
next() {
loc_++;
while (loc_ < lut_->size() && !bitset_->empty() && bitset_->test(cur_vec_id())) {
while (loc_ < lut_.size() && !bitset_.empty() && bitset_.test(cur_vec_id())) {
loc_++;
}
}
// advance loc until cur_vec_id() >= vec_id
void
seek(table_t vec_id) {
while (loc_ < lut_->size() && cur_vec_id() < vec_id) {
while (loc_ < lut_.size() && cur_vec_id() < vec_id) {
next();
}
}
Expand All @@ -298,11 +289,11 @@ class InvertedIndex {
if (is_end()) {
return num_vec_;
}
return (*lut_)[loc_].id;
return lut_[loc_].id;
}
T
cur_distance() const {
return (*lut_)[loc_].distance;
return lut_[loc_].distance;
}
[[nodiscard]] bool
is_end() const {
Expand All @@ -314,46 +305,36 @@ class InvertedIndex {
}
[[nodiscard]] size_t
size() const {
return lut_->size();
return lut_.size();
}
[[nodiscard]] float
max_score() const {
return max_score_;
}

private:
friend void
swap(Cursor& lhs, Cursor& rhs) {
using std::swap;
swap(lhs.lut_, rhs.lut_);
swap(lhs.loc_, rhs.loc_);
swap(lhs.num_vec_, rhs.num_vec_);
swap(lhs.max_score_, rhs.max_score_);
swap(lhs.q_value_, rhs.q_value_);
// all cursors share the same bitset
}
const std::vector<Neighbor<T>>* lut_;
const std::vector<Neighbor<T>>& lut_;
size_t loc_ = 0;
size_t num_vec_ = 0;
float max_score_ = 0.0f;
float q_value_ = 0.0f;
const BitsetView* bitset_;
const BitsetView bitset_;
}; // class Cursor

void
search_wand(std::vector<std::pair<IndicesT, T>>& q_vec, MinMaxHeap<T>& heap, const BitsetView& bitset) const {
search_wand(std::vector<std::pair<IndicesT, T>>& q_vec, MaxMinHeap<T>& heap, const BitsetView& bitset) const {
auto q_dim = q_vec.size();
std::vector<std::shared_ptr<Cursor>> cursors(q_dim);
for (size_t i = 0; i < q_dim; ++i) {
auto [idx, val] = q_vec[i];
cursors[i] = std::make_shared<Cursor>(&(inverted_lut_[idx]), n_rows_, max_in_dim_[idx] * val, val, &bitset);
cursors[i] = std::make_shared<Cursor>(inverted_lut_[idx], n_rows_, max_in_dim_[idx] * val, val, bitset);
}
auto sort_cursors = [&cursors] {
std::sort(cursors.begin(), cursors.end(),
[](auto& x, auto& y) { return x->cur_vec_id() < y->cur_vec_id(); });
};
sort_cursors();
auto score_above_threshold = [&heap](float x) { return !heap.full() || x > -heap.top().distance; };
auto score_above_threshold = [&heap](float x) { return !heap.full() || x > heap.top().distance; };
while (true) {
float upper_bound = 0;
size_t pivot;
Expand Down Expand Up @@ -381,7 +362,7 @@ class InvertedIndex {
score += cursor->cur_distance() * cursor->q_value();
cursor->next();
}
heap.push(pivot_id, -score);
heap.push(pivot_id, score);
sort_cursors();
} else {
uint64_t next_list = pivot;
Expand All @@ -399,32 +380,32 @@ class InvertedIndex {
}

void
refine_and_collect(std::unordered_map<IndicesT, T>& q_map, MinMaxHeap<T> inaccurate, size_t k, float* distances,
label_t* labels) const {
std::priority_queue<Neighbor<T>> heap;
refine_and_collect(const std::unordered_map<IndicesT, T>& q_map, MaxMinHeap<T>& inaccurate, size_t k,
float* distances, label_t* labels) const {
std::priority_queue<Neighbor<T>, std::vector<Neighbor<T>>, std::greater<Neighbor<T>>> heap;

while (inaccurate.size()) {
while (!inaccurate.empty()) {
auto [u, d] = inaccurate.top();
inaccurate.pop();

auto dist_acc = dot_product(q_map, u);
if (heap.size() < k) {
heap.emplace(u, -dist_acc);
} else if (heap.top().distance > -dist_acc) {
heap.emplace(u, dist_acc);
} else if (heap.top().distance < dist_acc) {
heap.pop();
heap.emplace(u, -dist_acc);
heap.emplace(u, dist_acc);
}
}
collect_result(heap, distances, labels);
}

template <typename HeapType>
void
collect_result(HeapType heap, float* distances, label_t* labels) const {
collect_result(HeapType& heap, float* distances, label_t* labels) const {
int cnt = heap.size();
for (auto i = cnt - 1; i >= 0; --i) {
labels[i] = heap.top().id;
distances[i] = -heap.top().distance;
distances[i] = heap.top().distance;
heap.pop();
}
}
Expand Down

0 comments on commit 460800d

Please sign in to comment.