Skip to content

Commit

Permalink
knowhere support multi data type
Browse files Browse the repository at this point in the history
Signed-off-by: cqy123456 <[email protected]>
  • Loading branch information
cqy123456 committed Nov 23, 2023
1 parent 2191316 commit 982fb96
Show file tree
Hide file tree
Showing 37 changed files with 1,291 additions and 390 deletions.
4 changes: 2 additions & 2 deletions benchmark/benchmark_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -188,10 +188,10 @@ class Benchmark_base {
void
free_all() {
if (xb_ != nullptr) {
delete[](float*) xb_;
delete[] (float*)xb_;
}
if (xq_ != nullptr) {
delete[](float*) xq_;
delete[] (float*)xq_;
}
if (gt_radius_ != nullptr) {
delete[] gt_radius_;
Expand Down
2 changes: 1 addition & 1 deletion benchmark/hdf5/benchmark_float.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ TEST_F(Benchmark_float, TEST_DISKANN) {
std::shared_ptr<knowhere::FileManager> file_manager = std::make_shared<knowhere::LocalFileManager>();
auto diskann_index_pack = knowhere::Pack(file_manager);

index_ = knowhere::IndexFactory::Instance().Create(
index_ = knowhere::IndexFactory::Instance().Create<float>(
index_type_, knowhere::Version::GetCurrentVersion().VersionNumber(), diskann_index_pack);
printf("[%.3f s] Building all on %d vectors\n", get_time_diff(), nb_);
knowhere::DataSetPtr ds_ptr = nullptr;
Expand Down
2 changes: 1 addition & 1 deletion benchmark/hdf5/benchmark_float_bitset.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ TEST_F(Benchmark_float_bitset, TEST_DISKANN) {
auto diskann_index_pack = knowhere::Pack(file_manager);

auto version = knowhere::Version::GetCurrentVersion().VersionNumber();
index_ = knowhere::IndexFactory::Instance().Create(index_type_, version, diskann_index_pack);
index_ = knowhere::IndexFactory::Instance().Create<float>(index_type_, version, diskann_index_pack);
printf("[%.3f s] Building all on %d vectors\n", get_time_diff(), nb_);
knowhere::DataSetPtr ds_ptr = nullptr;
index_.Build(*ds_ptr, conf);
Expand Down
2 changes: 1 addition & 1 deletion benchmark/hdf5/benchmark_float_range_bitset.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ TEST_F(Benchmark_float_range_bitset, TEST_DISKANN) {
auto diskann_index_pack = knowhere::Pack(file_manager);

auto version = knowhere::Version::GetCurrentVersion().VersionNumber();
index_ = knowhere::IndexFactory::Instance().Create(index_type_, version, diskann_index_pack);
index_ = knowhere::IndexFactory::Instance().Create<float>(index_type_, version, diskann_index_pack);
printf("[%.3f s] Building all on %d vectors\n", get_time_diff(), nb_);
knowhere::DataSetPtr ds_ptr = nullptr;
index_.Build(*ds_ptr, conf);
Expand Down
4 changes: 2 additions & 2 deletions benchmark/hdf5/benchmark_knowhere.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ class Benchmark_knowhere : public Benchmark_hdf5 {
create_index(const std::string& index_file_name, const knowhere::Json& conf) {
auto version = knowhere::Version::GetCurrentVersion().VersionNumber();
printf("[%.3f s] Creating index \"%s\"\n", get_time_diff(), index_type_.c_str());
index_ = knowhere::IndexFactory::Instance().Create(index_type_, version);
index_ = knowhere::IndexFactory::Instance().Create<float>(index_type_, version);

try {
printf("[%.3f s] Reading index file: %s\n", get_time_diff(), index_file_name.c_str());
Expand All @@ -121,7 +121,7 @@ class Benchmark_knowhere : public Benchmark_hdf5 {

std::string golden_index_file_name = ann_test_name_ + "_" + golden_index_type_ + "_GOLDEN" + ".index";
printf("[%.3f s] Creating golden index \"%s\"\n", get_time_diff(), golden_index_type_.c_str());
golden_index_ = knowhere::IndexFactory::Instance().Create(golden_index_type_, version);
golden_index_ = knowhere::IndexFactory::Instance().Create<float>(golden_index_type_, version);

try {
printf("[%.3f s] Reading golden index file: %s\n", get_time_diff(), golden_index_file_name.c_str());
Expand Down
4 changes: 4 additions & 0 deletions include/knowhere/comp/brute_force.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,22 @@
#include "knowhere/bitsetview.h"
#include "knowhere/dataset.h"
#include "knowhere/factory.h"
#include "knowhere/operands.h"

namespace knowhere {

class BruteForce {
public:
template <typename DataType>
static expected<DataSetPtr>
Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset, const Json& config, const BitsetView& bitset);

template <typename DataType>
static Status
SearchWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_dataset, int64_t* ids, float* dis,
const Json& config, const BitsetView& bitset);

template <typename DataType>
static expected<DataSetPtr>
RangeSearch(const DataSetPtr base_dataset, const DataSetPtr query_dataset, const Json& config,
const BitsetView& bitset);
Expand Down
16 changes: 10 additions & 6 deletions include/knowhere/dataset.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

namespace knowhere {

class DataSet {
class DataSet : public std::enable_shared_from_this<const DataSet> {
public:
typedef std::variant<const float*, const size_t*, const int64_t*, const void*, int64_t, std::string, std::any> Var;
DataSet() = default;
Expand All @@ -36,30 +36,35 @@ class DataSet {
{
auto ptr = std::get_if<0>(&x.second);
if (ptr != nullptr) {
delete[] * ptr;
delete[] *ptr;
}
}
{
auto ptr = std::get_if<1>(&x.second);
if (ptr != nullptr) {
delete[] * ptr;
delete[] *ptr;
}
}
{
auto ptr = std::get_if<2>(&x.second);
if (ptr != nullptr) {
delete[] * ptr;
delete[] *ptr;
}
}
{
auto ptr = std::get_if<3>(&x.second);
if (ptr != nullptr) {
delete[](char*)(*ptr);
delete[] (char*)(*ptr);
}
}
}
}

std::shared_ptr<const DataSet>
Get() const {
return shared_from_this();
}

void
SetDistance(const float* dis) {
std::unique_lock lock(mutex_);
Expand Down Expand Up @@ -227,7 +232,6 @@ class DataSet {
bool is_owner = true;
};
using DataSetPtr = std::shared_ptr<DataSet>;

inline DataSetPtr
GenDataSet(const int64_t nb, const int64_t dim, const void* xb) {
auto ret_ds = std::make_shared<DataSet>();
Expand Down
37 changes: 33 additions & 4 deletions include/knowhere/factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,23 +21,52 @@
namespace knowhere {
class IndexFactory {
public:
template <typename DataType>
Index<IndexNode>
Create(const std::string& name, const int32_t& version, const Object& object = nullptr);
template <typename DataType>
const IndexFactory&
Register(const std::string& name, std::function<Index<IndexNode>(const int32_t& version, const Object&)> func);
static IndexFactory&
Instance();

private:
typedef std::map<std::string, std::function<Index<IndexNode>(const int32_t&, const Object&)>> FuncMap;
struct FunMapValueBase {};
template <typename T1>
struct FunMapValue : FunMapValueBase {
public:
FunMapValue(std::function<T1(const int32_t&, const Object&)>& input) : fun_value(input) {
}
std::function<T1(const int32_t&, const Object&)> fun_value;
};
typedef std::map<std::string, FunMapValueBase*> FuncMap;
IndexFactory();
static FuncMap&
MapInstance();
template <typename DataType>
std::string
GetMapKey(const std::string& name);
};

#define KNOWHERE_CONCAT(x, y) x##y
#define KNOWHERE_REGISTER_GLOBAL(name, func) \
const IndexFactory& KNOWHERE_CONCAT(index_factory_ref_, name) = IndexFactory::Instance().Register(#name, func)
#define KNOWHERE_CONCAT(x, y) index_factory_ref_##x##y
#define KNOWHERE_CONCAT_STR(x, y) #x "_" #y
#define KNOWHERE_REGISTER_GLOBAL(name, func, data_type) \
const IndexFactory& KNOWHERE_CONCAT(name, data_type) = IndexFactory::Instance().Register<data_type>(#name, func)
#define KNOWHERE_SIMPLE_REGISTER_GLOBAL(name, index_node, data_type, ...) \
KNOWHERE_REGISTER_GLOBAL( \
name, \
[](const int32_t& version, const Object& object) { \
return (Index<index_node<data_type, ##__VA_ARGS__>>::Create(version, object)); \
}, \
data_type)
#define KNOWHERE_MOCK_REGISTER_GLOBAL(name, index_node, data_type, mock_data_type, ...) \
KNOWHERE_REGISTER_GLOBAL( \
name, \
[](const int32_t& version, const Object& object) { \
return (Index<IndexNodeDataMockWrapper<data_type>>::Create( \
std::make_unique<index_node<mock_data_type, ##__VA_ARGS__>>(version, object))); \
}, \
data_type)
} // namespace knowhere

#endif /* INDEX_FACTORY_H */
1 change: 0 additions & 1 deletion include/knowhere/index.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
#include "knowhere/index_node.h"

namespace knowhere {

template <typename T1>
class Index {
public:
Expand Down
2 changes: 1 addition & 1 deletion include/knowhere/index_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@
#include "knowhere/dataset.h"
#include "knowhere/expected.h"
#include "knowhere/object.h"
#include "knowhere/operands.h"
#include "knowhere/version.h"

namespace knowhere {

class IndexNode : public Object {
public:
IndexNode(const int32_t ver) : version_(ver) {
Expand Down
100 changes: 100 additions & 0 deletions include/knowhere/index_node_data_mock_wrapper.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
// Copyright (C) 2019-2023 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.

#ifndef INDEX_NODE_DATA_MOCK_WRAPPER_H
#define INDEX_NODE_DATA_MOCK_WRAPPER_H

#include "knowhere/index_node.h"
namespace knowhere {

template <typename DataType>
class IndexNodeDataMockWrapper : public IndexNode {
public:
IndexNodeDataMockWrapper(std::unique_ptr<IndexNode> index_node) : index_node_(std::move(index_node)) {
}

Status
Build(const DataSet& dataset, const Config& cfg) override;

Status
Train(const DataSet& dataset, const Config& cfg) override;

Status
Add(const DataSet& dataset, const Config& cfg) override;
expected<DataSetPtr>
Search(const DataSet& dataset, const Config& cfg, const BitsetView& bitset) const override;

expected<DataSetPtr>
RangeSearch(const DataSet& dataset, const Config& cfg, const BitsetView& bitset) const override;

expected<std::vector<std::shared_ptr<typename IndexNode::iterator>>>
AnnIterator(const DataSet& dataset, const Config& cfg, const BitsetView& bitset) const;

expected<DataSetPtr>
GetVectorByIds(const DataSet& dataset) const override;

bool
HasRawData(const std::string& metric_type) const override {
return index_node_->HasRawData(metric_type);
}

expected<DataSetPtr>
GetIndexMeta(const Config& cfg) const override {
return index_node_->GetIndexMeta(cfg);
}

Status
Serialize(BinarySet& binset) const override {
return index_node_->Serialize(binset);
}

Status
Deserialize(const BinarySet& binset, const Config& config) override {
return index_node_->Deserialize(binset, config);
}

Status
DeserializeFromFile(const std::string& filename, const Config& config) override {
return index_node_->DeserializeFromFile(filename, config);
}

std::unique_ptr<BaseConfig>
CreateConfig() const override {
return index_node_->CreateConfig();
}

int64_t
Dim() const override {
return index_node_->Dim();
}

int64_t
Size() const override {
return index_node_->Size();
}

int64_t
Count() const override {
return index_node_->Count();
}

std::string
Type() const override {
return index_node_->Type();
}

private:
std::unique_ptr<IndexNode> index_node_;
};

} // namespace knowhere

#endif /* INDEX_NODE_DATA_MOCK_WRAPPER_H */
Loading

0 comments on commit 982fb96

Please sign in to comment.