Skip to content

Commit

Permalink
sparse: refactor approx dimension max score ratio
Browse files Browse the repository at this point in the history
1. Move the dimension max score ratio from build params to search params,
and rename it from `wand_bm25_max_score_ratio` to `dim_max_score_ratio`.

2. Remove template param `bm25` and add a new `SparseMetricType`.

3. Wrap some params of `Search()` to `InvertedIndexSearchParams`.

Signed-off-by: Shawn Wang <[email protected]>
  • Loading branch information
sparknack committed Jan 15, 2025
1 parent 45d757c commit 7747dd3
Show file tree
Hide file tree
Showing 5 changed files with 128 additions and 96 deletions.
2 changes: 1 addition & 1 deletion include/knowhere/comp/index_param.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ constexpr const char* BM25_K1 = "bm25_k1";
constexpr const char* BM25_B = "bm25_b";
// average document length
constexpr const char* BM25_AVGDL = "bm25_avgdl";
constexpr const char* WAND_BM25_MAX_SCORE_RATIO = "wand_bm25_max_score_ratio";
constexpr const char* DIM_MAX_SCORE_RATIO = "dim_max_score_ratio";
}; // namespace meta

namespace indexparam {
Expand Down
5 changes: 5 additions & 0 deletions include/knowhere/sparse_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@

namespace knowhere::sparse {

enum class SparseMetricType {
METRIC_IP = 1,
METRIC_BM25 = 2,
};

// integer type in SparseRow
using table_t = uint32_t;
// type used to represent the id of a vector in the index interface.
Expand Down
55 changes: 35 additions & 20 deletions src/index/sparse/sparse_index_node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -95,27 +95,40 @@ class SparseInvertedIndexNode : public IndexNode {
LOG_KNOWHERE_ERROR_ << "Could not search empty " << Type();
return expected<DataSetPtr>::Err(Status::empty_index, "index not loaded");
}

auto cfg = static_cast<const SparseInvertedIndexConfig&>(*config);

auto computer_or = index_->GetDocValueComputer(cfg);
if (!computer_or.has_value()) {
return expected<DataSetPtr>::Err(computer_or.error(), computer_or.what());
}
auto computer = computer_or.value();
auto nq = dataset->GetRows();
auto queries = static_cast<const sparse::SparseRow<T>*>(dataset->GetTensor());
auto k = cfg.k.value();
auto refine_factor = cfg.refine_factor.value_or(10);
auto dim_max_score_ratio = cfg.dim_max_score_ratio.value();
auto drop_ratio_search = cfg.drop_ratio_search.value_or(0.0f);
auto refine_factor = cfg.refine_factor.value_or(10);
// if no data was dropped during search, no refinement is needed.
if (drop_ratio_search == 0) {
refine_factor = 1;
}

sparse::InvertedIndexSearchParams params = {
.refine_factor = refine_factor,
.drop_ratio_search = drop_ratio_search,
.dim_max_score_ratio = dim_max_score_ratio,
.computer = computer,
};

auto queries = static_cast<const sparse::SparseRow<T>*>(dataset->GetTensor());
auto nq = dataset->GetRows();
auto k = cfg.k.value();
auto p_id = std::make_unique<sparse::label_t[]>(nq * k);
auto p_dist = std::make_unique<float[]>(nq * k);

std::vector<folly::Future<folly::Unit>> futs;
futs.reserve(nq);
for (int64_t idx = 0; idx < nq; ++idx) {
futs.emplace_back(search_pool_->push([&, idx = idx, p_id = p_id.get(), p_dist = p_dist.get()]() {
index_->Search(queries[idx], k, drop_ratio_search, p_dist + idx * k, p_id + idx * k, refine_factor,
bitset, computer);
index_->Search(queries[idx], k, p_dist + idx * k, p_id + idx * k, bitset, params);
}));
}
WaitAllSuccess(futs);
Expand Down Expand Up @@ -358,36 +371,38 @@ class SparseInvertedIndexNode : public IndexNode {
auto k1 = cfg.bm25_k1.value();
auto b = cfg.bm25_b.value();
auto avgdl = cfg.bm25_avgdl.value();
auto max_score_ratio = cfg.wand_bm25_max_score_ratio.value();

if (use_wand || cfg.inverted_index_algo.value() == "DAAT_WAND") {
auto index =
new sparse::InvertedIndex<T, uint16_t, sparse::InvertedIndexAlgo::DAAT_WAND, true, mmapped>();
index->SetBM25Params(k1, b, avgdl, max_score_ratio);
auto index = new sparse::InvertedIndex<T, uint16_t, sparse::InvertedIndexAlgo::DAAT_WAND, mmapped>(
sparse::SparseMetricType::METRIC_BM25);
index->SetBM25Params(k1, b, avgdl);
return index;
} else if (cfg.inverted_index_algo.value() == "DAAT_MAXSCORE") {
auto index =
new sparse::InvertedIndex<T, uint16_t, sparse::InvertedIndexAlgo::DAAT_MAXSCORE, true, mmapped>();
index->SetBM25Params(k1, b, avgdl, max_score_ratio);
auto index = new sparse::InvertedIndex<T, uint16_t, sparse::InvertedIndexAlgo::DAAT_MAXSCORE, mmapped>(
sparse::SparseMetricType::METRIC_BM25);
index->SetBM25Params(k1, b, avgdl);
return index;
} else if (cfg.inverted_index_algo.value() == "TAAT_NAIVE") {
auto index =
new sparse::InvertedIndex<T, uint16_t, sparse::InvertedIndexAlgo::TAAT_NAIVE, true, mmapped>();
index->SetBM25Params(k1, b, avgdl, max_score_ratio);
auto index = new sparse::InvertedIndex<T, uint16_t, sparse::InvertedIndexAlgo::TAAT_NAIVE, mmapped>(
sparse::SparseMetricType::METRIC_BM25);
index->SetBM25Params(k1, b, avgdl);
return index;
} else {
return expected<sparse::BaseInvertedIndex<T>*>::Err(Status::invalid_args,
"Invalid search algorithm for SparseInvertedIndex");
}
} else {
if (use_wand || cfg.inverted_index_algo.value() == "DAAT_WAND") {
auto index = new sparse::InvertedIndex<T, T, sparse::InvertedIndexAlgo::DAAT_WAND, false, mmapped>();
auto index = new sparse::InvertedIndex<T, T, sparse::InvertedIndexAlgo::DAAT_WAND, mmapped>(
sparse::SparseMetricType::METRIC_IP);
return index;
} else if (cfg.inverted_index_algo.value() == "DAAT_MAXSCORE") {
auto index =
new sparse::InvertedIndex<T, T, sparse::InvertedIndexAlgo::DAAT_MAXSCORE, false, mmapped>();
auto index = new sparse::InvertedIndex<T, T, sparse::InvertedIndexAlgo::DAAT_MAXSCORE, mmapped>(
sparse::SparseMetricType::METRIC_IP);
return index;
} else if (cfg.inverted_index_algo.value() == "TAAT_NAIVE") {
auto index = new sparse::InvertedIndex<T, T, sparse::InvertedIndexAlgo::TAAT_NAIVE, false, mmapped>();
auto index = new sparse::InvertedIndex<T, T, sparse::InvertedIndexAlgo::TAAT_NAIVE, mmapped>(
sparse::SparseMetricType::METRIC_IP);
return index;
} else {
return expected<sparse::BaseInvertedIndex<T>*>::Err(Status::invalid_args,
Expand Down
Loading

0 comments on commit 7747dd3

Please sign in to comment.