From d4cfb63f7a36e14e836bae39928bfd16514e7d1a Mon Sep 17 00:00:00 2001 From: Shawn Wang Date: Sun, 29 Dec 2024 21:37:38 +0800 Subject: [PATCH] sparse: add daat maxscore algorithm support Signed-off-by: Shawn Wang --- include/knowhere/comp/index_param.h | 3 +- src/index/sparse/sparse_index_node.cc | 36 +++- src/index/sparse/sparse_inverted_index.h | 160 +++++++++++++++--- .../sparse/sparse_inverted_index_config.h | 7 + tests/ut/test_sparse.cc | 12 +- 5 files changed, 188 insertions(+), 30 deletions(-) diff --git a/include/knowhere/comp/index_param.h b/include/knowhere/comp/index_param.h index d389522fd..364b66c11 100644 --- a/include/knowhere/comp/index_param.h +++ b/include/knowhere/comp/index_param.h @@ -173,7 +173,8 @@ constexpr const char* HNSW_REFINE_TYPE = "refine_type"; constexpr const char* SQ_TYPE = "sq_type"; // for IVF_SQ and HNSW_SQ constexpr const char* PRQ_NUM = "nrq"; // for PRQ, number of redisual quantizers -// Sparse Params +// Sparse Inverted Index Params +constexpr const char* INVERTED_INDEX_ALGO = "inverted_index_algo"; constexpr const char* DROP_RATIO_BUILD = "drop_ratio_build"; constexpr const char* DROP_RATIO_SEARCH = "drop_ratio_search"; } // namespace indexparam diff --git a/src/index/sparse/sparse_index_node.cc b/src/index/sparse/sparse_index_node.cc index b3bccf3a7..9ce1c6290 100644 --- a/src/index/sparse/sparse_index_node.cc +++ b/src/index/sparse/sparse_index_node.cc @@ -27,7 +27,7 @@ namespace knowhere { -// Inverted Index impl for sparse vectors. May optionally use WAND algorithm to speed up search. +// Inverted Index impl for sparse vectors. // // Not overriding RangeSearch, will use the default implementation in IndexNode. // @@ -348,8 +348,6 @@ class SparseInvertedIndexNode : public IndexNode { expected*> CreateIndex(const SparseInvertedIndexConfig& cfg) const { if (IsMetricType(cfg.metric_type.value(), metric::BM25)) { - // quantize float to uint16_t when BM25 metric type is used. - auto idx = new sparse::InvertedIndex(); if (!cfg.bm25_k1.has_value() || !cfg.bm25_b.has_value() || !cfg.bm25_avgdl.has_value()) { return expected*>::Err( Status::invalid_args, "BM25 parameters k1, b, and avgdl must be set when building/loading"); @@ -358,10 +356,36 @@ class SparseInvertedIndexNode : public IndexNode { auto b = cfg.bm25_b.value(); auto avgdl = cfg.bm25_avgdl.value(); auto max_score_ratio = cfg.wand_bm25_max_score_ratio.value(); - idx->SetBM25Params(k1, b, avgdl, max_score_ratio); - return idx; + if (use_wand || cfg.inverted_index_algo.value() == "DAAT_WAND") { + auto index = new sparse::InvertedIndex(); + index->SetBM25Params(k1, b, avgdl, max_score_ratio); + return index; + } else if (cfg.inverted_index_algo.value() == "DAAT_MAXSCORE") { + auto index = new sparse::InvertedIndex(); + index->SetBM25Params(k1, b, avgdl, max_score_ratio); + return index; + } else if (cfg.inverted_index_algo.value() == "TAAT_NAIVE") { + auto index = new sparse::InvertedIndex(); + index->SetBM25Params(k1, b, avgdl, max_score_ratio); + return index; + } else { + return expected*>::Err(Status::invalid_args, + "Invalid search algorithm for SparseInvertedIndex"); + } } else { - return new sparse::InvertedIndex(); + if (use_wand || cfg.inverted_index_algo.value() == "DAAT_WAND") { + auto index = new sparse::InvertedIndex(); + return index; + } else if (cfg.inverted_index_algo.value() == "DAAT_MAXSCORE") { + auto index = new sparse::InvertedIndex(); + return index; + } else if (cfg.inverted_index_algo.value() == "TAAT_NAIVE") { + auto index = new sparse::InvertedIndex(); + return index; + } else { + return expected*>::Err(Status::invalid_args, + "Invalid search algorithm for SparseInvertedIndex"); + } } } diff --git a/src/index/sparse/sparse_inverted_index.h b/src/index/sparse/sparse_inverted_index.h index 611b0dd36..68534f6a7 100644 --- a/src/index/sparse/sparse_inverted_index.h +++ b/src/index/sparse/sparse_inverted_index.h @@ -34,6 +34,13 @@ #include "knowhere/utils.h" namespace knowhere::sparse { + +enum InvertedIndexAlgo { + TAAT_NAIVE, + DAAT_WAND, + DAAT_MAXSCORE, +}; + template class BaseInvertedIndex { public: @@ -77,7 +84,7 @@ class BaseInvertedIndex { n_cols() const = 0; }; -template +template class InvertedIndex : public BaseInvertedIndex { public: explicit InvertedIndex() { @@ -132,12 +139,13 @@ class InvertedIndex : public BaseInvertedIndex { "avgdl must be supplied during searching"); } auto avgdl = cfg.bm25_avgdl.value(); - if constexpr (use_wand) { - // wand: search time k1/b must equal load time config. + if constexpr (algo == InvertedIndexAlgo::DAAT_WAND || algo == InvertedIndexAlgo::DAAT_MAXSCORE) { + // daat_wand and daat_maxscore: search time k1/b must equal load time config. if ((cfg.bm25_k1.has_value() && cfg.bm25_k1.value() != bm25_params_->k1) || ((cfg.bm25_b.has_value() && cfg.bm25_b.value() != bm25_params_->b))) { return expected>::Err( - Status::invalid_args, "search time k1/b must equal load time config for WAND index."); + Status::invalid_args, + "search time k1/b must equal load time config for DAAT_WAND or DAAT_MAXSCORE algorithm."); } return GetDocValueBM25Computer(bm25_params_->k1, bm25_params_->b, avgdl); } else { @@ -281,7 +289,7 @@ class InvertedIndex : public BaseInvertedIndex { map_byte_size_ = inverted_index_ids_byte_size + inverted_index_vals_byte_size + plists_ids_byte_size + plists_vals_byte_size; - if constexpr (use_wand) { + if constexpr (algo == InvertedIndexAlgo::DAAT_WAND || algo == InvertedIndexAlgo::DAAT_MAXSCORE) { map_byte_size_ += max_score_in_dim_byte_size; } if constexpr (bm25) { @@ -330,7 +338,7 @@ class InvertedIndex : public BaseInvertedIndex { inverted_index_vals_.initialize(ptr, inverted_index_vals_byte_size); ptr += inverted_index_vals_byte_size; - if constexpr (use_wand) { + if constexpr (algo == InvertedIndexAlgo::DAAT_WAND || algo == InvertedIndexAlgo::DAAT_MAXSCORE) { max_score_in_dim_.initialize(ptr, max_score_in_dim_byte_size); ptr += max_score_in_dim_byte_size; } @@ -355,7 +363,7 @@ class InvertedIndex : public BaseInvertedIndex { size_t dim_id = 0; for (const auto& [idx, count] : idx_counts) { dim_map_[idx] = dim_id; - if constexpr (use_wand) { + if constexpr (algo == InvertedIndexAlgo::DAAT_WAND || algo == InvertedIndexAlgo::DAAT_MAXSCORE) { max_score_in_dim_.emplace_back(0.0f); } ++dim_id; @@ -420,10 +428,13 @@ class InvertedIndex : public BaseInvertedIndex { refine_factor = 1; } MaxMinHeap heap(k * refine_factor); - if constexpr (!use_wand) { - search_brute_force(query, q_threshold, heap, bitset, computer); + // DAAT_WAND and DAAT_MAXSCORE are based on the implementation in PISA. + if constexpr (algo == InvertedIndexAlgo::DAAT_WAND) { + search_daat_wand(query, q_threshold, heap, bitset, computer); + } else if constexpr (algo == InvertedIndexAlgo::DAAT_MAXSCORE) { + search_daat_maxscore(query, q_threshold, heap, bitset, computer); } else { - search_wand(query, q_threshold, heap, bitset, computer); + search_taat_naive(query, q_threshold, heap, bitset, computer); } if (refine_factor == 1) { @@ -498,7 +509,7 @@ class InvertedIndex : public BaseInvertedIndex { res += sizeof(typename decltype(inverted_index_vals_)::value_type::value_type) * inverted_index_vals_[i].capacity(); } - if constexpr (use_wand) { + if constexpr (algo == InvertedIndexAlgo::DAAT_WAND || algo == InvertedIndexAlgo::DAAT_MAXSCORE) { res += sizeof(typename decltype(max_score_in_dim_)::value_type) * max_score_in_dim_.capacity(); } return res; @@ -626,8 +637,8 @@ class InvertedIndex : public BaseInvertedIndex { // TODO: may switch to row-wise brute force if filter rate is high. Benchmark needed. template void - search_brute_force(const SparseRow& q_vec, DType q_threshold, MaxMinHeap& heap, DocIdFilter& filter, - const DocValueComputer& computer) const { + search_taat_naive(const SparseRow& q_vec, DType q_threshold, MaxMinHeap& heap, DocIdFilter& filter, + const DocValueComputer& computer) const { auto scores = compute_all_distances(q_vec, q_threshold, computer); for (size_t i = 0; i < n_rows_internal_; ++i) { if ((filter.empty() || !filter.test(i)) && scores[i] != 0) { @@ -639,8 +650,8 @@ class InvertedIndex : public BaseInvertedIndex { // any value in q_vec that is smaller than q_threshold will be ignored. template void - search_wand(const SparseRow& q_vec, DType q_threshold, MaxMinHeap& heap, DocIdFilter& filter, - const DocValueComputer& computer) const { + search_daat_wand(const SparseRow& q_vec, DType q_threshold, MaxMinHeap& heap, DocIdFilter& filter, + const DocValueComputer& computer) const { auto q_dim = q_vec.size(); std::vector>> cursors(q_dim); size_t valid_q_dim = 0; @@ -709,6 +720,111 @@ class InvertedIndex : public BaseInvertedIndex { } } + template + void + search_daat_maxscore(const SparseRow& q_vec, DType q_threshold, MaxMinHeap& heap, + const DocIdFilter& filter, const DocValueComputer& computer) const { + auto q_dim = q_vec.size(); + std::vector>> cursors(q_dim); + size_t valid_q_dim = 0; + for (size_t i = 0; i < q_dim; ++i) { + auto [idx, val] = q_vec[i]; + auto dim_id = dim_map_.find(idx); + if (dim_id == dim_map_.end() || std::abs(val) < q_threshold) { + continue; + } + auto& plist_ids = inverted_index_ids_[dim_id->second]; + auto& plist_vals = inverted_index_vals_[dim_id->second]; + cursors[valid_q_dim++] = std::make_shared>( + plist_ids, plist_vals, n_rows_internal_, max_score_in_dim_[dim_id->second] * val, val, filter); + } + if (valid_q_dim == 0) { + return; + } + cursors.resize(valid_q_dim); + + std::sort(cursors.begin(), cursors.end(), [](auto& x, auto& y) { return x->max_score_ > y->max_score_; }); + + float threshold = heap.full() ? heap.top().val : 0; + + std::vector upper_bounds(cursors.size()); + float bound_sum = 0.0; + for (size_t i = cursors.size() - 1; i + 1 > 0; --i) { + bound_sum += cursors[i]->max_score_; + upper_bounds[i] = bound_sum; + } + + uint32_t next_cand_vec_id = n_rows_internal_; + for (size_t i = 0; i < cursors.size(); ++i) { + if (cursors[i]->cur_vec_id_ < next_cand_vec_id) { + next_cand_vec_id = cursors[i]->cur_vec_id_; + } + } + + // first_ne_idx is the index of the first non-essential cursor + size_t first_ne_idx = cursors.size(); + + while (first_ne_idx != 0 && upper_bounds[first_ne_idx - 1] <= threshold) { + --first_ne_idx; + if (first_ne_idx == 0) { + return; + } + } + + float curr_cand_score = 0.0f; + uint32_t curr_cand_vec_id = 0; + + while (curr_cand_vec_id < n_rows_internal_) { + auto found_cand = false; + while (found_cand == false) { + // start find from next_vec_id + if (next_cand_vec_id >= n_rows_internal_) { + return; + } + // get current candidate vector + curr_cand_vec_id = next_cand_vec_id; + curr_cand_score = 0.0f; + // update next_cand_vec_id + next_cand_vec_id = n_rows_internal_; + + for (size_t i = 0; i < first_ne_idx; ++i) { + if (cursors[i]->cur_vec_id_ == curr_cand_vec_id) { + float cur_vec_sum = bm25 ? bm25_params_->row_sums.at(cursors[i]->cur_vec_id_) : 0; + curr_cand_score += cursors[i]->q_value_ * computer(cursors[i]->cur_vec_val(), cur_vec_sum); + cursors[i]->next(); + } + if (cursors[i]->cur_vec_id_ < next_cand_vec_id) { + next_cand_vec_id = cursors[i]->cur_vec_id_; + } + } + + found_cand = true; + for (size_t i = first_ne_idx; i < cursors.size(); ++i) { + if (curr_cand_score + upper_bounds[i] <= threshold) { + found_cand = false; + break; + } + cursors[i]->seek(curr_cand_vec_id); + if (cursors[i]->cur_vec_id_ == curr_cand_vec_id) { + float cur_vec_sum = bm25 ? bm25_params_->row_sums.at(cursors[i]->cur_vec_id_) : 0; + curr_cand_score += cursors[i]->q_value_ * computer(cursors[i]->cur_vec_val(), cur_vec_sum); + } + } + } + + if (curr_cand_score > threshold) { + heap.push(curr_cand_vec_id, curr_cand_score); + threshold = heap.full() ? heap.top().val : 0; + while (first_ne_idx != 0 && upper_bounds[first_ne_idx - 1] <= threshold) { + --first_ne_idx; + if (first_ne_idx == 0) { + return; + } + } + } + } + } + void refine_and_collect(const SparseRow& q_vec, MaxMinHeap& inacc_heap, size_t k, float* distances, label_t* labels, const DocValueComputer& computer) const { @@ -722,10 +838,12 @@ class InvertedIndex : public BaseInvertedIndex { } DocIdFilterByVector filter(std::move(docids)); - if (use_wand) { - search_wand(q_vec, 0, heap, filter, computer); + if constexpr (algo == InvertedIndexAlgo::DAAT_WAND) { + search_daat_wand(q_vec, 0, heap, filter, computer); + } else if constexpr (algo == InvertedIndexAlgo::DAAT_MAXSCORE) { + search_daat_maxscore(q_vec, 0, heap, filter, computer); } else { - search_brute_force(q_vec, 0, heap, filter, computer); + search_taat_naive(q_vec, 0, heap, filter, computer); } collect_result(heap, distances, labels); } @@ -762,13 +880,13 @@ class InvertedIndex : public BaseInvertedIndex { dim_it = dim_map_.insert({idx, next_dim_id_++}).first; inverted_index_ids_.emplace_back(); inverted_index_vals_.emplace_back(); - if constexpr (use_wand) { + if constexpr (algo == InvertedIndexAlgo::DAAT_WAND || algo == InvertedIndexAlgo::DAAT_MAXSCORE) { max_score_in_dim_.emplace_back(0.0f); } } inverted_index_ids_[dim_it->second].emplace_back(vec_id); inverted_index_vals_[dim_it->second].emplace_back(get_quant_val(val)); - if constexpr (use_wand) { + if constexpr (algo == InvertedIndexAlgo::DAAT_WAND || algo == InvertedIndexAlgo::DAAT_MAXSCORE) { auto score = static_cast(val); if constexpr (bm25) { score = bm25_params_->max_score_ratio * bm25_params_->wand_max_score_computer(val, row_sum); @@ -820,7 +938,7 @@ class InvertedIndex : public BaseInvertedIndex { // corresponds to the document length of each doc in the BM25 formula. Vector row_sums; - // below are used only for WAND index. + // below are used only for DAAT_WAND and DAAT_MAXSCORE algorithms. float max_score_ratio; DocValueComputer wand_max_score_computer; diff --git a/src/index/sparse/sparse_inverted_index_config.h b/src/index/sparse/sparse_inverted_index_config.h index 7c56494eb..3a7340667 100644 --- a/src/index/sparse/sparse_inverted_index_config.h +++ b/src/index/sparse/sparse_inverted_index_config.h @@ -23,6 +23,7 @@ class SparseInvertedIndexConfig : public BaseConfig { CFG_FLOAT drop_ratio_search; CFG_INT refine_factor; CFG_FLOAT wand_bm25_max_score_ratio; + CFG_STRING inverted_index_algo; KNOHWERE_DECLARE_CONFIG(SparseInvertedIndexConfig) { // NOTE: drop_ratio_build has been deprecated, it won't change anything KNOWHERE_CONFIG_DECLARE_FIELD(drop_ratio_build) @@ -61,6 +62,12 @@ class SparseInvertedIndexConfig : public BaseConfig { .for_train() .for_deserialize() .for_deserialize_from_file(); + KNOWHERE_CONFIG_DECLARE_FIELD(inverted_index_algo) + .description("inverted index algorithm") + .set_default("DAAT_MAXSCORE") + .for_train_and_search() + .for_deserialize() + .for_deserialize_from_file(); } }; // class SparseInvertedIndexConfig diff --git a/tests/ut/test_sparse.cc b/tests/ut/test_sparse.cc index c84886b32..80cd7bc4e 100644 --- a/tests/ut/test_sparse.cc +++ b/tests/ut/test_sparse.cc @@ -47,6 +47,8 @@ TEST_CASE("Test Mem Sparse Index With Float Vector", "[float metrics]") { auto metric = GENERATE(knowhere::metric::IP, knowhere::metric::BM25); + auto inverted_index_algo = GENERATE("TAAT_NAIVE", "DAAT_WAND", "DAAT_MAXSCORE"); + auto drop_ratio_search = metric == knowhere::metric::BM25 ? GENERATE(0.0, 0.1) : GENERATE(0.0, 0.3); auto version = GenTestVersionList(); @@ -62,9 +64,11 @@ TEST_CASE("Test Mem Sparse Index With Float Vector", "[float metrics]") { return json; }; - auto sparse_inverted_index_gen = [base_gen, drop_ratio_search = drop_ratio_search]() { + auto sparse_inverted_index_gen = [base_gen, drop_ratio_search = drop_ratio_search, + inverted_index_algo = inverted_index_algo]() { knowhere::Json json = base_gen(); json[knowhere::indexparam::DROP_RATIO_SEARCH] = drop_ratio_search; + json[knowhere::indexparam::INVERTED_INDEX_ALGO] = inverted_index_algo; return json; }; @@ -460,6 +464,8 @@ TEST_CASE("Test Mem Sparse Index CC", "[float metrics]") { auto query_ds = doc_vector_gen(nq, dim); + auto inverted_index_algo = GENERATE("TAAT_NAIVE", "DAAT_WAND", "DAAT_MAXSCORE"); + auto drop_ratio_search = GENERATE(0.0, 0.3); auto metric = GENERATE(knowhere::metric::IP); @@ -476,9 +482,11 @@ TEST_CASE("Test Mem Sparse Index CC", "[float metrics]") { return json; }; - auto sparse_inverted_index_gen = [base_gen, drop_ratio_search = drop_ratio_search]() { + auto sparse_inverted_index_gen = [base_gen, drop_ratio_search = drop_ratio_search, + inverted_index_algo = inverted_index_algo]() { knowhere::Json json = base_gen(); json[knowhere::indexparam::DROP_RATIO_SEARCH] = drop_ratio_search; + json[knowhere::indexparam::INVERTED_INDEX_ALGO] = inverted_index_algo; return json; };