From 982fb96533b98c11816b2e0248bbfe5e40071129 Mon Sep 17 00:00:00 2001 From: cqy123456 Date: Wed, 22 Nov 2023 22:53:13 -0500 Subject: [PATCH] knowhere support multi data type Signed-off-by: cqy123456 --- benchmark/benchmark_base.h | 4 +- benchmark/hdf5/benchmark_float.cpp | 2 +- benchmark/hdf5/benchmark_float_bitset.cpp | 2 +- .../hdf5/benchmark_float_range_bitset.cpp | 2 +- benchmark/hdf5/benchmark_knowhere.h | 4 +- include/knowhere/comp/brute_force.h | 4 + include/knowhere/dataset.h | 16 +- include/knowhere/factory.h | 37 ++- include/knowhere/index.h | 1 - include/knowhere/index_node.h | 2 +- .../knowhere/index_node_data_mock_wrapper.h | 100 +++++++ include/knowhere/operands.h | 161 ++++++++++ include/knowhere/utils.h | 20 ++ src/common/comp/brute_force.cc | 91 ++++-- src/common/factory.cc | 62 +++- src/common/index_node_data_mock_wrapper.cc | 108 +++++++ src/index/diskann/diskann.cc | 88 +++--- src/index/flat/flat.cc | 105 +++---- src/index/gpu/flat_gpu/flat_gpu.cc | 10 +- src/index/gpu/ivf_gpu/ivf_gpu.cc | 27 +- src/index/hnsw/hnsw.cc | 54 ++-- src/index/ivf/ivf.cc | 275 ++++++++++-------- src/index/ivf_raft/ivf_raft.cu | 24 +- src/index/ivf_raft/ivf_raft.cuh | 2 +- tests/ut/test_bruteforce.cc | 12 +- tests/ut/test_diskann.cc | 56 ++-- tests/ut/test_feder.cc | 6 +- tests/ut/test_get_vector.cc | 8 +- tests/ut/test_gpu_search.cc | 18 +- tests/ut/test_half_presicion.cc | 258 ++++++++++++++++ tests/ut/test_iterator.cc | 16 +- tests/ut/test_ivfflat_cc.cc | 9 +- tests/ut/test_mmap.cc | 20 +- tests/ut/test_search.cc | 30 +- tests/ut/test_simd.cc | 6 +- thirdparty/DiskANN/include/diskann/utils.h | 6 +- thirdparty/DiskANN/src/utils.cpp | 35 +++ 37 files changed, 1291 insertions(+), 390 deletions(-) create mode 100644 include/knowhere/index_node_data_mock_wrapper.h create mode 100644 include/knowhere/operands.h create mode 100644 src/common/index_node_data_mock_wrapper.cc create mode 100644 tests/ut/test_half_presicion.cc diff --git a/benchmark/benchmark_base.h b/benchmark/benchmark_base.h index 81eb4c315..96d2f7216 100644 --- a/benchmark/benchmark_base.h +++ b/benchmark/benchmark_base.h @@ -188,10 +188,10 @@ class Benchmark_base { void free_all() { if (xb_ != nullptr) { - delete[](float*) xb_; + delete[] (float*)xb_; } if (xq_ != nullptr) { - delete[](float*) xq_; + delete[] (float*)xq_; } if (gt_radius_ != nullptr) { delete[] gt_radius_; diff --git a/benchmark/hdf5/benchmark_float.cpp b/benchmark/hdf5/benchmark_float.cpp index 5f42d7e24..a735a42f8 100644 --- a/benchmark/hdf5/benchmark_float.cpp +++ b/benchmark/hdf5/benchmark_float.cpp @@ -273,7 +273,7 @@ TEST_F(Benchmark_float, TEST_DISKANN) { std::shared_ptr file_manager = std::make_shared(); auto diskann_index_pack = knowhere::Pack(file_manager); - index_ = knowhere::IndexFactory::Instance().Create( + index_ = knowhere::IndexFactory::Instance().Create( index_type_, knowhere::Version::GetCurrentVersion().VersionNumber(), diskann_index_pack); printf("[%.3f s] Building all on %d vectors\n", get_time_diff(), nb_); knowhere::DataSetPtr ds_ptr = nullptr; diff --git a/benchmark/hdf5/benchmark_float_bitset.cpp b/benchmark/hdf5/benchmark_float_bitset.cpp index beb34eff3..70e083043 100644 --- a/benchmark/hdf5/benchmark_float_bitset.cpp +++ b/benchmark/hdf5/benchmark_float_bitset.cpp @@ -234,7 +234,7 @@ TEST_F(Benchmark_float_bitset, TEST_DISKANN) { auto diskann_index_pack = knowhere::Pack(file_manager); auto version = knowhere::Version::GetCurrentVersion().VersionNumber(); - index_ = knowhere::IndexFactory::Instance().Create(index_type_, version, diskann_index_pack); + index_ = knowhere::IndexFactory::Instance().Create(index_type_, version, diskann_index_pack); printf("[%.3f s] Building all on %d vectors\n", get_time_diff(), nb_); knowhere::DataSetPtr ds_ptr = nullptr; index_.Build(*ds_ptr, conf); diff --git a/benchmark/hdf5/benchmark_float_range_bitset.cpp b/benchmark/hdf5/benchmark_float_range_bitset.cpp index f5c406ce7..763f0dd31 100644 --- a/benchmark/hdf5/benchmark_float_range_bitset.cpp +++ b/benchmark/hdf5/benchmark_float_range_bitset.cpp @@ -235,7 +235,7 @@ TEST_F(Benchmark_float_range_bitset, TEST_DISKANN) { auto diskann_index_pack = knowhere::Pack(file_manager); auto version = knowhere::Version::GetCurrentVersion().VersionNumber(); - index_ = knowhere::IndexFactory::Instance().Create(index_type_, version, diskann_index_pack); + index_ = knowhere::IndexFactory::Instance().Create(index_type_, version, diskann_index_pack); printf("[%.3f s] Building all on %d vectors\n", get_time_diff(), nb_); knowhere::DataSetPtr ds_ptr = nullptr; index_.Build(*ds_ptr, conf); diff --git a/benchmark/hdf5/benchmark_knowhere.h b/benchmark/hdf5/benchmark_knowhere.h index 1bd1ae1db..1ab93dc1f 100644 --- a/benchmark/hdf5/benchmark_knowhere.h +++ b/benchmark/hdf5/benchmark_knowhere.h @@ -98,7 +98,7 @@ class Benchmark_knowhere : public Benchmark_hdf5 { create_index(const std::string& index_file_name, const knowhere::Json& conf) { auto version = knowhere::Version::GetCurrentVersion().VersionNumber(); printf("[%.3f s] Creating index \"%s\"\n", get_time_diff(), index_type_.c_str()); - index_ = knowhere::IndexFactory::Instance().Create(index_type_, version); + index_ = knowhere::IndexFactory::Instance().Create(index_type_, version); try { printf("[%.3f s] Reading index file: %s\n", get_time_diff(), index_file_name.c_str()); @@ -121,7 +121,7 @@ class Benchmark_knowhere : public Benchmark_hdf5 { std::string golden_index_file_name = ann_test_name_ + "_" + golden_index_type_ + "_GOLDEN" + ".index"; printf("[%.3f s] Creating golden index \"%s\"\n", get_time_diff(), golden_index_type_.c_str()); - golden_index_ = knowhere::IndexFactory::Instance().Create(golden_index_type_, version); + golden_index_ = knowhere::IndexFactory::Instance().Create(golden_index_type_, version); try { printf("[%.3f s] Reading golden index file: %s\n", get_time_diff(), golden_index_file_name.c_str()); diff --git a/include/knowhere/comp/brute_force.h b/include/knowhere/comp/brute_force.h index 240aa0f4b..e66217d61 100644 --- a/include/knowhere/comp/brute_force.h +++ b/include/knowhere/comp/brute_force.h @@ -14,18 +14,22 @@ #include "knowhere/bitsetview.h" #include "knowhere/dataset.h" #include "knowhere/factory.h" +#include "knowhere/operands.h" namespace knowhere { class BruteForce { public: + template static expected Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset, const Json& config, const BitsetView& bitset); + template static Status SearchWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_dataset, int64_t* ids, float* dis, const Json& config, const BitsetView& bitset); + template static expected RangeSearch(const DataSetPtr base_dataset, const DataSetPtr query_dataset, const Json& config, const BitsetView& bitset); diff --git a/include/knowhere/dataset.h b/include/knowhere/dataset.h index c77decaf3..ebec862a9 100644 --- a/include/knowhere/dataset.h +++ b/include/knowhere/dataset.h @@ -24,7 +24,7 @@ namespace knowhere { -class DataSet { +class DataSet : public std::enable_shared_from_this { public: typedef std::variant Var; DataSet() = default; @@ -36,30 +36,35 @@ class DataSet { { auto ptr = std::get_if<0>(&x.second); if (ptr != nullptr) { - delete[] * ptr; + delete[] *ptr; } } { auto ptr = std::get_if<1>(&x.second); if (ptr != nullptr) { - delete[] * ptr; + delete[] *ptr; } } { auto ptr = std::get_if<2>(&x.second); if (ptr != nullptr) { - delete[] * ptr; + delete[] *ptr; } } { auto ptr = std::get_if<3>(&x.second); if (ptr != nullptr) { - delete[](char*)(*ptr); + delete[] (char*)(*ptr); } } } } + std::shared_ptr + Get() const { + return shared_from_this(); + } + void SetDistance(const float* dis) { std::unique_lock lock(mutex_); @@ -227,7 +232,6 @@ class DataSet { bool is_owner = true; }; using DataSetPtr = std::shared_ptr; - inline DataSetPtr GenDataSet(const int64_t nb, const int64_t dim, const void* xb) { auto ret_ds = std::make_shared(); diff --git a/include/knowhere/factory.h b/include/knowhere/factory.h index 3138f7494..c678443b3 100644 --- a/include/knowhere/factory.h +++ b/include/knowhere/factory.h @@ -21,23 +21,52 @@ namespace knowhere { class IndexFactory { public: + template Index Create(const std::string& name, const int32_t& version, const Object& object = nullptr); + template const IndexFactory& Register(const std::string& name, std::function(const int32_t& version, const Object&)> func); static IndexFactory& Instance(); private: - typedef std::map(const int32_t&, const Object&)>> FuncMap; + struct FunMapValueBase {}; + template + struct FunMapValue : FunMapValueBase { + public: + FunMapValue(std::function& input) : fun_value(input) { + } + std::function fun_value; + }; + typedef std::map FuncMap; IndexFactory(); static FuncMap& MapInstance(); + template + std::string + GetMapKey(const std::string& name); }; -#define KNOWHERE_CONCAT(x, y) x##y -#define KNOWHERE_REGISTER_GLOBAL(name, func) \ - const IndexFactory& KNOWHERE_CONCAT(index_factory_ref_, name) = IndexFactory::Instance().Register(#name, func) +#define KNOWHERE_CONCAT(x, y) index_factory_ref_##x##y +#define KNOWHERE_CONCAT_STR(x, y) #x "_" #y +#define KNOWHERE_REGISTER_GLOBAL(name, func, data_type) \ + const IndexFactory& KNOWHERE_CONCAT(name, data_type) = IndexFactory::Instance().Register(#name, func) +#define KNOWHERE_SIMPLE_REGISTER_GLOBAL(name, index_node, data_type, ...) \ + KNOWHERE_REGISTER_GLOBAL( \ + name, \ + [](const int32_t& version, const Object& object) { \ + return (Index>::Create(version, object)); \ + }, \ + data_type) +#define KNOWHERE_MOCK_REGISTER_GLOBAL(name, index_node, data_type, mock_data_type, ...) \ + KNOWHERE_REGISTER_GLOBAL( \ + name, \ + [](const int32_t& version, const Object& object) { \ + return (Index>::Create( \ + std::make_unique>(version, object))); \ + }, \ + data_type) } // namespace knowhere #endif /* INDEX_FACTORY_H */ diff --git a/include/knowhere/index.h b/include/knowhere/index.h index 164eea88c..c9062c1fe 100644 --- a/include/knowhere/index.h +++ b/include/knowhere/index.h @@ -19,7 +19,6 @@ #include "knowhere/index_node.h" namespace knowhere { - template class Index { public: diff --git a/include/knowhere/index_node.h b/include/knowhere/index_node.h index 4613d37ca..6af39b2b5 100644 --- a/include/knowhere/index_node.h +++ b/include/knowhere/index_node.h @@ -18,10 +18,10 @@ #include "knowhere/dataset.h" #include "knowhere/expected.h" #include "knowhere/object.h" +#include "knowhere/operands.h" #include "knowhere/version.h" namespace knowhere { - class IndexNode : public Object { public: IndexNode(const int32_t ver) : version_(ver) { diff --git a/include/knowhere/index_node_data_mock_wrapper.h b/include/knowhere/index_node_data_mock_wrapper.h new file mode 100644 index 000000000..8878a20b0 --- /dev/null +++ b/include/knowhere/index_node_data_mock_wrapper.h @@ -0,0 +1,100 @@ +// Copyright (C) 2019-2023 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#ifndef INDEX_NODE_DATA_MOCK_WRAPPER_H +#define INDEX_NODE_DATA_MOCK_WRAPPER_H + +#include "knowhere/index_node.h" +namespace knowhere { + +template +class IndexNodeDataMockWrapper : public IndexNode { + public: + IndexNodeDataMockWrapper(std::unique_ptr index_node) : index_node_(std::move(index_node)) { + } + + Status + Build(const DataSet& dataset, const Config& cfg) override; + + Status + Train(const DataSet& dataset, const Config& cfg) override; + + Status + Add(const DataSet& dataset, const Config& cfg) override; + expected + Search(const DataSet& dataset, const Config& cfg, const BitsetView& bitset) const override; + + expected + RangeSearch(const DataSet& dataset, const Config& cfg, const BitsetView& bitset) const override; + + expected>> + AnnIterator(const DataSet& dataset, const Config& cfg, const BitsetView& bitset) const; + + expected + GetVectorByIds(const DataSet& dataset) const override; + + bool + HasRawData(const std::string& metric_type) const override { + return index_node_->HasRawData(metric_type); + } + + expected + GetIndexMeta(const Config& cfg) const override { + return index_node_->GetIndexMeta(cfg); + } + + Status + Serialize(BinarySet& binset) const override { + return index_node_->Serialize(binset); + } + + Status + Deserialize(const BinarySet& binset, const Config& config) override { + return index_node_->Deserialize(binset, config); + } + + Status + DeserializeFromFile(const std::string& filename, const Config& config) override { + return index_node_->DeserializeFromFile(filename, config); + } + + std::unique_ptr + CreateConfig() const override { + return index_node_->CreateConfig(); + } + + int64_t + Dim() const override { + return index_node_->Dim(); + } + + int64_t + Size() const override { + return index_node_->Size(); + } + + int64_t + Count() const override { + return index_node_->Count(); + } + + std::string + Type() const override { + return index_node_->Type(); + } + + private: + std::unique_ptr index_node_; +}; + +} // namespace knowhere + +#endif /* INDEX_NODE_DATA_MOCK_WRAPPER_H */ diff --git a/include/knowhere/operands.h b/include/knowhere/operands.h new file mode 100644 index 000000000..9a65c1935 --- /dev/null +++ b/include/knowhere/operands.h @@ -0,0 +1,161 @@ +// Copyright (C) 2019-2023 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not +// use this file except in compliance with the License. You may obtain a copy of +// the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations under +// the License. + +#ifndef OPERANDS_H +#define OPERANDS_H +#include + +namespace { +union fp32_bits { + uint32_t as_bits; + float as_value; +}; + +inline float +fp32_from_bits(const uint32_t& w) { + return fp32_bits{.as_bits = w}.as_value; +} + +inline uint32_t +fp32_to_bits(const float& f) { + return fp32_bits{.as_value = f}.as_bits; +} +}; // namespace + +namespace knowhere { +using fp32 = float; +using bin1 = uint8_t; + +struct fp16 { + uint16_t bits = 0; + fp16() = default; + fp16(float f) { + bits = from_fp32(f); + }; + uint16_t + from_fp32(float f) const { + // const float scale_to_inf = 0x1.0p+112f; + // const float scale_to_zero = 0x1.0p-110f; + constexpr uint32_t scale_to_inf_bits = (uint32_t)239 << 23; + constexpr uint32_t scale_to_zero_bits = (uint32_t)17 << 23; + float scale_to_inf_val, scale_to_zero_val; + std::memcpy(&scale_to_inf_val, &scale_to_inf_bits, sizeof(scale_to_inf_val)); + std::memcpy(&scale_to_zero_val, &scale_to_zero_bits, sizeof(scale_to_zero_val)); + const float scale_to_inf = scale_to_inf_val; + const float scale_to_zero = scale_to_zero_val; + +#if defined(_MSC_VER) && _MSC_VER == 1916 + float base = ((f < 0.0 ? -f : f) * scale_to_inf) * scale_to_zero; +#else + float base = (fabsf(f) * scale_to_inf) * scale_to_zero; +#endif + + const uint32_t w = fp32_to_bits(f); + const uint32_t shl1_w = w + w; + const uint32_t sign = w & UINT32_C(0x80000000); + uint32_t bias = shl1_w & UINT32_C(0xFF000000); + if (bias < UINT32_C(0x71000000)) { + bias = UINT32_C(0x71000000); + } + + base = fp32_from_bits((bias >> 1) + UINT32_C(0x07800000)) + base; + const uint32_t bits = fp32_to_bits(base); + const uint32_t exp_bits = (bits >> 13) & UINT32_C(0x00007C00); + const uint32_t mantissa_bits = bits & UINT32_C(0x00000FFF); + const uint32_t nonsign = exp_bits + mantissa_bits; + return static_cast((sign >> 16) | (shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign)); + } + + float + to_fp32(uint16_t h) const { + const uint32_t w = (uint32_t)h << 16; + const uint32_t sign = w & UINT32_C(0x80000000); + const uint32_t two_w = w + w; + + constexpr uint32_t exp_offset = UINT32_C(0xE0) << 23; + constexpr uint32_t scale_bits = (uint32_t)15 << 23; + + float exp_scale_val; + std::memcpy(&exp_scale_val, &scale_bits, sizeof(exp_scale_val)); + const float exp_scale = exp_scale_val; + const float normalized_value = fp32_from_bits((two_w >> 4) + exp_offset) * exp_scale; + + constexpr uint32_t magic_mask = UINT32_C(126) << 23; + constexpr float magic_bias = 0.5f; + const float denormalized_value = fp32_from_bits((two_w >> 17) | magic_mask) - magic_bias; + constexpr uint32_t denormalized_cutoff = UINT32_C(1) << 27; + const uint32_t result = + sign | (two_w < denormalized_cutoff ? fp32_to_bits(denormalized_value) : fp32_to_bits(normalized_value)); + return fp32_from_bits(result); + } + operator float() const { + return to_fp32(bits); + } +}; + +struct bf16 { + uint16_t bits = 0; + bf16() = default; + bf16(float f) { + bits = from_fp32(f); + }; + uint16_t + from_fp32(float f) const { + uint32_t fp32Bits = fp32_to_bits(f); + uint16_t bf16Bits = (uint16_t)((fp32Bits >> 16) & 0x8000); + bf16Bits |= (uint16_t)((fp32Bits >> 23) & 0x7F); + bf16Bits |= (uint16_t)(fp32Bits >> 16); + return bf16Bits; + } + float + to_fp32(uint16_t h) const { + uint32_t bits = ((unsigned int)h) << 16; + bits &= 0xFFFF0000; + return fp32_from_bits(bits); + } + operator float() const { + return this->to_fp32(bits); + } +}; + +template +struct MockData { + using type = T; +}; + +template <> +struct MockData { + using type = knowhere::fp32; +}; + +template <> +struct MockData { + using type = knowhere::fp32; +}; + +// define result type, in case use double or uint8 data type in the future +template +struct ResultType { + using type = DataType; +}; + +template +struct ResultType< + DataType, std::enable_if_t || std::is_same_v || + std::is_same_v || std::is_same_v>> { + using type = knowhere::fp32; +}; + +} // namespace knowhere +#endif /* OPERANDS_H */ diff --git a/include/knowhere/utils.h b/include/knowhere/utils.h index a6ff6bed4..42e7a731e 100644 --- a/include/knowhere/utils.h +++ b/include/knowhere/utils.h @@ -68,6 +68,26 @@ hash_binary_vec(const uint8_t* x, size_t d) { return h; } +template +inline DataSetPtr +data_type_conversion(const DataSet& src) { + auto dim = src.GetDim(); + auto rows = src.GetRows(); + + auto des_data = new OutType[dim * rows]; + auto src_data = (InType*)src.GetTensor(); + for (auto i = 0; i < dim * rows; i++) { + des_data[i] = (OutType)src_data[i]; + } + + auto des = std::make_shared(); + des->SetRows(rows); + des->SetDim(dim); + des->SetTensor(des_data); + des->SetIsOwner(true); + return des; +} + template inline T round_down(const T value, const T align) { diff --git a/src/common/comp/brute_force.cc b/src/common/comp/brute_force.cc index 042581c8f..51503a27d 100644 --- a/src/common/comp/brute_force.cc +++ b/src/common/comp/brute_force.cc @@ -30,17 +30,22 @@ namespace knowhere { /* knowhere wrapper API to call faiss brute force search for all metric types */ class BruteForceConfig : public BaseConfig {}; - +template expected BruteForce::Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset, const Json& config, const BitsetView& bitset) { - auto xb = base_dataset->GetTensor(); - auto nb = base_dataset->GetRows(); - auto dim = base_dataset->GetDim(); - - auto xq = query_dataset->GetTensor(); - auto nq = query_dataset->GetRows(); + DataSetPtr base(base_dataset); + DataSetPtr query(query_dataset); + if constexpr (!std::is_same_v::type>) { + base = data_type_conversion::type>(*base_dataset); + query = data_type_conversion::type>(*query_dataset); + } + auto xb = base->GetTensor(); + auto nb = base->GetRows(); + auto dim = base->GetDim(); + auto xq = query->GetTensor(); + auto nq = query->GetRows(); BruteForceConfig cfg; std::string msg; auto status = Config::Load(cfg, config, knowhere::SEARCH, &msg); @@ -133,15 +138,22 @@ BruteForce::Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset return GenResultDataSet(nq, cfg.k.value(), labels, distances); } +template Status BruteForce::SearchWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_dataset, int64_t* ids, float* dis, const Json& config, const BitsetView& bitset) { - auto xb = base_dataset->GetTensor(); - auto nb = base_dataset->GetRows(); - auto dim = base_dataset->GetDim(); + DataSetPtr base(base_dataset); + DataSetPtr query(query_dataset); + if constexpr (!std::is_same_v::type>) { + base = data_type_conversion::type>(*base_dataset); + query = data_type_conversion::type>(*query_dataset); + } + auto xb = base->GetTensor(); + auto nb = base->GetRows(); + auto dim = base->GetDim(); - auto xq = query_dataset->GetTensor(); - auto nq = query_dataset->GetRows(); + auto xq = query->GetTensor(); + auto nq = query->GetRows(); BruteForceConfig cfg; RETURN_IF_ERROR(Config::Load(cfg, config, knowhere::SEARCH)); @@ -231,15 +243,22 @@ BruteForce::SearchWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_ /** knowhere wrapper API to call faiss brute force range search for all metric types */ +template expected BruteForce::RangeSearch(const DataSetPtr base_dataset, const DataSetPtr query_dataset, const Json& config, const BitsetView& bitset) { - auto xb = base_dataset->GetTensor(); - auto nb = base_dataset->GetRows(); - auto dim = base_dataset->GetDim(); + DataSetPtr base(base_dataset); + DataSetPtr query(query_dataset); + if constexpr (!std::is_same_v::type>) { + base = data_type_conversion::type>(*base_dataset); + query = data_type_conversion::type>(*query_dataset); + } + auto xb = base->GetTensor(); + auto nb = base->GetRows(); + auto dim = base->GetDim(); - auto xq = query_dataset->GetTensor(); - auto nq = query_dataset->GetRows(); + auto xq = query->GetTensor(); + auto nq = query->GetRows(); BruteForceConfig cfg; std::string msg; @@ -343,4 +362,42 @@ BruteForce::RangeSearch(const DataSetPtr base_dataset, const DataSetPtr query_da GetRangeSearchResult(result_dist_array, result_id_array, is_ip, nq, radius, range_filter, distances, ids, lims); return GenResultDataSet(nq, ids, distances, lims); } +template expected +BruteForce::Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset, const Json& config, + const BitsetView& bitset); +template expected +BruteForce::Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset, const Json& config, + const BitsetView& bitset); +template expected +BruteForce::Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset, const Json& config, + const BitsetView& bitset); +template expected +BruteForce::Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset, const Json& config, + const BitsetView& bitset); + +template Status +BruteForce::SearchWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_dataset, int64_t* ids, + float* dis, const Json& config, const BitsetView& bitset); +template Status +BruteForce::SearchWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_dataset, int64_t* ids, + float* dis, const Json& config, const BitsetView& bitset); +template Status +BruteForce::SearchWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_dataset, int64_t* ids, + float* dis, const Json& config, const BitsetView& bitset); +template Status +BruteForce::SearchWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_dataset, int64_t* ids, + float* dis, const Json& config, const BitsetView& bitset); + +template expected +BruteForce::RangeSearch(const DataSetPtr base_dataset, const DataSetPtr query_dataset, + const Json& config, const BitsetView& bitset); +template expected +BruteForce::RangeSearch(const DataSetPtr base_dataset, const DataSetPtr query_dataset, + const Json& config, const BitsetView& bitset); +template expected +BruteForce::RangeSearch(const DataSetPtr base_dataset, const DataSetPtr query_dataset, + const Json& config, const BitsetView& bitset); +template expected +BruteForce::RangeSearch(const DataSetPtr base_dataset, const DataSetPtr query_dataset, + const Json& config, const BitsetView& bitset); } // namespace knowhere diff --git a/src/common/factory.cc b/src/common/factory.cc index 62d21f29d..0741c61f7 100644 --- a/src/common/factory.cc +++ b/src/common/factory.cc @@ -13,22 +13,43 @@ namespace knowhere { +template Index IndexFactory::Create(const std::string& name, const int32_t& version, const Object& object) { auto& func_mapping_ = MapInstance(); - assert(func_mapping_.find(name) != func_mapping_.end()); - LOG_KNOWHERE_INFO_ << "create knowhere index " << name << " with version " << version; - return func_mapping_[name](version, object); + auto key = GetMapKey(name); + assert(func_mapping_.find(key) != func_mapping_.end()); + auto fun_map_v = (FunMapValue>*)(func_mapping_[key]); + return fun_map_v->fun_value(version, object); } +template const IndexFactory& -IndexFactory::Register(const std::string& name, - std::function(const int32_t& version, const Object&)> func) { +IndexFactory::Register(const std::string& name, std::function(const int32_t&, const Object&)> func) { auto& func_mapping_ = MapInstance(); - func_mapping_[name] = func; + auto key = GetMapKey(name); + assert(func_mapping_.find(key) == func_mapping_.end()); + auto value = new FunMapValue>(func); + func_mapping_[key] = value; return *this; } +template +std::string +IndexFactory::GetMapKey(const std::string& name) { + if (std::is_same_v) { + return name + std::string("_fp32"); + } else if (std::is_same_v) { + return name + std::string("_fp16"); + } else if (std::is_same_v) { + return name + std::string("_bf16"); + } else if (std::is_same_v) { + return name + std::string("_bin1"); + } else { + assert(false && "invalid data type"); + } +} + IndexFactory& IndexFactory::Instance() { static IndexFactory factory; @@ -42,4 +63,33 @@ IndexFactory::MapInstance() { static FuncMap func_map; return func_map; } + +template class Index +IndexFactory::Create(const std::string&, const int32_t&, const Object&); +template class Index +IndexFactory::Create(const std::string&, const int32_t&, const Object&); +template class Index +IndexFactory::Create(const std::string&, const int32_t&, const Object&); +template class Index +IndexFactory::Create(const std::string&, const int32_t&, const Object&); +template const IndexFactory& +IndexFactory::Register(const std::string&, + std::function(const int32_t&, const Object&)>); +template const IndexFactory& +IndexFactory::Register(const std::string&, + std::function(const int32_t&, const Object&)>); +template const IndexFactory& +IndexFactory::Register(const std::string&, + std::function(const int32_t&, const Object&)>); +template const IndexFactory& +IndexFactory::Register(const std::string&, + std::function(const int32_t&, const Object&)>); +template std::string +IndexFactory::GetMapKey(const std::string&); +template std::string +IndexFactory::GetMapKey(const std::string&); +template std::string +IndexFactory::GetMapKey(const std::string&); +template std::string +IndexFactory::GetMapKey(const std::string&); } // namespace knowhere diff --git a/src/common/index_node_data_mock_wrapper.cc b/src/common/index_node_data_mock_wrapper.cc new file mode 100644 index 000000000..529f5a468 --- /dev/null +++ b/src/common/index_node_data_mock_wrapper.cc @@ -0,0 +1,108 @@ +// Copyright (C) 2019-2023 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not +// use this file except in compliance with the License. You may obtain a copy of +// the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations under +// the License. + +#include "knowhere/index_node_data_mock_wrapper.h" + +#include "knowhere/comp/thread_pool.h" +#include "knowhere/index_node.h" +#include "knowhere/utils.h" + +namespace knowhere { + +template +Status +IndexNodeDataMockWrapper::Build(const DataSet& dataset, const Config& cfg) { + std::shared_ptr ds_ptr = nullptr; + if (this->Type() != knowhere::IndexEnum::INDEX_DISKANN) { + ds_ptr = dataset.Get(); + if constexpr (!std::is_same_v::type>) { + ds_ptr = data_type_conversion::type>(dataset); + } + } + return index_node_->Build(*ds_ptr, cfg); +} + +template +Status +IndexNodeDataMockWrapper::Train(const DataSet& dataset, const Config& cfg) { + std::shared_ptr ds_ptr = nullptr; + ds_ptr = dataset.Get(); + if constexpr (!std::is_same_v::type>) { + ds_ptr = data_type_conversion::type>(dataset); + } + return index_node_->Train(*ds_ptr, cfg); +} + +template +Status +IndexNodeDataMockWrapper::Add(const DataSet& dataset, const Config& cfg) { + std::shared_ptr ds_ptr = nullptr; + ds_ptr = dataset.Get(); + if constexpr (!std::is_same_v::type>) { + ds_ptr = data_type_conversion::type>(dataset); + } + return index_node_->Add(*ds_ptr, cfg); +} + +template +expected +IndexNodeDataMockWrapper::Search(const DataSet& dataset, const Config& cfg, const BitsetView& bitset) const { + auto ds_ptr = dataset.Get(); + if constexpr (!std::is_same_v::type>) { + ds_ptr = data_type_conversion::type>(dataset); + } + return index_node_->Search(*ds_ptr, cfg, bitset); +} + +template +expected +IndexNodeDataMockWrapper::RangeSearch(const DataSet& dataset, const Config& cfg, + const BitsetView& bitset) const { + auto ds_ptr = dataset.Get(); + if constexpr (!std::is_same_v::type>) { + ds_ptr = data_type_conversion::type>(dataset); + } + return index_node_->RangeSearch(*ds_ptr, cfg, bitset); +} + +template +expected>> +IndexNodeDataMockWrapper::AnnIterator(const DataSet& dataset, const Config& cfg, + const BitsetView& bitset) const { + auto ds_ptr = dataset.Get(); + if constexpr (!std::is_same_v::type>) { + ds_ptr = data_type_conversion::type>(dataset); + } + return index_node_->AnnIterator(*ds_ptr, cfg, bitset); +} + +template +expected +IndexNodeDataMockWrapper::GetVectorByIds(const DataSet& dataset) const { + auto res = index_node_->GetVectorByIds(dataset); + if constexpr (!std::is_same_v::type>) { + if (res.has_value()) { + auto res_v = data_type_conversion::type>(*res.value()); + return res_v; + } else { + return res; + } + } else { + return res; + } +} + +template class knowhere::IndexNodeDataMockWrapper; +template class knowhere::IndexNodeDataMockWrapper; +} // namespace knowhere diff --git a/src/index/diskann/diskann.cc b/src/index/diskann/diskann.cc index c1c9f11cc..c873898a9 100644 --- a/src/index/diskann/diskann.cc +++ b/src/index/diskann/diskann.cc @@ -27,16 +27,18 @@ #include "knowhere/expected.h" #include "knowhere/factory.h" #include "knowhere/file_manager.h" +#include "knowhere/index_node_data_mock_wrapper.h" #include "knowhere/log.h" #include "knowhere/utils.h" namespace knowhere { - -template +// TODO: Remove template FileDataType after supporting fp16/bf16. +template class DiskANNIndexNode : public IndexNode { - static_assert(std::is_same_v, "DiskANN only support float"); + static_assert(std::is_same_v, "DiskANN only support float"); public: + using DistType = typename ResultType::type; DiskANNIndexNode(const int32_t& version, const Object& object) : is_prepared_(false), dim_(-1), count_(-1) { assert(typeid(object) == typeid(Pack>)); auto diskann_index_pack = dynamic_cast>*>(&object); @@ -162,7 +164,7 @@ class DiskANNIndexNode : public IndexNode { mutable std::mutex preparation_lock_; std::atomic_bool is_prepared_; std::shared_ptr file_manager_; - std::unique_ptr> pq_flash_index_; + std::unique_ptr> pq_flash_index_; std::atomic_int64_t dim_; std::atomic_int64_t count_; std::shared_ptr search_pool_; @@ -241,7 +243,8 @@ inline bool CheckMetric(const std::string& diskann_metric) { if (diskann_metric != knowhere::metric::L2 && diskann_metric != knowhere::metric::IP && diskann_metric != knowhere::metric::COSINE) { - LOG_KNOWHERE_ERROR_ << "DiskANN currently only supports floating point data for Minimum Euclidean " + LOG_KNOWHERE_ERROR_ << "DiskANN currently only supports floating point " + "data for Minimum Euclidean " "distance(L2), Max Inner Product Search(IP) " "and Minimum Cosine Search(COSINE)." << std::endl; @@ -252,9 +255,9 @@ CheckMetric(const std::string& diskann_metric) { } } // namespace -template +template Status -DiskANNIndexNode::Build(const DataSet& dataset, const Config& cfg) { +DiskANNIndexNode::Build(const DataSet& dataset, const Config& cfg) { assert(file_manager_ != nullptr); std::lock_guard lock(preparation_lock_); auto build_conf = static_cast(cfg); @@ -274,7 +277,13 @@ DiskANNIndexNode::Build(const DataSet& dataset, const Config& cfg) { LOG_KNOWHERE_ERROR_ << "Failed load the raw data before building." << std::endl; return Status::disk_file_error; } - auto& data_path = build_conf.data_path.value(); + auto data_path = build_conf.data_path.value(); + + if constexpr (!std::is_same_v) { + std::string mock_data_path = data_path + "_mock"; + diskann::convert_types_in_file(data_path, mock_data_path); + data_path = mock_data_path; + } index_prefix_ = build_conf.index_prefix.value(); size_t count; @@ -308,7 +317,7 @@ DiskANNIndexNode::Build(const DataSet& dataset, const Config& cfg) { build_conf.accelerate_build.value(), static_cast(num_nodes_to_cache)}; RETURN_IF_ERROR(TryDiskANNCall([&]() { - int res = diskann::build_disk_index(diskann_internal_build_config); + int res = diskann::build_disk_index(diskann_internal_build_config); if (res != 0) throw diskann::ANNException("diskann::build_disk_index returned non-zero value: " + std::to_string(res), -1); @@ -329,12 +338,15 @@ DiskANNIndexNode::Build(const DataSet& dataset, const Config& cfg) { } is_prepared_.store(false); + if constexpr (!std::is_same_v) { + std::remove(data_path.c_str()); + } return Status::success; } -template +template Status -DiskANNIndexNode::Deserialize(const BinarySet& binset, const Config& cfg) { +DiskANNIndexNode::Deserialize(const BinarySet& binset, const Config& cfg) { std::lock_guard lock(preparation_lock_); auto prep_conf = static_cast(cfg); if (!CheckMetric(prep_conf.metric_type.value())) { @@ -388,7 +400,7 @@ DiskANNIndexNode::Deserialize(const BinarySet& binset, const Config& cfg) { reader.reset(new LinuxAlignedFileReader()); - pq_flash_index_ = std::make_unique>(reader, diskann_metric); + pq_flash_index_ = std::make_unique>(reader, diskann_metric); auto disk_ann_call = [&]() { int res = pq_flash_index_->load(search_pool_->size(), index_prefix_.c_str()); if (res != 0) { @@ -466,15 +478,16 @@ DiskANNIndexNode::Deserialize(const BinarySet& binset, const Config& cfg) { uint64_t warmup_num = 0; uint64_t warmup_dim = 0; uint64_t warmup_aligned_dim = 0; - T* warmup = nullptr; + DataType* warmup = nullptr; if (TryDiskANNCall([&]() { - diskann::load_aligned_bin(warmup_query_file, warmup, warmup_num, warmup_dim, warmup_aligned_dim); + diskann::load_aligned_bin(warmup_query_file, warmup, warmup_num, warmup_dim, + warmup_aligned_dim); }) != Status::success) { LOG_KNOWHERE_ERROR_ << "Failed to load warmup file for DiskANN."; return Status::disk_file_error; } std::vector warmup_result_ids_64(warmup_num, 0); - std::vector warmup_result_dists(warmup_num, 0); + std::vector warmup_result_dists(warmup_num, 0); bool all_searches_are_good = true; @@ -506,9 +519,10 @@ DiskANNIndexNode::Deserialize(const BinarySet& binset, const Config& cfg) { return Status::success; } -template +template expected -DiskANNIndexNode::Search(const DataSet& dataset, const Config& cfg, const BitsetView& bitset) const { +DiskANNIndexNode::Search(const DataSet& dataset, const Config& cfg, + const BitsetView& bitset) const { if (!is_prepared_.load() || !pq_flash_index_) { LOG_KNOWHERE_ERROR_ << "Failed to load diskann."; return expected::Err(Status::empty_index, "DiskANN not loaded"); @@ -526,7 +540,7 @@ DiskANNIndexNode::Search(const DataSet& dataset, const Config& cfg, const Bit auto nq = dataset.GetRows(); auto dim = dataset.GetDim(); - auto xq = static_cast(dataset.GetTensor()); + auto xq = static_cast(dataset.GetTensor()); feder::diskann::FederResultUniq feder_result; if (search_conf.trace_visit.value()) { @@ -539,7 +553,7 @@ DiskANNIndexNode::Search(const DataSet& dataset, const Config& cfg, const Bit } auto p_id = new int64_t[k * nq]; - auto p_dist = new float[k * nq]; + auto p_dist = new DistType[k * nq]; bool all_searches_are_good = true; std::vector> futures; @@ -574,9 +588,10 @@ DiskANNIndexNode::Search(const DataSet& dataset, const Config& cfg, const Bit return res; } -template +template expected -DiskANNIndexNode::RangeSearch(const DataSet& dataset, const Config& cfg, const BitsetView& bitset) const { +DiskANNIndexNode::RangeSearch(const DataSet& dataset, const Config& cfg, + const BitsetView& bitset) const { if (!is_prepared_.load() || !pq_flash_index_) { LOG_KNOWHERE_ERROR_ << "Failed to load diskann."; return expected::Err(Status::empty_index, "index not loaded"); @@ -603,14 +618,14 @@ DiskANNIndexNode::RangeSearch(const DataSet& dataset, const Config& cfg, cons auto dim = dataset.GetDim(); auto nq = dataset.GetRows(); - auto xq = static_cast(dataset.GetTensor()); + auto xq = static_cast(dataset.GetTensor()); int64_t* p_id = nullptr; - float* p_dist = nullptr; + DistType* p_dist = nullptr; size_t* p_lims = nullptr; std::vector> result_id_array(nq); - std::vector> result_dist_array(nq); + std::vector> result_dist_array(nq); std::vector> futures; futures.reserve(nq); @@ -647,18 +662,17 @@ DiskANNIndexNode::RangeSearch(const DataSet& dataset, const Config& cfg, cons * It first tries to get data from cache, if failed, it will try to get data from disk. * It reads as much as possible and it is thread-pool free, it totally depends on the outside to control concurrency. */ -template +template expected -DiskANNIndexNode::GetVectorByIds(const DataSet& dataset) const { +DiskANNIndexNode::GetVectorByIds(const DataSet& dataset) const { if (!is_prepared_.load() || !pq_flash_index_) { LOG_KNOWHERE_ERROR_ << "Failed to load diskann."; return expected::Err(Status::empty_index, "index not loaded"); } - auto dim = Dim(); auto rows = dataset.GetRows(); auto ids = dataset.GetIds(); - float* data = new float[dim * rows]; + auto* data = new DataType[dim * rows]; if (data == nullptr) { LOG_KNOWHERE_ERROR_ << "Failed to allocate memory for data."; return expected::Err(Status::malloc_error, "failed to allocate memory for data"); @@ -672,9 +686,9 @@ DiskANNIndexNode::GetVectorByIds(const DataSet& dataset) const { return GenResultDataSet(rows, dim, data); } -template +template expected -DiskANNIndexNode::GetIndexMeta(const Config& cfg) const { +DiskANNIndexNode::GetIndexMeta(const Config& cfg) const { std::vector entry_points; for (size_t i = 0; i < pq_flash_index_->get_num_medoids(); i++) { entry_points.push_back(pq_flash_index_->get_medoids()[i]); @@ -692,17 +706,17 @@ DiskANNIndexNode::GetIndexMeta(const Config& cfg) const { return GenResultDataSet(json_meta.dump(), json_id_set.dump()); } -template +template uint64_t -DiskANNIndexNode::GetCachedNodeNum(const float cache_dram_budget, const uint64_t data_dim, - const uint64_t max_degree) { - uint32_t one_cached_node_budget = (max_degree + 1) * sizeof(unsigned) + sizeof(T) * data_dim; +DiskANNIndexNode::GetCachedNodeNum(const float cache_dram_budget, const uint64_t data_dim, + const uint64_t max_degree) { + uint32_t one_cached_node_budget = (max_degree + 1) * sizeof(unsigned) + sizeof(DataType) * data_dim; auto num_nodes_to_cache = static_cast(1024 * 1024 * 1024 * cache_dram_budget) / (one_cached_node_budget * kCacheExpansionRate); return num_nodes_to_cache; } -KNOWHERE_REGISTER_GLOBAL(DISKANN, [](const int32_t& version, const Object& object) { - return Index>::Create(version, object); -}); +KNOWHERE_SIMPLE_REGISTER_GLOBAL(DISKANN, DiskANNIndexNode, fp32); +KNOWHERE_MOCK_REGISTER_GLOBAL(DISKANN, DiskANNIndexNode, fp16, fp32, fp16); +KNOWHERE_MOCK_REGISTER_GLOBAL(DISKANN, DiskANNIndexNode, bf16, fp32, bf16); } // namespace knowhere diff --git a/src/index/flat/flat.cc b/src/index/flat/flat.cc index e7ef2af77..ac5ce03d6 100644 --- a/src/index/flat/flat.cc +++ b/src/index/flat/flat.cc @@ -19,17 +19,19 @@ #include "knowhere/bitsetview_idselector.h" #include "knowhere/comp/thread_pool.h" #include "knowhere/factory.h" +#include "knowhere/index_node_data_mock_wrapper.h" #include "knowhere/log.h" #include "knowhere/utils.h" namespace knowhere { -template +template class FlatIndexNode : public IndexNode { public: FlatIndexNode(const int32_t version, const Object& object) : index_(nullptr) { - static_assert(std::is_same::value || std::is_same::value, - "not support"); + static_assert( + std::is_same::value || std::is_same::value, + "not support"); search_pool_ = ThreadPool::GetGlobalSearchThreadPool(); } @@ -42,10 +44,10 @@ class FlatIndexNode : public IndexNode { LOG_KNOWHERE_WARNING_ << "please check metric type: " << f_cfg.metric_type.value(); return metric.error(); } - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { index_ = std::make_unique(dataset.GetDim(), metric.value()); } - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { bool is_cosine = IsMetricType(f_cfg.metric_type.value(), knowhere::metric::COSINE); index_ = std::make_unique(dataset.GetDim(), metric.value(), is_cosine); } @@ -56,12 +58,7 @@ class FlatIndexNode : public IndexNode { Add(const DataSet& dataset, const Config& cfg) override { auto x = dataset.GetTensor(); auto n = dataset.GetRows(); - if constexpr (std::is_same::value) { - index_->add(n, (const float*)x); - } - if constexpr (std::is_same::value) { - index_->add(n, (const uint8_t*)x); - } + index_->add(n, (const DataType*)x); return Status::success; } @@ -98,9 +95,9 @@ class FlatIndexNode : public IndexNode { BitsetViewIDSelector bw_idselector(bitset); faiss::IDSelector* id_selector = (bitset.empty()) ? nullptr : &bw_idselector; - if constexpr (std::is_same::value) { - auto cur_query = (const float*)x + dim * index; - std::unique_ptr copied_query = nullptr; + if constexpr (std::is_same::value) { + auto cur_query = (const DataType*)x + dim * index; + std::unique_ptr copied_query = nullptr; if (is_cosine) { copied_query = CopyAndNormalizeVecs(cur_query, 1, dim); cur_query = copied_query.get(); @@ -111,13 +108,13 @@ class FlatIndexNode : public IndexNode { index_->search(1, cur_query, k, cur_dis, cur_ids, &search_params); } - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { auto cur_i_dis = reinterpret_cast(cur_dis); faiss::SearchParameters search_params; search_params.sel = id_selector; - index_->search(1, (const uint8_t*)x + index * dim / 8, k, cur_i_dis, cur_ids, &search_params); + index_->search(1, (const DataType*)x + index * dim / 8, k, cur_i_dis, cur_ids, &search_params); if (index_->metric_type == faiss::METRIC_Hamming) { for (int64_t j = 0; j < k; j++) { @@ -142,7 +139,6 @@ class FlatIndexNode : public IndexNode { LOG_KNOWHERE_WARNING_ << "error inner faiss: " << e.what(); return expected::Err(Status::faiss_inner_error, e.what()); } - return GenResultDataSet(nq, k, ids, distances); } @@ -184,9 +180,9 @@ class FlatIndexNode : public IndexNode { BitsetViewIDSelector bw_idselector(bitset); faiss::IDSelector* id_selector = (bitset.empty()) ? nullptr : &bw_idselector; - if constexpr (std::is_same::value) { - auto cur_query = (const float*)xq + dim * index; - std::unique_ptr copied_query = nullptr; + if constexpr (std::is_same::value) { + auto cur_query = (const DataType*)xq + dim * index; + std::unique_ptr copied_query = nullptr; if (is_cosine) { copied_query = CopyAndNormalizeVecs(cur_query, 1, dim); cur_query = copied_query.get(); @@ -197,11 +193,11 @@ class FlatIndexNode : public IndexNode { index_->range_search(1, cur_query, radius, &res, &search_params); } - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { faiss::SearchParameters search_params; search_params.sel = id_selector; - index_->range_search(1, (const uint8_t*)xq + index * dim / 8, radius, &res, &search_params); + index_->range_search(1, (const DataType*)xq + index * dim / 8, radius, &res, &search_params); } auto elem_cnt = res.lims[1]; result_dist_array[index].resize(elem_cnt); @@ -241,30 +237,30 @@ class FlatIndexNode : public IndexNode { auto dim = Dim(); auto rows = dataset.GetRows(); auto ids = dataset.GetIds(); - if constexpr (std::is_same::value) { - float* data = nullptr; + if constexpr (std::is_same::value) { + DataType* data = nullptr; try { - data = new float[rows * dim]; + data = new DataType[rows * dim]; for (int64_t i = 0; i < rows; i++) { index_->reconstruct(ids[i], data + i * dim); } return GenResultDataSet(rows, dim, data); } catch (const std::exception& e) { - std::unique_ptr auto_del(data); + std::unique_ptr auto_del(data); LOG_KNOWHERE_WARNING_ << "faiss inner error: " << e.what(); return expected::Err(Status::faiss_inner_error, e.what()); } } - if constexpr (std::is_same::value) { - uint8_t* data = nullptr; + if constexpr (std::is_same::value) { + DataType* data = nullptr; try { - data = new uint8_t[rows * dim / 8]; + data = new DataType[rows * dim / 8]; for (int64_t i = 0; i < rows; i++) { index_->reconstruct(ids[i], data + i * dim / 8); } return GenResultDataSet(rows, dim, data); } catch (const std::exception& e) { - std::unique_ptr auto_del(data); + std::unique_ptr auto_del(data); LOG_KNOWHERE_WARNING_ << "error inner faiss: " << e.what(); return expected::Err(Status::faiss_inner_error, e.what()); } @@ -273,14 +269,14 @@ class FlatIndexNode : public IndexNode { bool HasRawData(const std::string& metric_type) const override { - if constexpr (std::is_same::value) { - if (version_ <= Version::GetMinimalVersion()) { + if constexpr (std::is_same::value) { + if (this->version_ <= Version::GetMinimalVersion()) { return !IsMetricType(metric_type, metric::COSINE); } else { return true; } } - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { return true; } } @@ -298,10 +294,10 @@ class FlatIndexNode : public IndexNode { } try { MemoryIOWriter writer; - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { faiss::write_index(index_.get(), &writer); } - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { faiss::write_index_binary(index_.get(), &writer); } std::shared_ptr data(writer.data()); @@ -325,13 +321,13 @@ class FlatIndexNode : public IndexNode { } MemoryIOReader reader(binary->data.get(), binary->size); - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { faiss::Index* index = faiss::read_index(&reader); - index_.reset(static_cast(index)); + index_.reset(static_cast(index)); } - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { faiss::IndexBinary* index = faiss::read_index_binary(&reader); - index_.reset(static_cast(index)); + index_.reset(static_cast(index)); } return Status::success; } @@ -345,13 +341,13 @@ class FlatIndexNode : public IndexNode { io_flags |= faiss::IO_FLAG_MMAP; } - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { faiss::Index* index = faiss::read_index(filename.data(), io_flags); - index_.reset(static_cast(index)); + index_.reset(static_cast(index)); } - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { faiss::IndexBinary* index = faiss::read_index_binary(filename.data(), io_flags); - index_.reset(static_cast(index)); + index_.reset(static_cast(index)); } return Status::success; } @@ -368,7 +364,7 @@ class FlatIndexNode : public IndexNode { int64_t Size() const override { - return index_->ntotal * index_->d * sizeof(float); + return index_->ntotal * index_->d * sizeof(DataType); } int64_t @@ -378,27 +374,22 @@ class FlatIndexNode : public IndexNode { std::string Type() const override { - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { return knowhere::IndexEnum::INDEX_FAISS_IDMAP; } - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { return knowhere::IndexEnum::INDEX_FAISS_BIN_IDMAP; } } private: - std::unique_ptr index_; + std::unique_ptr index_; std::shared_ptr search_pool_; }; -KNOWHERE_REGISTER_GLOBAL(FLAT, [](const int32_t& version, const Object& object) { - return Index>::Create(version, object); -}); -KNOWHERE_REGISTER_GLOBAL(BINFLAT, [](const int32_t& version, const Object& object) { - return Index>::Create(version, object); -}); -KNOWHERE_REGISTER_GLOBAL(BIN_FLAT, [](const int32_t& version, const Object& object) { - return Index>::Create(version, object); -}); - +KNOWHERE_SIMPLE_REGISTER_GLOBAL(FLAT, FlatIndexNode, fp32, faiss::IndexFlat); +KNOWHERE_SIMPLE_REGISTER_GLOBAL(BINFLAT, FlatIndexNode, bin1, faiss::IndexBinaryFlat); +KNOWHERE_SIMPLE_REGISTER_GLOBAL(BIN_FLAT, FlatIndexNode, bin1, faiss::IndexBinaryFlat); +KNOWHERE_MOCK_REGISTER_GLOBAL(FLAT, FlatIndexNode, fp16, fp32, faiss::IndexFlat); +KNOWHERE_MOCK_REGISTER_GLOBAL(FLAT, FlatIndexNode, bf16, fp32, faiss::IndexFlat); } // namespace knowhere diff --git a/src/index/gpu/flat_gpu/flat_gpu.cc b/src/index/gpu/flat_gpu/flat_gpu.cc index 9436ff938..3ab51ab32 100644 --- a/src/index/gpu/flat_gpu/flat_gpu.cc +++ b/src/index/gpu/flat_gpu/flat_gpu.cc @@ -21,6 +21,7 @@ namespace knowhere { +template class GpuFlatIndexNode : public IndexNode { public: GpuFlatIndexNode(const int32_t& version, const Object& object) : index_(nullptr) { @@ -189,8 +190,11 @@ class GpuFlatIndexNode : public IndexNode { std::unique_ptr index_; }; -KNOWHERE_REGISTER_GLOBAL(GPU_FAISS_FLAT, [](const int32_t& version, const Object& object) { - return Index::Create(version, object); -}); +KNOWHERE_REGISTER_GLOBAL( + GPU_FAISS_FLAT, + [](const int32_t& version, const Object& object) { + return (Index>::Create(version, object)); + }, + fp32); } // namespace knowhere diff --git a/src/index/gpu/ivf_gpu/ivf_gpu.cc b/src/index/gpu/ivf_gpu/ivf_gpu.cc index ec6b753f8..c6ac064e0 100644 --- a/src/index/gpu/ivf_gpu/ivf_gpu.cc +++ b/src/index/gpu/ivf_gpu/ivf_gpu.cc @@ -273,14 +273,23 @@ class GpuIvfIndexNode : public IndexNode { std::unique_ptr index_; }; -KNOWHERE_REGISTER_GLOBAL(GPU_FAISS_IVF_FLAT, [](const int32_t& version, const Object& object) { - return Index>::Create(version, object); -}); -KNOWHERE_REGISTER_GLOBAL(GPU_FAISS_IVF_PQ, [](const int32_t& version, const Object& object) { - return Index>::Create(version, object); -}); -KNOWHERE_REGISTER_GLOBAL(GPU_FAISS_IVF_SQ8, [](const int32_t& version, const Object& object) { - return Index>::Create(version, object); -}); +KNOWHERE_REGISTER_GLOBAL( + GPU_FAISS_IVF_FLAT, + [](const int32_t& version, const Object& object) { + return (Index>>::Create(version, object)); + }, + fp32); +KNOWHERE_REGISTER_GLOBAL( + GPU_FAISS_IVF_PQ, + [](const int32_t& version, const Object& object) { + return (Index>>::Create(version, object)); + }, + fp32); +KNOWHERE_REGISTER_GLOBAL( + GPU_FAISS_IVF_SQ8, + [](const int32_t& version, const Object& object) { + return Index>>::Create(version, object); + }, + fp32); } // namespace knowhere diff --git a/src/index/hnsw/hnsw.cc b/src/index/hnsw/hnsw.cc index 9dc1d8255..118994e64 100644 --- a/src/index/hnsw/hnsw.cc +++ b/src/index/hnsw/hnsw.cc @@ -26,12 +26,18 @@ #include "knowhere/config.h" #include "knowhere/expected.h" #include "knowhere/factory.h" +#include "knowhere/index_node_data_mock_wrapper.h" #include "knowhere/log.h" #include "knowhere/utils.h" namespace knowhere { +template class HnswIndexNode : public IndexNode { + static_assert(std::is_same_v || std::is_same_v, + "HnswIndexNode only support float/bianry"); + public: + using DistType = typename ResultType::type; HnswIndexNode(const int32_t& /*version*/, const Object& object) : index_(nullptr) { search_pool_ = ThreadPool::GetGlobalSearchThreadPool(); } @@ -41,7 +47,7 @@ class HnswIndexNode : public IndexNode { auto rows = dataset.GetRows(); auto dim = dataset.GetDim(); auto hnsw_cfg = static_cast(cfg); - hnswlib::SpaceInterface* space = nullptr; + hnswlib::SpaceInterface* space = nullptr; if (IsMetricType(hnsw_cfg.metric_type.value(), metric::L2)) { space = new (std::nothrow) hnswlib::L2Space(dim); } else if (IsMetricType(hnsw_cfg.metric_type.value(), metric::IP)) { @@ -57,7 +63,7 @@ class HnswIndexNode : public IndexNode { return Status::invalid_metric_type; } auto index = new (std::nothrow) - hnswlib::HierarchicalNSW(space, rows, hnsw_cfg.M.value(), hnsw_cfg.efConstruction.value()); + hnswlib::HierarchicalNSW(space, rows, hnsw_cfg.M.value(), hnsw_cfg.efConstruction.value()); if (index == nullptr) { LOG_KNOWHERE_WARNING_ << "memory malloc error."; return Status::malloc_error; @@ -103,7 +109,7 @@ class HnswIndexNode : public IndexNode { expected Search(const DataSet& dataset, const Config& cfg, const BitsetView& bitset) const override { if (!index_) { - LOG_KNOWHERE_WARNING_ << "search on empty index"; + LOG_KNOWHERE_WARNING_ << "search on empty indefloatx"; return expected::Err(Status::empty_index, "index not loaded"); } auto nq = dataset.GetRows(); @@ -121,7 +127,7 @@ class HnswIndexNode : public IndexNode { } auto p_id = new int64_t[k * nq]; - auto p_dist = new float[k * nq]; + auto p_dist = new DistType[k * nq]; hnswlib::SearchParam param{(size_t)hnsw_cfg.ef.value(), hnsw_cfg.for_tuning.value()}; bool transform = @@ -142,7 +148,7 @@ class HnswIndexNode : public IndexNode { p_single_id[idx] = id; } for (size_t idx = rst_size; idx < (size_t)k; idx++) { - p_single_dis[idx] = float(1.0 / 0.0); + p_single_dis[idx] = DistType(1.0 / 0.0); p_single_id[idx] = -1; } })); @@ -167,7 +173,7 @@ class HnswIndexNode : public IndexNode { private: class iterator : public IndexNode::iterator { public: - iterator(const hnswlib::HierarchicalNSW* index, const char* query, const bool transform, + iterator(const hnswlib::HierarchicalNSW* index, const char* query, const bool transform, const BitsetView& bitset, const bool for_tuning = false, const size_t seed_ef = kIteratorSeedEf) : index_(index), transform_(transform), @@ -175,7 +181,7 @@ class HnswIndexNode : public IndexNode { UpdateNext(); } - std::pair + std::pair Next() override { auto ret = std::make_pair(next_id_, next_dist_); UpdateNext(); @@ -200,16 +206,16 @@ class HnswIndexNode : public IndexNode { has_next_ = false; } } - const hnswlib::HierarchicalNSW* index_; + const hnswlib::HierarchicalNSW* index_; const bool transform_; std::unique_ptr workspace_; bool has_next_; - float next_dist_; + DistType next_dist_; int64_t next_id_; }; public: - expected>> + expected>> AnnIterator(const DataSet& dataset, const Config& cfg, const BitsetView& bitset) const override { if (!index_) { LOG_KNOWHERE_WARNING_ << "creating iterator on empty index"; @@ -254,10 +260,10 @@ class HnswIndexNode : public IndexNode { auto hnsw_cfg = static_cast(cfg); bool is_ip = (index_->metric_type_ == hnswlib::Metric::INNER_PRODUCT || index_->metric_type_ == hnswlib::Metric::COSINE); - float range_filter = hnsw_cfg.range_filter.value(); + DistType range_filter = hnsw_cfg.range_filter.value(); - float radius_for_calc = (is_ip ? -hnsw_cfg.radius.value() : hnsw_cfg.radius.value()); - float radius_for_filter = hnsw_cfg.radius.value(); + DistType radius_for_calc = (is_ip ? -hnsw_cfg.radius.value() : hnsw_cfg.radius.value()); + DistType radius_for_filter = hnsw_cfg.radius.value(); feder::hnsw::FederResultUniq feder_result; if (hnsw_cfg.trace_visit.value()) { @@ -270,11 +276,11 @@ class HnswIndexNode : public IndexNode { hnswlib::SearchParam param{(size_t)hnsw_cfg.ef.value()}; int64_t* ids = nullptr; - float* dis = nullptr; + DistType* dis = nullptr; size_t* lims = nullptr; std::vector> result_id_array(nq); - std::vector> result_dist_array(nq); + std::vector> result_dist_array(nq); std::vector result_size(nq); std::vector result_lims(nq + 1); @@ -412,8 +418,8 @@ class HnswIndexNode : public IndexNode { MemoryIOReader reader(binary->data.get(), binary->size); - hnswlib::SpaceInterface* space = nullptr; - index_ = new (std::nothrow) hnswlib::HierarchicalNSW(space); + hnswlib::SpaceInterface* space = nullptr; + index_ = new (std::nothrow) hnswlib::HierarchicalNSW(space); index_->loadIndex(reader); LOG_KNOWHERE_INFO_ << "Loaded HNSW index. #points num:" << index_->max_elements_ << " #M:" << index_->M_ << " #max level:" << index_->maxlevel_ @@ -432,8 +438,8 @@ class HnswIndexNode : public IndexNode { delete index_; } try { - hnswlib::SpaceInterface* space = nullptr; - index_ = new (std::nothrow) hnswlib::HierarchicalNSW(space); + hnswlib::SpaceInterface* space = nullptr; + index_ = new (std::nothrow) hnswlib::HierarchicalNSW(space); index_->loadIndex(filename, config); } catch (std::exception& e) { LOG_KNOWHERE_WARNING_ << "hnsw inner error: " << e.what(); @@ -520,12 +526,12 @@ class HnswIndexNode : public IndexNode { } private: - hnswlib::HierarchicalNSW* index_; + hnswlib::HierarchicalNSW* index_; std::shared_ptr search_pool_; }; -KNOWHERE_REGISTER_GLOBAL(HNSW, [](const int32_t& version, const Object& object) { - return Index::Create(version, object); -}); - +KNOWHERE_SIMPLE_REGISTER_GLOBAL(HNSW, HnswIndexNode, fp32); +KNOWHERE_SIMPLE_REGISTER_GLOBAL(HNSW, HnswIndexNode, bin1); +KNOWHERE_MOCK_REGISTER_GLOBAL(HNSW, HnswIndexNode, fp16, fp32); +KNOWHERE_MOCK_REGISTER_GLOBAL(HNSW, HnswIndexNode, bf16, fp32); } // namespace knowhere diff --git a/src/index/ivf/ivf.cc b/src/index/ivf/ivf.cc index 53c8276de..d3c805892 100644 --- a/src/index/ivf/ivf.cc +++ b/src/index/ivf/ivf.cc @@ -29,19 +29,34 @@ #include "knowhere/expected.h" #include "knowhere/factory.h" #include "knowhere/feder/IVFFlat.h" +#include "knowhere/index_node_data_mock_wrapper.h" #include "knowhere/log.h" #include "knowhere/utils.h" namespace knowhere { +struct IVFBaseTag {}; +struct IVFFlatTag {}; -template +template +struct IndexDispatch { + using Tag = IVFBaseTag; +}; + +template <> +struct IndexDispatch { + using Tag = IVFFlatTag; +}; + +template class IvfIndexNode : public IndexNode { public: IvfIndexNode(const int32_t version, const Object& object) : IndexNode(version), index_(nullptr) { - static_assert(std::is_same::value || std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || std::is_same::value, + static_assert(std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same::value, "not support"); search_pool_ = ThreadPool::GetGlobalSearchThreadPool(); } @@ -60,53 +75,55 @@ class IvfIndexNode : public IndexNode { if (!index_) { return false; } - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { return true; } - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { return true; } - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { return false; } - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { return index_->with_raw_data(); } - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { return false; } - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { return true; } } expected GetIndexMeta(const Config& cfg) const override { - return expected::Err(Status::not_implemented, "GetIndexMeta not implemented"); + return this->GetIndexMetaImpl(cfg, typename IndexDispatch::Tag{}); } Status - Serialize(BinarySet& binset) const override; + Serialize(BinarySet& binset) const override { + return this->SerializeImpl(binset, typename IndexDispatch::Tag{}); + } Status Deserialize(const BinarySet& binset, const Config& config) override; Status DeserializeFromFile(const std::string& filename, const Config& config) override; std::unique_ptr CreateConfig() const override { - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { return std::make_unique(); } - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { return std::make_unique(); } - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { return std::make_unique(); } - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { return std::make_unique(); } - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { return std::make_unique(); } - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { return std::make_unique(); } }; @@ -122,19 +139,19 @@ class IvfIndexNode : public IndexNode { if (!index_) { return 0; } - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { auto nb = index_->invlists->compute_ntotal(); auto nlist = index_->nlist; auto code_size = index_->code_size; return ((nb + nlist) * (code_size + sizeof(int64_t))); } - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { auto nb = index_->invlists->compute_ntotal(); auto nlist = index_->nlist; auto code_size = index_->code_size; return (nb * code_size + nb * sizeof(int64_t) + nlist * code_size); } - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { auto nb = index_->invlists->compute_ntotal(); auto code_size = index_->code_size; auto pq = index_->pq; @@ -146,16 +163,16 @@ class IvfIndexNode : public IndexNode { auto precomputed_table = nlist * pq.M * pq.ksub * sizeof(float); return (capacity + centroid_table + precomputed_table); } - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { return index_->size(); } - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { auto nb = index_->invlists->compute_ntotal(); auto code_size = index_->code_size; auto nlist = index_->nlist; return (nb * code_size + nb * sizeof(int64_t) + 2 * code_size + nlist * code_size); } - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { auto nb = index_->invlists->compute_ntotal(); auto nlist = index_->nlist; auto code_size = index_->code_size; @@ -171,28 +188,42 @@ class IvfIndexNode : public IndexNode { }; std::string Type() const override { - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { return knowhere::IndexEnum::INDEX_FAISS_IVFFLAT; } - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { return knowhere::IndexEnum::INDEX_FAISS_IVFFLAT_CC; } - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { return knowhere::IndexEnum::INDEX_FAISS_IVFPQ; } - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { return knowhere::IndexEnum::INDEX_FAISS_SCANN; } - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { return knowhere::IndexEnum::INDEX_FAISS_IVFSQ8; } - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { return knowhere::IndexEnum::INDEX_FAISS_BIN_IVFFLAT; } }; private: - std::unique_ptr index_; + expected + GetIndexMetaImpl(const Config& cfg, IVFBaseTag) const { + return expected::Err(Status::not_implemented, "GetIndexMeta not implemented"); + } + expected + GetIndexMetaImpl(const Config& cfg, IVFFlatTag) const; + + Status + SerializeImpl(BinarySet& binset, IVFBaseTag) const; + + Status + SerializeImpl(BinarySet& binset, IVFFlatTag) const; + + private: + std::unique_ptr index_; std::shared_ptr search_pool_; }; @@ -243,9 +274,9 @@ to_index_flat(std::unique_ptr&& index) { } // namespace -template +template Status -IvfIndexNode::Train(const DataSet& dataset, const Config& cfg) { +IvfIndexNode::Train(const DataSet& dataset, const Config& cfg) { const BaseConfig& base_cfg = static_cast(cfg); std::unique_ptr setter; if (base_cfg.num_build_thread.has_value()) { @@ -257,7 +288,8 @@ IvfIndexNode::Train(const DataSet& dataset, const Config& cfg) { bool is_cosine = IsMetricType(base_cfg.metric_type.value(), knowhere::metric::COSINE); // do normalize for COSINE metric type - if constexpr (std::is_same_v || std::is_same_v) { + if constexpr (std::is_same_v || + std::is_same_v) { if (is_cosine) { Normalize(dataset); } @@ -275,18 +307,18 @@ IvfIndexNode::Train(const DataSet& dataset, const Config& cfg) { // faiss scann needs at least 16 rows since nbits=4 constexpr int64_t SCANN_MIN_ROWS = 16; - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { if (rows < SCANN_MIN_ROWS) { LOG_KNOWHERE_ERROR_ << rows << " rows is not enough, scann needs at least 16 rows to build index"; return Status::faiss_inner_error; } } - std::unique_ptr index; + std::unique_ptr index; // if cfg.use_elkan is used, then we'll use a temporary instance of // IndexFlatElkan for the training. try { - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { const IvfFlatConfig& ivf_flat_cfg = static_cast(cfg); auto nlist = MatchNlist(rows, ivf_flat_cfg.nlist.value()); @@ -306,7 +338,7 @@ IvfIndexNode::Train(const DataSet& dataset, const Config& cfg) { qzr.release(); index->own_fields = true; } - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { const IvfFlatCcConfig& ivf_flat_cc_cfg = static_cast(cfg); auto nlist = MatchNlist(rows, ivf_flat_cc_cfg.nlist.value()); @@ -329,7 +361,7 @@ IvfIndexNode::Train(const DataSet& dataset, const Config& cfg) { // ivfflat_cc has no serialize stage, make map at build stage index->make_direct_map(true, faiss::DirectMap::ConcurrentArray); } - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { const IvfPqConfig& ivf_pq_cfg = static_cast(cfg); auto nlist = MatchNlist(rows, ivf_pq_cfg.nlist.value()); auto nbits = MatchNbits(rows, ivf_pq_cfg.nbits.value()); @@ -351,7 +383,7 @@ IvfIndexNode::Train(const DataSet& dataset, const Config& cfg) { qzr.release(); index->own_fields = true; } - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { const ScannConfig& scann_cfg = static_cast(cfg); auto nlist = MatchNlist(rows, scann_cfg.nlist.value()); bool is_cosine = base_cfg.metric_type.value() == metric::COSINE; @@ -384,7 +416,7 @@ IvfIndexNode::Train(const DataSet& dataset, const Config& cfg) { base_index.release(); index->own_fields = true; } - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { const IvfSqConfig& ivf_sq_cfg = static_cast(cfg); auto nlist = MatchNlist(rows, ivf_sq_cfg.nlist.value()); @@ -405,7 +437,7 @@ IvfIndexNode::Train(const DataSet& dataset, const Config& cfg) { qzr.release(); index->own_fields = true; } - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { const IvfBinConfig& ivf_bin_cfg = static_cast(cfg); auto nlist = MatchNlist(rows, ivf_bin_cfg.nlist.value()); @@ -428,9 +460,9 @@ IvfIndexNode::Train(const DataSet& dataset, const Config& cfg) { return Status::success; } -template +template Status -IvfIndexNode::Add(const DataSet& dataset, const Config& cfg) { +IvfIndexNode::Add(const DataSet& dataset, const Config& cfg) { if (!this->index_) { LOG_KNOWHERE_ERROR_ << "Can not add data to empty IVF index."; return Status::empty_index; @@ -445,7 +477,7 @@ IvfIndexNode::Add(const DataSet& dataset, const Config& cfg) { setter = std::make_unique(); } try { - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { index_->add(rows, (const uint8_t*)data); } else { index_->add(rows, (const float*)data); @@ -457,9 +489,9 @@ IvfIndexNode::Add(const DataSet& dataset, const Config& cfg) { return Status::success; } -template +template expected -IvfIndexNode::Search(const DataSet& dataset, const Config& cfg, const BitsetView& bitset) const { +IvfIndexNode::Search(const DataSet& dataset, const Config& cfg, const BitsetView& bitset) const { if (!this->index_) { LOG_KNOWHERE_WARNING_ << "search on empty index"; return expected::Err(Status::empty_index, "index not loaded"); @@ -494,7 +526,7 @@ IvfIndexNode::Search(const DataSet& dataset, const Config& cfg, const BitsetV BitsetViewIDSelector bw_idselector(bitset); faiss::IDSelector* id_selector = (bitset.empty()) ? nullptr : &bw_idselector; - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { auto cur_data = (const uint8_t*)data + index * dim / 8; faiss::IVFSearchParameters ivf_search_params; @@ -507,7 +539,7 @@ IvfIndexNode::Search(const DataSet& dataset, const Config& cfg, const BitsetV distances[i + offset] = static_cast(i_distances[i + offset]); } } - } else if constexpr (std::is_same::value) { + } else if constexpr (std::is_same::value) { auto cur_query = (const float*)data + index * dim; if (is_cosine) { copied_query = CopyAndNormalizeVecs(cur_query, 1, dim); @@ -520,7 +552,7 @@ IvfIndexNode::Search(const DataSet& dataset, const Config& cfg, const BitsetV ivf_search_params.sel = id_selector; index_->search(1, cur_query, k, distances + offset, ids + offset, &ivf_search_params); - } else if constexpr (std::is_same::value) { + } else if constexpr (std::is_same::value) { auto cur_query = (const float*)data + index * dim; const ScannConfig& scann_cfg = static_cast(cfg); if (is_cosine) { @@ -574,9 +606,10 @@ IvfIndexNode::Search(const DataSet& dataset, const Config& cfg, const BitsetV return res; } -template +template expected -IvfIndexNode::RangeSearch(const DataSet& dataset, const Config& cfg, const BitsetView& bitset) const { +IvfIndexNode::RangeSearch(const DataSet& dataset, const Config& cfg, + const BitsetView& bitset) const { if (!this->index_) { LOG_KNOWHERE_WARNING_ << "range search on empty index"; return expected::Err(Status::empty_index, "index not loaded"); @@ -618,7 +651,7 @@ IvfIndexNode::RangeSearch(const DataSet& dataset, const Config& cfg, const Bi BitsetViewIDSelector bw_idselector(bitset); faiss::IDSelector* id_selector = (bitset.empty()) ? nullptr : &bw_idselector; - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { auto cur_data = (const uint8_t*)xq + index * dim / 8; faiss::IVFSearchParameters ivf_search_params; @@ -626,7 +659,7 @@ IvfIndexNode::RangeSearch(const DataSet& dataset, const Config& cfg, const Bi ivf_search_params.sel = id_selector; index_->range_search(1, cur_data, radius, &res, &ivf_search_params); - } else if constexpr (std::is_same::value) { + } else if constexpr (std::is_same::value) { auto cur_query = (const float*)xq + index * dim; if (is_cosine) { copied_query = CopyAndNormalizeVecs(cur_query, 1, dim); @@ -639,7 +672,7 @@ IvfIndexNode::RangeSearch(const DataSet& dataset, const Config& cfg, const Bi ivf_search_params.sel = id_selector; index_->range_search(1, cur_query, radius, &res, &ivf_search_params); - } else if constexpr (std::is_same::value) { + } else if constexpr (std::is_same::value) { auto cur_query = (const float*)xq + index * dim; if (is_cosine) { copied_query = CopyAndNormalizeVecs(cur_query, 1, dim); @@ -700,16 +733,16 @@ IvfIndexNode::RangeSearch(const DataSet& dataset, const Config& cfg, const Bi return GenResultDataSet(nq, ids, distances, lims); } -template +template expected -IvfIndexNode::GetVectorByIds(const DataSet& dataset) const { +IvfIndexNode::GetVectorByIds(const DataSet& dataset) const { if (!this->index_) { return expected::Err(Status::empty_index, "index not loaded"); } if (!this->index_->is_trained) { return expected::Err(Status::index_not_trained, "index not trained"); } - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { auto dim = Dim(); auto rows = dataset.GetRows(); auto ids = dataset.GetIds(); @@ -728,7 +761,8 @@ IvfIndexNode::GetVectorByIds(const DataSet& dataset) const { LOG_KNOWHERE_WARNING_ << "faiss inner error: " << e.what(); return expected::Err(Status::faiss_inner_error, e.what()); } - } else if constexpr (std::is_same::value || std::is_same::value) { + } else if constexpr (std::is_same::value || + std::is_same::value) { auto dim = Dim(); auto rows = dataset.GetRows(); auto ids = dataset.GetIds(); @@ -747,7 +781,7 @@ IvfIndexNode::GetVectorByIds(const DataSet& dataset) const { LOG_KNOWHERE_WARNING_ << "faiss inner error: " << e.what(); return expected::Err(Status::faiss_inner_error, e.what()); } - } else if constexpr (std::is_same::value) { + } else if constexpr (std::is_same::value) { // we should never go here since we should call HasRawData() first if (!index_->with_raw_data()) { return expected::Err(Status::not_implemented, "GetVectorByIds not implemented"); @@ -775,9 +809,9 @@ IvfIndexNode::GetVectorByIds(const DataSet& dataset) const { } } -template <> +template expected -IvfIndexNode::GetIndexMeta(const Config& config) const { +IvfIndexNode::GetIndexMetaImpl(const Config& config, IVFFlatTag) const { if (!index_) { LOG_KNOWHERE_WARNING_ << "get index meta on empty index"; return expected::Err(Status::empty_index, "index not loaded"); @@ -814,12 +848,12 @@ IvfIndexNode::GetIndexMeta(const Config& config) const { return GenResultDataSet(json_meta.dump(), json_id_set.dump()); } -template +template Status -IvfIndexNode::Serialize(BinarySet& binset) const { +IvfIndexNode::SerializeImpl(BinarySet& binset, IVFBaseTag) const { try { MemoryIOWriter writer; - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { faiss::write_index_binary(index_.get(), &writer); } else { faiss::write_index(index_.get(), &writer); @@ -833,13 +867,13 @@ IvfIndexNode::Serialize(BinarySet& binset) const { } } -template <> +template Status -IvfIndexNode::Serialize(BinarySet& binset) const { +IvfIndexNode::SerializeImpl(BinarySet& binset, IVFFlatTag) const { try { MemoryIOWriter writer; - LOG_KNOWHERE_INFO_ << "request version " << version_.VersionNumber(); - if (version_ <= Version::GetMinimalVersion()) { + LOG_KNOWHERE_INFO_ << "request version " << this->version_.VersionNumber(); + if (this->version_ <= Version::GetMinimalVersion()) { faiss::write_index_nm(index_.get(), &writer); LOG_KNOWHERE_INFO_ << "write IVF_FLAT_NM, file size " << writer.tellg(); } else { @@ -850,7 +884,7 @@ IvfIndexNode::Serialize(BinarySet& binset) const { binset.Append(Type(), index_data_ptr, writer.tellg()); // append raw data for backward compatible - if (version_ <= Version::GetMinimalVersion()) { + if (this->version_ <= Version::GetMinimalVersion()) { size_t dim = index_->d; size_t rows = index_->ntotal; size_t raw_data_size = dim * rows * sizeof(float); @@ -877,9 +911,9 @@ IvfIndexNode::Serialize(BinarySet& binset) const { } } -template +template Status -IvfIndexNode::Deserialize(const BinarySet& binset, const Config& config) { +IvfIndexNode::Deserialize(const BinarySet& binset, const Config& config) { std::vector names = {"IVF", // compatible with knowhere-1.x "BinaryIVF", // compatible with knowhere-1.x Type()}; @@ -891,8 +925,8 @@ IvfIndexNode::Deserialize(const BinarySet& binset, const Config& config) { MemoryIOReader reader(binary->data.get(), binary->size); try { - if constexpr (std::is_same::value) { - if (version_ <= Version::GetMinimalVersion()) { + if constexpr (std::is_same::value) { + if (this->version_ <= Version::GetMinimalVersion()) { auto raw_binary = binset.GetByName("RAW_DATA"); const BaseConfig& base_cfg = static_cast(config); ConvertIVFFlat(binset, base_cfg.metric_type.value(), raw_binary->data.get(), raw_binary->size); @@ -901,12 +935,12 @@ IvfIndexNode::Deserialize(const BinarySet& binset, const Config& config) { reader.total_ = binary->size; } index_.reset(static_cast(faiss::read_index(&reader))); - } else if constexpr (std::is_same::value) { - index_.reset(static_cast(faiss::read_index_binary(&reader))); + } else if constexpr (std::is_same::value) { + index_.reset(static_cast(faiss::read_index_binary(&reader))); } else { - index_.reset(static_cast(faiss::read_index(&reader))); + index_.reset(static_cast(faiss::read_index(&reader))); } - if constexpr (!std::is_same_v) { + if constexpr (!std::is_same_v) { const BaseConfig& base_cfg = static_cast(config); if (HasRawData(base_cfg.metric_type.value())) { index_->make_direct_map(true); @@ -919,9 +953,9 @@ IvfIndexNode::Deserialize(const BinarySet& binset, const Config& config) { return Status::success; } -template +template Status -IvfIndexNode::DeserializeFromFile(const std::string& filename, const Config& config) { +IvfIndexNode::DeserializeFromFile(const std::string& filename, const Config& config) { auto cfg = static_cast(config); int io_flags = 0; @@ -929,12 +963,12 @@ IvfIndexNode::DeserializeFromFile(const std::string& filename, const Config& io_flags |= faiss::IO_FLAG_MMAP; } try { - if constexpr (std::is_same::value) { - index_.reset(static_cast(faiss::read_index_binary(filename.data(), io_flags))); + if constexpr (std::is_same::value) { + index_.reset(static_cast(faiss::read_index_binary(filename.data(), io_flags))); } else { - index_.reset(static_cast(faiss::read_index(filename.data(), io_flags))); + index_.reset(static_cast(faiss::read_index(filename.data(), io_flags))); } - if constexpr (!std::is_same_v) { + if constexpr (!std::is_same_v) { const BaseConfig& base_cfg = static_cast(config); if (HasRawData(base_cfg.metric_type.value())) { index_->make_direct_map(true); @@ -946,42 +980,37 @@ IvfIndexNode::DeserializeFromFile(const std::string& filename, const Config& } return Status::success; } - -KNOWHERE_REGISTER_GLOBAL(IVFBIN, [](const int32_t& version, const Object& object) { - return Index>::Create(version, object); -}); - -KNOWHERE_REGISTER_GLOBAL(BIN_IVF_FLAT, [](const int32_t& version, const Object& object) { - return Index>::Create(version, object); -}); - -KNOWHERE_REGISTER_GLOBAL(IVFFLAT, [](const int32_t& version, const Object& object) { - return Index>::Create(version, object); -}); -KNOWHERE_REGISTER_GLOBAL(IVF_FLAT, [](const int32_t& version, const Object& object) { - return Index>::Create(version, object); -}); -KNOWHERE_REGISTER_GLOBAL(IVFFLATCC, [](const int32_t& version, const Object& object) { - return Index>::Create(version, object); -}); -KNOWHERE_REGISTER_GLOBAL(IVF_FLAT_CC, [](const int32_t& version, const Object& object) { - return Index>::Create(version, object); -}); -KNOWHERE_REGISTER_GLOBAL(SCANN, [](const int32_t& version, const Object& object) { - return Index>::Create(version, object); -}); -KNOWHERE_REGISTER_GLOBAL(IVFPQ, [](const int32_t& version, const Object& object) { - return Index>::Create(version, object); -}); -KNOWHERE_REGISTER_GLOBAL(IVF_PQ, [](const int32_t& version, const Object& object) { - return Index>::Create(version, object); -}); - -KNOWHERE_REGISTER_GLOBAL(IVFSQ, [](const int32_t& version, const Object& object) { - return Index>::Create(version, object); -}); -KNOWHERE_REGISTER_GLOBAL(IVF_SQ8, [](const int32_t& version, const Object& object) { - return Index>::Create(version, object); -}); - +// bin1 +KNOWHERE_SIMPLE_REGISTER_GLOBAL(IVFBIN, IvfIndexNode, bin1, faiss::IndexBinaryIVF); +KNOWHERE_SIMPLE_REGISTER_GLOBAL(BIN_IVF_FLAT, IvfIndexNode, bin1, faiss::IndexBinaryIVF); +// fp32 +KNOWHERE_SIMPLE_REGISTER_GLOBAL(IVFFLAT, IvfIndexNode, fp32, faiss::IndexIVFFlat); +KNOWHERE_SIMPLE_REGISTER_GLOBAL(IVF_FLAT, IvfIndexNode, fp32, faiss::IndexIVFFlat); +KNOWHERE_SIMPLE_REGISTER_GLOBAL(IVFFLATCC, IvfIndexNode, fp32, faiss::IndexIVFFlatCC); +KNOWHERE_SIMPLE_REGISTER_GLOBAL(IVF_FLAT_CC, IvfIndexNode, fp32, faiss::IndexIVFFlatCC); +KNOWHERE_SIMPLE_REGISTER_GLOBAL(SCANN, IvfIndexNode, fp32, faiss::IndexScaNN); +KNOWHERE_SIMPLE_REGISTER_GLOBAL(IVFPQ, IvfIndexNode, fp32, faiss::IndexIVFPQ); +KNOWHERE_SIMPLE_REGISTER_GLOBAL(IVF_PQ, IvfIndexNode, fp32, faiss::IndexIVFPQ); +KNOWHERE_SIMPLE_REGISTER_GLOBAL(IVFSQ, IvfIndexNode, fp32, faiss::IndexIVFScalarQuantizer); +KNOWHERE_SIMPLE_REGISTER_GLOBAL(IVF_SQ8, IvfIndexNode, fp32, faiss::IndexIVFScalarQuantizer); +// fp16 +KNOWHERE_MOCK_REGISTER_GLOBAL(IVFFLAT, IvfIndexNode, fp16, fp32, faiss::IndexIVFFlat); +KNOWHERE_MOCK_REGISTER_GLOBAL(IVF_FLAT, IvfIndexNode, fp16, fp32, faiss::IndexIVFFlat); +KNOWHERE_MOCK_REGISTER_GLOBAL(IVFFLATCC, IvfIndexNode, fp16, fp32, faiss::IndexIVFFlatCC); +KNOWHERE_MOCK_REGISTER_GLOBAL(IVF_FLAT_CC, IvfIndexNode, fp16, fp32, faiss::IndexIVFFlatCC); +KNOWHERE_MOCK_REGISTER_GLOBAL(SCANN, IvfIndexNode, fp16, fp32, faiss::IndexScaNN); +KNOWHERE_MOCK_REGISTER_GLOBAL(IVFPQ, IvfIndexNode, fp16, fp32, faiss::IndexIVFPQ); +KNOWHERE_MOCK_REGISTER_GLOBAL(IVF_PQ, IvfIndexNode, fp16, fp32, faiss::IndexIVFPQ); +KNOWHERE_MOCK_REGISTER_GLOBAL(IVFSQ, IvfIndexNode, fp16, fp32, faiss::IndexIVFScalarQuantizer); +KNOWHERE_MOCK_REGISTER_GLOBAL(IVF_SQ8, IvfIndexNode, fp16, fp32, faiss::IndexIVFScalarQuantizer); +// bf16 +KNOWHERE_MOCK_REGISTER_GLOBAL(IVFFLAT, IvfIndexNode, bf16, fp32, faiss::IndexIVFFlat); +KNOWHERE_MOCK_REGISTER_GLOBAL(IVF_FLAT, IvfIndexNode, bf16, fp32, faiss::IndexIVFFlat); +KNOWHERE_MOCK_REGISTER_GLOBAL(IVFFLATCC, IvfIndexNode, bf16, fp32, faiss::IndexIVFFlatCC); +KNOWHERE_MOCK_REGISTER_GLOBAL(IVF_FLAT_CC, IvfIndexNode, bf16, fp32, faiss::IndexIVFFlatCC); +KNOWHERE_MOCK_REGISTER_GLOBAL(SCANN, IvfIndexNode, bf16, fp32, faiss::IndexScaNN); +KNOWHERE_MOCK_REGISTER_GLOBAL(IVFPQ, IvfIndexNode, bf16, fp32, faiss::IndexIVFPQ); +KNOWHERE_MOCK_REGISTER_GLOBAL(IVF_PQ, IvfIndexNode, bf16, fp32, faiss::IndexIVFPQ); +KNOWHERE_MOCK_REGISTER_GLOBAL(IVFSQ, IvfIndexNode, bf16, fp32, faiss::IndexIVFScalarQuantizer); +KNOWHERE_MOCK_REGISTER_GLOBAL(IVF_SQ8, IvfIndexNode, bf16, fp32, faiss::IndexIVFScalarQuantizer); } // namespace knowhere diff --git a/src/index/ivf_raft/ivf_raft.cu b/src/index/ivf_raft/ivf_raft.cu index 622666c91..072e511f4 100644 --- a/src/index/ivf_raft/ivf_raft.cu +++ b/src/index/ivf_raft/ivf_raft.cu @@ -24,22 +24,22 @@ constexpr uint32_t cuda_concurrent_size = 32; namespace knowhere { KNOWHERE_REGISTER_GLOBAL(GPU_RAFT_IVF_FLAT, [](const int32_t& version, const Object& object) { - return Index::Create( - std::make_unique>(version, object), cuda_concurrent_size); -}); + return (Index::Create( + std::make_unique>(version, object), cuda_concurrent_size)); +}, fp32); KNOWHERE_REGISTER_GLOBAL(GPU_RAFT_IVF_PQ, [](const int32_t& version, const Object& object) { - return Index::Create( - std::make_unique>(version, object), cuda_concurrent_size); -}); + return (Index::Create( + std::make_unique>(version, object), cuda_concurrent_size)); +}, fp32); KNOWHERE_REGISTER_GLOBAL(GPU_IVF_FLAT, [](const int32_t& version, const Object& object) { - return Index::Create( - std::make_unique>(version, object), cuda_concurrent_size); -}); + return (Index::Create( + std::make_unique>(version, object), cuda_concurrent_size)); +}, fp32); KNOWHERE_REGISTER_GLOBAL(GPU_IVF_PQ, [](const int32_t& version, const Object& object) { - return Index::Create( - std::make_unique>(version, object), cuda_concurrent_size); -}); + return (Index::Create( + std::make_unique>(version, object), cuda_concurrent_size)); +}, fp32); } // namespace knowhere diff --git a/src/index/ivf_raft/ivf_raft.cuh b/src/index/ivf_raft/ivf_raft.cuh index 02eee67b7..74b40df7e 100644 --- a/src/index/ivf_raft/ivf_raft.cuh +++ b/src/index/ivf_raft/ivf_raft.cuh @@ -230,7 +230,7 @@ struct KnowhereConfigType { typedef RaftIvfPqConfig Type; }; -template +template class RaftIvfIndexNode : public IndexNode { public: RaftIvfIndexNode(const int32_t& /*version*/, const Object& object) : device_id_{-1}, gpu_index_{} { diff --git a/tests/ut/test_bruteforce.cc b/tests/ut/test_bruteforce.cc index fd5c16f78..4cdd80143 100644 --- a/tests/ut/test_bruteforce.cc +++ b/tests/ut/test_bruteforce.cc @@ -38,7 +38,7 @@ TEST_CASE("Test Brute Force", "[float vector]") { }; SECTION("Test Search") { - auto res = knowhere::BruteForce::Search(train_ds, query_ds, conf, nullptr); + auto res = knowhere::BruteForce::Search(train_ds, query_ds, conf, nullptr); REQUIRE(res.has_value()); auto ids = res.value()->GetIds(); auto dist = res.value()->GetDistance(); @@ -55,7 +55,7 @@ TEST_CASE("Test Brute Force", "[float vector]") { SECTION("Test Search With Buf") { auto ids = new int64_t[nq * k]; auto dist = new float[nq * k]; - auto res = knowhere::BruteForce::SearchWithBuf(train_ds, query_ds, ids, dist, conf, nullptr); + auto res = knowhere::BruteForce::SearchWithBuf(train_ds, query_ds, ids, dist, conf, nullptr); REQUIRE(res == knowhere::Status::success); for (int64_t i = 0; i < nq; i++) { REQUIRE(ids[i * k] == i); @@ -70,7 +70,7 @@ TEST_CASE("Test Brute Force", "[float vector]") { } SECTION("Test Range Search") { - auto res = knowhere::BruteForce::RangeSearch(train_ds, query_ds, conf, nullptr); + auto res = knowhere::BruteForce::RangeSearch(train_ds, query_ds, conf, nullptr); REQUIRE(res.has_value()); auto ids = res.value()->GetIds(); auto dist = res.value()->GetDistance(); @@ -112,7 +112,7 @@ TEST_CASE("Test Brute Force", "[binary vector]") { }; SECTION("Test Search") { - auto res = knowhere::BruteForce::Search(train_ds, query_ds, conf, nullptr); + auto res = knowhere::BruteForce::Search(train_ds, query_ds, conf, nullptr); REQUIRE(res.has_value()); auto ids = res.value()->GetIds(); auto dist = res.value()->GetDistance(); @@ -125,7 +125,7 @@ TEST_CASE("Test Brute Force", "[binary vector]") { SECTION("Test Search With Buf") { auto ids = new int64_t[nq * k]; auto dist = new float[nq * k]; - auto res = knowhere::BruteForce::SearchWithBuf(train_ds, query_ds, ids, dist, conf, nullptr); + auto res = knowhere::BruteForce::SearchWithBuf(train_ds, query_ds, ids, dist, conf, nullptr); REQUIRE(res == knowhere::Status::success); for (int64_t i = 0; i < nq; i++) { REQUIRE(ids[i * k] == i); @@ -144,7 +144,7 @@ TEST_CASE("Test Brute Force", "[binary vector]") { auto cfg = conf; cfg[knowhere::meta::RADIUS] = radius_map[metric]; - auto res = knowhere::BruteForce::RangeSearch(train_ds, query_ds, cfg, nullptr); + auto res = knowhere::BruteForce::RangeSearch(train_ds, query_ds, cfg, nullptr); REQUIRE(res.has_value()); auto ids = res.value()->GetIds(); auto dist = res.value()->GetDistance(); diff --git a/tests/ut/test_diskann.cc b/tests/ut/test_diskann.cc index 891299697..c2cb09b24 100644 --- a/tests/ut/test_diskann.cc +++ b/tests/ut/test_diskann.cc @@ -20,6 +20,7 @@ #include "knowhere/comp/local_file_manager.h" #include "knowhere/expected.h" #include "knowhere/factory.h" +#include "knowhere/utils.h" #include "utils.h" #if __has_include() #include @@ -52,12 +53,13 @@ constexpr float kL2RangeAp = 0.9; constexpr float kIpRangeAp = 0.9; constexpr float kCosineRangeAp = 0.9; +template void -WriteRawDataToDisk(const std::string data_path, const float* raw_data, const uint32_t num, const uint32_t dim) { +WriteRawDataToDisk(const std::string data_path, const T* raw_data, const uint32_t num, const uint32_t dim) { std::ofstream writer(data_path.c_str(), std::ios::binary); writer.write((char*)&num, sizeof(uint32_t)); writer.write((char*)&dim, sizeof(uint32_t)); - writer.write((char*)raw_data, sizeof(float) * num * dim); + writer.write((char*)raw_data, sizeof(T) * num * dim); writer.close(); } @@ -93,11 +95,11 @@ TEST_CASE("Invalid diskann params test", "[diskann]") { auto diskann_index_pack = knowhere::Pack(file_manager); auto base_ds = GenDataSet(rows_num, kDim, 30); auto base_ptr = static_cast(base_ds->GetTensor()); - WriteRawDataToDisk(kRawDataPath, base_ptr, rows_num, kDim); + WriteRawDataToDisk(kRawDataPath, base_ptr, rows_num, kDim); // build process SECTION("Invalid build params test") { knowhere::DataSet* ds_ptr = nullptr; - auto diskann = knowhere::IndexFactory::Instance().Create("DISKANN", version, diskann_index_pack); + auto diskann = knowhere::IndexFactory::Instance().Create("DISKANN", version, diskann_index_pack); knowhere::Json test_json; knowhere::Status test_stat; // invalid metric type @@ -115,7 +117,7 @@ TEST_CASE("Invalid diskann params test", "[diskann]") { SECTION("Invalid search params test") { knowhere::DataSet* ds_ptr = nullptr; auto binarySet = knowhere::BinarySet(); - auto diskann = knowhere::IndexFactory::Instance().Create("DISKANN", version, diskann_index_pack); + auto diskann = knowhere::IndexFactory::Instance().Create("DISKANN", version, diskann_index_pack); diskann.Build(*ds_ptr, test_gen()); diskann.Serialize(binarySet); diskann.Deserialize(binarySet, test_gen()); @@ -145,7 +147,10 @@ TEST_CASE("Invalid diskann params test", "[diskann]") { fs::remove(kDir); } -TEST_CASE("Test DiskANNIndexNode.", "[diskann]") { +template +inline void +base_search() { + std::cout << "test data type" << typeid(data_type).name() << std::endl; fs::remove_all(kDir); fs::remove(kDir); REQUIRE_NOTHROW(fs::create_directory(kDir)); @@ -222,20 +227,26 @@ TEST_CASE("Test DiskANNIndexNode.", "[diskann]") { return json; }; - auto query_ds = GenDataSet(kNumQueries, kDim, 42); + auto fp32_query_ds = GenDataSet(kNumQueries, kDim, 42); knowhere::DataSetPtr knn_gt_ptr = nullptr; knowhere::DataSetPtr range_search_gt_ptr = nullptr; - auto base_ds = GenDataSet(kNumRows, kDim, 30); + auto fp32_base_ds = GenDataSet(kNumRows, kDim, 30); + knowhere::DataSetPtr base_ds(fp32_base_ds); + knowhere::DataSetPtr query_ds(fp32_query_ds); + if (!std::is_same_v) { + base_ds = knowhere::data_type_conversion(*fp32_base_ds); + query_ds = knowhere::data_type_conversion(*fp32_query_ds); + } { - auto base_ptr = static_cast(base_ds->GetTensor()); - WriteRawDataToDisk(kRawDataPath, base_ptr, kNumRows, kDim); + auto base_ptr = static_cast(base_ds->GetTensor()); + WriteRawDataToDisk(kRawDataPath, base_ptr, kNumRows, kDim); // generate the gt of knn search and range search auto base_json = base_gen(); - auto result_knn = knowhere::BruteForce::Search(base_ds, query_ds, base_json, nullptr); + auto result_knn = knowhere::BruteForce::Search(base_ds, query_ds, base_json, nullptr); knn_gt_ptr = result_knn.value(); - auto result_range = knowhere::BruteForce::RangeSearch(base_ds, query_ds, base_json, nullptr); + auto result_range = knowhere::BruteForce::RangeSearch(base_ds, query_ds, base_json, nullptr); range_search_gt_ptr = result_range.value(); } @@ -247,7 +258,7 @@ TEST_CASE("Test DiskANNIndexNode.", "[diskann]") { // build process { knowhere::DataSet* ds_ptr = nullptr; - auto diskann = knowhere::IndexFactory::Instance().Create("DISKANN", version, diskann_index_pack); + auto diskann = knowhere::IndexFactory::Instance().Create("DISKANN", version, diskann_index_pack); auto build_json = build_gen().dump(); knowhere::Json json = knowhere::Json::parse(build_json); diskann.Build(*ds_ptr, json); @@ -255,7 +266,7 @@ TEST_CASE("Test DiskANNIndexNode.", "[diskann]") { } { // knn search - auto diskann = knowhere::IndexFactory::Instance().Create("DISKANN", version, diskann_index_pack); + auto diskann = knowhere::IndexFactory::Instance().Create("DISKANN", version, diskann_index_pack); diskann.Deserialize(binset, deserialize_json); auto knn_search_json = knn_search_gen().dump(); @@ -272,7 +283,8 @@ TEST_CASE("Test DiskANNIndexNode.", "[diskann]") { if (fs::exists(cached_nodes_file_path)) { fs::remove(cached_nodes_file_path); } - auto diskann_tmp = knowhere::IndexFactory::Instance().Create("DISKANN", version, diskann_index_pack); + auto diskann_tmp = + knowhere::IndexFactory::Instance().Create("DISKANN", version, diskann_index_pack); diskann_tmp.Deserialize(binset, deserialize_json); auto knn_search_json = knn_search_gen().dump(); knowhere::Json knn_json = knowhere::Json::parse(knn_search_json); @@ -293,7 +305,7 @@ TEST_CASE("Test DiskANNIndexNode.", "[diskann]") { auto bitset_data = gen_func(kNumRows, percentage * kNumRows); knowhere::BitsetView bitset(bitset_data.data(), kNumRows); auto results = diskann.Search(*query_ds, knn_json, bitset); - auto gt = knowhere::BruteForce::Search(base_ds, query_ds, knn_json, bitset); + auto gt = knowhere::BruteForce::Search(base_ds, query_ds, knn_json, bitset); float recall = GetKNNRecall(*gt.value(), *results.value()); if (percentage == 0.98f) { REQUIRE(recall >= 0.9f); @@ -318,6 +330,12 @@ TEST_CASE("Test DiskANNIndexNode.", "[diskann]") { fs::remove(kDir); } +TEST_CASE("Test DiskANNIndexNode.", "[diskann]") { + base_search(); + base_search(); + base_search(); +} + // This test case only check L2 TEST_CASE("Test DiskANN GetVectorByIds", "[diskann]") { auto version = GenTestVersionList(); @@ -349,13 +367,13 @@ TEST_CASE("Test DiskANN GetVectorByIds", "[diskann]") { auto query_ds = GenDataSet(kNumQueries, dim, 42); auto base_ds = GenDataSet(kNumRows, dim, 30); auto base_ptr = static_cast(base_ds->GetTensor()); - WriteRawDataToDisk(kRawDataPath, base_ptr, kNumRows, dim); + WriteRawDataToDisk(kRawDataPath, base_ptr, kNumRows, dim); std::shared_ptr file_manager = std::make_shared(); auto diskann_index_pack = knowhere::Pack(file_manager); knowhere::DataSet* ds_ptr = nullptr; - auto diskann = knowhere::IndexFactory::Instance().Create("DISKANN", version, diskann_index_pack); + auto diskann = knowhere::IndexFactory::Instance().Create("DISKANN", version, diskann_index_pack); auto build_json = build_gen().dump(); knowhere::Json json = knowhere::Json::parse(build_json); diskann.Build(*ds_ptr, json); @@ -371,7 +389,7 @@ TEST_CASE("Test DiskANN GetVectorByIds", "[diskann]") { return json; }; knowhere::Json deserialize_json = knowhere::Json::parse(deserialize_gen().dump()); - auto index = knowhere::IndexFactory::Instance().Create("DISKANN", version, diskann_index_pack); + auto index = knowhere::IndexFactory::Instance().Create("DISKANN", version, diskann_index_pack); auto ret = index.Deserialize(binset, deserialize_json); REQUIRE(ret == knowhere::Status::success); std::vector ids_sizes = {1, kNumRows * 0.2, kNumRows * 0.7, kNumRows}; diff --git a/tests/ut/test_feder.cc b/tests/ut/test_feder.cc index 905597edc..131dadb3f 100644 --- a/tests/ut/test_feder.cc +++ b/tests/ut/test_feder.cc @@ -173,11 +173,11 @@ TEST_CASE("Test Feder", "[feder]") { const auto query_ds = GenDataSet(nq, dim, seed); const knowhere::Json conf = base_gen(); - auto gt = knowhere::BruteForce::Search(train_ds, query_ds, conf, nullptr); + auto gt = knowhere::BruteForce::Search(train_ds, query_ds, conf, nullptr); SECTION("Test HNSW Feder") { auto name = knowhere::IndexEnum::INDEX_HNSW; - auto idx = knowhere::IndexFactory::Instance().Create(name, version); + auto idx = knowhere::IndexFactory::Instance().Create(name, version); REQUIRE(idx.Type() == name); auto json = hnsw_gen(); @@ -201,7 +201,7 @@ TEST_CASE("Test Feder", "[feder]") { SECTION("Test IVF_FLAT Feder") { auto name = knowhere::IndexEnum::INDEX_FAISS_IVFFLAT; - auto idx = knowhere::IndexFactory::Instance().Create(name, version); + auto idx = knowhere::IndexFactory::Instance().Create(name, version); REQUIRE(idx.Type() == name); auto json = ivfflat_gen(); diff --git a/tests/ut/test_get_vector.cc b/tests/ut/test_get_vector.cc index 182be9702..585061eed 100644 --- a/tests/ut/test_get_vector.cc +++ b/tests/ut/test_get_vector.cc @@ -60,7 +60,7 @@ TEST_CASE("Test Binary Get Vector By Ids", "[Binary GetVectorByIds]") { make_tuple(knowhere::IndexEnum::INDEX_FAISS_BIN_IVFFLAT, bin_ivfflat_gen), make_tuple(knowhere::IndexEnum::INDEX_HNSW, bin_hnsw_gen), })); - auto idx = knowhere::IndexFactory::Instance().Create(name, version); + auto idx = knowhere::IndexFactory::Instance().Create(name, version); auto cfg_json = gen().dump(); CAPTURE(name, cfg_json); knowhere::Json json = knowhere::Json::parse(cfg_json); @@ -75,7 +75,7 @@ TEST_CASE("Test Binary Get Vector By Ids", "[Binary GetVectorByIds]") { knowhere::BinarySet bs; idx.Serialize(bs); - auto idx_new = knowhere::IndexFactory::Instance().Create(name, version); + auto idx_new = knowhere::IndexFactory::Instance().Create(name, version); idx_new.Deserialize(bs); auto retrieve_task = [&]() { @@ -173,7 +173,7 @@ TEST_CASE("Test Float Get Vector By Ids", "[Float GetVectorByIds]") { make_tuple(knowhere::IndexEnum::INDEX_FAISS_SCANN, scann_gen2), make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen), })); - auto idx = knowhere::IndexFactory::Instance().Create(name, version); + auto idx = knowhere::IndexFactory::Instance().Create(name, version); auto cfg_json = gen().dump(); CAPTURE(name, cfg_json); knowhere::Json json = knowhere::Json::parse(cfg_json); @@ -189,7 +189,7 @@ TEST_CASE("Test Float Get Vector By Ids", "[Float GetVectorByIds]") { knowhere::BinarySet bs; idx.Serialize(bs); - auto idx_new = knowhere::IndexFactory::Instance().Create(name, version); + auto idx_new = knowhere::IndexFactory::Instance().Create(name, version); idx_new.Deserialize(bs); auto retrieve_task = [&]() { diff --git a/tests/ut/test_gpu_search.cc b/tests/ut/test_gpu_search.cc index 22c430fb4..61109cf5b 100644 --- a/tests/ut/test_gpu_search.cc +++ b/tests/ut/test_gpu_search.cc @@ -71,7 +71,7 @@ TEST_CASE("Test All GPU Index", "[search]") { make_tuple(knowhere::IndexEnum::INDEX_RAFT_IVFFLAT, ivfflat_gen), make_tuple(knowhere::IndexEnum::INDEX_RAFT_IVFPQ, ivfpq_gen), })); - auto idx = knowhere::IndexFactory::Instance().Create(name, version); + auto idx = knowhere::IndexFactory::Instance().Create(name, version); auto cfg_json = gen().dump(); CAPTURE(name, cfg_json); knowhere::Json json = knowhere::Json::parse(cfg_json); @@ -100,7 +100,7 @@ TEST_CASE("Test All GPU Index", "[search]") { make_tuple(knowhere::IndexEnum::INDEX_RAFT_IVFFLAT, ivfflat_gen), make_tuple(knowhere::IndexEnum::INDEX_RAFT_IVFPQ, ivfpq_gen), })); - auto idx = knowhere::IndexFactory::Instance().Create(name, version); + auto idx = knowhere::IndexFactory::Instance().Create(name, version); auto cfg_json = gen().dump(); CAPTURE(name, cfg_json); knowhere::Json json = knowhere::Json::parse(cfg_json); @@ -119,7 +119,7 @@ TEST_CASE("Test All GPU Index", "[search]") { knowhere::BitsetView bitset(bitset_data.data(), nb); auto results = idx.Search(*query_ds, json, bitset); REQUIRE(results.has_value()); - auto gt = knowhere::BruteForce::Search(train_ds, query_ds, json, bitset); + auto gt = knowhere::BruteForce::Search(train_ds, query_ds, json, bitset); float recall = GetKNNRecall(*gt.value(), *results.value()); if (percentage == 0.98f) { REQUIRE(recall > 0.4f); @@ -142,7 +142,7 @@ TEST_CASE("Test All GPU Index", "[search]") { make_tuple(knowhere::IndexEnum::INDEX_RAFT_IVFFLAT, ivfflat_gen), make_tuple(knowhere::IndexEnum::INDEX_RAFT_IVFPQ, ivfpq_gen), })); - auto idx = knowhere::IndexFactory::Instance().Create(name, version); + auto idx = knowhere::IndexFactory::Instance().Create(name, version); auto cfg_json = gen().dump(); CAPTURE(name, cfg_json); knowhere::Json json = knowhere::Json::parse(cfg_json); @@ -158,7 +158,7 @@ TEST_CASE("Test All GPU Index", "[search]") { json[knowhere::meta::TOPK] = std::get<0>(topKTuple); auto results = idx.Search(*query_ds, json, nullptr); REQUIRE(results.has_value()); - auto gt = knowhere::BruteForce::Search(train_ds, query_ds, json, nullptr); + auto gt = knowhere::BruteForce::Search(train_ds, query_ds, json, nullptr); float recall = GetKNNRecall(*gt.value(), *results.value()); REQUIRE(recall >= std::get<1>(topKTuple)); } @@ -175,7 +175,7 @@ TEST_CASE("Test All GPU Index", "[search]") { make_tuple(knowhere::IndexEnum::INDEX_RAFT_IVFPQ, ivfpq_gen), })); - auto idx = knowhere::IndexFactory::Instance().Create(name, version); + auto idx = knowhere::IndexFactory::Instance().Create(name, version); auto cfg_json = gen().dump(); CAPTURE(name, cfg_json); knowhere::Json json = knowhere::Json::parse(cfg_json); @@ -186,7 +186,7 @@ TEST_CASE("Test All GPU Index", "[search]") { REQUIRE(res == knowhere::Status::success); knowhere::BinarySet bs; idx.Serialize(bs); - auto idx_ = knowhere::IndexFactory::Instance().Create(name, version); + auto idx_ = knowhere::IndexFactory::Instance().Create(name, version); idx_.Deserialize(bs); auto results = idx_.Search(*query_ds, json, nullptr); REQUIRE(results.has_value()); @@ -203,7 +203,7 @@ TEST_CASE("Test All GPU Index", "[search]") { make_tuple(knowhere::IndexEnum::INDEX_RAFT_IVFPQ, ivfpq_gen), })); auto rows = 16; - auto idx = knowhere::IndexFactory::Instance().Create(name, version); + auto idx = knowhere::IndexFactory::Instance().Create(name, version); auto cfg_json = gen().dump(); CAPTURE(name, cfg_json); knowhere::Json json = knowhere::Json::parse(cfg_json); @@ -218,7 +218,7 @@ TEST_CASE("Test All GPU Index", "[search]") { knowhere::BitsetView bitset(bitset_data.data(), rows); auto results = idx.Search(*train_ds, json, bitset); REQUIRE(results.has_value()); - auto gt = knowhere::BruteForce::Search(train_ds, train_ds, json, bitset); + auto gt = knowhere::BruteForce::Search(train_ds, train_ds, json, bitset); float recall = GetKNNRecall(*gt.value(), *results.value()); REQUIRE(recall == 1.0f); } diff --git a/tests/ut/test_half_presicion.cc b/tests/ut/test_half_presicion.cc new file mode 100644 index 000000000..cd534cd7a --- /dev/null +++ b/tests/ut/test_half_presicion.cc @@ -0,0 +1,258 @@ +// Copyright (C) 2019-2023 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include "catch2/catch_approx.hpp" +#include "catch2/catch_test_macros.hpp" +#include "catch2/generators/catch_generators.hpp" +#include "faiss/utils/binary_distances.h" +#include "hnswlib/hnswalg.h" +#include "knowhere/bitsetview.h" +#include "knowhere/comp/brute_force.h" +#include "knowhere/comp/index_param.h" +#include "knowhere/comp/knowhere_config.h" +#include "knowhere/factory.h" +#include "knowhere/log.h" +#include "knowhere/utils.h" +#include "utils.h" + +namespace { +constexpr float kKnnRecallThreshold = 0.6f; +constexpr float kBruteForceRecallThreshold = 0.99f; +} // namespace +template +void +BaseSearchTest() { + using Catch::Approx; + + const int64_t nb = 1000, nq = 10; + const int64_t dim = 128; + + auto metric = GENERATE(as{}, knowhere::metric::L2, knowhere::metric::COSINE); + auto topk = GENERATE(as{}, 5, 120); + auto version = GenTestVersionList(); + + auto base_gen = [=]() { + knowhere::Json json; + json[knowhere::meta::DIM] = dim; + json[knowhere::meta::METRIC_TYPE] = metric; + json[knowhere::meta::TOPK] = topk; + json[knowhere::meta::RADIUS] = knowhere::IsMetricType(metric, knowhere::metric::L2) ? 10.0 : 0.99; + json[knowhere::meta::RANGE_FILTER] = knowhere::IsMetricType(metric, knowhere::metric::L2) ? 0.0 : 1.01; + return json; + }; + + auto ivfflat_gen = [base_gen]() { + knowhere::Json json = base_gen(); + json[knowhere::indexparam::NLIST] = 16; + json[knowhere::indexparam::NPROBE] = 8; + return json; + }; + + auto ivfflatcc_gen = [ivfflat_gen]() { + knowhere::Json json = ivfflat_gen(); + json[knowhere::indexparam::SSIZE] = 48; + return json; + }; + + auto ivfsq_gen = ivfflat_gen; + + auto flat_gen = base_gen; + + auto ivfpq_gen = [ivfflat_gen]() { + knowhere::Json json = ivfflat_gen(); + json[knowhere::indexparam::M] = 4; + json[knowhere::indexparam::NBITS] = 8; + return json; + }; + + auto scann_gen = [ivfflat_gen]() { + knowhere::Json json = ivfflat_gen(); + json[knowhere::indexparam::NPROBE] = 14; + json[knowhere::indexparam::REORDER_K] = 500; + json[knowhere::indexparam::WITH_RAW_DATA] = true; + return json; + }; + + auto scann_gen2 = [scann_gen]() { + knowhere::Json json = scann_gen(); + json[knowhere::indexparam::WITH_RAW_DATA] = false; + return json; + }; + + auto hnsw_gen = [base_gen]() { + knowhere::Json json = base_gen(); + json[knowhere::indexparam::HNSW_M] = 128; + json[knowhere::indexparam::EFCONSTRUCTION] = 200; + json[knowhere::indexparam::EF] = 200; + return json; + }; + + const auto fp32_train_ds = GenDataSet(nb, dim); + const auto fp32_query_ds = GenDataSet(nq, dim); + auto train_ds = knowhere::data_type_conversion(*fp32_train_ds); + auto query_ds = knowhere::data_type_conversion(*fp32_query_ds); + + const knowhere::Json conf = { + {knowhere::meta::METRIC_TYPE, metric}, + {knowhere::meta::TOPK, topk}, + }; + auto gt = knowhere::BruteForce::Search(train_ds, query_ds, conf, nullptr); + + SECTION("Test half-float Search") { + using std::make_tuple; + auto [name, gen] = GENERATE_REF(table>({ + make_tuple(knowhere::IndexEnum::INDEX_FAISS_IDMAP, flat_gen), + make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT, ivfflat_gen), + make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT_CC, ivfflatcc_gen), + make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFSQ8, ivfsq_gen), + make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFPQ, ivfpq_gen), + make_tuple(knowhere::IndexEnum::INDEX_FAISS_SCANN, scann_gen), + make_tuple(knowhere::IndexEnum::INDEX_FAISS_SCANN, scann_gen2), + make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen), + })); + auto idx = knowhere::IndexFactory::Instance().Create(name, version); + auto cfg_json = gen().dump(); + CAPTURE(name, cfg_json); + knowhere::Json json = knowhere::Json::parse(cfg_json); + REQUIRE(idx.Type() == name); + REQUIRE(idx.Build(*train_ds, json) == knowhere::Status::success); + REQUIRE(idx.Size() > 0); + REQUIRE(idx.Count() == nb); + + knowhere::BinarySet bs; + REQUIRE(idx.Serialize(bs) == knowhere::Status::success); + REQUIRE(idx.Deserialize(bs, json) == knowhere::Status::success); + + auto results = idx.Search(*query_ds, json, nullptr); + REQUIRE(results.has_value()); + float recall = GetKNNRecall(*gt.value(), *results.value()); + bool scann_without_raw_data = + (name == knowhere::IndexEnum::INDEX_FAISS_SCANN && scann_gen2().dump() == cfg_json); + if (name != knowhere::IndexEnum::INDEX_FAISS_IVFPQ && !scann_without_raw_data) { + REQUIRE(recall > kKnnRecallThreshold); + } + + if (metric == knowhere::metric::COSINE) { + if (name != knowhere::IndexEnum::INDEX_FAISS_IVFSQ8 && name != knowhere::IndexEnum::INDEX_FAISS_IVFPQ && + !scann_without_raw_data) { + REQUIRE(CheckDistanceInScope(*results.value(), topk, -1.00001, 1.00001)); + } + } + } + + SECTION("Test half-float Range Search") { + using std::make_tuple; + auto [name, gen] = GENERATE_REF(table>({ + make_tuple(knowhere::IndexEnum::INDEX_FAISS_IDMAP, flat_gen), + make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT, ivfflat_gen), + make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT_CC, ivfflatcc_gen), + make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFSQ8, ivfsq_gen), + make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFPQ, ivfpq_gen), + make_tuple(knowhere::IndexEnum::INDEX_FAISS_SCANN, scann_gen), + make_tuple(knowhere::IndexEnum::INDEX_FAISS_SCANN, scann_gen2), + make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen), + })); + auto idx = knowhere::IndexFactory::Instance().Create(name, version); + auto cfg_json = gen().dump(); + CAPTURE(name, cfg_json); + knowhere::Json json = knowhere::Json::parse(cfg_json); + REQUIRE(idx.Type() == name); + REQUIRE(idx.Build(*train_ds, json) == knowhere::Status::success); + + knowhere::BinarySet bs; + REQUIRE(idx.Serialize(bs) == knowhere::Status::success); + REQUIRE(idx.Deserialize(bs, json) == knowhere::Status::success); + + auto results = idx.RangeSearch(*query_ds, json, nullptr); + REQUIRE(results.has_value()); + auto ids = results.value()->GetIds(); + auto lims = results.value()->GetLims(); + auto dis = results.value()->GetDistance(); + bool scann_without_raw_data = + (name == knowhere::IndexEnum::INDEX_FAISS_SCANN && scann_gen2().dump() == cfg_json); + if (name != knowhere::IndexEnum::INDEX_FAISS_IVFPQ && name != knowhere::IndexEnum::INDEX_FAISS_SCANN) { + for (int i = 0; i < nq; ++i) { + CHECK(ids[lims[i]] == i); + } + } + + if (metric == knowhere::metric::COSINE) { + if (name != knowhere::IndexEnum::INDEX_FAISS_IVFSQ8 && name != knowhere::IndexEnum::INDEX_FAISS_IVFPQ && + !scann_without_raw_data) { + REQUIRE(CheckDistanceInScope(*results.value(), -1.00001, 1.00001)); + } + } + } + + SECTION("Test half-float Search with Bitset") { + using std::make_tuple; + auto [name, gen, threshold] = GENERATE_REF(table, float>({ + make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen, hnswlib::kHnswSearchKnnBFThreshold), + })); + auto idx = knowhere::IndexFactory::Instance().Create(name, version); + auto cfg_json = gen().dump(); + CAPTURE(name, cfg_json); + knowhere::Json json = knowhere::Json::parse(cfg_json); + REQUIRE(idx.Type() == name); + REQUIRE(idx.Build(*train_ds, json) == knowhere::Status::success); + + std::vector(size_t, size_t)>> gen_bitset_funcs = { + GenerateBitsetWithFirstTbitsSet, GenerateBitsetWithRandomTbitsSet}; + const auto bitset_percentages = {0.4f, 0.98f}; + for (const float percentage : bitset_percentages) { + for (const auto& gen_func : gen_bitset_funcs) { + auto bitset_data = gen_func(nb, percentage * nb); + knowhere::BitsetView bitset(bitset_data.data(), nb); + auto results = idx.Search(*query_ds, json, bitset); + auto gt = knowhere::BruteForce::Search(train_ds, query_ds, json, bitset); + float recall = GetKNNRecall(*gt.value(), *results.value()); + if (percentage > threshold) { + REQUIRE(recall > kBruteForceRecallThreshold); + } else { + REQUIRE(recall > kKnnRecallThreshold); + } + } + } + } + + SECTION("Test Serialize/Deserialize") { + using std::make_tuple; + auto [name, gen] = GENERATE_REF(table>({ + make_tuple(knowhere::IndexEnum::INDEX_FAISS_IDMAP, flat_gen), + make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT, ivfflat_gen), + make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT_CC, ivfflatcc_gen), + make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFSQ8, ivfsq_gen), + make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFPQ, ivfpq_gen), + make_tuple(knowhere::IndexEnum::INDEX_FAISS_SCANN, scann_gen), + make_tuple(knowhere::IndexEnum::INDEX_FAISS_SCANN, scann_gen2), + make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen), + })); + + auto idx = knowhere::IndexFactory::Instance().Create(name, version); + auto cfg_json = gen().dump(); + CAPTURE(name, cfg_json); + knowhere::Json json = knowhere::Json::parse(cfg_json); + REQUIRE(idx.Type() == name); + REQUIRE(idx.Build(*train_ds, json) == knowhere::Status::success); + knowhere::BinarySet bs; + idx.Serialize(bs); + + auto idx_ = knowhere::IndexFactory::Instance().Create(name, version); + idx_.Deserialize(bs); + auto results = idx_.Search(*query_ds, json, nullptr); + REQUIRE(results.has_value()); + } +} + +TEST_CASE("Test Mem Index With fp16/bf16 Vector", "[float metrics]") { + BaseSearchTest(); + BaseSearchTest(); +} diff --git a/tests/ut/test_iterator.cc b/tests/ut/test_iterator.cc index 1f440f886..204099da3 100644 --- a/tests/ut/test_iterator.cc +++ b/tests/ut/test_iterator.cc @@ -83,14 +83,14 @@ TEST_CASE("Test Iterator Mem Index With Float Vector", "[float metrics]") { {knowhere::meta::METRIC_TYPE, metric}, {knowhere::meta::TOPK, topk}, }; - auto gt = knowhere::BruteForce::Search(train_ds, query_ds, conf, nullptr); + auto gt = knowhere::BruteForce::Search(train_ds, query_ds, conf, nullptr); SECTION("Test Search using iterator") { using std::make_tuple; auto [name, gen] = GENERATE_REF(table>({ make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen), })); - auto idx = knowhere::IndexFactory::Instance().Create(name, version); + auto idx = knowhere::IndexFactory::Instance().Create(name, version); auto cfg_json = gen().dump(); CAPTURE(name, cfg_json); knowhere::Json json = knowhere::Json::parse(cfg_json); @@ -114,7 +114,7 @@ TEST_CASE("Test Iterator Mem Index With Float Vector", "[float metrics]") { auto [name, gen, threshold] = GENERATE_REF(table, float>({ make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen, hnswlib::kHnswSearchKnnBFThreshold), })); - auto idx = knowhere::IndexFactory::Instance().Create(name, version); + auto idx = knowhere::IndexFactory::Instance().Create(name, version); auto cfg_json = gen().dump(); CAPTURE(name, cfg_json); knowhere::Json json = knowhere::Json::parse(cfg_json); @@ -132,7 +132,7 @@ TEST_CASE("Test Iterator Mem Index With Float Vector", "[float metrics]") { auto its = idx.AnnIterator(*query_ds, json, bitset); REQUIRE(its.has_value()); auto results = GetKNNResult(its.value(), topk, &bitset); - auto gt = knowhere::BruteForce::Search(train_ds, query_ds, json, bitset); + auto gt = knowhere::BruteForce::Search(train_ds, query_ds, json, bitset); float recall = GetKNNRecall(*gt.value(), *results); REQUIRE(recall > kKnnRecallThreshold); } @@ -144,7 +144,7 @@ TEST_CASE("Test Iterator Mem Index With Float Vector", "[float metrics]") { auto [name, gen, threshold] = GENERATE_REF(table, float>({ make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen, hnswlib::kHnswSearchKnnBFThreshold), })); - auto idx = knowhere::IndexFactory::Instance().Create(name, version); + auto idx = knowhere::IndexFactory::Instance().Create(name, version); auto cfg_json = gen().dump(); CAPTURE(name, cfg_json); knowhere::Json json = knowhere::Json::parse(cfg_json); @@ -165,7 +165,7 @@ TEST_CASE("Test Iterator Mem Index With Float Vector", "[float metrics]") { for (const auto& it : its.value()) { REQUIRE(!it->HasNext()); } - auto gt = knowhere::BruteForce::Search(train_ds, query_ds, json, bitset); + auto gt = knowhere::BruteForce::Search(train_ds, query_ds, json, bitset); float recall = GetKNNRecall(*gt.value(), *results); REQUIRE(recall > kKnnRecallThreshold); } @@ -205,13 +205,13 @@ TEST_CASE("Test Iterator Mem Index With Binary Vector", "[float metrics]") { {knowhere::meta::TOPK, topk}, }; - auto gt = knowhere::BruteForce::Search(train_ds, query_ds, conf, nullptr); + auto gt = knowhere::BruteForce::Search(train_ds, query_ds, conf, nullptr); SECTION("Test Search using iterator") { using std::make_tuple; auto [name, gen] = GENERATE_REF(table>({ make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen), })); - auto idx = knowhere::IndexFactory::Instance().Create(name, version); + auto idx = knowhere::IndexFactory::Instance().Create(name, version); auto cfg_json = gen().dump(); CAPTURE(name, cfg_json); knowhere::Json json = knowhere::Json::parse(cfg_json); diff --git a/tests/ut/test_ivfflat_cc.cc b/tests/ut/test_ivfflat_cc.cc index 442dd1383..c21c57a13 100644 --- a/tests/ut/test_ivfflat_cc.cc +++ b/tests/ut/test_ivfflat_cc.cc @@ -141,7 +141,7 @@ TEST_CASE("Test Build Search Concurrency", "[Concurrency]") { auto [name, gen] = GENERATE_REF(table>({ make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT_CC, ivfflatcc_gen), })); - auto idx = knowhere::IndexFactory::Instance().Create(name, version); + auto idx = knowhere::IndexFactory::Instance().Create(name, version); auto cfg_json = gen().dump(); CAPTURE(name, cfg_json); knowhere::Json json = knowhere::Json::parse(cfg_json); @@ -183,9 +183,10 @@ TEST_CASE("Test Build Search Concurrency", "[Concurrency]") { SECTION("Test Build & Search Correctness") { using std::make_tuple; - auto ivf_flat = knowhere::IndexFactory::Instance().Create(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT, version); + auto ivf_flat = + knowhere::IndexFactory::Instance().Create(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT, version); auto ivf_flat_cc = - knowhere::IndexFactory::Instance().Create(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT_CC, version); + knowhere::IndexFactory::Instance().Create(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT_CC, version); knowhere::Json ivf_flat_json = knowhere::Json::parse(ivfflat_gen().dump()); knowhere::Json ivf_flat_cc_json = knowhere::Json::parse(ivfflatcc_gen().dump()); @@ -242,7 +243,7 @@ TEST_CASE("Test Build Search Concurrency", "[Concurrency]") { auto [name, gen] = GENERATE_REF(table>({ make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT_CC, ivfflatcc_gen), })); - auto idx = knowhere::IndexFactory::Instance().Create(name, version); + auto idx = knowhere::IndexFactory::Instance().Create(name, version); auto cfg_json = gen().dump(); CAPTURE(name, cfg_json); knowhere::Json json = knowhere::Json::parse(cfg_json); diff --git a/tests/ut/test_mmap.cc b/tests/ut/test_mmap.cc index 5b635bc3b..04662189b 100644 --- a/tests/ut/test_mmap.cc +++ b/tests/ut/test_mmap.cc @@ -116,7 +116,7 @@ TEST_CASE("Search mmap", "[float metrics]") { {knowhere::meta::METRIC_TYPE, metric}, {knowhere::meta::TOPK, topk}, }; - auto gt = knowhere::BruteForce::Search(train_ds, query_ds, conf, nullptr); + auto gt = knowhere::BruteForce::Search(train_ds, query_ds, conf, nullptr); SECTION("Test Search") { using std::make_tuple; @@ -127,7 +127,7 @@ TEST_CASE("Search mmap", "[float metrics]") { make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFPQ, ivfpq_gen), make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen), })); - auto idx = knowhere::IndexFactory::Instance().Create(name, version); + auto idx = knowhere::IndexFactory::Instance().Create(name, version); auto cfg_json = gen().dump(); CAPTURE(name, cfg_json); knowhere::Json json = knowhere::Json::parse(cfg_json); @@ -154,7 +154,7 @@ TEST_CASE("Search mmap", "[float metrics]") { make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFPQ, ivfpq_gen), make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen), })); - auto idx = knowhere::IndexFactory::Instance().Create(name, version); + auto idx = knowhere::IndexFactory::Instance().Create(name, version); auto cfg_json = gen().dump(); CAPTURE(name, cfg_json); knowhere::Json json = knowhere::Json::parse(cfg_json); @@ -178,7 +178,7 @@ TEST_CASE("Search mmap", "[float metrics]") { auto [name, gen, threshold] = GENERATE_REF(table, float>({ make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen, hnswlib::kHnswSearchKnnBFThreshold), })); - auto idx = knowhere::IndexFactory::Instance().Create(name, version); + auto idx = knowhere::IndexFactory::Instance().Create(name, version); auto cfg_json = gen().dump(); CAPTURE(name, cfg_json); knowhere::Json json = knowhere::Json::parse(cfg_json); @@ -194,7 +194,7 @@ TEST_CASE("Search mmap", "[float metrics]") { auto bitset_data = gen_func(nb, percentage * nb); knowhere::BitsetView bitset(bitset_data.data(), nb); auto results = idx.Search(*query_ds, json, bitset); - auto gt = knowhere::BruteForce::Search(train_ds, query_ds, json, bitset); + auto gt = knowhere::BruteForce::Search(train_ds, query_ds, json, bitset); float recall = GetKNNRecall(*gt.value(), *results.value()); if (percentage > threshold) { REQUIRE(recall > kBruteForceRecallThreshold); @@ -262,7 +262,7 @@ TEST_CASE("Search binary mmap", "[float metrics]") { REQUIRE(index.DeserializeFromFile(path, conf) == knowhere::Status::success); }; - auto gt = knowhere::BruteForce::Search(train_ds, query_ds, conf, nullptr); + auto gt = knowhere::BruteForce::Search(train_ds, query_ds, conf, nullptr); SECTION("Test Search") { using std::make_tuple; auto [name, gen] = GENERATE_REF(table>({ @@ -270,7 +270,7 @@ TEST_CASE("Search binary mmap", "[float metrics]") { make_tuple(knowhere::IndexEnum::INDEX_FAISS_BIN_IVFFLAT, ivfflat_gen), make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen), })); - auto idx = knowhere::IndexFactory::Instance().Create(name, version); + auto idx = knowhere::IndexFactory::Instance().Create(name, version); auto cfg_json = gen().dump(); CAPTURE(name, cfg_json); knowhere::Json json = knowhere::Json::parse(cfg_json); @@ -292,7 +292,7 @@ TEST_CASE("Search binary mmap", "[float metrics]") { make_tuple(knowhere::IndexEnum::INDEX_FAISS_BIN_IDMAP, flat_gen), make_tuple(knowhere::IndexEnum::INDEX_FAISS_BIN_IVFFLAT, ivfflat_gen), })); - auto idx = knowhere::IndexFactory::Instance().Create(name, version); + auto idx = knowhere::IndexFactory::Instance().Create(name, version); auto cfg_json = gen().dump(); CAPTURE(name, cfg_json); knowhere::Json json = knowhere::Json::parse(cfg_json); @@ -374,7 +374,7 @@ TEST_CASE("Search binary mmap", "[bool metrics]") { std::make_tuple(knowhere::IndexEnum::INDEX_FAISS_BIN_IDMAP, flat_gen), std::make_tuple(knowhere::IndexEnum::INDEX_FAISS_BIN_IVFFLAT, ivfflat_gen), })); - auto idx = knowhere::IndexFactory::Instance().Create(name, version); + auto idx = knowhere::IndexFactory::Instance().Create(name, version); auto cfg_json = gen().dump(); CAPTURE(name, cfg_json); knowhere::Json json = knowhere::Json::parse(cfg_json); @@ -417,7 +417,7 @@ TEST_CASE("Search binary mmap", "[bool metrics]") { make_tuple(knowhere::IndexEnum::INDEX_FAISS_BIN_IDMAP, flat_gen), make_tuple(knowhere::IndexEnum::INDEX_FAISS_BIN_IVFFLAT, ivfflat_gen), })); - auto idx = knowhere::IndexFactory::Instance().Create(name); + auto idx = knowhere::IndexFactory::Instance().Create(name); auto cfg_json = gen().dump(); CAPTURE(name, cfg_json); knowhere::Json json = knowhere::Json::parse(cfg_json); diff --git a/tests/ut/test_search.cc b/tests/ut/test_search.cc index f5e7e2ab1..2c3f96535 100644 --- a/tests/ut/test_search.cc +++ b/tests/ut/test_search.cc @@ -100,7 +100,7 @@ TEST_CASE("Test Mem Index With Float Vector", "[float metrics]") { {knowhere::meta::METRIC_TYPE, metric}, {knowhere::meta::TOPK, topk}, }; - auto gt = knowhere::BruteForce::Search(train_ds, query_ds, conf, nullptr); + auto gt = knowhere::BruteForce::Search(train_ds, query_ds, conf, nullptr); SECTION("Test Search") { using std::make_tuple; @@ -114,7 +114,7 @@ TEST_CASE("Test Mem Index With Float Vector", "[float metrics]") { make_tuple(knowhere::IndexEnum::INDEX_FAISS_SCANN, scann_gen2), make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen), })); - auto idx = knowhere::IndexFactory::Instance().Create(name, version); + auto idx = knowhere::IndexFactory::Instance().Create(name, version); auto cfg_json = gen().dump(); CAPTURE(name, cfg_json); knowhere::Json json = knowhere::Json::parse(cfg_json); @@ -156,7 +156,7 @@ TEST_CASE("Test Mem Index With Float Vector", "[float metrics]") { make_tuple(knowhere::IndexEnum::INDEX_FAISS_SCANN, scann_gen2), make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen), })); - auto idx = knowhere::IndexFactory::Instance().Create(name, version); + auto idx = knowhere::IndexFactory::Instance().Create(name, version); auto cfg_json = gen().dump(); CAPTURE(name, cfg_json); knowhere::Json json = knowhere::Json::parse(cfg_json); @@ -192,7 +192,7 @@ TEST_CASE("Test Mem Index With Float Vector", "[float metrics]") { auto [name, gen, threshold] = GENERATE_REF(table, float>({ make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen, hnswlib::kHnswSearchKnnBFThreshold), })); - auto idx = knowhere::IndexFactory::Instance().Create(name, version); + auto idx = knowhere::IndexFactory::Instance().Create(name, version); auto cfg_json = gen().dump(); CAPTURE(name, cfg_json); knowhere::Json json = knowhere::Json::parse(cfg_json); @@ -207,7 +207,7 @@ TEST_CASE("Test Mem Index With Float Vector", "[float metrics]") { auto bitset_data = gen_func(nb, percentage * nb); knowhere::BitsetView bitset(bitset_data.data(), nb); auto results = idx.Search(*query_ds, json, bitset); - auto gt = knowhere::BruteForce::Search(train_ds, query_ds, json, bitset); + auto gt = knowhere::BruteForce::Search(train_ds, query_ds, json, bitset); float recall = GetKNNRecall(*gt.value(), *results.value()); if (percentage > threshold) { REQUIRE(recall > kBruteForceRecallThreshold); @@ -231,7 +231,7 @@ TEST_CASE("Test Mem Index With Float Vector", "[float metrics]") { make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen), })); - auto idx = knowhere::IndexFactory::Instance().Create(name, version); + auto idx = knowhere::IndexFactory::Instance().Create(name, version); auto cfg_json = gen().dump(); CAPTURE(name, cfg_json); knowhere::Json json = knowhere::Json::parse(cfg_json); @@ -240,14 +240,14 @@ TEST_CASE("Test Mem Index With Float Vector", "[float metrics]") { knowhere::BinarySet bs; idx.Serialize(bs); - auto idx_ = knowhere::IndexFactory::Instance().Create(name, version); + auto idx_ = knowhere::IndexFactory::Instance().Create(name, version); idx_.Deserialize(bs); auto results = idx_.Search(*query_ds, json, nullptr); REQUIRE(results.has_value()); } SECTION("Test IVFPQ with invalid params") { - auto idx = knowhere::IndexFactory::Instance().Create(knowhere::IndexEnum::INDEX_FAISS_IVFPQ, version); + auto idx = knowhere::IndexFactory::Instance().Create(knowhere::IndexEnum::INDEX_FAISS_IVFPQ, version); uint32_t nb = 1000; uint32_t dim = 128; auto ivf_pq_gen = [&]() { @@ -307,7 +307,7 @@ TEST_CASE("Test Mem Index With Binary Vector", "[float metrics]") { {knowhere::meta::TOPK, topk}, }; - auto gt = knowhere::BruteForce::Search(train_ds, query_ds, conf, nullptr); + auto gt = knowhere::BruteForce::Search(train_ds, query_ds, conf, nullptr); SECTION("Test Search") { using std::make_tuple; auto [name, gen] = GENERATE_REF(table>({ @@ -315,7 +315,7 @@ TEST_CASE("Test Mem Index With Binary Vector", "[float metrics]") { make_tuple(knowhere::IndexEnum::INDEX_FAISS_BIN_IVFFLAT, ivfflat_gen), make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen), })); - auto idx = knowhere::IndexFactory::Instance().Create(name, version); + auto idx = knowhere::IndexFactory::Instance().Create(name, version); auto cfg_json = gen().dump(); CAPTURE(name, cfg_json); knowhere::Json json = knowhere::Json::parse(cfg_json); @@ -335,7 +335,7 @@ TEST_CASE("Test Mem Index With Binary Vector", "[float metrics]") { make_tuple(knowhere::IndexEnum::INDEX_FAISS_BIN_IDMAP, flat_gen), make_tuple(knowhere::IndexEnum::INDEX_FAISS_BIN_IVFFLAT, ivfflat_gen), })); - auto idx = knowhere::IndexFactory::Instance().Create(name, version); + auto idx = knowhere::IndexFactory::Instance().Create(name, version); auto cfg_json = gen().dump(); CAPTURE(name, cfg_json); knowhere::Json json = knowhere::Json::parse(cfg_json); @@ -357,7 +357,7 @@ TEST_CASE("Test Mem Index With Binary Vector", "[float metrics]") { make_tuple(knowhere::IndexEnum::INDEX_FAISS_BIN_IVFFLAT, ivfflat_gen), })); - auto idx = knowhere::IndexFactory::Instance().Create(name, version); + auto idx = knowhere::IndexFactory::Instance().Create(name, version); auto cfg_json = gen().dump(); CAPTURE(name, cfg_json); knowhere::Json json = knowhere::Json::parse(cfg_json); @@ -366,7 +366,7 @@ TEST_CASE("Test Mem Index With Binary Vector", "[float metrics]") { knowhere::BinarySet bs; idx.Serialize(bs); - auto idx_ = knowhere::IndexFactory::Instance().Create(name, version); + auto idx_ = knowhere::IndexFactory::Instance().Create(name, version); idx_.Deserialize(bs); auto results = idx_.Search(*query_ds, json, nullptr); REQUIRE(results.has_value()); @@ -478,7 +478,7 @@ TEST_CASE("Test Mem Index With Binary Vector", "[bool metrics]") { std::make_tuple(knowhere::IndexEnum::INDEX_FAISS_BIN_IDMAP, flat_gen), std::make_tuple(knowhere::IndexEnum::INDEX_FAISS_BIN_IVFFLAT, ivfflat_gen), })); - auto idx = knowhere::IndexFactory::Instance().Create(name, version); + auto idx = knowhere::IndexFactory::Instance().Create(name, version); auto cfg_json = gen().dump(); CAPTURE(name, cfg_json); knowhere::Json json = knowhere::Json::parse(cfg_json); @@ -525,7 +525,7 @@ TEST_CASE("Test Mem Index With Binary Vector", "[bool metrics]") { make_tuple(knowhere::IndexEnum::INDEX_FAISS_BIN_IDMAP, flat_gen), make_tuple(knowhere::IndexEnum::INDEX_FAISS_BIN_IVFFLAT, ivfflat_gen), })); - auto idx = knowhere::IndexFactory::Instance().Create(name); + auto idx = knowhere::IndexFactory::Instance().Create(name); auto cfg_json = gen().dump(); CAPTURE(name, cfg_json); knowhere::Json json = knowhere::Json::parse(cfg_json); diff --git a/tests/ut/test_simd.cc b/tests/ut/test_simd.cc index fba56c32c..cbce8f37e 100644 --- a/tests/ut/test_simd.cc +++ b/tests/ut/test_simd.cc @@ -38,7 +38,7 @@ TEST_CASE("Test BruteForce Search SIMD", "[bf]") { auto test_search_with_simd = [&](knowhere::KnowhereConfig::SimdType simd_type) { knowhere::KnowhereConfig::SetSimdType(simd_type); - auto gt = knowhere::BruteForce::Search(train_ds, query_ds, conf, nullptr); + auto gt = knowhere::BruteForce::Search(train_ds, query_ds, conf, nullptr); REQUIRE(gt.has_value()); auto gt_ids = gt.value()->GetIds(); auto gt_dist = gt.value()->GetDistance(); @@ -83,7 +83,7 @@ TEST_CASE("Test PQ Search SIMD", "[pq]") { conf[knowhere::indexparam::M] = m; knowhere::KnowhereConfig::SetSimdType(simd_type); - auto gt = knowhere::BruteForce::Search(train_ds, query_ds, conf, nullptr); + auto gt = knowhere::BruteForce::Search(train_ds, query_ds, conf, nullptr); REQUIRE(gt.has_value()); auto gt_ids = gt.value()->GetIds(); auto gt_dist = gt.value()->GetDistance(); @@ -97,7 +97,7 @@ TEST_CASE("Test PQ Search SIMD", "[pq]") { } } - auto idx = knowhere::IndexFactory::Instance().Create(knowhere::IndexEnum::INDEX_FAISS_IVFPQ, version); + auto idx = knowhere::IndexFactory::Instance().Create(knowhere::IndexEnum::INDEX_FAISS_IVFPQ, version); REQUIRE(idx.Build(*train_ds, conf) == knowhere::Status::success); auto res = idx.Search(*query_ds, conf, nullptr); REQUIRE(res.has_value()); diff --git a/thirdparty/DiskANN/include/diskann/utils.h b/thirdparty/DiskANN/include/diskann/utils.h index f61dd19d1..cd9fbb4c3 100644 --- a/thirdparty/DiskANN/include/diskann/utils.h +++ b/thirdparty/DiskANN/include/diskann/utils.h @@ -36,6 +36,7 @@ typedef int FileHandle; #include "ann_exception.h" #include "common_includes.h" #include "knowhere/comp/thread_pool.h" +#include "knowhere/operands.h" // taken from // https://github.com/Microsoft/BLAS-on-flash/blob/master/include/utils.h @@ -564,7 +565,6 @@ namespace diskann { template void convert_types(const InType* srcmat, OutType* destmat, size_t npts, size_t dim) { -#pragma omp parallel for schedule(static, 65536) for (int64_t i = 0; i < (_s64) npts; i++) { for (uint64_t j = 0; j < dim; j++) { destmat[i * dim + j] = (OutType) srcmat[i * dim + j]; @@ -835,6 +835,10 @@ namespace diskann { // NOTE: Implementation in utils.cpp. void block_convert(std::ofstream& writr, std::ifstream& readr, float* read_buf, _u64 npts, _u64 ndims); + + template + DISKANN_DLLEXPORT void convert_types_in_file(const std::string& inFileName, + const std::string& outFileName); void normalize_data_file(const std::string& inFileName, const std::string& outFileName); diff --git a/thirdparty/DiskANN/src/utils.cpp b/thirdparty/DiskANN/src/utils.cpp index 989b8ca83..ddfbf8dfc 100644 --- a/thirdparty/DiskANN/src/utils.cpp +++ b/thirdparty/DiskANN/src/utils.cpp @@ -58,4 +58,39 @@ namespace diskann { LOG_KNOWHERE_DEBUG_ << "Wrote normalized points to file: " << outFileName; } + + template + DISKANN_DLLEXPORT void convert_types_in_file(const std::string& inFileName, + const std::string& outFileName) { + std::ifstream readr(inFileName, std::ios::binary); + std::ofstream writr(outFileName, std::ios::binary); + + int npts_s32, ndims_s32; + readr.read((char*) &npts_s32, sizeof(_s32)); + readr.read((char*) &ndims_s32, sizeof(_s32)); + + writr.write((char*) &npts_s32, sizeof(_s32)); + writr.write((char*) &ndims_s32, sizeof(_s32)); + + _u64 npts = (_u64) npts_s32, ndims = (_u64) ndims_s32; + + _u64 blk_size = 131072; + _u64 nblks = ROUND_UP(npts, blk_size) / blk_size; + LOG_KNOWHERE_DEBUG_ << "# blks: " << nblks; + + IN* read_buf = new IN[npts * ndims]; + OUT* write_buf = new OUT[npts * ndims]; + for (_u64 i = 0; i < nblks; i++) { + _u64 cblk_size = std::min(npts - i * blk_size, blk_size); + readr.read((char*) read_buf, npts * ndims * sizeof(IN)); + convert_types(read_buf, write_buf, cblk_size, ndims); + writr.write((char*) write_buf, npts * ndims * sizeof(OUT)); + } + + delete[] read_buf; + delete[] write_buf; + } + + template void convert_types_in_file(const std::string& , const std::string&); + template void convert_types_in_file(const std::string& , const std::string&); } // namespace diskann