Skip to content

Commit

Permalink
sparse: readd GetVectorByIds support for growing segments with IP met…
Browse files Browse the repository at this point in the history
…ric (#1022)

Signed-off-by: Shawn Wang <[email protected]>
  • Loading branch information
sparknack authored Jan 10, 2025
1 parent dadbcfc commit a01d8ba
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 40 deletions.
65 changes: 57 additions & 8 deletions src/index/sparse/sparse_index_node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,13 @@

#include <sys/mman.h>

#include <exception>

#include "index/sparse/sparse_inverted_index.h"
#include "index/sparse/sparse_inverted_index_config.h"
#include "io/file_io.h"
#include "io/memory_io.h"
#include "knowhere/comp/index_param.h"
#include "knowhere/comp/thread_pool.h"
#include "knowhere/config.h"
#include "knowhere/dataset.h"
Expand Down Expand Up @@ -389,15 +392,23 @@ class SparseInvertedIndexNodeCC : public SparseInvertedIndexNode<T, use_wand> {
}

Status
Add(const DataSetPtr dataset, std::shared_ptr<Config> cfg) override {
Add(const DataSetPtr dataset, std::shared_ptr<Config> config) override {
std::unique_lock<std::mutex> lock(mutex_);
uint64_t task_id = next_task_id_++;
add_tasks_.push(task_id);

// add task is allowed to run only after all search tasks that come before it have finished.
cv_.wait(lock, [this, task_id]() { return current_task_id_ == task_id && active_readers_ == 0; });

auto res = SparseInvertedIndexNode<T, use_wand>::Add(dataset, cfg);
auto res = SparseInvertedIndexNode<T, use_wand>::Add(dataset, config);

auto cfg = static_cast<const SparseInvertedIndexConfig&>(*config);
if (IsMetricType(cfg.metric_type.value(), metric::IP)) {
// insert dataset to raw data if metric type is IP
auto data = static_cast<const sparse::SparseRow<T>*>(dataset->GetTensor());
auto rows = dataset->GetRows();
raw_data_.insert(raw_data_.end(), data, data + rows);
}

add_tasks_.pop();
current_task_id_++;
Expand Down Expand Up @@ -431,12 +442,6 @@ class SparseInvertedIndexNodeCC : public SparseInvertedIndexNode<T, use_wand> {
return SparseInvertedIndexNode<T, use_wand>::RangeSearch(dataset, std::move(cfg), bitset);
}

expected<DataSetPtr>
GetVectorByIds(const DataSetPtr dataset) const override {
ReadPermission permission(*this);
return SparseInvertedIndexNode<T, use_wand>::GetVectorByIds(dataset);
}

int64_t
Dim() const override {
ReadPermission permission(*this);
Expand All @@ -461,6 +466,49 @@ class SparseInvertedIndexNodeCC : public SparseInvertedIndexNode<T, use_wand> {
: knowhere::IndexEnum::INDEX_SPARSE_INVERTED_INDEX_CC;
}

expected<DataSetPtr>
GetVectorByIds(const DataSetPtr dataset) const override {
ReadPermission permission(*this);

if (raw_data_.empty()) {
return expected<DataSetPtr>::Err(Status::invalid_args, "GetVectorByIds failed: raw data is empty");
}

auto rows = dataset->GetRows();
auto ids = dataset->GetIds();
auto data = std::make_unique<sparse::SparseRow<T>[]>(rows);
int64_t dim = 0;

try {
for (int64_t i = 0; i < rows; ++i) {
data[i] = raw_data_[ids[i]];
dim = std::max(dim, data[i].dim());
}
} catch (std::exception& e) {
return expected<DataSetPtr>::Err(Status::invalid_args, "GetVectorByIds failed: " + std::string(e.what()));
}

auto res = GenResultDataSet(rows, dim, data.release());
res->SetIsSparse(true);

return res;
}

[[nodiscard]] bool
HasRawData(const std::string& metric_type) const override {
return IsMetricType(metric_type, metric::IP);
}

Status
Deserialize(const BinarySet& binset, std::shared_ptr<Config> config) override {
return Status::not_implemented;
}

Status
DeserializeFromFile(const std::string& filename, std::shared_ptr<Config> config) override {
return Status::not_implemented;
}

private:
struct ReadPermission {
ReadPermission(const SparseInvertedIndexNodeCC& node) : node_(node) {
Expand Down Expand Up @@ -490,6 +538,7 @@ class SparseInvertedIndexNodeCC : public SparseInvertedIndexNode<T, use_wand> {
mutable std::queue<uint64_t> add_tasks_;
mutable uint64_t next_task_id_ = 0;
mutable uint64_t current_task_id_ = 0;
mutable std::vector<sparse::SparseRow<T>> raw_data_ = {};
}; // class SparseInvertedIndexNodeCC

KNOWHERE_SIMPLE_REGISTER_SPARSE_FLOAT_GLOBAL(SPARSE_INVERTED_INDEX, SparseInvertedIndexNode, knowhere::feature::MMAP,
Expand Down
82 changes: 50 additions & 32 deletions tests/ut/test_sparse.cc
Original file line number Diff line number Diff line change
Expand Up @@ -504,41 +504,42 @@ TEST_CASE("Test Mem Sparse Index CC", "[float metrics]") {

auto test_time = 10;

SECTION("Test Search") {
using std::make_tuple;
auto [name, gen] = GENERATE_REF(table<std::string, std::function<knowhere::Json()>>({
make_tuple(knowhere::IndexEnum::INDEX_SPARSE_INVERTED_INDEX_CC, sparse_inverted_index_gen),
make_tuple(knowhere::IndexEnum::INDEX_SPARSE_WAND_CC, sparse_inverted_index_gen),
}));
using std::make_tuple;
auto [name, gen] = GENERATE_REF(table<std::string, std::function<knowhere::Json()>>({
make_tuple(knowhere::IndexEnum::INDEX_SPARSE_INVERTED_INDEX_CC, sparse_inverted_index_gen),
make_tuple(knowhere::IndexEnum::INDEX_SPARSE_WAND_CC, sparse_inverted_index_gen),
}));

auto idx = knowhere::IndexFactory::Instance().Create<knowhere::fp32>(name, version).value();
auto cfg_json = gen().dump();
CAPTURE(name, cfg_json);
knowhere::Json json = knowhere::Json::parse(cfg_json);
REQUIRE(idx.Type() == name);
// build the index with some initial data
REQUIRE(idx.Build(doc_vector_gen(nb, dim), json) == knowhere::Status::success);

auto add_task = [&]() {
auto start = std::chrono::steady_clock::now();
while (std::chrono::duration_cast<std::chrono::seconds>(std::chrono::steady_clock::now() - start).count() <
test_time) {
auto doc_ds = doc_vector_gen(nb, dim);
auto res = idx.Add(doc_ds, json);
REQUIRE(res == knowhere::Status::success);
}
};
auto idx = knowhere::IndexFactory::Instance().Create<knowhere::fp32>(name, version).value();
auto cfg_json = gen().dump();
CAPTURE(name, cfg_json);
knowhere::Json json = knowhere::Json::parse(cfg_json);
REQUIRE(idx.Type() == name);
// build the index with some initial data
auto train_ds = doc_vector_gen(nb, dim);
REQUIRE(idx.Build(train_ds, json) == knowhere::Status::success);

auto search_task = [&]() {
auto start = std::chrono::steady_clock::now();
while (std::chrono::duration_cast<std::chrono::seconds>(std::chrono::steady_clock::now() - start).count() <
test_time) {
auto results = idx.Search(query_ds, json, nullptr);
REQUIRE(results.has_value());
check_result(*results.value());
}
};
auto add_task = [&]() {
auto start = std::chrono::steady_clock::now();
while (std::chrono::duration_cast<std::chrono::seconds>(std::chrono::steady_clock::now() - start).count() <
test_time) {
auto doc_ds = doc_vector_gen(nb, dim);
auto res = idx.Add(doc_ds, json);
REQUIRE(res == knowhere::Status::success);
}
};

auto search_task = [&]() {
auto start = std::chrono::steady_clock::now();
while (std::chrono::duration_cast<std::chrono::seconds>(std::chrono::steady_clock::now() - start).count() <
test_time) {
auto results = idx.Search(query_ds, json, nullptr);
REQUIRE(results.has_value());
check_result(*results.value());
}
};

SECTION("Test Search") {
std::vector<std::future<void>> task_list;
for (int thread = 0; thread < 5; thread++) {
task_list.push_back(std::async(std::launch::async, search_task));
Expand All @@ -548,4 +549,21 @@ TEST_CASE("Test Mem Sparse Index CC", "[float metrics]") {
task.wait();
}
}

SECTION("Test GetVectorByIds") {
std::vector<int64_t> ids = {0, 1, 2};
REQUIRE(idx.HasRawData(metric));
auto results = idx.GetVectorByIds(GenIdsDataSet(3, ids));
REQUIRE(results.has_value());
auto xb = (knowhere::sparse::SparseRow<float>*)train_ds->GetTensor();
auto res_data = (knowhere::sparse::SparseRow<float>*)results.value()->GetTensor();
for (int i = 0; i < 3; ++i) {
const auto& truth_row = xb[i];
const auto& res_row = res_data[i];
REQUIRE(truth_row.size() == res_row.size());
for (size_t j = 0; j < truth_row.size(); ++j) {
REQUIRE(truth_row[j] == res_row[j]);
}
}
}
}

0 comments on commit a01d8ba

Please sign in to comment.