Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement the interfaces for range search, filter in pyramid #310

Merged
merged 3 commits into from
Jan 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
198 changes: 158 additions & 40 deletions src/index/pyramid.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,115 @@

namespace vsag {

Binary
binaryset_to_binary(const BinarySet binary_set) {
/*
* The serialized layout of the Binary data in memory will be as follows:
* | key_size_0 | key_0 (L_0 bytes) | binary_size_0 | binary_data_0 (S_0 bytes) |
* | key_size_1 | key_1 (L_1 bytes) | binary_size_1 | binary_data_1 (S_1 bytes) |
* | ... | ... | ... | ... |
* | key_size_(N-1) | key_(N-1) (L_(N-1) bytes) | binary_size_(N-1) | binary_data_(N-1) (S_(N-1) bytes) |
* Where:
* - `key_size_k`: size of the k-th key (in bytes)
* - `key_k`: the actual k-th key data (length L_k)
* - `binary_size_k`: size of the binary data associated with the k-th key (in bytes)
* - `binary_data_k`: the actual binary data contents (length S_k)
* - N: total number of keys in the BinarySet
*/
size_t total_size = 0;
auto keys = binary_set.GetKeys();

for (const auto& key : keys) {
total_size += sizeof(size_t) + key.size();
total_size += sizeof(size_t);
total_size += binary_set.Get(key).size;
}

Binary result;
result.data = std::shared_ptr<int8_t[]>(new int8_t[total_size]);
result.size = total_size;

size_t offset = 0;

for (const auto& key : keys) {
size_t key_size = key.size();
memcpy(result.data.get() + offset, &key_size, sizeof(size_t));
offset += sizeof(size_t);
memcpy(result.data.get() + offset, key.data(), key_size);
offset += key_size;

Binary binary = binary_set.Get(key);
memcpy(result.data.get() + offset, &binary.size, sizeof(size_t));
offset += sizeof(size_t);
memcpy(result.data.get() + offset, binary.data.get(), binary.size);
offset += binary.size;
}

return result;
}

BinarySet
binary_to_binaryset(const Binary binary) {
/*
* The Binary structure is serialized in the following layout:
* | key_size (sizeof(size_t)) | key (of length key_size) | binary_size (sizeof(size_t)) | binary data (of length binary_size) |
* Each key and its associated binary data are sequentially stored in the Binary object's data array,
* and this information guides the deserialization process here.
*/
BinarySet binary_set;
size_t offset = 0;

while (offset < binary.size) {
size_t key_size;
memcpy(&key_size, binary.data.get() + offset, sizeof(size_t));
offset += sizeof(size_t);

std::string key(reinterpret_cast<const char*>(binary.data.get() + offset), key_size);
offset += key_size;

size_t binary_size;
memcpy(&binary_size, binary.data.get() + offset, sizeof(size_t));
offset += sizeof(size_t);

Binary new_binary;
new_binary.size = binary_size;
new_binary.data = std::shared_ptr<int8_t[]>(new int8_t[binary_size]);
memcpy(new_binary.data.get(), binary.data.get() + offset, binary_size);
offset += binary_size;

binary_set.Set(key, new_binary);
}

return binary_set;
}

Check warning on line 100 in src/index/pyramid.cpp

View check run for this annotation

Codecov / codecov/patch

src/index/pyramid.cpp#L100

Added line #L100 was not covered by tests

ReaderSet
reader_to_readerset(std::shared_ptr<Reader> reader) {
ReaderSet reader_set;
size_t offset = 0;

while (offset < reader->Size()) {
size_t key_size;
reader->Read(offset, sizeof(size_t), &key_size);
offset += sizeof(size_t);
std::shared_ptr<char[]> key_chars = std::shared_ptr<char[]>(new char[key_size]);
reader->Read(offset, key_size, key_chars.get());
std::string key(key_chars.get(), key_size);
offset += key_size;

size_t binary_size;
reader->Read(offset, sizeof(size_t), &binary_size);
offset += sizeof(size_t);

auto new_reader = std::make_shared<SubReader>(reader, offset, binary_size);
offset += binary_size;

reader_set.Set(key, new_reader);
}

return reader_set;
}

Check warning on line 127 in src/index/pyramid.cpp

View check run for this annotation

Codecov / codecov/patch

src/index/pyramid.cpp#L127

Added line #L127 was not covered by tests

template <typename T>
using Deque = std::deque<T, vsag::AllocatorWrapper<T>>;

Expand Down Expand Up @@ -85,10 +194,10 @@
}

