Skip to content

Commit

Permalink
Signed-off-by: jinjiabao.jjb <[email protected]>
Browse files Browse the repository at this point in the history
  • Loading branch information
jinjiabao.jjb committed Jan 20, 2025
1 parent 04283a8 commit f2ea40e
Show file tree
Hide file tree
Showing 5 changed files with 261 additions and 49 deletions.
180 changes: 140 additions & 40 deletions src/index/pyramid.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,97 @@

namespace vsag {

// Function to convert BinarySet to a Binary
Binary
binaryset_to_binary(const BinarySet binary_set) {
size_t total_size = 0;
auto keys = binary_set.GetKeys();

for (const auto& key : keys) {
total_size += sizeof(size_t) + key.size();
total_size += sizeof(size_t);
total_size += binary_set.Get(key).size;
}

Binary result;
result.data = std::shared_ptr<int8_t[]>(new int8_t[total_size]);
result.size = total_size;

size_t offset = 0;

for (const auto& key : keys) {
size_t key_size = key.size();
memcpy(result.data.get() + offset, &key_size, sizeof(size_t));
offset += sizeof(size_t);
memcpy(result.data.get() + offset, key.data(), key_size);
offset += key_size;

Binary binary = binary_set.Get(key);
memcpy(result.data.get() + offset, &binary.size, sizeof(size_t));
offset += sizeof(size_t);
memcpy(result.data.get() + offset, binary.data.get(), binary.size);
offset += binary.size;
}

return result;
}

BinarySet
binary_to_binaryset(const Binary binary) {
BinarySet binary_set;
size_t offset = 0;

while (offset < binary.size) {
size_t key_size;
memcpy(&key_size, binary.data.get() + offset, sizeof(size_t));
offset += sizeof(size_t);

std::string key(reinterpret_cast<const char*>(binary.data.get() + offset), key_size);
offset += key_size;

size_t binary_size;
memcpy(&binary_size, binary.data.get() + offset, sizeof(size_t));
offset += sizeof(size_t);

Binary new_binary;
new_binary.size = binary_size;
new_binary.data = std::shared_ptr<int8_t[]>(new int8_t[binary_size]);
memcpy(new_binary.data.get(), binary.data.get() + offset, binary_size);
offset += binary_size;

binary_set.Set(key, new_binary);
}

return binary_set;
}

Check warning on line 82 in src/index/pyramid.cpp

View check run for this annotation

Codecov / codecov/patch

src/index/pyramid.cpp#L82

Added line #L82 was not covered by tests

ReaderSet
reader_to_readerset(std::shared_ptr<Reader> reader) {
ReaderSet reader_set;
size_t offset = 0;

while (offset < reader->Size()) {
size_t key_size;
reader->Read(offset, sizeof(size_t), &key_size);
offset += sizeof(size_t);
std::shared_ptr<char[]> key_chars = std::shared_ptr<char[]>(new char[key_size]);
reader->Read(offset, key_size, key_chars.get());
std::string key(key_chars.get(), key_size);
offset += key_size;

size_t binary_size;
reader->Read(offset, sizeof(size_t), &binary_size);
offset += sizeof(size_t);

auto new_eader = std::make_shared<SubReader>(reader, offset, binary_size);
offset += binary_size;

reader_set.Set(key, new_eader);
}

return reader_set;
}

Check warning on line 109 in src/index/pyramid.cpp

View check run for this annotation

Codecov / codecov/patch

src/index/pyramid.cpp#L109

Added line #L109 was not covered by tests

template <typename T>
using Deque = std::deque<T, vsag::AllocatorWrapper<T>>;

Expand Down Expand Up @@ -85,10 +176,10 @@ Pyramid::Add(const DatasetPtr& base) {
}

tl::expected<DatasetPtr, Error>
Pyramid::KnnSearch(const DatasetPtr& query,
int64_t k,
const std::string& parameters,
BitsetPtr invalid) const {
Pyramid::knn_search(const DatasetPtr& query,
int64_t k,
const std::string& parameters,
const SearchFunc& search_func) const {
auto path = query->GetPaths(); // TODO(inabao): provide different search modes.

std::string current_path = path[0];
Expand Down Expand Up @@ -118,7 +209,7 @@ Pyramid::KnnSearch(const DatasetPtr& query,
auto node = candidate_indexes.front();
candidate_indexes.pop_front();
if (node->index) {
auto result = node->index->KnnSearch(query, k, parameters, invalid);
auto result = search_func(node->index);
if (result.has_value()) {
DatasetPtr r = result.value();
for (int i = 0; i < r->GetDim(); ++i) {
Expand Down Expand Up @@ -160,52 +251,61 @@ Pyramid::KnnSearch(const DatasetPtr& query,
return result;
}

tl::expected<DatasetPtr, Error>
Pyramid::KnnSearch(const DatasetPtr& query,
int64_t k,
const std::string& parameters,
const std::function<bool(int64_t)>& filter) const {
return {};
}

tl::expected<DatasetPtr, Error>
Pyramid::RangeSearch(const DatasetPtr& query,
float radius,
const std::string& parameters,
int64_t limited_size) const {
return {};
}

tl::expected<DatasetPtr, Error>
Pyramid::RangeSearch(const DatasetPtr& query,
float radius,
const std::string& parameters,
BitsetPtr invalid,
int64_t limited_size) const {
return {};
}

tl::expected<DatasetPtr, Error>
Pyramid::RangeSearch(const DatasetPtr& query,
float radius,
const std::string& parameters,
const std::function<bool(int64_t)>& filter,
int64_t limited_size) const {
return {};
}

tl::expected<BinarySet, Error>
Pyramid::Serialize() const {
return {};
BinarySet binary_set;
for (const auto& root_index : indexes_) {
std::string path = root_index.first;
std::vector<std::pair<std::string, std::shared_ptr<IndexNode>>> need_serialize_indexes;
need_serialize_indexes.emplace_back(path, root_index.second);
while (not need_serialize_indexes.empty()) {
auto [current_path, index_node] = need_serialize_indexes.back();
need_serialize_indexes.pop_back();
if (index_node->index) {
auto serialize_result = index_node->index->Serialize();
if (not serialize_result.has_value()) {
return tl::unexpected(serialize_result.error());

Check warning on line 267 in src/index/pyramid.cpp

View check run for this annotation

Codecov / codecov/patch

src/index/pyramid.cpp#L267

Added line #L267 was not covered by tests
}
binary_set.Set(current_path, binaryset_to_binary(serialize_result.value()));
}
for (const auto& sub_index_node : index_node->children) {
need_serialize_indexes.emplace_back(
current_path + PART_OCTOTHORPE + sub_index_node.first, sub_index_node.second);
}
}
}
return binary_set;
}

tl::expected<void, Error>
Pyramid::Deserialize(const BinarySet& binary_set) {
auto keys = binary_set.GetKeys();
for (const auto& path : keys) {
const auto& binary = binary_set.Get(path);
auto path_slices = split(path, PART_OCTOTHORPE);
std::shared_ptr<IndexNode> node = try_get_node_with_init(indexes_, path_slices[0]);
for (int j = 1; j < path_slices.size(); ++j) {
node = try_get_node_with_init(node->children, path_slices[j]);
}
node->CreateIndex(pyramid_param_.index_builder);
node->index->Deserialize(binary_to_binaryset(binary));
}
return {};
}

tl::expected<void, Error>
Pyramid::Deserialize(const ReaderSet& reader_set) {
auto keys = reader_set.GetKeys();
for (const auto& path : keys) {
const auto& reader = reader_set.Get(path);
auto path_slices = split(path, PART_OCTOTHORPE);
std::shared_ptr<IndexNode> node = try_get_node_with_init(indexes_, path_slices[0]);
for (int j = 1; j < path_slices.size(); ++j) {
node = try_get_node_with_init(node->children, path_slices[j]);
}
node->CreateIndex(pyramid_param_.index_builder);
node->index->Deserialize(reader_to_readerset(reader));
}
return {};
}

Expand Down
86 changes: 81 additions & 5 deletions src/index/pyramid.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,48 @@

#include <utility>

#include "base_filter_functor.h"
#include "logger.h"
#include "pyramid_zparameters.h"
#include "safe_allocator.h"

namespace vsag {

class SubReader : public Reader {
public:
SubReader(std::shared_ptr<Reader> parrent_reader, uint64_t start_pos, uint64_t size)
: parrent_reader_(std::move(parrent_reader)), size_(size), start_pos_(start_pos) {
}

void
Read(uint64_t offset, uint64_t len, void* dest) override {
if (offset + len > size_)
throw std::out_of_range("Read out of range.");

Check warning on line 36 in src/index/pyramid.h

View check run for this annotation

Codecov / codecov/patch

src/index/pyramid.h#L36

Added line #L36 was not covered by tests
parrent_reader_->Read(offset + start_pos_, len, dest);
}

void
AsyncRead(uint64_t offset, uint64_t len, void* dest, CallBack callback) override {
}

Check warning on line 42 in src/index/pyramid.h

View check run for this annotation

Codecov / codecov/patch

src/index/pyramid.h#L41-L42

Added lines #L41 - L42 were not covered by tests

uint64_t
Size() const override {
return size_;
}

private:
std::shared_ptr<Reader> parrent_reader_;
uint64_t size_;
uint64_t start_pos_;
};

Binary
binaryset_to_binary(const BinarySet binary_set);
BinarySet
binary_to_binaryset(const Binary binary);
ReaderSet
reader_to_readerset(std::shared_ptr<Reader> reader);

struct IndexNode {
std::shared_ptr<Index> index{nullptr};
UnorderedMap<std::string, std::shared_ptr<IndexNode>> children;
Expand All @@ -35,6 +72,8 @@ struct IndexNode {
}
};

using SearchFunc = std::function<tl::expected<DatasetPtr, Error>(IndexPtr)>;

class Pyramid : public Index {
public:
Pyramid(PyramidParameters pyramid_param, const IndexCommonParam& commom_param)
Expand All @@ -55,33 +94,64 @@ class Pyramid : public Index {
KnnSearch(const DatasetPtr& query,
int64_t k,
const std::string& parameters,
BitsetPtr invalid = nullptr) const override;
BitsetPtr invalid = nullptr) const override {
SearchFunc search_func = [&](IndexPtr index) {
return index->KnnSearch(query, k, parameters, invalid);
};
SAFE_CALL(return this->knn_search(query, k, parameters, search_func);)
}

tl::expected<DatasetPtr, Error>
KnnSearch(const DatasetPtr& query,
int64_t k,
const std::string& parameters,
const std::function<bool(int64_t)>& filter) const override;
const std::function<bool(int64_t)>& filter) const override {
SearchFunc search_func = [&](IndexPtr index) {
return index->KnnSearch(query, k, parameters, filter);
};
SAFE_CALL(return this->knn_search(query, k, parameters, search_func);)
}

tl::expected<DatasetPtr, Error>
RangeSearch(const DatasetPtr& query,
float radius,
const std::string& parameters,
int64_t limited_size = -1) const override;
int64_t limited_size = -1) const override {
SearchFunc search_func = [&](IndexPtr index) {
return index->RangeSearch(query, radius, parameters, limited_size);
};
int64_t final_limit =
limited_size == -1 ? std::numeric_limits<int64_t>::max() : limited_size;
SAFE_CALL(return this->knn_search(query, final_limit, parameters, search_func);)
}

tl::expected<DatasetPtr, Error>
RangeSearch(const DatasetPtr& query,
float radius,
const std::string& parameters,
BitsetPtr invalid,
int64_t limited_size = -1) const override;
int64_t limited_size = -1) const override {
SearchFunc search_func = [&](IndexPtr index) {
return index->RangeSearch(query, radius, parameters, invalid, limited_size);
};

Check warning on line 136 in src/index/pyramid.h

View check run for this annotation

Codecov / codecov/patch

src/index/pyramid.h#L134-L136

Added lines #L134 - L136 were not covered by tests
int64_t final_limit =
limited_size == -1 ? std::numeric_limits<int64_t>::max() : limited_size;
SAFE_CALL(return this->knn_search(query, final_limit, parameters, search_func);)
}

Check warning on line 140 in src/index/pyramid.h

View check run for this annotation

Codecov / codecov/patch

src/index/pyramid.h#L138-L140

Added lines #L138 - L140 were not covered by tests

tl::expected<DatasetPtr, Error>
RangeSearch(const DatasetPtr& query,
float radius,
const std::string& parameters,
const std::function<bool(int64_t)>& filter,
int64_t limited_size = -1) const override;
int64_t limited_size = -1) const override {
SearchFunc search_func = [&](IndexPtr index) {
return index->RangeSearch(query, radius, parameters, filter, limited_size);
};

Check warning on line 150 in src/index/pyramid.h

View check run for this annotation

Codecov / codecov/patch

src/index/pyramid.h#L148-L150

Added lines #L148 - L150 were not covered by tests
int64_t final_limit =
limited_size == -1 ? std::numeric_limits<int64_t>::max() : limited_size;
SAFE_CALL(return this->knn_search(query, final_limit, parameters, search_func);)
}

Check warning on line 154 in src/index/pyramid.h

View check run for this annotation

Codecov / codecov/patch

src/index/pyramid.h#L152-L154

Added lines #L152 - L154 were not covered by tests

tl::expected<BinarySet, Error>
Serialize() const override;
Expand All @@ -99,6 +169,12 @@ class Pyramid : public Index {
GetMemoryUsage() const override;

private:
tl::expected<DatasetPtr, Error>
knn_search(const DatasetPtr& query,
int64_t k,
const std::string& parameters,
const SearchFunc& search_func) const;

inline std::shared_ptr<IndexNode>
try_get_node_with_init(UnorderedMap<std::string, std::shared_ptr<IndexNode>>& index_map,
const std::string& key) {
Expand Down
3 changes: 2 additions & 1 deletion tests/fixtures/test_dataset.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,8 @@ TestDataset::TestDataset(uint64_t dim, uint64_t count, std::string metric_str, b
this->range_radius_.resize(query_count);
for (uint64_t i = 0; i < query_count; ++i) {
this->range_radius_[i] =
0.5f * (result.first[i * count + top_k] + result.first[i * count + top_k - 1]);
0.5f * (this->ground_truth_->GetDistances()[i * top_k + top_k - 1] +
this->ground_truth_->GetDistances()[i * top_k + top_k - 2]);
}
delete[] result.first;
delete[] result.second;
Expand Down
Loading

0 comments on commit f2ea40e

Please sign in to comment.