Skip to content

Commit

Permalink
Not use template in PrecomputedDistanceIterator
Browse files Browse the repository at this point in the history
Signed-off-by: min.tian <[email protected]>
  • Loading branch information
alwayslove2013 committed Dec 31, 2024
1 parent 92f234b commit cf9d25d
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 43 deletions.
34 changes: 6 additions & 28 deletions include/knowhere/index/index_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -580,10 +580,9 @@ class IndexIterator : public IndexNode::iterator {
// for all types of iterators in the `ANNIterator` interface, moving heavy computations to the first '->Next()' call.
// This way, the iterator initialization does not need to perform any concurrent acceleration, and the search pool
// only needs to handle the heavy work of `->Next()`
template <typename Compute_Dist_Func>
class PrecomputedDistanceIterator : public IndexNode::iterator {
public:
PrecomputedDistanceIterator(Compute_Dist_Func compute_dist_func, bool larger_is_closer,
PrecomputedDistanceIterator(std::function<std::vector<DistId>()> compute_dist_func, bool larger_is_closer,
bool use_knowhere_search_pool = true)
: compute_dist_func_(compute_dist_func),
larger_is_closer_(larger_is_closer),
Expand Down Expand Up @@ -631,42 +630,21 @@ class PrecomputedDistanceIterator : public IndexNode::iterator {
std::vector<folly::Future<folly::Unit>> futs;
futs.emplace_back(ThreadPool::GetGlobalSearchThreadPool()->push([&]() {
ThreadPool::ScopedSearchOmpSetter setter(1);
compute_all_dist_ids();
results_ = compute_dist_func_();
}));
WaitAllSuccess(futs);
#else
compute_all_dist_ids();
results_ = compute_dist_func_();
#endif
} else {
compute_all_dist_ids();
results_ = compute_dist_func_();
}
sort_size_ = get_sort_size(results_.size());
sort_next();
initialized_ = true;
}

private:
void
compute_all_dist_ids() {
using ReturnType = std::invoke_result_t<Compute_Dist_Func>;
if constexpr (std::is_same_v<ReturnType, std::vector<DistId>>) {
results_ = compute_dist_func_();
} else if constexpr (std::is_same_v<ReturnType, std::vector<float>>) {
// From a list of distances with index being id, filtering out zero distances.
std::vector<float> dists = compute_dist_func_();
// 30% is a ratio guesstimate of non-zero distances: probability of 2 random sparse splade vectors(100 non
// zero dims out of 30000 total dims) sharing at least 1 common non-zero dimension.
results_.reserve(dists.size() * 0.3);
for (size_t i = 0; i < dists.size(); i++) {
if (dists[i] != 0) {
results_.emplace_back((int64_t)i, dists[i]);
}
}
} else {
throw std::runtime_error("unknown compute_dist_func");
}
sort_size_ = get_sort_size(results_.size());
}

static inline size_t
get_sort_size(size_t rows) {
return std::max((size_t)50000, rows / 10);
Expand All @@ -690,7 +668,7 @@ class PrecomputedDistanceIterator : public IndexNode::iterator {
sorted_ = current_end;
}

Compute_Dist_Func compute_dist_func_;
std::function<std::vector<DistId>()> compute_dist_func_;
const bool larger_is_closer_;
bool use_knowhere_search_pool_ = true;
bool initialized_ = false;
Expand Down
7 changes: 3 additions & 4 deletions src/common/comp/brute_force.cc
Original file line number Diff line number Diff line change
Expand Up @@ -855,8 +855,8 @@ BruteForce::AnnIterator(const DataSetPtr base_dataset, const DataSetPtr query_da
}
return distances_ids;
};
vec[i] = std::make_shared<PrecomputedDistanceIterator<decltype(compute_dist_func)>>(
compute_dist_func, larger_is_closer, use_knowhere_search_pool);
vec[i] = std::make_shared<PrecomputedDistanceIterator>(compute_dist_func, larger_is_closer,
use_knowhere_search_pool);
}
} catch (const std::exception& e) {
return expected<std::vector<IndexNode::IteratorPtr>>::Err(Status::brute_force_inner_error, e.what());
Expand Down Expand Up @@ -949,8 +949,7 @@ BruteForce::AnnIterator<knowhere::sparse::SparseRow<float>>(const DataSetPtr bas
return distances_ids;
};

vec[i] = std::make_shared<PrecomputedDistanceIterator<decltype(compute_dist_func)>>(
compute_dist_func, true, use_knowhere_search_pool);
vec[i] = std::make_shared<PrecomputedDistanceIterator>(compute_dist_func, true, use_knowhere_search_pool);
}
} catch (const std::exception& e) {
return expected<std::vector<IndexNode::IteratorPtr>>::Err(Status::brute_force_inner_error, e.what());
Expand Down
32 changes: 21 additions & 11 deletions src/index/sparse/sparse_index_node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -123,11 +123,10 @@ class SparseInvertedIndexNode : public IndexNode {
}

private:
template <typename Compute_Dist_Func>
class RefineIterator : public IndexIterator {
public:
RefineIterator(const sparse::BaseInvertedIndex<T>* index, sparse::SparseRow<T>&& query,
std::shared_ptr<PrecomputedDistanceIterator<Compute_Dist_Func>> precomputed_it,
std::shared_ptr<PrecomputedDistanceIterator> precomputed_it,
const sparse::DocValueComputer<T>& computer, bool use_knowhere_search_pool = true,
const float refine_ratio = 0.5f)
: IndexIterator(true, use_knowhere_search_pool, refine_ratio),
Expand Down Expand Up @@ -160,7 +159,7 @@ class SparseInvertedIndexNode : public IndexNode {
const sparse::BaseInvertedIndex<T>* index_;
sparse::SparseRow<T> query_;
const sparse::DocValueComputer<T> computer_;
std::shared_ptr<PrecomputedDistanceIterator<Compute_Dist_Func>> precomputed_it_;
std::shared_ptr<PrecomputedDistanceIterator> precomputed_it_;
bool first_return_ = true;
};

Expand Down Expand Up @@ -193,20 +192,31 @@ class SparseInvertedIndexNode : public IndexNode {
for (int i = 0; i < nq; ++i) {
// Heavy computations with `compute_dist_func` will be deferred until the first call to
// 'Iterator->Next()'.
auto compute_dist_func = [=]() -> std::vector<float> {
auto compute_dist_func = [=]() -> std::vector<DistId> {
auto queries = static_cast<const sparse::SparseRow<T>*>(dataset->GetTensor());
return index_->GetAllDistances(queries[i], drop_ratio_search, bitset, computer);
std::vector<float> distances =
index_->GetAllDistances(queries[i], drop_ratio_search, bitset, computer);
std::vector<DistId> distances_ids;
// 30% is a ratio guesstimate of non-zero distances: probability of 2 random sparse splade
// vectors(100 non zero dims out of 30000 total dims) sharing at least 1 common non-zero
// dimension.
distances_ids.reserve(distances.size() * 0.3);
for (size_t i = 0; i < distances.size(); i++) {
if (distances[i] != 0) {
distances_ids.emplace_back((int64_t)i, distances[i]);
}
}
return distances_ids;
};
if (!approximated || queries[i].size() == 0) {
auto it = std::make_shared<PrecomputedDistanceIterator<decltype(compute_dist_func)>>(
compute_dist_func, true, use_knowhere_search_pool);
auto it = std::make_shared<PrecomputedDistanceIterator>(compute_dist_func, true,
use_knowhere_search_pool);
vec[i] = it;
} else {
sparse::SparseRow<T> query_copy(queries[i]);
auto it = std::make_shared<PrecomputedDistanceIterator<decltype(compute_dist_func)>>(
compute_dist_func, true, false);
vec[i] = std::make_shared<RefineIterator<decltype(compute_dist_func)>>(
index_, std::move(query_copy), it, computer, use_knowhere_search_pool);
auto it = std::make_shared<PrecomputedDistanceIterator>(compute_dist_func, true, false);
vec[i] = std::make_shared<RefineIterator>(index_, std::move(query_copy), it, computer,
use_knowhere_search_pool);
}
}
} catch (const std::exception& e) {
Expand Down

0 comments on commit cf9d25d

Please sign in to comment.