tl::expected<DatasetPtr, Error>
Pyramid::KnnSearch(const DatasetPtr& query,
int64_t k,
const std::string& parameters,
BitsetPtr invalid) const {
Pyramid::knn_search(const DatasetPtr& query,
int64_t k,
const std::string& parameters,
const SearchFunc& search_func) const {
auto path = query->GetPaths(); // TODO(inabao): provide different search modes.

std::string current_path = path[0];
Expand Down Expand Up @@ -118,7 +227,7 @@
auto node = candidate_indexes.front();
candidate_indexes.pop_front();
if (node->index) {
auto result = node->index->KnnSearch(query, k, parameters, invalid);
auto result = search_func(node->index);
if (result.has_value()) {
DatasetPtr r = result.value();
for (int i = 0; i < r->GetDim(); ++i) {
Expand Down Expand Up @@ -162,52 +271,61 @@
return result;
}

tl::expected<DatasetPtr, Error>
Pyramid::KnnSearch(const DatasetPtr& query,
int64_t k,
const std::string& parameters,
const std::function<bool(int64_t)>& filter) const {
return {};
}

tl::expected<DatasetPtr, Error>
Pyramid::RangeSearch(const DatasetPtr& query,
float radius,
const std::string& parameters,
int64_t limited_size) const {
return {};
}

tl::expected<DatasetPtr, Error>
Pyramid::RangeSearch(const DatasetPtr& query,
float radius,
const std::string& parameters,
BitsetPtr invalid,
int64_t limited_size) const {
return {};
}

tl::expected<DatasetPtr, Error>
Pyramid::RangeSearch(const DatasetPtr& query,
float radius,
const std::string& parameters,
const std::function<bool(int64_t)>& filter,
int64_t limited_size) const {
return {};
}

tl::expected<BinarySet, Error>
Pyramid::Serialize() const {
return {};
BinarySet binary_set;
for (const auto& root_index : indexes_) {
std::string path = root_index.first;
std::vector<std::pair<std::string, std::shared_ptr<IndexNode>>> need_serialize_indexes;
need_serialize_indexes.emplace_back(path, root_index.second);
while (not need_serialize_indexes.empty()) {
auto [current_path, index_node] = need_serialize_indexes.back();
need_serialize_indexes.pop_back();
if (index_node->index) {
auto serialize_result = index_node->index->Serialize();
if (not serialize_result.has_value()) {
return tl::unexpected(serialize_result.error());

Check warning on line 287 in src/index/pyramid.cpp

View check run for this annotation

Codecov / codecov/patch

src/index/pyramid.cpp#L287

Added line #L287 was not covered by tests
}
binary_set.Set(current_path, binaryset_to_binary(serialize_result.value()));
}
for (const auto& sub_index_node : index_node->children) {
need_serialize_indexes.emplace_back(
current_path + PART_OCTOTHORPE + sub_index_node.first, sub_index_node.second);
}
}
}
return binary_set;
}

tl::expected<void, Error>
Pyramid::Deserialize(const BinarySet& binary_set) {
auto keys = binary_set.GetKeys();
for (const auto& path : keys) {
const auto& binary = binary_set.Get(path);
auto path_slices = split(path, PART_OCTOTHORPE);
std::shared_ptr<IndexNode> node = try_get_node_with_init(indexes_, path_slices[0]);
for (int j = 1; j < path_slices.size(); ++j) {
node = try_get_node_with_init(node->children, path_slices[j]);
}
node->CreateIndex(pyramid_param_.index_builder);
node->index->Deserialize(binary_to_binaryset(binary));
}
return {};
}

tl::expected<void, Error>
Pyramid::Deserialize(const ReaderSet& reader_set) {
auto keys = reader_set.GetKeys();
for (const auto& path : keys) {
const auto& reader = reader_set.Get(path);
auto path_slices = split(path, PART_OCTOTHORPE);
std::shared_ptr<IndexNode> node = try_get_node_with_init(indexes_, path_slices[0]);
for (int j = 1; j < path_slices.size(); ++j) {
node = try_get_node_with_init(node->children, path_slices[j]);
}
node->CreateIndex(pyramid_param_.index_builder);
node->index->Deserialize(reader_to_readerset(reader));
}
return {};
}

Expand Down
87 changes: 82 additions & 5 deletions src/index/pyramid.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,49 @@

#include <utility>

#include "base_filter_functor.h"
#include "logger.h"
#include "pyramid_zparameters.h"
#include "safe_allocator.h"

namespace vsag {

class SubReader : public Reader {
public:
SubReader(std::shared_ptr<Reader> parent_reader, uint64_t start_pos, uint64_t size)
: parent_reader_(std::move(parent_reader)), size_(size), start_pos_(start_pos) {
}

void
Read(uint64_t offset, uint64_t len, void* dest) override {
if (offset + len > size_)
throw std::out_of_range("Read out of range.");

Check warning on line 36 in src/index/pyramid.h

View check run for this annotation

Codecov / codecov/patch

src/index/pyramid.h#L36

Added line #L36 was not covered by tests
parent_reader_->Read(offset + start_pos_, len, dest);
}

void
AsyncRead(uint64_t offset, uint64_t len, void* dest, CallBack callback) override {
throw std::runtime_error("No support for SubReader AsyncRead");

Check warning on line 42 in src/index/pyramid.h

View check run for this annotation

Codecov / codecov/patch

src/index/pyramid.h#L41-L42

Added lines #L41 - L42 were not covered by tests
}

uint64_t
Size() const override {
return size_;
}

private:
std::shared_ptr<Reader> parent_reader_;
uint64_t size_;
uint64_t start_pos_;
};

Binary
binaryset_to_binary(const BinarySet binary_set);
BinarySet
binary_to_binaryset(const Binary binary);
ReaderSet
reader_to_readerset(std::shared_ptr<Reader> reader);

struct IndexNode {
std::shared_ptr<Index> index{nullptr};
UnorderedMap<std::string, std::shared_ptr<IndexNode>> children;
Expand All @@ -35,6 +73,8 @@
}
};

using SearchFunc = std::function<tl::expected<DatasetPtr, Error>(IndexPtr)>;

class Pyramid : public Index {
public:
Pyramid(PyramidParameters pyramid_param, const IndexCommonParam& commom_param)
Expand All @@ -55,33 +95,64 @@
KnnSearch(const DatasetPtr& query,
int64_t k,
const std::string& parameters,
BitsetPtr invalid = nullptr) const override;
BitsetPtr invalid = nullptr) const override {
SearchFunc search_func = [&](IndexPtr index) {
return index->KnnSearch(query, k, parameters, invalid);
};
SAFE_CALL(return this->knn_search(query, k, parameters, search_func);)
}

tl::expected<DatasetPtr, Error>
KnnSearch(const DatasetPtr& query,
int64_t k,
const std::string& parameters,
const std::function<bool(int64_t)>& filter) const override;
const std::function<bool(int64_t)>& filter) const override {
SearchFunc search_func = [&](IndexPtr index) {
return index->KnnSearch(query, k, parameters, filter);
};
SAFE_CALL(return this->knn_search(query, k, parameters, search_func);)
}

tl::expected<DatasetPtr, Error>
RangeSearch(const DatasetPtr& query,
float radius,
const std::string& parameters,
int64_t limited_size = -1) const override;
int64_t limited_size = -1) const override {
SearchFunc search_func = [&](IndexPtr index) {
return index->RangeSearch(query, radius, parameters, limited_size);
};
int64_t final_limit =
limited_size == -1 ? std::numeric_limits<int64_t>::max() : limited_size;
SAFE_CALL(return this->knn_search(query, final_limit, parameters, search_func);)
}

tl::expected<DatasetPtr, Error>
RangeSearch(const DatasetPtr& query,
float radius,
const std::string& parameters,
BitsetPtr invalid,
int64_t limited_size = -1) const override;
int64_t limited_size = -1) const override {
SearchFunc search_func = [&](IndexPtr index) {
return index->RangeSearch(query, radius, parameters, invalid, limited_size);
};

Check warning on line 137 in src/index/pyramid.h

View check run for this annotation

Codecov / codecov/patch

src/index/pyramid.h#L135-L137

Added lines #L135 - L137 were not covered by tests
int64_t final_limit =
limited_size == -1 ? std::numeric_limits<int64_t>::max() : limited_size;
SAFE_CALL(return this->knn_search(query, final_limit, parameters, search_func);)
}

Check warning on line 141 in src/index/pyramid.h

View check run for this annotation

Codecov / codecov/patch

src/index/pyramid.h#L139-L141

Added lines #L139 - L141 were not covered by tests

tl::expected<DatasetPtr, Error>
RangeSearch(const DatasetPtr& query,
float radius,
const std::string& parameters,
const std::function<bool(int64_t)>& filter,
int64_t limited_size = -1) const override;
int64_t limited_size = -1) const override {
SearchFunc search_func = [&](IndexPtr index) {
return index->RangeSearch(query, radius, parameters, filter, limited_size);
};

Check warning on line 151 in src/index/pyramid.h

View check run for this annotation

Codecov / codecov/patch

src/index/pyramid.h#L149-L151

Added lines #L149 - L151 were not covered by tests
int64_t final_limit =
limited_size == -1 ? std::numeric_limits<int64_t>::max() : limited_size;
SAFE_CALL(return this->knn_search(query, final_limit, parameters, search_func);)
}

