Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
Signed-off-by: cqy123456 <[email protected]>
  • Loading branch information
cqy123456 committed Dec 12, 2023
1 parent 6084153 commit 927d315
Show file tree
Hide file tree
Showing 8 changed files with 37 additions and 39 deletions.
5 changes: 0 additions & 5 deletions include/knowhere/dataset.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,6 @@ class DataSet : public std::enable_shared_from_this<const DataSet> {
}
}

std::shared_ptr<const DataSet>
Get() const {
return shared_from_this();
}

void
SetDistance(const float* dis) {
std::unique_lock lock(mutex_);
Expand Down
4 changes: 2 additions & 2 deletions include/knowhere/factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class IndexFactory {
}
std::function<T1(const int32_t&, const Object&)> fun_value;
};
typedef std::map<std::string, FunMapValueBase*> FuncMap;
typedef std::map<std::string, std::unique_ptr<FunMapValueBase>> FuncMap;
IndexFactory();
static FuncMap&
MapInstance();
Expand All @@ -58,7 +58,7 @@ class IndexFactory {
[](const int32_t& version, const Object& object) { \
return (Index<index_node<data_type, ##__VA_ARGS__>>::Create(version, object)); \
}, \
data_type)
data_type)
#define KNOWHERE_MOCK_REGISTER_GLOBAL(name, index_node, data_type, ...) \
KNOWHERE_REGISTER_GLOBAL( \
name, \
Expand Down
5 changes: 2 additions & 3 deletions src/common/factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ IndexFactory::Create(const std::string& name, const int32_t& version, const Obje
auto& func_mapping_ = MapInstance();
auto key = GetMapKey<DataType>(name);
assert(func_mapping_.find(key) != func_mapping_.end());
auto fun_map_v = (FunMapValue<Index<IndexNode>>*)(func_mapping_[key]);
auto fun_map_v = (FunMapValue<Index<IndexNode>>*)(func_mapping_[key].get());
return fun_map_v->fun_value(version, object);
}

Expand All @@ -29,8 +29,7 @@ IndexFactory::Register(const std::string& name, std::function<Index<IndexNode>(c
auto& func_mapping_ = MapInstance();
auto key = GetMapKey<DataType>(name);
assert(func_mapping_.find(key) == func_mapping_.end());
auto value = new FunMapValue<Index<IndexNode>>(func);
func_mapping_[key] = value;
func_mapping_[key] = std::make_unique<FunMapValue<Index<IndexNode>>>(func);
return *this;
}

Expand Down
21 changes: 8 additions & 13 deletions src/common/index_node_data_mock_wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,21 +23,17 @@ namespace knowhere {
template <typename DataType>
Status
IndexNodeDataMockWrapper<DataType>::Build(const DataSet& dataset, const Config& cfg) {
std::shared_ptr<const DataSet> ds_ptr = nullptr;
if (this->Type() != knowhere::IndexEnum::INDEX_DISKANN) {
ds_ptr = dataset.Get();
if constexpr (!std::is_same_v<DataType, typename MockData<DataType>::type>) {
ds_ptr = data_type_conversion<DataType, typename MockData<DataType>::type>(dataset);
}
std::shared_ptr<const DataSet> ds_ptr = dataset.shared_from_this();
if constexpr (!std::is_same_v<DataType, typename MockData<DataType>::type>) {
ds_ptr = data_type_conversion<DataType, typename MockData<DataType>::type>(dataset);
}
return index_node_->Build(*ds_ptr, cfg);
}

template <typename DataType>
Status
IndexNodeDataMockWrapper<DataType>::Train(const DataSet& dataset, const Config& cfg) {
std::shared_ptr<const DataSet> ds_ptr = nullptr;
ds_ptr = dataset.Get();
std::shared_ptr<const DataSet> ds_ptr = dataset.shared_from_this();
if constexpr (!std::is_same_v<DataType, typename MockData<DataType>::type>) {
ds_ptr = data_type_conversion<DataType, typename MockData<DataType>::type>(dataset);
}
Expand All @@ -47,8 +43,7 @@ IndexNodeDataMockWrapper<DataType>::Train(const DataSet& dataset, const Config&
template <typename DataType>
Status
IndexNodeDataMockWrapper<DataType>::Add(const DataSet& dataset, const Config& cfg) {
std::shared_ptr<const DataSet> ds_ptr = nullptr;
ds_ptr = dataset.Get();
std::shared_ptr<const DataSet> ds_ptr = dataset.shared_from_this();
if constexpr (!std::is_same_v<DataType, typename MockData<DataType>::type>) {
ds_ptr = data_type_conversion<DataType, typename MockData<DataType>::type>(dataset);
}
Expand All @@ -58,7 +53,7 @@ IndexNodeDataMockWrapper<DataType>::Add(const DataSet& dataset, const Config& cf
template <typename DataType>
expected<DataSetPtr>
IndexNodeDataMockWrapper<DataType>::Search(const DataSet& dataset, const Config& cfg, const BitsetView& bitset) const {
auto ds_ptr = dataset.Get();
auto ds_ptr = dataset.shared_from_this();
if constexpr (!std::is_same_v<DataType, typename MockData<DataType>::type>) {
ds_ptr = data_type_conversion<DataType, typename MockData<DataType>::type>(dataset);
}
Expand All @@ -69,7 +64,7 @@ template <typename DataType>
expected<DataSetPtr>
IndexNodeDataMockWrapper<DataType>::RangeSearch(const DataSet& dataset, const Config& cfg,
const BitsetView& bitset) const {
auto ds_ptr = dataset.Get();
auto ds_ptr = dataset.shared_from_this();
if constexpr (!std::is_same_v<DataType, typename MockData<DataType>::type>) {
ds_ptr = data_type_conversion<DataType, typename MockData<DataType>::type>(dataset);
}
Expand All @@ -80,7 +75,7 @@ template <typename DataType>
expected<std::vector<std::shared_ptr<typename IndexNode::iterator>>>
IndexNodeDataMockWrapper<DataType>::AnnIterator(const DataSet& dataset, const Config& cfg,
const BitsetView& bitset) const {
auto ds_ptr = dataset.Get();
auto ds_ptr = dataset.shared_from_this();
if constexpr (!std::is_same_v<DataType, typename MockData<DataType>::type>) {
ds_ptr = data_type_conversion<DataType, typename MockData<DataType>::type>(dataset);
}
Expand Down
8 changes: 4 additions & 4 deletions src/index/flat/flat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ class FlatIndexNode : public IndexNode {
faiss::SearchParameters search_params;
search_params.sel = id_selector;

index_->search(1, (const DataType*)x + index * dim / 8, k, cur_i_dis, cur_ids, &search_params);
index_->search(1, (const uint8_t*)x + index * dim / 8, k, cur_i_dis, cur_ids, &search_params);

if (index_->metric_type == faiss::METRIC_Hamming) {
for (int64_t j = 0; j < k; j++) {
Expand Down Expand Up @@ -197,7 +197,7 @@ class FlatIndexNode : public IndexNode {
faiss::SearchParameters search_params;
search_params.sel = id_selector;

index_->range_search(1, (const DataType*)xq + index * dim / 8, radius, &res, &search_params);
index_->range_search(1, (const uint8_t*)xq + index * dim / 8, radius, &res, &search_params);
}
auto elem_cnt = res.lims[1];
result_dist_array[index].resize(elem_cnt);
Expand Down Expand Up @@ -254,13 +254,13 @@ class FlatIndexNode : public IndexNode {
if constexpr (std::is_same<IndexType, faiss::IndexBinaryFlat>::value) {
DataType* data = nullptr;
try {
data = new DataType[rows * dim / 8];
data = new uint8_t[rows * dim / 8];
for (int64_t i = 0; i < rows; i++) {
index_->reconstruct(ids[i], data + i * dim / 8);
}
return GenResultDataSet(rows, dim, data);
} catch (const std::exception& e) {
std::unique_ptr<DataType[]> auto_del(data);
std::unique_ptr<uint8_t[]> auto_del(data);
LOG_KNOWHERE_WARNING_ << "error inner faiss: " << e.what();
return expected<DataSetPtr>::Err(Status::faiss_inner_error, e.what());
}
Expand Down
11 changes: 7 additions & 4 deletions src/index/gpu_raft/gpu_raft_cagra.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,12 @@
namespace knowhere {
template struct GpuRaftIndexNode<raft_proto::raft_index_kind::cagra>;

KNOWHERE_REGISTER_GLOBAL(GPU_RAFT_CAGRA, [](const int32_t& version, const Object& object) {
return Index<IndexNodeThreadPoolWrapper>::Create(std::make_unique<GpuRaftCagraIndexNode>(version, object),
cuda_concurrent_size);
}, fp32);
KNOWHERE_REGISTER_GLOBAL(
GPU_RAFT_CAGRA,
[](const int32_t& version, const Object& object) {
return Index<IndexNodeThreadPoolWrapper>::Create(std::make_unique<GpuRaftCagraIndexNode>(version, object),
cuda_concurrent_size);
},
fp32);

} // namespace knowhere
11 changes: 7 additions & 4 deletions src/index/gpu_raft/gpu_raft_ivf_flat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,12 @@
namespace knowhere {
template struct GpuRaftIndexNode<raft_proto::raft_index_kind::ivf_flat>;

KNOWHERE_REGISTER_GLOBAL(GPU_RAFT_IVF_FLAT, [](const int32_t& version, const Object& object) {
return Index<IndexNodeThreadPoolWrapper>::Create(std::make_unique<GpuRaftIvfFlatIndexNode>(version, object),
cuda_concurrent_size);
}, fp32);
KNOWHERE_REGISTER_GLOBAL(
GPU_RAFT_IVF_FLAT,
[](const int32_t& version, const Object& object) {
return Index<IndexNodeThreadPoolWrapper>::Create(std::make_unique<GpuRaftIvfFlatIndexNode>(version, object),
cuda_concurrent_size);
},
fp32);

} // namespace knowhere
11 changes: 7 additions & 4 deletions src/index/gpu_raft/gpu_raft_ivf_pq.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,12 @@
namespace knowhere {
template struct GpuRaftIndexNode<raft_proto::raft_index_kind::ivf_pq>;

KNOWHERE_REGISTER_GLOBAL(GPU_RAFT_IVF_PQ, [](const int32_t& version, const Object& object) {
return Index<IndexNodeThreadPoolWrapper>::Create(std::make_unique<GpuRaftIvfPqIndexNode>(version, object),
cuda_concurrent_size);
}, fp32);
KNOWHERE_REGISTER_GLOBAL(
GPU_RAFT_IVF_PQ,
[](const int32_t& version, const Object& object) {
return Index<IndexNodeThreadPoolWrapper>::Create(std::make_unique<GpuRaftIvfPqIndexNode>(version, object),
cuda_concurrent_size);
},
fp32);

} // namespace knowhere

0 comments on commit 927d315

Please sign in to comment.