Check warning on line 155 in src/index/pyramid.h

View check run for this annotation

Codecov / codecov/patch

src/index/pyramid.h#L153-L155

Added lines #L153 - L155 were not covered by tests

tl::expected<BinarySet, Error>
Serialize() const override;
Expand All @@ -99,6 +170,12 @@
GetMemoryUsage() const override;

private:
tl::expected<DatasetPtr, Error>
knn_search(const DatasetPtr& query,
int64_t k,
const std::string& parameters,
const SearchFunc& search_func) const;

inline std::shared_ptr<IndexNode>
try_get_node_with_init(UnorderedMap<std::string, std::shared_ptr<IndexNode>>& index_map,
const std::string& key) {
Expand Down
7 changes: 5 additions & 2 deletions tests/fixtures/test_dataset.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -300,8 +300,11 @@ TestDataset::CreateTestDataset(uint64_t dim,
dataset->range_ground_truth_ = dataset->ground_truth_;
dataset->range_radius_.resize(query_count);
for (uint64_t i = 0; i < query_count; ++i) {
dataset->range_radius_[i] = 0.5f * (result.first[i * count + dataset->top_k] +
result.first[i * count + dataset->top_k - 1]);
dataset->range_radius_[i] =
0.5f * (dataset->range_ground_truth_
->GetDistances()[i * dataset->top_k + dataset->top_k - 1] +
dataset->range_ground_truth_
->GetDistances()[i * dataset->top_k + dataset->top_k - 2]);
}
delete[] result.first;
delete[] result.second;
Expand Down
Loading
Loading