diff --git a/.github/workflows/release-python.yml b/.github/workflows/release-python.yml index cdd4c1f5b..72829e348 100644 --- a/.github/workflows/release-python.yml +++ b/.github/workflows/release-python.yml @@ -44,7 +44,7 @@ jobs: CIBW_ARCHS: ${{ matrix.arch }} CIBW_BEFORE_ALL_LINUX: "bash scripts/python_deps.sh && rm -rf build && mkdir build && cd build && conan install .. --build=missing -o with_diskann=True -s compiler.libcxx=libstdc++11 -s build_type=Release && conan build .. && cd -" # CIBW_BEFORE_ALL_MACOS: "bash scripts/python_deps.sh && rm -rf build && mkdir build && cd build && CC=$(brew --prefix llvm)/bin/clang CXX=$(brew --prefix llvm)/bin/clang++ conan install .. --build=missing -s build_type=Release && CC=$(brew --prefix llvm)/bin/clang CXX=$(brew --prefix llvm)/bin/clang++ conan build .. && cd -" - CIBW_BEFORE_BUILD: "pip3 install pytest numpy faiss-cpu" + CIBW_BEFORE_BUILD: "pip3 install pytest numpy faiss-cpu bfloat16" # CIBW_ENVIRONMENT_MACOS: > # _PYTHON_HOST_PLATFORM=macosx-10.15-${{ matrix.arch }} # CIBW_BEFORE_BUILD: "bash scripts/python_deps.sh && pip3 install pytest numpy faiss-cpu" diff --git a/.github/workflows/ut.yaml b/.github/workflows/ut.yaml index f0d9fa170..4444c5144 100644 --- a/.github/workflows/ut.yaml +++ b/.github/workflows/ut.yaml @@ -102,6 +102,7 @@ jobs: sudo apt update \ && sudo apt install -y cmake g++ gcc libopenblas-dev libaio-dev libcurl4-openssl-dev libevent-dev libgflags-dev python3 python3-pip python3-setuptools \ && pip3 install conan==1.58.0 pytest faiss-cpu numpy wheel \ + && pip3 install bfloat16 \ && conan remote add default-conan-local https://milvus01.jfrog.io/artifactory/api/conan/default-conan-local - name: Build run: | diff --git a/README.md b/README.md index b15d3f151..aabfdad3d 100644 --- a/README.md +++ b/README.md @@ -86,6 +86,7 @@ install dependency: ``` sudo apt install swig python3-dev +pip3 install bfloat16 ``` after build knowhere: diff --git a/benchmark/hdf5/benchmark_float.cpp b/benchmark/hdf5/benchmark_float.cpp index 5f42d7e24..a735a42f8 100644 --- a/benchmark/hdf5/benchmark_float.cpp +++ b/benchmark/hdf5/benchmark_float.cpp @@ -273,7 +273,7 @@ TEST_F(Benchmark_float, TEST_DISKANN) { std::shared_ptr file_manager = std::make_shared(); auto diskann_index_pack = knowhere::Pack(file_manager); - index_ = knowhere::IndexFactory::Instance().Create( + index_ = knowhere::IndexFactory::Instance().Create( index_type_, knowhere::Version::GetCurrentVersion().VersionNumber(), diskann_index_pack); printf("[%.3f s] Building all on %d vectors\n", get_time_diff(), nb_); knowhere::DataSetPtr ds_ptr = nullptr; diff --git a/benchmark/hdf5/benchmark_float_bitset.cpp b/benchmark/hdf5/benchmark_float_bitset.cpp index beb34eff3..70e083043 100644 --- a/benchmark/hdf5/benchmark_float_bitset.cpp +++ b/benchmark/hdf5/benchmark_float_bitset.cpp @@ -234,7 +234,7 @@ TEST_F(Benchmark_float_bitset, TEST_DISKANN) { auto diskann_index_pack = knowhere::Pack(file_manager); auto version = knowhere::Version::GetCurrentVersion().VersionNumber(); - index_ = knowhere::IndexFactory::Instance().Create(index_type_, version, diskann_index_pack); + index_ = knowhere::IndexFactory::Instance().Create(index_type_, version, diskann_index_pack); printf("[%.3f s] Building all on %d vectors\n", get_time_diff(), nb_); knowhere::DataSetPtr ds_ptr = nullptr; index_.Build(*ds_ptr, conf); diff --git a/benchmark/hdf5/benchmark_float_range_bitset.cpp b/benchmark/hdf5/benchmark_float_range_bitset.cpp index f5c406ce7..763f0dd31 100644 --- a/benchmark/hdf5/benchmark_float_range_bitset.cpp +++ b/benchmark/hdf5/benchmark_float_range_bitset.cpp @@ -235,7 +235,7 @@ TEST_F(Benchmark_float_range_bitset, TEST_DISKANN) { auto diskann_index_pack = knowhere::Pack(file_manager); auto version = knowhere::Version::GetCurrentVersion().VersionNumber(); - index_ = knowhere::IndexFactory::Instance().Create(index_type_, version, diskann_index_pack); + index_ = knowhere::IndexFactory::Instance().Create(index_type_, version, diskann_index_pack); printf("[%.3f s] Building all on %d vectors\n", get_time_diff(), nb_); knowhere::DataSetPtr ds_ptr = nullptr; index_.Build(*ds_ptr, conf); diff --git a/benchmark/hdf5/benchmark_knowhere.h b/benchmark/hdf5/benchmark_knowhere.h index ea1faa39a..5eca6c0f4 100644 --- a/benchmark/hdf5/benchmark_knowhere.h +++ b/benchmark/hdf5/benchmark_knowhere.h @@ -97,7 +97,7 @@ class Benchmark_knowhere : public Benchmark_hdf5 { create_index(const std::string& index_file_name, const knowhere::Json& conf) { auto version = knowhere::Version::GetCurrentVersion().VersionNumber(); printf("[%.3f s] Creating index \"%s\"\n", get_time_diff(), index_type_.c_str()); - index_ = knowhere::IndexFactory::Instance().Create(index_type_, version); + index_ = knowhere::IndexFactory::Instance().Create(index_type_, version); try { printf("[%.3f s] Reading index file: %s\n", get_time_diff(), index_file_name.c_str()); @@ -120,7 +120,7 @@ class Benchmark_knowhere : public Benchmark_hdf5 { std::string golden_index_file_name = ann_test_name_ + "_" + golden_index_type_ + "_GOLDEN" + ".index"; printf("[%.3f s] Creating golden index \"%s\"\n", get_time_diff(), golden_index_type_.c_str()); - golden_index_ = knowhere::IndexFactory::Instance().Create(golden_index_type_, version); + golden_index_ = knowhere::IndexFactory::Instance().Create(golden_index_type_, version); try { printf("[%.3f s] Reading golden index file: %s\n", get_time_diff(), golden_index_file_name.c_str()); diff --git a/include/knowhere/comp/brute_force.h b/include/knowhere/comp/brute_force.h index 240aa0f4b..e66217d61 100644 --- a/include/knowhere/comp/brute_force.h +++ b/include/knowhere/comp/brute_force.h @@ -14,18 +14,22 @@ #include "knowhere/bitsetview.h" #include "knowhere/dataset.h" #include "knowhere/factory.h" +#include "knowhere/operands.h" namespace knowhere { class BruteForce { public: + template static expected Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset, const Json& config, const BitsetView& bitset); + template static Status SearchWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_dataset, int64_t* ids, float* dis, const Json& config, const BitsetView& bitset); + template static expected RangeSearch(const DataSetPtr base_dataset, const DataSetPtr query_dataset, const Json& config, const BitsetView& bitset); diff --git a/include/knowhere/dataset.h b/include/knowhere/dataset.h index c77decaf3..64324ade9 100644 --- a/include/knowhere/dataset.h +++ b/include/knowhere/dataset.h @@ -24,7 +24,7 @@ namespace knowhere { -class DataSet { +class DataSet : public std::enable_shared_from_this { public: typedef std::variant Var; DataSet() = default; @@ -227,7 +227,6 @@ class DataSet { bool is_owner = true; }; using DataSetPtr = std::shared_ptr; - inline DataSetPtr GenDataSet(const int64_t nb, const int64_t dim, const void* xb) { auto ret_ds = std::make_shared(); diff --git a/include/knowhere/factory.h b/include/knowhere/factory.h index 3138f7494..3cc5e63f4 100644 --- a/include/knowhere/factory.h +++ b/include/knowhere/factory.h @@ -21,23 +21,54 @@ namespace knowhere { class IndexFactory { public: + template Index Create(const std::string& name, const int32_t& version, const Object& object = nullptr); + template const IndexFactory& Register(const std::string& name, std::function(const int32_t& version, const Object&)> func); static IndexFactory& Instance(); private: - typedef std::map(const int32_t&, const Object&)>> FuncMap; + struct FunMapValueBase { + virtual ~FunMapValueBase() = default; + }; + template + struct FunMapValue : FunMapValueBase { + public: + FunMapValue(std::function& input) : fun_value(input) { + } + std::function fun_value; + }; + typedef std::map> FuncMap; IndexFactory(); static FuncMap& MapInstance(); + template + std::string + GetMapKey(const std::string& name); }; -#define KNOWHERE_CONCAT(x, y) x##y -#define KNOWHERE_REGISTER_GLOBAL(name, func) \ - const IndexFactory& KNOWHERE_CONCAT(index_factory_ref_, name) = IndexFactory::Instance().Register(#name, func) +#define KNOWHERE_CONCAT(x, y) index_factory_ref_##x##y +#define KNOWHERE_CONCAT_STR(x, y) #x "_" #y +#define KNOWHERE_REGISTER_GLOBAL(name, func, data_type) \ + const IndexFactory& KNOWHERE_CONCAT(name, data_type) = IndexFactory::Instance().Register(#name, func) +#define KNOWHERE_SIMPLE_REGISTER_GLOBAL(name, index_node, data_type, ...) \ + KNOWHERE_REGISTER_GLOBAL( \ + name, \ + [](const int32_t& version, const Object& object) { \ + return (Index>::Create(version, object)); \ + }, \ + data_type) +#define KNOWHERE_MOCK_REGISTER_GLOBAL(name, index_node, data_type, ...) \ + KNOWHERE_REGISTER_GLOBAL( \ + name, \ + [](const int32_t& version, const Object& object) { \ + return (Index>::Create( \ + std::make_unique::type, ##__VA_ARGS__>>(version, object))); \ + }, \ + data_type) } // namespace knowhere #endif /* INDEX_FACTORY_H */ diff --git a/include/knowhere/index.h b/include/knowhere/index.h index 164eea88c..c9062c1fe 100644 --- a/include/knowhere/index.h +++ b/include/knowhere/index.h @@ -19,7 +19,6 @@ #include "knowhere/index_node.h" namespace knowhere { - template class Index { public: diff --git a/include/knowhere/index_node.h b/include/knowhere/index_node.h index 4613d37ca..6af39b2b5 100644 --- a/include/knowhere/index_node.h +++ b/include/knowhere/index_node.h @@ -18,10 +18,10 @@ #include "knowhere/dataset.h" #include "knowhere/expected.h" #include "knowhere/object.h" +#include "knowhere/operands.h" #include "knowhere/version.h" namespace knowhere { - class IndexNode : public Object { public: IndexNode(const int32_t ver) : version_(ver) { diff --git a/include/knowhere/index_node_data_mock_wrapper.h b/include/knowhere/index_node_data_mock_wrapper.h new file mode 100644 index 000000000..9c008f727 --- /dev/null +++ b/include/knowhere/index_node_data_mock_wrapper.h @@ -0,0 +1,100 @@ +// Copyright (C) 2019-2023 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#ifndef INDEX_NODE_DATA_MOCK_WRAPPER_H +#define INDEX_NODE_DATA_MOCK_WRAPPER_H + +#include "knowhere/index_node.h" +namespace knowhere { + +template +class IndexNodeDataMockWrapper : public IndexNode { + public: + IndexNodeDataMockWrapper(std::unique_ptr index_node) : index_node_(std::move(index_node)) { + } + + Status + Build(const DataSet& dataset, const Config& cfg) override; + + Status + Train(const DataSet& dataset, const Config& cfg) override; + + Status + Add(const DataSet& dataset, const Config& cfg) override; + expected + Search(const DataSet& dataset, const Config& cfg, const BitsetView& bitset) const override; + + expected + RangeSearch(const DataSet& dataset, const Config& cfg, const BitsetView& bitset) const override; + + expected>> + AnnIterator(const DataSet& dataset, const Config& cfg, const BitsetView& bitset) const override; + + expected + GetVectorByIds(const DataSet& dataset) const override; + + bool + HasRawData(const std::string& metric_type) const override { + return index_node_->HasRawData(metric_type); + } + + expected + GetIndexMeta(const Config& cfg) const override { + return index_node_->GetIndexMeta(cfg); + } + + Status + Serialize(BinarySet& binset) const override { + return index_node_->Serialize(binset); + } + + Status + Deserialize(const BinarySet& binset, const Config& config) override { + return index_node_->Deserialize(binset, config); + } + + Status + DeserializeFromFile(const std::string& filename, const Config& config) override { + return index_node_->DeserializeFromFile(filename, config); + } + + std::unique_ptr + CreateConfig() const override { + return index_node_->CreateConfig(); + } + + int64_t + Dim() const override { + return index_node_->Dim(); + } + + int64_t + Size() const override { + return index_node_->Size(); + } + + int64_t + Count() const override { + return index_node_->Count(); + } + + std::string + Type() const override { + return index_node_->Type(); + } + + private: + std::unique_ptr index_node_; +}; + +} // namespace knowhere + +#endif /* INDEX_NODE_DATA_MOCK_WRAPPER_H */ diff --git a/include/knowhere/operands.h b/include/knowhere/operands.h new file mode 100644 index 000000000..becb7aa07 --- /dev/null +++ b/include/knowhere/operands.h @@ -0,0 +1,165 @@ +// Copyright (C) 2019-2023 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not +// use this file except in compliance with the License. You may obtain a copy of +// the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations under +// the License. + +#ifndef OPERANDS_H +#define OPERANDS_H +#include + +namespace { +union fp32_bits { + uint32_t as_bits; + float as_value; +}; + +inline float +fp32_from_bits(const uint32_t& w) { + return fp32_bits{.as_bits = w}.as_value; +} + +inline uint32_t +fp32_to_bits(const float& f) { + return fp32_bits{.as_value = f}.as_bits; +} +}; // namespace + +namespace knowhere { +using fp32 = float; +using bin1 = uint8_t; + +struct fp16 { + public: + fp16() = default; + fp16(const float& f) { + from_fp32(f); + }; + operator float() const { + return to_fp32(bits); + } + + private: + uint16_t bits = 0; + void + from_fp32(const float f) { + // const float scale_to_inf = 0x1.0p+112f; + // const float scale_to_zero = 0x1.0p-110f; + constexpr uint32_t scale_to_inf_bits = (uint32_t)239 << 23; + constexpr uint32_t scale_to_zero_bits = (uint32_t)17 << 23; + float scale_to_inf_val, scale_to_zero_val; + std::memcpy(&scale_to_inf_val, &scale_to_inf_bits, sizeof(scale_to_inf_val)); + std::memcpy(&scale_to_zero_val, &scale_to_zero_bits, sizeof(scale_to_zero_val)); + const float scale_to_inf = scale_to_inf_val; + const float scale_to_zero = scale_to_zero_val; + +#if defined(_MSC_VER) && _MSC_VER == 1916 + float base = ((f < 0.0 ? -f : f) * scale_to_inf) * scale_to_zero; +#else + float base = (fabsf(f) * scale_to_inf) * scale_to_zero; +#endif + + const uint32_t w = fp32_to_bits(f); + const uint32_t shl1_w = w + w; + const uint32_t sign = w & UINT32_C(0x80000000); + uint32_t bias = shl1_w & UINT32_C(0xFF000000); + if (bias < UINT32_C(0x71000000)) { + bias = UINT32_C(0x71000000); + } + + base = fp32_from_bits((bias >> 1) + UINT32_C(0x07800000)) + base; + const uint32_t bits = fp32_to_bits(base); + const uint32_t exp_bits = (bits >> 13) & UINT32_C(0x00007C00); + const uint32_t mantissa_bits = bits & UINT32_C(0x00000FFF); + const uint32_t nonsign = exp_bits + mantissa_bits; + this->bits = static_cast((sign >> 16) | (shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign)); + } + + float + to_fp32(const uint16_t h) const { + const uint32_t w = (uint32_t)h << 16; + const uint32_t sign = w & UINT32_C(0x80000000); + const uint32_t two_w = w + w; + + constexpr uint32_t exp_offset = UINT32_C(0xE0) << 23; + constexpr uint32_t scale_bits = (uint32_t)15 << 23; + + float exp_scale_val; + std::memcpy(&exp_scale_val, &scale_bits, sizeof(exp_scale_val)); + const float exp_scale = exp_scale_val; + const float normalized_value = fp32_from_bits((two_w >> 4) + exp_offset) * exp_scale; + + constexpr uint32_t magic_mask = UINT32_C(126) << 23; + constexpr float magic_bias = 0.5f; + const float denormalized_value = fp32_from_bits((two_w >> 17) | magic_mask) - magic_bias; + constexpr uint32_t denormalized_cutoff = UINT32_C(1) << 27; + const uint32_t result = + sign | (two_w < denormalized_cutoff ? fp32_to_bits(denormalized_value) : fp32_to_bits(normalized_value)); + return fp32_from_bits(result); + } +}; + +struct bf16 { + public: + bf16() = default; + bf16(const float& f) { + from_fp32(f); + }; + operator float() const { + return this->to_fp32(bits); + } + + private: + uint16_t bits = 0; + void + from_fp32(const float f) { + volatile uint32_t fp32Bits = fp32_to_bits(f); + volatile uint16_t bf16Bits = (uint16_t)(fp32Bits >> 16); + this->bits = bf16Bits; + } + float + to_fp32(const uint16_t h) const { + uint32_t bits = ((unsigned int)h) << 16; + bits &= 0xFFFF0000; + return fp32_from_bits(bits); + } +}; + +template +struct MockData { + using type = T; +}; + +template <> +struct MockData { + using type = knowhere::fp32; +}; + +template <> +struct MockData { + using type = knowhere::fp32; +}; + +// define result type, in case use double or uint8 data type in the future +template +struct ResultType { + using type = DataType; +}; + +template +struct ResultType< + DataType, std::enable_if_t || std::is_same_v || + std::is_same_v || std::is_same_v>> { + using type = knowhere::fp32; +}; + +} // namespace knowhere +#endif /* OPERANDS_H */ diff --git a/include/knowhere/utils.h b/include/knowhere/utils.h index a6ff6bed4..42e7a731e 100644 --- a/include/knowhere/utils.h +++ b/include/knowhere/utils.h @@ -68,6 +68,26 @@ hash_binary_vec(const uint8_t* x, size_t d) { return h; } +template +inline DataSetPtr +data_type_conversion(const DataSet& src) { + auto dim = src.GetDim(); + auto rows = src.GetRows(); + + auto des_data = new OutType[dim * rows]; + auto src_data = (InType*)src.GetTensor(); + for (auto i = 0; i < dim * rows; i++) { + des_data[i] = (OutType)src_data[i]; + } + + auto des = std::make_shared(); + des->SetRows(rows); + des->SetDim(dim); + des->SetTensor(des_data); + des->SetIsOwner(true); + return des; +} + template inline T round_down(const T value, const T align) { diff --git a/python/knowhere/__init__.py b/python/knowhere/__init__.py index 815db8e31..a8ba9d009 100644 --- a/python/knowhere/__init__.py +++ b/python/knowhere/__init__.py @@ -1,12 +1,20 @@ from . import swigknowhere from .swigknowhere import Status from .swigknowhere import GetBinarySet, GetNullDataSet, GetNullBitSetView -from .swigknowhere import BruteForceSearch, BruteForceRangeSearch +# from .swigknowhere import BruteForceSearch, BruteForceRangeSearch import numpy as np +from bfloat16 import bfloat16 -def CreateIndex(name, version): - return swigknowhere.IndexWrap(name, version) +def CreateIndex(name, version, type=np.float32): + if type == np.float32: + return swigknowhere.IndexWrapFloat(name, version) + if type == np.float16: + return swigknowhere.IndexWrapFP16(name, version) + if type == bfloat16: + return swigknowhere.IndexWrapBF16(name, version) + if type == np.uint8: + return swigknowhere.IndexWrapBin(name, version) def GetCurrentVersion(): @@ -33,9 +41,15 @@ def ArrayToDataSet(arr): return swigknowhere.Array2DataSetI(arr) if arr.dtype == np.float32: return swigknowhere.Array2DataSetF(arr) + if arr.dtype == np.float16: + arr = arr.astype(np.float32) + return swigknowhere.Array2DataSetFP16(arr) + if arr.dtype == bfloat16: + arr = arr.astype(np.float32) + return swigknowhere.Array2DataSetBF16(arr) raise ValueError( """ - ArrayToDataSet only support numpy array dtype float32 and int32. + ArrayToDataSet only support numpy array dtype float32,int32,float16 and bfloat16. """ ) diff --git a/python/knowhere/knowhere.i b/python/knowhere/knowhere.i index 70dd7f7ed..b57d0e0ff 100644 --- a/python/knowhere/knowhere.i +++ b/python/knowhere/knowhere.i @@ -30,6 +30,7 @@ typedef uint64_t size_t; #include #include #include +#include #include #include #include @@ -137,6 +138,7 @@ class AnnIteratorWrap { std::shared_ptr it_; }; +template class IndexWrap { public: IndexWrap(const std::string& name, const int32_t& version) { @@ -144,9 +146,9 @@ class IndexWrap { if (knowhere::UseDiskLoad(name, version)) { std::shared_ptr file_manager = std::make_shared(); auto diskann_pack = knowhere::Pack(file_manager); - idx = IndexFactory::Instance().Create(name, version, diskann_pack); + idx = IndexFactory::Instance().Create(name, version, diskann_pack); } else { - idx = IndexFactory::Instance().Create(name, version); + idx = IndexFactory::Instance().Create(name, version); } } @@ -301,6 +303,39 @@ Array2DataSetF(float* xb, int nb, int dim) { return ds; }; +knowhere::DataSetPtr +Array2DataSetFP16(float* xb, int nb, int dim) { + auto ds = std::make_shared(); + ds->SetIsOwner(true); + ds->SetRows(nb); + ds->SetDim(dim); + // float to fp16 + auto fp16_data = new knowhere::fp16[nb * dim]; + for (int i = 0; i < nb * dim; ++i) { + fp16_data[i] = knowhere::fp16(xb[i]); + } + ds->SetTensor(fp16_data); + return ds; +}; +#pragma GCC push_options +#pragma GCC optimize("O0") +knowhere::DataSetPtr +Array2DataSetBF16(float* xb, int nb, int dim) { + using bf16 = knowhere::bf16; + auto ds = std::make_shared(); + ds->SetIsOwner(true); + ds->SetRows(nb); + ds->SetDim(dim); + bf16* bf16_data = new bf16[nb * dim]; + for (int i = 0; i < nb * dim; ++i) { + bf16_data[i] = knowhere::bf16(xb[i]); + } + + ds->SetTensor(bf16_data); + return ds; +}; +#pragma GCC pop_options + int32_t CurrentVersion() { return knowhere::Version::GetCurrentVersion().VersionNumber(); @@ -464,11 +499,12 @@ Load(knowhere::BinarySetPtr binset, const std::string& file_name) { } } +template knowhere::DataSetPtr BruteForceSearch(knowhere::DataSetPtr base_dataset, knowhere::DataSetPtr query_dataset, const std::string& json, const knowhere::BitsetView& bitset, knowhere::Status& status) { GILReleaser rel; - auto res = knowhere::BruteForce::Search(base_dataset, query_dataset, knowhere::Json::parse(json), bitset); + auto res = knowhere::BruteForce::Search(base_dataset, query_dataset, knowhere::Json::parse(json), bitset); if (res.has_value()) { status = knowhere::Status::success; return res.value(); @@ -478,11 +514,12 @@ BruteForceSearch(knowhere::DataSetPtr base_dataset, knowhere::DataSetPtr query_d } } +template knowhere::DataSetPtr BruteForceRangeSearch(knowhere::DataSetPtr base_dataset, knowhere::DataSetPtr query_dataset, const std::string& json, const knowhere::BitsetView& bitset, knowhere::Status& status) { GILReleaser rel; - auto res = knowhere::BruteForce::RangeSearch(base_dataset, query_dataset, knowhere::Json::parse(json), bitset); + auto res = knowhere::BruteForce::RangeSearch(base_dataset, query_dataset, knowhere::Json::parse(json), bitset); if (res.has_value()) { status = knowhere::Status::success; return res.value(); @@ -510,3 +547,7 @@ SetSimdType(const std::string type) { %} %template(AnnIteratorWrapVector) std::vector; +%template(IndexWrapFloat) IndexWrap; +%template(IndexWrapFP16) IndexWrap; +%template(IndexWrapBF16) IndexWrap; +%template(IndexWrapBin) IndexWrap; diff --git a/src/common/comp/brute_force.cc b/src/common/comp/brute_force.cc index 042581c8f..51503a27d 100644 --- a/src/common/comp/brute_force.cc +++ b/src/common/comp/brute_force.cc @@ -30,17 +30,22 @@ namespace knowhere { /* knowhere wrapper API to call faiss brute force search for all metric types */ class BruteForceConfig : public BaseConfig {}; - +template expected BruteForce::Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset, const Json& config, const BitsetView& bitset) { - auto xb = base_dataset->GetTensor(); - auto nb = base_dataset->GetRows(); - auto dim = base_dataset->GetDim(); - - auto xq = query_dataset->GetTensor(); - auto nq = query_dataset->GetRows(); + DataSetPtr base(base_dataset); + DataSetPtr query(query_dataset); + if constexpr (!std::is_same_v::type>) { + base = data_type_conversion::type>(*base_dataset); + query = data_type_conversion::type>(*query_dataset); + } + auto xb = base->GetTensor(); + auto nb = base->GetRows(); + auto dim = base->GetDim(); + auto xq = query->GetTensor(); + auto nq = query->GetRows(); BruteForceConfig cfg; std::string msg; auto status = Config::Load(cfg, config, knowhere::SEARCH, &msg); @@ -133,15 +138,22 @@ BruteForce::Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset return GenResultDataSet(nq, cfg.k.value(), labels, distances); } +template Status BruteForce::SearchWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_dataset, int64_t* ids, float* dis, const Json& config, const BitsetView& bitset) { - auto xb = base_dataset->GetTensor(); - auto nb = base_dataset->GetRows(); - auto dim = base_dataset->GetDim(); + DataSetPtr base(base_dataset); + DataSetPtr query(query_dataset); + if constexpr (!std::is_same_v::type>) { + base = data_type_conversion::type>(*base_dataset); + query = data_type_conversion::type>(*query_dataset); + } + auto xb = base->GetTensor(); + auto nb = base->GetRows(); + auto dim = base->GetDim(); - auto xq = query_dataset->GetTensor(); - auto nq = query_dataset->GetRows(); + auto xq = query->GetTensor(); + auto nq = query->GetRows(); BruteForceConfig cfg; RETURN_IF_ERROR(Config::Load(cfg, config, knowhere::SEARCH)); @@ -231,15 +243,22 @@ BruteForce::SearchWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_ /** knowhere wrapper API to call faiss brute force range search for all metric types */ +template expected BruteForce::RangeSearch(const DataSetPtr base_dataset, const DataSetPtr query_dataset, const Json& config, const BitsetView& bitset) { - auto xb = base_dataset->GetTensor(); - auto nb = base_dataset->GetRows(); - auto dim = base_dataset->GetDim(); + DataSetPtr base(base_dataset); + DataSetPtr query(query_dataset); + if constexpr (!std::is_same_v::type>) { + base = data_type_conversion::type>(*base_dataset); + query = data_type_conversion::type>(*query_dataset); + } + auto xb = base->GetTensor(); + auto nb = base->GetRows(); + auto dim = base->GetDim(); - auto xq = query_dataset->GetTensor(); - auto nq = query_dataset->GetRows(); + auto xq = query->GetTensor(); + auto nq = query->GetRows(); BruteForceConfig cfg; std::string msg; @@ -343,4 +362,42 @@ BruteForce::RangeSearch(const DataSetPtr base_dataset, const DataSetPtr query_da GetRangeSearchResult(result_dist_array, result_id_array, is_ip, nq, radius, range_filter, distances, ids, lims); return GenResultDataSet(nq, ids, distances, lims); } +template expected +BruteForce::Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset, const Json& config, + const BitsetView& bitset); +template expected +BruteForce::Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset, const Json& config, + const BitsetView& bitset); +template expected +BruteForce::Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset, const Json& config, + const BitsetView& bitset); +template expected +BruteForce::Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset, const Json& config, + const BitsetView& bitset); + +template Status +BruteForce::SearchWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_dataset, int64_t* ids, + float* dis, const Json& config, const BitsetView& bitset); +template Status +BruteForce::SearchWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_dataset, int64_t* ids, + float* dis, const Json& config, const BitsetView& bitset); +template Status +BruteForce::SearchWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_dataset, int64_t* ids, + float* dis, const Json& config, const BitsetView& bitset); +template Status +BruteForce::SearchWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_dataset, int64_t* ids, + float* dis, const Json& config, const BitsetView& bitset); + +template expected +BruteForce::RangeSearch(const DataSetPtr base_dataset, const DataSetPtr query_dataset, + const Json& config, const BitsetView& bitset); +template expected +BruteForce::RangeSearch(const DataSetPtr base_dataset, const DataSetPtr query_dataset, + const Json& config, const BitsetView& bitset); +template expected +BruteForce::RangeSearch(const DataSetPtr base_dataset, const DataSetPtr query_dataset, + const Json& config, const BitsetView& bitset); +template expected +BruteForce::RangeSearch(const DataSetPtr base_dataset, const DataSetPtr query_dataset, + const Json& config, const BitsetView& bitset); } // namespace knowhere diff --git a/src/common/factory.cc b/src/common/factory.cc index 62d21f29d..fd075b8d5 100644 --- a/src/common/factory.cc +++ b/src/common/factory.cc @@ -13,22 +13,42 @@ namespace knowhere { +template Index IndexFactory::Create(const std::string& name, const int32_t& version, const Object& object) { auto& func_mapping_ = MapInstance(); - assert(func_mapping_.find(name) != func_mapping_.end()); - LOG_KNOWHERE_INFO_ << "create knowhere index " << name << " with version " << version; - return func_mapping_[name](version, object); + auto key = GetMapKey(name); + assert(func_mapping_.find(key) != func_mapping_.end()); + auto fun_map_v = (FunMapValue>*)(func_mapping_[key].get()); + return fun_map_v->fun_value(version, object); } +template const IndexFactory& -IndexFactory::Register(const std::string& name, - std::function(const int32_t& version, const Object&)> func) { +IndexFactory::Register(const std::string& name, std::function(const int32_t&, const Object&)> func) { auto& func_mapping_ = MapInstance(); - func_mapping_[name] = func; + auto key = GetMapKey(name); + assert(func_mapping_.find(key) == func_mapping_.end()); + func_mapping_[key] = std::make_unique>>(func); return *this; } +template +std::string +IndexFactory::GetMapKey(const std::string& name) { + if (std::is_same_v) { + return name + std::string("_fp32"); + } else if (std::is_same_v) { + return name + std::string("_fp16"); + } else if (std::is_same_v) { + return name + std::string("_bf16"); + } else if (std::is_same_v) { + return name + std::string("_bin1"); + } else { + assert(false && "invalid data type"); + } +} + IndexFactory& IndexFactory::Instance() { static IndexFactory factory; @@ -42,4 +62,33 @@ IndexFactory::MapInstance() { static FuncMap func_map; return func_map; } + +template class Index +IndexFactory::Create(const std::string&, const int32_t&, const Object&); +template class Index +IndexFactory::Create(const std::string&, const int32_t&, const Object&); +template class Index +IndexFactory::Create(const std::string&, const int32_t&, const Object&); +template class Index +IndexFactory::Create(const std::string&, const int32_t&, const Object&); +template const IndexFactory& +IndexFactory::Register(const std::string&, + std::function(const int32_t&, const Object&)>); +template const IndexFactory& +IndexFactory::Register(const std::string&, + std::function(const int32_t&, const Object&)>); +template const IndexFactory& +IndexFactory::Register(const std::string&, + std::function(const int32_t&, const Object&)>); +template const IndexFactory& +IndexFactory::Register(const std::string&, + std::function(const int32_t&, const Object&)>); +template std::string +IndexFactory::GetMapKey(const std::string&); +template std::string +IndexFactory::GetMapKey(const std::string&); +template std::string +IndexFactory::GetMapKey(const std::string&); +template std::string +IndexFactory::GetMapKey(const std::string&); } // namespace knowhere diff --git a/src/common/index_node_data_mock_wrapper.cc b/src/common/index_node_data_mock_wrapper.cc new file mode 100644 index 000000000..b36211e0c --- /dev/null +++ b/src/common/index_node_data_mock_wrapper.cc @@ -0,0 +1,103 @@ +// Copyright (C) 2019-2023 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not +// use this file except in compliance with the License. You may obtain a copy of +// the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations under +// the License. + +#include "knowhere/index_node_data_mock_wrapper.h" + +#include "knowhere/comp/thread_pool.h" +#include "knowhere/index_node.h" +#include "knowhere/utils.h" + +namespace knowhere { + +template +Status +IndexNodeDataMockWrapper::Build(const DataSet& dataset, const Config& cfg) { + std::shared_ptr ds_ptr = dataset.shared_from_this(); + if constexpr (!std::is_same_v::type>) { + ds_ptr = data_type_conversion::type>(dataset); + } + return index_node_->Build(*ds_ptr, cfg); +} + +template +Status +IndexNodeDataMockWrapper::Train(const DataSet& dataset, const Config& cfg) { + std::shared_ptr ds_ptr = dataset.shared_from_this(); + if constexpr (!std::is_same_v::type>) { + ds_ptr = data_type_conversion::type>(dataset); + } + return index_node_->Train(*ds_ptr, cfg); +} + +template +Status +IndexNodeDataMockWrapper::Add(const DataSet& dataset, const Config& cfg) { + std::shared_ptr ds_ptr = dataset.shared_from_this(); + if constexpr (!std::is_same_v::type>) { + ds_ptr = data_type_conversion::type>(dataset); + } + return index_node_->Add(*ds_ptr, cfg); +} + +template +expected +IndexNodeDataMockWrapper::Search(const DataSet& dataset, const Config& cfg, const BitsetView& bitset) const { + auto ds_ptr = dataset.shared_from_this(); + if constexpr (!std::is_same_v::type>) { + ds_ptr = data_type_conversion::type>(dataset); + } + return index_node_->Search(*ds_ptr, cfg, bitset); +} + +template +expected +IndexNodeDataMockWrapper::RangeSearch(const DataSet& dataset, const Config& cfg, + const BitsetView& bitset) const { + auto ds_ptr = dataset.shared_from_this(); + if constexpr (!std::is_same_v::type>) { + ds_ptr = data_type_conversion::type>(dataset); + } + return index_node_->RangeSearch(*ds_ptr, cfg, bitset); +} + +template +expected>> +IndexNodeDataMockWrapper::AnnIterator(const DataSet& dataset, const Config& cfg, + const BitsetView& bitset) const { + auto ds_ptr = dataset.shared_from_this(); + if constexpr (!std::is_same_v::type>) { + ds_ptr = data_type_conversion::type>(dataset); + } + return index_node_->AnnIterator(*ds_ptr, cfg, bitset); +} + +template +expected +IndexNodeDataMockWrapper::GetVectorByIds(const DataSet& dataset) const { + auto res = index_node_->GetVectorByIds(dataset); + if constexpr (!std::is_same_v::type>) { + if (res.has_value()) { + auto res_v = data_type_conversion::type>(*res.value()); + return res_v; + } else { + return res; + } + } else { + return res; + } +} + +template class knowhere::IndexNodeDataMockWrapper; +template class knowhere::IndexNodeDataMockWrapper; +} // namespace knowhere diff --git a/src/index/diskann/diskann.cc b/src/index/diskann/diskann.cc index aa05610e7..66651c0e1 100644 --- a/src/index/diskann/diskann.cc +++ b/src/index/diskann/diskann.cc @@ -25,16 +25,17 @@ #include "knowhere/expected.h" #include "knowhere/factory.h" #include "knowhere/file_manager.h" +#include "knowhere/index_node_data_mock_wrapper.h" #include "knowhere/log.h" #include "knowhere/utils.h" namespace knowhere { - -template +template class DiskANNIndexNode : public IndexNode { - static_assert(std::is_same_v, "DiskANN only support float"); + static_assert(std::is_same_v, "DiskANN only support float"); public: + using DistType = typename ResultType::type; DiskANNIndexNode(const int32_t& version, const Object& object) : is_prepared_(false), dim_(-1), count_(-1) { assert(typeid(object) == typeid(Pack>)); auto diskann_index_pack = dynamic_cast>*>(&object); @@ -160,7 +161,7 @@ class DiskANNIndexNode : public IndexNode { mutable std::mutex preparation_lock_; std::atomic_bool is_prepared_; std::shared_ptr file_manager_; - std::unique_ptr> pq_flash_index_; + std::unique_ptr> pq_flash_index_; std::atomic_int64_t dim_; std::atomic_int64_t count_; std::shared_ptr search_pool_; @@ -239,7 +240,8 @@ inline bool CheckMetric(const std::string& diskann_metric) { if (diskann_metric != knowhere::metric::L2 && diskann_metric != knowhere::metric::IP && diskann_metric != knowhere::metric::COSINE) { - LOG_KNOWHERE_ERROR_ << "DiskANN currently only supports floating point data for Minimum Euclidean " + LOG_KNOWHERE_ERROR_ << "DiskANN currently only supports floating point " + "data for Minimum Euclidean " "distance(L2), Max Inner Product Search(IP) " "and Minimum Cosine Search(COSINE)." << std::endl; @@ -250,9 +252,9 @@ CheckMetric(const std::string& diskann_metric) { } } // namespace -template +template Status -DiskANNIndexNode::Build(const DataSet& dataset, const Config& cfg) { +DiskANNIndexNode::Build(const DataSet& dataset, const Config& cfg) { assert(file_manager_ != nullptr); std::lock_guard lock(preparation_lock_); auto build_conf = static_cast(cfg); @@ -272,7 +274,8 @@ DiskANNIndexNode::Build(const DataSet& dataset, const Config& cfg) { LOG_KNOWHERE_ERROR_ << "Failed load the raw data before building." << std::endl; return Status::disk_file_error; } - auto& data_path = build_conf.data_path.value(); + auto data_path = build_conf.data_path.value(); + index_prefix_ = build_conf.index_prefix.value(); size_t count; @@ -306,7 +309,7 @@ DiskANNIndexNode::Build(const DataSet& dataset, const Config& cfg) { build_conf.accelerate_build.value(), static_cast(num_nodes_to_cache)}; RETURN_IF_ERROR(TryDiskANNCall([&]() { - int res = diskann::build_disk_index(diskann_internal_build_config); + int res = diskann::build_disk_index(diskann_internal_build_config); if (res != 0) throw diskann::ANNException("diskann::build_disk_index returned non-zero value: " + std::to_string(res), -1); @@ -330,9 +333,9 @@ DiskANNIndexNode::Build(const DataSet& dataset, const Config& cfg) { return Status::success; } -template +template Status -DiskANNIndexNode::Deserialize(const BinarySet& binset, const Config& cfg) { +DiskANNIndexNode::Deserialize(const BinarySet& binset, const Config& cfg) { std::lock_guard lock(preparation_lock_); auto prep_conf = static_cast(cfg); if (!CheckMetric(prep_conf.metric_type.value())) { @@ -386,7 +389,7 @@ DiskANNIndexNode::Deserialize(const BinarySet& binset, const Config& cfg) { reader.reset(new LinuxAlignedFileReader()); - pq_flash_index_ = std::make_unique>(reader, diskann_metric); + pq_flash_index_ = std::make_unique>(reader, diskann_metric); auto disk_ann_call = [&]() { int res = pq_flash_index_->load(search_pool_->size(), index_prefix_.c_str()); if (res != 0) { @@ -464,15 +467,16 @@ DiskANNIndexNode::Deserialize(const BinarySet& binset, const Config& cfg) { uint64_t warmup_num = 0; uint64_t warmup_dim = 0; uint64_t warmup_aligned_dim = 0; - T* warmup = nullptr; + DataType* warmup = nullptr; if (TryDiskANNCall([&]() { - diskann::load_aligned_bin(warmup_query_file, warmup, warmup_num, warmup_dim, warmup_aligned_dim); + diskann::load_aligned_bin(warmup_query_file, warmup, warmup_num, warmup_dim, + warmup_aligned_dim); }) != Status::success) { LOG_KNOWHERE_ERROR_ << "Failed to load warmup file for DiskANN."; return Status::disk_file_error; } std::vector warmup_result_ids_64(warmup_num, 0); - std::vector warmup_result_dists(warmup_num, 0); + std::vector warmup_result_dists(warmup_num, 0); bool all_searches_are_good = true; @@ -504,9 +508,9 @@ DiskANNIndexNode::Deserialize(const BinarySet& binset, const Config& cfg) { return Status::success; } -template +template expected -DiskANNIndexNode::Search(const DataSet& dataset, const Config& cfg, const BitsetView& bitset) const { +DiskANNIndexNode::Search(const DataSet& dataset, const Config& cfg, const BitsetView& bitset) const { if (!is_prepared_.load() || !pq_flash_index_) { LOG_KNOWHERE_ERROR_ << "Failed to load diskann."; return expected::Err(Status::empty_index, "DiskANN not loaded"); @@ -524,7 +528,7 @@ DiskANNIndexNode::Search(const DataSet& dataset, const Config& cfg, const Bit auto nq = dataset.GetRows(); auto dim = dataset.GetDim(); - auto xq = static_cast(dataset.GetTensor()); + auto xq = static_cast(dataset.GetTensor()); feder::diskann::FederResultUniq feder_result; if (search_conf.trace_visit.value()) { @@ -537,7 +541,7 @@ DiskANNIndexNode::Search(const DataSet& dataset, const Config& cfg, const Bit } auto p_id = new int64_t[k * nq]; - auto p_dist = new float[k * nq]; + auto p_dist = new DistType[k * nq]; bool all_searches_are_good = true; std::vector> futures; @@ -572,9 +576,9 @@ DiskANNIndexNode::Search(const DataSet& dataset, const Config& cfg, const Bit return res; } -template +template expected -DiskANNIndexNode::RangeSearch(const DataSet& dataset, const Config& cfg, const BitsetView& bitset) const { +DiskANNIndexNode::RangeSearch(const DataSet& dataset, const Config& cfg, const BitsetView& bitset) const { if (!is_prepared_.load() || !pq_flash_index_) { LOG_KNOWHERE_ERROR_ << "Failed to load diskann."; return expected::Err(Status::empty_index, "index not loaded"); @@ -600,14 +604,14 @@ DiskANNIndexNode::RangeSearch(const DataSet& dataset, const Config& cfg, cons auto dim = dataset.GetDim(); auto nq = dataset.GetRows(); - auto xq = static_cast(dataset.GetTensor()); + auto xq = static_cast(dataset.GetTensor()); int64_t* p_id = nullptr; - float* p_dist = nullptr; + DistType* p_dist = nullptr; size_t* p_lims = nullptr; std::vector> result_id_array(nq); - std::vector> result_dist_array(nq); + std::vector> result_dist_array(nq); std::vector> futures; futures.reserve(nq); @@ -644,18 +648,17 @@ DiskANNIndexNode::RangeSearch(const DataSet& dataset, const Config& cfg, cons * It first tries to get data from cache, if failed, it will try to get data from disk. * It reads as much as possible and it is thread-pool free, it totally depends on the outside to control concurrency. */ -template +template expected -DiskANNIndexNode::GetVectorByIds(const DataSet& dataset) const { +DiskANNIndexNode::GetVectorByIds(const DataSet& dataset) const { if (!is_prepared_.load() || !pq_flash_index_) { LOG_KNOWHERE_ERROR_ << "Failed to load diskann."; return expected::Err(Status::empty_index, "index not loaded"); } - auto dim = Dim(); auto rows = dataset.GetRows(); auto ids = dataset.GetIds(); - float* data = new float[dim * rows]; + auto* data = new DataType[dim * rows]; if (data == nullptr) { LOG_KNOWHERE_ERROR_ << "Failed to allocate memory for data."; return expected::Err(Status::malloc_error, "failed to allocate memory for data"); @@ -669,9 +672,9 @@ DiskANNIndexNode::GetVectorByIds(const DataSet& dataset) const { return GenResultDataSet(rows, dim, data); } -template +template expected -DiskANNIndexNode::GetIndexMeta(const Config& cfg) const { +DiskANNIndexNode::GetIndexMeta(const Config& cfg) const { std::vector entry_points; for (size_t i = 0; i < pq_flash_index_->get_num_medoids(); i++) { entry_points.push_back(pq_flash_index_->get_medoids()[i]); @@ -689,17 +692,15 @@ DiskANNIndexNode::GetIndexMeta(const Config& cfg) const { return GenResultDataSet(json_meta.dump(), json_id_set.dump()); } -template +template uint64_t -DiskANNIndexNode::GetCachedNodeNum(const float cache_dram_budget, const uint64_t data_dim, - const uint64_t max_degree) { - uint32_t one_cached_node_budget = (max_degree + 1) * sizeof(unsigned) + sizeof(T) * data_dim; +DiskANNIndexNode::GetCachedNodeNum(const float cache_dram_budget, const uint64_t data_dim, + const uint64_t max_degree) { + uint32_t one_cached_node_budget = (max_degree + 1) * sizeof(unsigned) + sizeof(DataType) * data_dim; auto num_nodes_to_cache = static_cast(1024 * 1024 * 1024 * cache_dram_budget) / (one_cached_node_budget * kCacheExpansionRate); return num_nodes_to_cache; } -KNOWHERE_REGISTER_GLOBAL(DISKANN, [](const int32_t& version, const Object& object) { - return Index>::Create(version, object); -}); +KNOWHERE_SIMPLE_REGISTER_GLOBAL(DISKANN, DiskANNIndexNode, fp32); } // namespace knowhere diff --git a/src/index/flat/flat.cc b/src/index/flat/flat.cc index 79f10eca3..4a17d2832 100644 --- a/src/index/flat/flat.cc +++ b/src/index/flat/flat.cc @@ -19,17 +19,19 @@ #include "knowhere/bitsetview_idselector.h" #include "knowhere/comp/thread_pool.h" #include "knowhere/factory.h" +#include "knowhere/index_node_data_mock_wrapper.h" #include "knowhere/log.h" #include "knowhere/utils.h" namespace knowhere { -template +template class FlatIndexNode : public IndexNode { public: FlatIndexNode(const int32_t version, const Object& object) : index_(nullptr) { - static_assert(std::is_same::value || std::is_same::value, - "not support"); + static_assert( + std::is_same::value || std::is_same::value, + "not support"); search_pool_ = ThreadPool::GetGlobalSearchThreadPool(); } @@ -42,10 +44,10 @@ class FlatIndexNode : public IndexNode { LOG_KNOWHERE_ERROR_ << "unsupported metric type: " << f_cfg.metric_type.value(); return metric.error(); } - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { index_ = std::make_unique(dataset.GetDim(), metric.value()); } - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { bool is_cosine = IsMetricType(f_cfg.metric_type.value(), knowhere::metric::COSINE); index_ = std::make_unique(dataset.GetDim(), metric.value(), is_cosine); } @@ -56,12 +58,7 @@ class FlatIndexNode : public IndexNode { Add(const DataSet& dataset, const Config& cfg) override { auto x = dataset.GetTensor(); auto n = dataset.GetRows(); - if constexpr (std::is_same::value) { - index_->add(n, (const float*)x); - } - if constexpr (std::is_same::value) { - index_->add(n, (const uint8_t*)x); - } + index_->add(n, (const DataType*)x); return Status::success; } @@ -98,9 +95,9 @@ class FlatIndexNode : public IndexNode { BitsetViewIDSelector bw_idselector(bitset); faiss::IDSelector* id_selector = (bitset.empty()) ? nullptr : &bw_idselector; - if constexpr (std::is_same::value) { - auto cur_query = (const float*)x + dim * index; - std::unique_ptr copied_query = nullptr; + if constexpr (std::is_same::value) { + auto cur_query = (const DataType*)x + dim * index; + std::unique_ptr copied_query = nullptr; if (is_cosine) { copied_query = CopyAndNormalizeVecs(cur_query, 1, dim); cur_query = copied_query.get(); @@ -111,7 +108,7 @@ class FlatIndexNode : public IndexNode { index_->search(1, cur_query, k, cur_dis, cur_ids, &search_params); } - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { auto cur_i_dis = reinterpret_cast(cur_dis); faiss::SearchParameters search_params; @@ -142,7 +139,6 @@ class FlatIndexNode : public IndexNode { LOG_KNOWHERE_WARNING_ << "error inner faiss: " << e.what(); return expected::Err(Status::faiss_inner_error, e.what()); } - return GenResultDataSet(nq, k, ids, distances); } @@ -184,9 +180,9 @@ class FlatIndexNode : public IndexNode { BitsetViewIDSelector bw_idselector(bitset); faiss::IDSelector* id_selector = (bitset.empty()) ? nullptr : &bw_idselector; - if constexpr (std::is_same::value) { - auto cur_query = (const float*)xq + dim * index; - std::unique_ptr copied_query = nullptr; + if constexpr (std::is_same::value) { + auto cur_query = (const DataType*)xq + dim * index; + std::unique_ptr copied_query = nullptr; if (is_cosine) { copied_query = CopyAndNormalizeVecs(cur_query, 1, dim); cur_query = copied_query.get(); @@ -197,7 +193,7 @@ class FlatIndexNode : public IndexNode { index_->range_search(1, cur_query, radius, &res, &search_params); } - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { faiss::SearchParameters search_params; search_params.sel = id_selector; @@ -241,22 +237,22 @@ class FlatIndexNode : public IndexNode { auto dim = Dim(); auto rows = dataset.GetRows(); auto ids = dataset.GetIds(); - if constexpr (std::is_same::value) { - float* data = nullptr; + if constexpr (std::is_same::value) { + DataType* data = nullptr; try { - data = new float[rows * dim]; + data = new DataType[rows * dim]; for (int64_t i = 0; i < rows; i++) { index_->reconstruct(ids[i], data + i * dim); } return GenResultDataSet(rows, dim, data); } catch (const std::exception& e) { - std::unique_ptr auto_del(data); + std::unique_ptr auto_del(data); LOG_KNOWHERE_WARNING_ << "faiss inner error: " << e.what(); return expected::Err(Status::faiss_inner_error, e.what()); } } - if constexpr (std::is_same::value) { - uint8_t* data = nullptr; + if constexpr (std::is_same::value) { + DataType* data = nullptr; try { data = new uint8_t[rows * dim / 8]; for (int64_t i = 0; i < rows; i++) { @@ -273,14 +269,14 @@ class FlatIndexNode : public IndexNode { bool HasRawData(const std::string& metric_type) const override { - if constexpr (std::is_same::value) { - if (version_ <= Version::GetMinimalVersion()) { + if constexpr (std::is_same::value) { + if (this->version_ <= Version::GetMinimalVersion()) { return !IsMetricType(metric_type, metric::COSINE); } else { return true; } } - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { return true; } } @@ -298,10 +294,10 @@ class FlatIndexNode : public IndexNode { } try { MemoryIOWriter writer; - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { faiss::write_index(index_.get(), &writer); } - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { faiss::write_index_binary(index_.get(), &writer); } std::shared_ptr data(writer.data()); @@ -325,13 +321,13 @@ class FlatIndexNode : public IndexNode { } MemoryIOReader reader(binary->data.get(), binary->size); - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { faiss::Index* index = faiss::read_index(&reader); - index_.reset(static_cast(index)); + index_.reset(static_cast(index)); } - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { faiss::IndexBinary* index = faiss::read_index_binary(&reader); - index_.reset(static_cast(index)); + index_.reset(static_cast(index)); } return Status::success; } @@ -345,13 +341,13 @@ class FlatIndexNode : public IndexNode { io_flags |= faiss::IO_FLAG_MMAP; } - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { faiss::Index* index = faiss::read_index(filename.data(), io_flags); - index_.reset(static_cast(index)); + index_.reset(static_cast(index)); } - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { faiss::IndexBinary* index = faiss::read_index_binary(filename.data(), io_flags); - index_.reset(static_cast(index)); + index_.reset(static_cast(index)); } return Status::success; } @@ -368,7 +364,7 @@ class FlatIndexNode : public IndexNode { int64_t Size() const override { - return index_->ntotal * index_->d * sizeof(float); + return index_->ntotal * index_->d * sizeof(DataType); } int64_t @@ -378,27 +374,22 @@ class FlatIndexNode : public IndexNode { std::string Type() const override { - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { return knowhere::IndexEnum::INDEX_FAISS_IDMAP; } - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { return knowhere::IndexEnum::INDEX_FAISS_BIN_IDMAP; } } private: - std::unique_ptr index_; + std::unique_ptr index_; std::shared_ptr search_pool_; }; -KNOWHERE_REGISTER_GLOBAL(FLAT, [](const int32_t& version, const Object& object) { - return Index>::Create(version, object); -}); -KNOWHERE_REGISTER_GLOBAL(BINFLAT, [](const int32_t& version, const Object& object) { - return Index>::Create(version, object); -}); -KNOWHERE_REGISTER_GLOBAL(BIN_FLAT, [](const int32_t& version, const Object& object) { - return Index>::Create(version, object); -}); - +KNOWHERE_SIMPLE_REGISTER_GLOBAL(FLAT, FlatIndexNode, fp32, faiss::IndexFlat); +KNOWHERE_SIMPLE_REGISTER_GLOBAL(BINFLAT, FlatIndexNode, bin1, faiss::IndexBinaryFlat); +KNOWHERE_SIMPLE_REGISTER_GLOBAL(BIN_FLAT, FlatIndexNode, bin1, faiss::IndexBinaryFlat); +KNOWHERE_MOCK_REGISTER_GLOBAL(FLAT, FlatIndexNode, fp16, faiss::IndexFlat); +KNOWHERE_MOCK_REGISTER_GLOBAL(FLAT, FlatIndexNode, bf16, faiss::IndexFlat); } // namespace knowhere diff --git a/src/index/gpu/flat_gpu/flat_gpu.cc b/src/index/gpu/flat_gpu/flat_gpu.cc index 9436ff938..3ab51ab32 100644 --- a/src/index/gpu/flat_gpu/flat_gpu.cc +++ b/src/index/gpu/flat_gpu/flat_gpu.cc @@ -21,6 +21,7 @@ namespace knowhere { +template class GpuFlatIndexNode : public IndexNode { public: GpuFlatIndexNode(const int32_t& version, const Object& object) : index_(nullptr) { @@ -189,8 +190,11 @@ class GpuFlatIndexNode : public IndexNode { std::unique_ptr index_; }; -KNOWHERE_REGISTER_GLOBAL(GPU_FAISS_FLAT, [](const int32_t& version, const Object& object) { - return Index::Create(version, object); -}); +KNOWHERE_REGISTER_GLOBAL( + GPU_FAISS_FLAT, + [](const int32_t& version, const Object& object) { + return (Index>::Create(version, object)); + }, + fp32); } // namespace knowhere diff --git a/src/index/gpu/ivf_gpu/ivf_gpu.cc b/src/index/gpu/ivf_gpu/ivf_gpu.cc index 2569f24a3..bf8d3cb00 100644 --- a/src/index/gpu/ivf_gpu/ivf_gpu.cc +++ b/src/index/gpu/ivf_gpu/ivf_gpu.cc @@ -273,14 +273,23 @@ class GpuIvfIndexNode : public IndexNode { std::unique_ptr index_; }; -KNOWHERE_REGISTER_GLOBAL(GPU_FAISS_IVF_FLAT, [](const int32_t& version, const Object& object) { - return Index>::Create(version, object); -}); -KNOWHERE_REGISTER_GLOBAL(GPU_FAISS_IVF_PQ, [](const int32_t& version, const Object& object) { - return Index>::Create(version, object); -}); -KNOWHERE_REGISTER_GLOBAL(GPU_FAISS_IVF_SQ8, [](const int32_t& version, const Object& object) { - return Index>::Create(version, object); -}); +KNOWHERE_REGISTER_GLOBAL( + GPU_FAISS_IVF_FLAT, + [](const int32_t& version, const Object& object) { + return (Index>>::Create(version, object)); + }, + fp32); +KNOWHERE_REGISTER_GLOBAL( + GPU_FAISS_IVF_PQ, + [](const int32_t& version, const Object& object) { + return (Index>>::Create(version, object)); + }, + fp32); +KNOWHERE_REGISTER_GLOBAL( + GPU_FAISS_IVF_SQ8, + [](const int32_t& version, const Object& object) { + return Index>>::Create(version, object); + }, + fp32); } // namespace knowhere diff --git a/src/index/gpu_raft/gpu_raft_cagra.cc b/src/index/gpu_raft/gpu_raft_cagra.cc index c6ea08661..2d7a2b9c7 100644 --- a/src/index/gpu_raft/gpu_raft_cagra.cc +++ b/src/index/gpu_raft/gpu_raft_cagra.cc @@ -25,9 +25,12 @@ namespace knowhere { template struct GpuRaftIndexNode; -KNOWHERE_REGISTER_GLOBAL(GPU_RAFT_CAGRA, [](const int32_t& version, const Object& object) { - return Index::Create(std::make_unique(version, object), - cuda_concurrent_size); -}); +KNOWHERE_REGISTER_GLOBAL( + GPU_RAFT_CAGRA, + [](const int32_t& version, const Object& object) { + return Index::Create(std::make_unique(version, object), + cuda_concurrent_size); + }, + fp32); } // namespace knowhere diff --git a/src/index/gpu_raft/gpu_raft_ivf_flat.cc b/src/index/gpu_raft/gpu_raft_ivf_flat.cc index ec39e1309..c65c06fd4 100644 --- a/src/index/gpu_raft/gpu_raft_ivf_flat.cc +++ b/src/index/gpu_raft/gpu_raft_ivf_flat.cc @@ -25,9 +25,12 @@ namespace knowhere { template struct GpuRaftIndexNode; -KNOWHERE_REGISTER_GLOBAL(GPU_RAFT_IVF_FLAT, [](const int32_t& version, const Object& object) { - return Index::Create(std::make_unique(version, object), - cuda_concurrent_size); -}); +KNOWHERE_REGISTER_GLOBAL( + GPU_RAFT_IVF_FLAT, + [](const int32_t& version, const Object& object) { + return Index::Create(std::make_unique(version, object), + cuda_concurrent_size); + }, + fp32); } // namespace knowhere diff --git a/src/index/gpu_raft/gpu_raft_ivf_pq.cc b/src/index/gpu_raft/gpu_raft_ivf_pq.cc index d59e647ea..c462bce75 100644 --- a/src/index/gpu_raft/gpu_raft_ivf_pq.cc +++ b/src/index/gpu_raft/gpu_raft_ivf_pq.cc @@ -25,9 +25,12 @@ namespace knowhere { template struct GpuRaftIndexNode; -KNOWHERE_REGISTER_GLOBAL(GPU_RAFT_IVF_PQ, [](const int32_t& version, const Object& object) { - return Index::Create(std::make_unique(version, object), - cuda_concurrent_size); -}); +KNOWHERE_REGISTER_GLOBAL( + GPU_RAFT_IVF_PQ, + [](const int32_t& version, const Object& object) { + return Index::Create(std::make_unique(version, object), + cuda_concurrent_size); + }, + fp32); } // namespace knowhere diff --git a/src/index/hnsw/hnsw.cc b/src/index/hnsw/hnsw.cc index d82dcd581..f13a835fe 100644 --- a/src/index/hnsw/hnsw.cc +++ b/src/index/hnsw/hnsw.cc @@ -23,12 +23,18 @@ #include "knowhere/config.h" #include "knowhere/expected.h" #include "knowhere/factory.h" +#include "knowhere/index_node_data_mock_wrapper.h" #include "knowhere/log.h" #include "knowhere/utils.h" namespace knowhere { +template class HnswIndexNode : public IndexNode { + static_assert(std::is_same_v || std::is_same_v, + "HnswIndexNode only support float/bianry"); + public: + using DistType = typename ResultType::type; HnswIndexNode(const int32_t& /*version*/, const Object& object) : index_(nullptr) { search_pool_ = ThreadPool::GetGlobalSearchThreadPool(); } @@ -38,7 +44,7 @@ class HnswIndexNode : public IndexNode { auto rows = dataset.GetRows(); auto dim = dataset.GetDim(); auto hnsw_cfg = static_cast(cfg); - hnswlib::SpaceInterface* space = nullptr; + hnswlib::SpaceInterface* space = nullptr; if (IsMetricType(hnsw_cfg.metric_type.value(), metric::L2)) { space = new (std::nothrow) hnswlib::L2Space(dim); } else if (IsMetricType(hnsw_cfg.metric_type.value(), metric::IP)) { @@ -54,7 +60,7 @@ class HnswIndexNode : public IndexNode { return Status::invalid_metric_type; } auto index = new (std::nothrow) - hnswlib::HierarchicalNSW(space, rows, hnsw_cfg.M.value(), hnsw_cfg.efConstruction.value()); + hnswlib::HierarchicalNSW(space, rows, hnsw_cfg.M.value(), hnsw_cfg.efConstruction.value()); if (index == nullptr) { LOG_KNOWHERE_WARNING_ << "memory malloc error."; return Status::malloc_error; @@ -107,7 +113,7 @@ class HnswIndexNode : public IndexNode { expected Search(const DataSet& dataset, const Config& cfg, const BitsetView& bitset) const override { if (!index_) { - LOG_KNOWHERE_WARNING_ << "search on empty index"; + LOG_KNOWHERE_WARNING_ << "search on empty indefloatx"; return expected::Err(Status::empty_index, "index not loaded"); } auto nq = dataset.GetRows(); @@ -125,7 +131,7 @@ class HnswIndexNode : public IndexNode { } auto p_id = new int64_t[k * nq]; - auto p_dist = new float[k * nq]; + auto p_dist = new DistType[k * nq]; hnswlib::SearchParam param{(size_t)hnsw_cfg.ef.value(), hnsw_cfg.for_tuning.value()}; bool transform = @@ -146,7 +152,7 @@ class HnswIndexNode : public IndexNode { p_single_id[idx] = id; } for (size_t idx = rst_size; idx < (size_t)k; idx++) { - p_single_dis[idx] = float(1.0 / 0.0); + p_single_dis[idx] = DistType(1.0 / 0.0); p_single_id[idx] = -1; } })); @@ -171,7 +177,7 @@ class HnswIndexNode : public IndexNode { private: class iterator : public IndexNode::iterator { public: - iterator(const hnswlib::HierarchicalNSW* index, const char* query, const bool transform, + iterator(const hnswlib::HierarchicalNSW* index, const char* query, const bool transform, const BitsetView& bitset, const bool for_tuning = false, const size_t seed_ef = kIteratorSeedEf) : index_(index), transform_(transform), @@ -179,7 +185,7 @@ class HnswIndexNode : public IndexNode { UpdateNext(); } - std::pair + std::pair Next() override { auto ret = std::make_pair(next_id_, next_dist_); UpdateNext(); @@ -204,16 +210,16 @@ class HnswIndexNode : public IndexNode { has_next_ = false; } } - const hnswlib::HierarchicalNSW* index_; + const hnswlib::HierarchicalNSW* index_; const bool transform_; std::unique_ptr workspace_; bool has_next_; - float next_dist_; + DistType next_dist_; int64_t next_id_; }; public: - expected>> + expected>> AnnIterator(const DataSet& dataset, const Config& cfg, const BitsetView& bitset) const override { if (!index_) { LOG_KNOWHERE_WARNING_ << "creating iterator on empty index"; @@ -258,10 +264,10 @@ class HnswIndexNode : public IndexNode { auto hnsw_cfg = static_cast(cfg); bool is_ip = (index_->metric_type_ == hnswlib::Metric::INNER_PRODUCT || index_->metric_type_ == hnswlib::Metric::COSINE); - float range_filter = hnsw_cfg.range_filter.value(); + DistType range_filter = hnsw_cfg.range_filter.value(); - float radius_for_calc = (is_ip ? -hnsw_cfg.radius.value() : hnsw_cfg.radius.value()); - float radius_for_filter = hnsw_cfg.radius.value(); + DistType radius_for_calc = (is_ip ? -hnsw_cfg.radius.value() : hnsw_cfg.radius.value()); + DistType radius_for_filter = hnsw_cfg.radius.value(); feder::hnsw::FederResultUniq feder_result; if (hnsw_cfg.trace_visit.value()) { @@ -274,11 +280,11 @@ class HnswIndexNode : public IndexNode { hnswlib::SearchParam param{(size_t)hnsw_cfg.ef.value()}; int64_t* ids = nullptr; - float* dis = nullptr; + DistType* dis = nullptr; size_t* lims = nullptr; std::vector> result_id_array(nq); - std::vector> result_dist_array(nq); + std::vector> result_dist_array(nq); std::vector result_size(nq); std::vector result_lims(nq + 1); @@ -416,8 +422,8 @@ class HnswIndexNode : public IndexNode { MemoryIOReader reader(binary->data.get(), binary->size); - hnswlib::SpaceInterface* space = nullptr; - index_ = new (std::nothrow) hnswlib::HierarchicalNSW(space); + hnswlib::SpaceInterface* space = nullptr; + index_ = new (std::nothrow) hnswlib::HierarchicalNSW(space); index_->loadIndex(reader); LOG_KNOWHERE_INFO_ << "Loaded HNSW index. #points num:" << index_->max_elements_ << " #M:" << index_->M_ << " #max level:" << index_->maxlevel_ @@ -436,8 +442,8 @@ class HnswIndexNode : public IndexNode { delete index_; } try { - hnswlib::SpaceInterface* space = nullptr; - index_ = new (std::nothrow) hnswlib::HierarchicalNSW(space); + hnswlib::SpaceInterface* space = nullptr; + index_ = new (std::nothrow) hnswlib::HierarchicalNSW(space); index_->loadIndex(filename, config); } catch (std::exception& e) { LOG_KNOWHERE_WARNING_ << "hnsw inner error: " << e.what(); @@ -524,12 +530,12 @@ class HnswIndexNode : public IndexNode { } private: - hnswlib::HierarchicalNSW* index_; + hnswlib::HierarchicalNSW* index_; std::shared_ptr search_pool_; }; -KNOWHERE_REGISTER_GLOBAL(HNSW, [](const int32_t& version, const Object& object) { - return Index::Create(version, object); -}); - +KNOWHERE_SIMPLE_REGISTER_GLOBAL(HNSW, HnswIndexNode, fp32); +KNOWHERE_SIMPLE_REGISTER_GLOBAL(HNSW, HnswIndexNode, bin1); +KNOWHERE_MOCK_REGISTER_GLOBAL(HNSW, HnswIndexNode, fp16); +KNOWHERE_MOCK_REGISTER_GLOBAL(HNSW, HnswIndexNode, bf16); } // namespace knowhere diff --git a/src/index/ivf/ivf.cc b/src/index/ivf/ivf.cc index 53c8276de..eff554d69 100644 --- a/src/index/ivf/ivf.cc +++ b/src/index/ivf/ivf.cc @@ -29,19 +29,34 @@ #include "knowhere/expected.h" #include "knowhere/factory.h" #include "knowhere/feder/IVFFlat.h" +#include "knowhere/index_node_data_mock_wrapper.h" #include "knowhere/log.h" #include "knowhere/utils.h" namespace knowhere { +struct IVFBaseTag {}; +struct IVFFlatTag {}; -template +template +struct IndexDispatch { + using Tag = IVFBaseTag; +}; + +template <> +struct IndexDispatch { + using Tag = IVFFlatTag; +}; + +template class IvfIndexNode : public IndexNode { public: IvfIndexNode(const int32_t version, const Object& object) : IndexNode(version), index_(nullptr) { - static_assert(std::is_same::value || std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || std::is_same::value, + static_assert(std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same::value, "not support"); search_pool_ = ThreadPool::GetGlobalSearchThreadPool(); } @@ -60,53 +75,55 @@ class IvfIndexNode : public IndexNode { if (!index_) { return false; } - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { return true; } - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { return true; } - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { return false; } - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { return index_->with_raw_data(); } - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { return false; } - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { return true; } } expected GetIndexMeta(const Config& cfg) const override { - return expected::Err(Status::not_implemented, "GetIndexMeta not implemented"); + return this->GetIndexMetaImpl(cfg, typename IndexDispatch::Tag{}); } Status - Serialize(BinarySet& binset) const override; + Serialize(BinarySet& binset) const override { + return this->SerializeImpl(binset, typename IndexDispatch::Tag{}); + } Status Deserialize(const BinarySet& binset, const Config& config) override; Status DeserializeFromFile(const std::string& filename, const Config& config) override; std::unique_ptr CreateConfig() const override { - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { return std::make_unique(); } - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { return std::make_unique(); } - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { return std::make_unique(); } - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { return std::make_unique(); } - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { return std::make_unique(); } - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { return std::make_unique(); } }; @@ -122,19 +139,19 @@ class IvfIndexNode : public IndexNode { if (!index_) { return 0; } - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { auto nb = index_->invlists->compute_ntotal(); auto nlist = index_->nlist; auto code_size = index_->code_size; return ((nb + nlist) * (code_size + sizeof(int64_t))); } - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { auto nb = index_->invlists->compute_ntotal(); auto nlist = index_->nlist; auto code_size = index_->code_size; return (nb * code_size + nb * sizeof(int64_t) + nlist * code_size); } - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { auto nb = index_->invlists->compute_ntotal(); auto code_size = index_->code_size; auto pq = index_->pq; @@ -146,16 +163,16 @@ class IvfIndexNode : public IndexNode { auto precomputed_table = nlist * pq.M * pq.ksub * sizeof(float); return (capacity + centroid_table + precomputed_table); } - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { return index_->size(); } - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { auto nb = index_->invlists->compute_ntotal(); auto code_size = index_->code_size; auto nlist = index_->nlist; return (nb * code_size + nb * sizeof(int64_t) + 2 * code_size + nlist * code_size); } - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { auto nb = index_->invlists->compute_ntotal(); auto nlist = index_->nlist; auto code_size = index_->code_size; @@ -171,28 +188,42 @@ class IvfIndexNode : public IndexNode { }; std::string Type() const override { - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { return knowhere::IndexEnum::INDEX_FAISS_IVFFLAT; } - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { return knowhere::IndexEnum::INDEX_FAISS_IVFFLAT_CC; } - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { return knowhere::IndexEnum::INDEX_FAISS_IVFPQ; } - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { return knowhere::IndexEnum::INDEX_FAISS_SCANN; } - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { return knowhere::IndexEnum::INDEX_FAISS_IVFSQ8; } - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { return knowhere::IndexEnum::INDEX_FAISS_BIN_IVFFLAT; } }; private: - std::unique_ptr index_; + expected + GetIndexMetaImpl(const Config& cfg, IVFBaseTag) const { + return expected::Err(Status::not_implemented, "GetIndexMeta not implemented"); + } + expected + GetIndexMetaImpl(const Config& cfg, IVFFlatTag) const; + + Status + SerializeImpl(BinarySet& binset, IVFBaseTag) const; + + Status + SerializeImpl(BinarySet& binset, IVFFlatTag) const; + + private: + std::unique_ptr index_; std::shared_ptr search_pool_; }; @@ -243,9 +274,9 @@ to_index_flat(std::unique_ptr&& index) { } // namespace -template +template Status -IvfIndexNode::Train(const DataSet& dataset, const Config& cfg) { +IvfIndexNode::Train(const DataSet& dataset, const Config& cfg) { const BaseConfig& base_cfg = static_cast(cfg); std::unique_ptr setter; if (base_cfg.num_build_thread.has_value()) { @@ -257,7 +288,8 @@ IvfIndexNode::Train(const DataSet& dataset, const Config& cfg) { bool is_cosine = IsMetricType(base_cfg.metric_type.value(), knowhere::metric::COSINE); // do normalize for COSINE metric type - if constexpr (std::is_same_v || std::is_same_v) { + if constexpr (std::is_same_v || + std::is_same_v) { if (is_cosine) { Normalize(dataset); } @@ -275,18 +307,18 @@ IvfIndexNode::Train(const DataSet& dataset, const Config& cfg) { // faiss scann needs at least 16 rows since nbits=4 constexpr int64_t SCANN_MIN_ROWS = 16; - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { if (rows < SCANN_MIN_ROWS) { LOG_KNOWHERE_ERROR_ << rows << " rows is not enough, scann needs at least 16 rows to build index"; return Status::faiss_inner_error; } } - std::unique_ptr index; + std::unique_ptr index; // if cfg.use_elkan is used, then we'll use a temporary instance of // IndexFlatElkan for the training. try { - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { const IvfFlatConfig& ivf_flat_cfg = static_cast(cfg); auto nlist = MatchNlist(rows, ivf_flat_cfg.nlist.value()); @@ -306,7 +338,7 @@ IvfIndexNode::Train(const DataSet& dataset, const Config& cfg) { qzr.release(); index->own_fields = true; } - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { const IvfFlatCcConfig& ivf_flat_cc_cfg = static_cast(cfg); auto nlist = MatchNlist(rows, ivf_flat_cc_cfg.nlist.value()); @@ -329,7 +361,7 @@ IvfIndexNode::Train(const DataSet& dataset, const Config& cfg) { // ivfflat_cc has no serialize stage, make map at build stage index->make_direct_map(true, faiss::DirectMap::ConcurrentArray); } - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { const IvfPqConfig& ivf_pq_cfg = static_cast(cfg); auto nlist = MatchNlist(rows, ivf_pq_cfg.nlist.value()); auto nbits = MatchNbits(rows, ivf_pq_cfg.nbits.value()); @@ -351,7 +383,7 @@ IvfIndexNode::Train(const DataSet& dataset, const Config& cfg) { qzr.release(); index->own_fields = true; } - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { const ScannConfig& scann_cfg = static_cast(cfg); auto nlist = MatchNlist(rows, scann_cfg.nlist.value()); bool is_cosine = base_cfg.metric_type.value() == metric::COSINE; @@ -384,7 +416,7 @@ IvfIndexNode::Train(const DataSet& dataset, const Config& cfg) { base_index.release(); index->own_fields = true; } - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { const IvfSqConfig& ivf_sq_cfg = static_cast(cfg); auto nlist = MatchNlist(rows, ivf_sq_cfg.nlist.value()); @@ -405,7 +437,7 @@ IvfIndexNode::Train(const DataSet& dataset, const Config& cfg) { qzr.release(); index->own_fields = true; } - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { const IvfBinConfig& ivf_bin_cfg = static_cast(cfg); auto nlist = MatchNlist(rows, ivf_bin_cfg.nlist.value()); @@ -428,9 +460,9 @@ IvfIndexNode::Train(const DataSet& dataset, const Config& cfg) { return Status::success; } -template +template Status -IvfIndexNode::Add(const DataSet& dataset, const Config& cfg) { +IvfIndexNode::Add(const DataSet& dataset, const Config& cfg) { if (!this->index_) { LOG_KNOWHERE_ERROR_ << "Can not add data to empty IVF index."; return Status::empty_index; @@ -445,7 +477,7 @@ IvfIndexNode::Add(const DataSet& dataset, const Config& cfg) { setter = std::make_unique(); } try { - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { index_->add(rows, (const uint8_t*)data); } else { index_->add(rows, (const float*)data); @@ -457,9 +489,9 @@ IvfIndexNode::Add(const DataSet& dataset, const Config& cfg) { return Status::success; } -template +template expected -IvfIndexNode::Search(const DataSet& dataset, const Config& cfg, const BitsetView& bitset) const { +IvfIndexNode::Search(const DataSet& dataset, const Config& cfg, const BitsetView& bitset) const { if (!this->index_) { LOG_KNOWHERE_WARNING_ << "search on empty index"; return expected::Err(Status::empty_index, "index not loaded"); @@ -494,7 +526,7 @@ IvfIndexNode::Search(const DataSet& dataset, const Config& cfg, const BitsetV BitsetViewIDSelector bw_idselector(bitset); faiss::IDSelector* id_selector = (bitset.empty()) ? nullptr : &bw_idselector; - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { auto cur_data = (const uint8_t*)data + index * dim / 8; faiss::IVFSearchParameters ivf_search_params; @@ -507,7 +539,7 @@ IvfIndexNode::Search(const DataSet& dataset, const Config& cfg, const BitsetV distances[i + offset] = static_cast(i_distances[i + offset]); } } - } else if constexpr (std::is_same::value) { + } else if constexpr (std::is_same::value) { auto cur_query = (const float*)data + index * dim; if (is_cosine) { copied_query = CopyAndNormalizeVecs(cur_query, 1, dim); @@ -520,7 +552,7 @@ IvfIndexNode::Search(const DataSet& dataset, const Config& cfg, const BitsetV ivf_search_params.sel = id_selector; index_->search(1, cur_query, k, distances + offset, ids + offset, &ivf_search_params); - } else if constexpr (std::is_same::value) { + } else if constexpr (std::is_same::value) { auto cur_query = (const float*)data + index * dim; const ScannConfig& scann_cfg = static_cast(cfg); if (is_cosine) { @@ -574,9 +606,10 @@ IvfIndexNode::Search(const DataSet& dataset, const Config& cfg, const BitsetV return res; } -template +template expected -IvfIndexNode::RangeSearch(const DataSet& dataset, const Config& cfg, const BitsetView& bitset) const { +IvfIndexNode::RangeSearch(const DataSet& dataset, const Config& cfg, + const BitsetView& bitset) const { if (!this->index_) { LOG_KNOWHERE_WARNING_ << "range search on empty index"; return expected::Err(Status::empty_index, "index not loaded"); @@ -618,7 +651,7 @@ IvfIndexNode::RangeSearch(const DataSet& dataset, const Config& cfg, const Bi BitsetViewIDSelector bw_idselector(bitset); faiss::IDSelector* id_selector = (bitset.empty()) ? nullptr : &bw_idselector; - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { auto cur_data = (const uint8_t*)xq + index * dim / 8; faiss::IVFSearchParameters ivf_search_params; @@ -626,7 +659,7 @@ IvfIndexNode::RangeSearch(const DataSet& dataset, const Config& cfg, const Bi ivf_search_params.sel = id_selector; index_->range_search(1, cur_data, radius, &res, &ivf_search_params); - } else if constexpr (std::is_same::value) { + } else if constexpr (std::is_same::value) { auto cur_query = (const float*)xq + index * dim; if (is_cosine) { copied_query = CopyAndNormalizeVecs(cur_query, 1, dim); @@ -639,7 +672,7 @@ IvfIndexNode::RangeSearch(const DataSet& dataset, const Config& cfg, const Bi ivf_search_params.sel = id_selector; index_->range_search(1, cur_query, radius, &res, &ivf_search_params); - } else if constexpr (std::is_same::value) { + } else if constexpr (std::is_same::value) { auto cur_query = (const float*)xq + index * dim; if (is_cosine) { copied_query = CopyAndNormalizeVecs(cur_query, 1, dim); @@ -700,16 +733,16 @@ IvfIndexNode::RangeSearch(const DataSet& dataset, const Config& cfg, const Bi return GenResultDataSet(nq, ids, distances, lims); } -template +template expected -IvfIndexNode::GetVectorByIds(const DataSet& dataset) const { +IvfIndexNode::GetVectorByIds(const DataSet& dataset) const { if (!this->index_) { return expected::Err(Status::empty_index, "index not loaded"); } if (!this->index_->is_trained) { return expected::Err(Status::index_not_trained, "index not trained"); } - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { auto dim = Dim(); auto rows = dataset.GetRows(); auto ids = dataset.GetIds(); @@ -728,7 +761,8 @@ IvfIndexNode::GetVectorByIds(const DataSet& dataset) const { LOG_KNOWHERE_WARNING_ << "faiss inner error: " << e.what(); return expected::Err(Status::faiss_inner_error, e.what()); } - } else if constexpr (std::is_same::value || std::is_same::value) { + } else if constexpr (std::is_same::value || + std::is_same::value) { auto dim = Dim(); auto rows = dataset.GetRows(); auto ids = dataset.GetIds(); @@ -747,7 +781,7 @@ IvfIndexNode::GetVectorByIds(const DataSet& dataset) const { LOG_KNOWHERE_WARNING_ << "faiss inner error: " << e.what(); return expected::Err(Status::faiss_inner_error, e.what()); } - } else if constexpr (std::is_same::value) { + } else if constexpr (std::is_same::value) { // we should never go here since we should call HasRawData() first if (!index_->with_raw_data()) { return expected::Err(Status::not_implemented, "GetVectorByIds not implemented"); @@ -775,9 +809,9 @@ IvfIndexNode::GetVectorByIds(const DataSet& dataset) const { } } -template <> +template expected -IvfIndexNode::GetIndexMeta(const Config& config) const { +IvfIndexNode::GetIndexMetaImpl(const Config& config, IVFFlatTag) const { if (!index_) { LOG_KNOWHERE_WARNING_ << "get index meta on empty index"; return expected::Err(Status::empty_index, "index not loaded"); @@ -814,12 +848,12 @@ IvfIndexNode::GetIndexMeta(const Config& config) const { return GenResultDataSet(json_meta.dump(), json_id_set.dump()); } -template +template Status -IvfIndexNode::Serialize(BinarySet& binset) const { +IvfIndexNode::SerializeImpl(BinarySet& binset, IVFBaseTag) const { try { MemoryIOWriter writer; - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { faiss::write_index_binary(index_.get(), &writer); } else { faiss::write_index(index_.get(), &writer); @@ -833,13 +867,13 @@ IvfIndexNode::Serialize(BinarySet& binset) const { } } -template <> +template Status -IvfIndexNode::Serialize(BinarySet& binset) const { +IvfIndexNode::SerializeImpl(BinarySet& binset, IVFFlatTag) const { try { MemoryIOWriter writer; - LOG_KNOWHERE_INFO_ << "request version " << version_.VersionNumber(); - if (version_ <= Version::GetMinimalVersion()) { + LOG_KNOWHERE_INFO_ << "request version " << this->version_.VersionNumber(); + if (this->version_ <= Version::GetMinimalVersion()) { faiss::write_index_nm(index_.get(), &writer); LOG_KNOWHERE_INFO_ << "write IVF_FLAT_NM, file size " << writer.tellg(); } else { @@ -850,7 +884,7 @@ IvfIndexNode::Serialize(BinarySet& binset) const { binset.Append(Type(), index_data_ptr, writer.tellg()); // append raw data for backward compatible - if (version_ <= Version::GetMinimalVersion()) { + if (this->version_ <= Version::GetMinimalVersion()) { size_t dim = index_->d; size_t rows = index_->ntotal; size_t raw_data_size = dim * rows * sizeof(float); @@ -877,9 +911,9 @@ IvfIndexNode::Serialize(BinarySet& binset) const { } } -template +template Status -IvfIndexNode::Deserialize(const BinarySet& binset, const Config& config) { +IvfIndexNode::Deserialize(const BinarySet& binset, const Config& config) { std::vector names = {"IVF", // compatible with knowhere-1.x "BinaryIVF", // compatible with knowhere-1.x Type()}; @@ -891,8 +925,8 @@ IvfIndexNode::Deserialize(const BinarySet& binset, const Config& config) { MemoryIOReader reader(binary->data.get(), binary->size); try { - if constexpr (std::is_same::value) { - if (version_ <= Version::GetMinimalVersion()) { + if constexpr (std::is_same::value) { + if (this->version_ <= Version::GetMinimalVersion()) { auto raw_binary = binset.GetByName("RAW_DATA"); const BaseConfig& base_cfg = static_cast(config); ConvertIVFFlat(binset, base_cfg.metric_type.value(), raw_binary->data.get(), raw_binary->size); @@ -901,12 +935,12 @@ IvfIndexNode::Deserialize(const BinarySet& binset, const Config& config) { reader.total_ = binary->size; } index_.reset(static_cast(faiss::read_index(&reader))); - } else if constexpr (std::is_same::value) { - index_.reset(static_cast(faiss::read_index_binary(&reader))); + } else if constexpr (std::is_same::value) { + index_.reset(static_cast(faiss::read_index_binary(&reader))); } else { - index_.reset(static_cast(faiss::read_index(&reader))); + index_.reset(static_cast(faiss::read_index(&reader))); } - if constexpr (!std::is_same_v) { + if constexpr (!std::is_same_v) { const BaseConfig& base_cfg = static_cast(config); if (HasRawData(base_cfg.metric_type.value())) { index_->make_direct_map(true); @@ -919,9 +953,9 @@ IvfIndexNode::Deserialize(const BinarySet& binset, const Config& config) { return Status::success; } -template +template Status -IvfIndexNode::DeserializeFromFile(const std::string& filename, const Config& config) { +IvfIndexNode::DeserializeFromFile(const std::string& filename, const Config& config) { auto cfg = static_cast(config); int io_flags = 0; @@ -929,12 +963,12 @@ IvfIndexNode::DeserializeFromFile(const std::string& filename, const Config& io_flags |= faiss::IO_FLAG_MMAP; } try { - if constexpr (std::is_same::value) { - index_.reset(static_cast(faiss::read_index_binary(filename.data(), io_flags))); + if constexpr (std::is_same::value) { + index_.reset(static_cast(faiss::read_index_binary(filename.data(), io_flags))); } else { - index_.reset(static_cast(faiss::read_index(filename.data(), io_flags))); + index_.reset(static_cast(faiss::read_index(filename.data(), io_flags))); } - if constexpr (!std::is_same_v) { + if constexpr (!std::is_same_v) { const BaseConfig& base_cfg = static_cast(config); if (HasRawData(base_cfg.metric_type.value())) { index_->make_direct_map(true); @@ -946,42 +980,37 @@ IvfIndexNode::DeserializeFromFile(const std::string& filename, const Config& } return Status::success; } - -KNOWHERE_REGISTER_GLOBAL(IVFBIN, [](const int32_t& version, const Object& object) { - return Index>::Create(version, object); -}); - -KNOWHERE_REGISTER_GLOBAL(BIN_IVF_FLAT, [](const int32_t& version, const Object& object) { - return Index>::Create(version, object); -}); - -KNOWHERE_REGISTER_GLOBAL(IVFFLAT, [](const int32_t& version, const Object& object) { - return Index>::Create(version, object); -}); -KNOWHERE_REGISTER_GLOBAL(IVF_FLAT, [](const int32_t& version, const Object& object) { - return Index>::Create(version, object); -}); -KNOWHERE_REGISTER_GLOBAL(IVFFLATCC, [](const int32_t& version, const Object& object) { - return Index>::Create(version, object); -}); -KNOWHERE_REGISTER_GLOBAL(IVF_FLAT_CC, [](const int32_t& version, const Object& object) { - return Index>::Create(version, object); -}); -KNOWHERE_REGISTER_GLOBAL(SCANN, [](const int32_t& version, const Object& object) { - return Index>::Create(version, object); -}); -KNOWHERE_REGISTER_GLOBAL(IVFPQ, [](const int32_t& version, const Object& object) { - return Index>::Create(version, object); -}); -KNOWHERE_REGISTER_GLOBAL(IVF_PQ, [](const int32_t& version, const Object& object) { - return Index>::Create(version, object); -}); - -KNOWHERE_REGISTER_GLOBAL(IVFSQ, [](const int32_t& version, const Object& object) { - return Index>::Create(version, object); -}); -KNOWHERE_REGISTER_GLOBAL(IVF_SQ8, [](const int32_t& version, const Object& object) { - return Index>::Create(version, object); -}); - +// bin1 +KNOWHERE_SIMPLE_REGISTER_GLOBAL(IVFBIN, IvfIndexNode, bin1, faiss::IndexBinaryIVF); +KNOWHERE_SIMPLE_REGISTER_GLOBAL(BIN_IVF_FLAT, IvfIndexNode, bin1, faiss::IndexBinaryIVF); +// fp32 +KNOWHERE_SIMPLE_REGISTER_GLOBAL(IVFFLAT, IvfIndexNode, fp32, faiss::IndexIVFFlat); +KNOWHERE_SIMPLE_REGISTER_GLOBAL(IVF_FLAT, IvfIndexNode, fp32, faiss::IndexIVFFlat); +KNOWHERE_SIMPLE_REGISTER_GLOBAL(IVFFLATCC, IvfIndexNode, fp32, faiss::IndexIVFFlatCC); +KNOWHERE_SIMPLE_REGISTER_GLOBAL(IVF_FLAT_CC, IvfIndexNode, fp32, faiss::IndexIVFFlatCC); +KNOWHERE_SIMPLE_REGISTER_GLOBAL(SCANN, IvfIndexNode, fp32, faiss::IndexScaNN); +KNOWHERE_SIMPLE_REGISTER_GLOBAL(IVFPQ, IvfIndexNode, fp32, faiss::IndexIVFPQ); +KNOWHERE_SIMPLE_REGISTER_GLOBAL(IVF_PQ, IvfIndexNode, fp32, faiss::IndexIVFPQ); +KNOWHERE_SIMPLE_REGISTER_GLOBAL(IVFSQ, IvfIndexNode, fp32, faiss::IndexIVFScalarQuantizer); +KNOWHERE_SIMPLE_REGISTER_GLOBAL(IVF_SQ8, IvfIndexNode, fp32, faiss::IndexIVFScalarQuantizer); +// fp16 +KNOWHERE_MOCK_REGISTER_GLOBAL(IVFFLAT, IvfIndexNode, fp16, faiss::IndexIVFFlat); +KNOWHERE_MOCK_REGISTER_GLOBAL(IVF_FLAT, IvfIndexNode, fp16, faiss::IndexIVFFlat); +KNOWHERE_MOCK_REGISTER_GLOBAL(IVFFLATCC, IvfIndexNode, fp16, faiss::IndexIVFFlatCC); +KNOWHERE_MOCK_REGISTER_GLOBAL(IVF_FLAT_CC, IvfIndexNode, fp16, faiss::IndexIVFFlatCC); +KNOWHERE_MOCK_REGISTER_GLOBAL(SCANN, IvfIndexNode, fp16, faiss::IndexScaNN); +KNOWHERE_MOCK_REGISTER_GLOBAL(IVFPQ, IvfIndexNode, fp16, faiss::IndexIVFPQ); +KNOWHERE_MOCK_REGISTER_GLOBAL(IVF_PQ, IvfIndexNode, fp16, faiss::IndexIVFPQ); +KNOWHERE_MOCK_REGISTER_GLOBAL(IVFSQ, IvfIndexNode, fp16, faiss::IndexIVFScalarQuantizer); +KNOWHERE_MOCK_REGISTER_GLOBAL(IVF_SQ8, IvfIndexNode, fp16, faiss::IndexIVFScalarQuantizer); +// bf16 +KNOWHERE_MOCK_REGISTER_GLOBAL(IVFFLAT, IvfIndexNode, bf16, faiss::IndexIVFFlat); +KNOWHERE_MOCK_REGISTER_GLOBAL(IVF_FLAT, IvfIndexNode, bf16, faiss::IndexIVFFlat); +KNOWHERE_MOCK_REGISTER_GLOBAL(IVFFLATCC, IvfIndexNode, bf16, faiss::IndexIVFFlatCC); +KNOWHERE_MOCK_REGISTER_GLOBAL(IVF_FLAT_CC, IvfIndexNode, bf16, faiss::IndexIVFFlatCC); +KNOWHERE_MOCK_REGISTER_GLOBAL(SCANN, IvfIndexNode, bf16, faiss::IndexScaNN); +KNOWHERE_MOCK_REGISTER_GLOBAL(IVFPQ, IvfIndexNode, bf16, faiss::IndexIVFPQ); +KNOWHERE_MOCK_REGISTER_GLOBAL(IVF_PQ, IvfIndexNode, bf16, faiss::IndexIVFPQ); +KNOWHERE_MOCK_REGISTER_GLOBAL(IVFSQ, IvfIndexNode, bf16, faiss::IndexIVFScalarQuantizer); +KNOWHERE_MOCK_REGISTER_GLOBAL(IVF_SQ8, IvfIndexNode, bf16, faiss::IndexIVFScalarQuantizer); } // namespace knowhere diff --git a/tests/python/conftest.py b/tests/python/conftest.py index 4411997fb..1e1d0d637 100644 --- a/tests/python/conftest.py +++ b/tests/python/conftest.py @@ -1,6 +1,7 @@ import pytest import numpy as np import faiss +from bfloat16 import bfloat16 @pytest.fixture() @@ -13,6 +14,26 @@ def wrap(xb_rows, xq_rows, dim): return wrap +@pytest.fixture() +def gen_data_with_type(): + def wrap(xb_rows, xq_rows, dim, type): + if type == np.float16 or bfloat16: + xb = np.random.randn(xb_rows, dim).astype(type) + xq = np.random.randn(xq_rows, dim).astype(type) + # To fix nan or inf when type is equal to float16 + min_value = -10.0 + max_value = 10.0 + xb = np.clip(xb, min_value, max_value).astype(type) + xq = np.clip(xq, min_value, max_value).astype(type) + return xb,xq + else: + return ( + np.random.randn(xb_rows, dim).astype(type), + np.random.randn(xq_rows, dim).astype(type), + ) + + return wrap + @pytest.fixture() def faiss_ans(): diff --git a/tests/python/test_index_load_and_save.py b/tests/python/test_index_load_and_save.py index bc6e18568..3ebc5c0f8 100644 --- a/tests/python/test_index_load_and_save.py +++ b/tests/python/test_index_load_and_save.py @@ -2,6 +2,8 @@ import json import pytest import os +import numpy as np +from bfloat16 import bfloat16 test_data = [ ( @@ -21,7 +23,7 @@ def test_save_and_load(gen_data, faiss_ans, recall, error, name, config): # simple load and save not work for ivf nm print(name, config) version = knowhere.GetCurrentVersion() - build_idx = knowhere.CreateIndex(name, version) + build_idx = knowhere.CreateIndex(name, version, np.float32) xb, xq = gen_data(10_000, 100, 256) # build, serialize and dump @@ -36,7 +38,77 @@ def test_save_and_load(gen_data, faiss_ans, recall, error, name, config): # load and deserialize new_binset = knowhere.GetBinarySet() knowhere.Load(new_binset, index_file) - search_idx = knowhere.CreateIndex(name, version) + search_idx = knowhere.CreateIndex(name, version, np.float32) + search_idx.Deserialize(new_binset) + + # test the loaded index + ans, _ = search_idx.Search( + knowhere.ArrayToDataSet(xq), json.dumps(config), knowhere.GetNullBitSetView() + ) + k_dis, k_ids = knowhere.DataSetToArray(ans) + f_dis, f_ids = faiss_ans(xb, xq, config["metric_type"], config["k"]) + assert recall(f_ids, k_ids) >= 0.99 + assert error(f_dis, f_dis) <= 0.01 + + # delete the index_file + os.remove(index_file) + +@pytest.mark.parametrize("name,config", test_data) +def test_float16_save_and_load(gen_data_with_type, faiss_ans, recall, error, name, config): + # simple load and save not work for ivf nm + print(name, config) + version = knowhere.GetCurrentVersion() + build_idx = knowhere.CreateIndex(name, version, np.float16) + xb, xq = gen_data_with_type(10_000, 100, 256, np.float16) + + # build, serialize and dump + build_idx.Build( + knowhere.ArrayToDataSet(xb), + json.dumps(config), + ) + binset = knowhere.GetBinarySet() + build_idx.Serialize(binset) + knowhere.Dump(binset, index_file) + + # load and deserialize + new_binset = knowhere.GetBinarySet() + knowhere.Load(new_binset, index_file) + search_idx = knowhere.CreateIndex(name, version, np.float16) + search_idx.Deserialize(new_binset) + + # test the loaded index + ans, _ = search_idx.Search( + knowhere.ArrayToDataSet(xq), json.dumps(config), knowhere.GetNullBitSetView() + ) + k_dis, k_ids = knowhere.DataSetToArray(ans) + f_dis, f_ids = faiss_ans(xb, xq, config["metric_type"], config["k"]) + assert recall(f_ids, k_ids) >= 0.99 + assert error(f_dis, f_dis) <= 0.01 + + # delete the index_file + os.remove(index_file) + +@pytest.mark.parametrize("name,config", test_data) +def test_bfloat16_save_and_load(gen_data_with_type, faiss_ans, recall, error, name, config): + # simple load and save not work for ivf nm + print(name, config) + version = knowhere.GetCurrentVersion() + build_idx = knowhere.CreateIndex(name, version, bfloat16) + xb, xq = gen_data_with_type(10_000, 100, 256, bfloat16) + + # build, serialize and dump + build_idx.Build( + knowhere.ArrayToDataSet(xb), + json.dumps(config), + ) + binset = knowhere.GetBinarySet() + build_idx.Serialize(binset) + knowhere.Dump(binset, index_file) + + # load and deserialize + new_binset = knowhere.GetBinarySet() + knowhere.Load(new_binset, index_file) + search_idx = knowhere.CreateIndex(name, version, bfloat16) search_idx.Deserialize(new_binset) # test the loaded index diff --git a/tests/python/test_index_random.py b/tests/python/test_index_random.py index e17c1904e..d0f18b901 100644 --- a/tests/python/test_index_random.py +++ b/tests/python/test_index_random.py @@ -1,6 +1,8 @@ import knowhere import json import pytest +import numpy as np +from bfloat16 import bfloat16 test_data = [ ( @@ -99,3 +101,89 @@ def test_index(gen_data, faiss_ans, recall, error, name, config): else: assert recall(f_ids, k_ids) >= 0.5 assert error(f_dis, f_dis) <= 0.01 + +@pytest.mark.parametrize("name,config", test_data) +def test_float16_index(gen_data_with_type, faiss_ans, recall, error, name, config): + print(name, config) + version = knowhere.GetCurrentVersion() + idx = knowhere.CreateIndex(name, version, np.float16) + xb, xq = gen_data_with_type(10000, 100, 256, np.float16) + + idx.Build( + knowhere.ArrayToDataSet(xb), + json.dumps(config), + ) + + ans, _ = idx.Search( + knowhere.ArrayToDataSet(xq), + json.dumps(config), + knowhere.GetNullBitSetView() + ) + k_dis, k_ids = knowhere.DataSetToArray(ans) + f_dis, f_ids = faiss_ans(xb, xq, config["metric_type"], config["k"]) + if (name != "IVFSQ"): + assert recall(f_ids, k_ids) >= 0.99 + else: + assert recall(f_ids, k_ids) >= 0.70 + assert error(f_dis, f_dis) <= 0.01 + + bitset = knowhere.CreateBitSet(xb.shape[0]) + for id in k_ids[:10,:1].ravel(): + if id < 0: + continue + bitset.SetBit(int(id)) + ans, _ = idx.Search( + knowhere.ArrayToDataSet(xq), + json.dumps(config), + bitset.GetBitSetView() + ) + + k_dis, k_ids = knowhere.DataSetToArray(ans) + if (name != "IVFSQ"): + assert recall(f_ids, k_ids) >= 0.7 + else: + assert recall(f_ids, k_ids) >= 0.5 + assert error(f_dis, f_dis) <= 0.01 + +@pytest.mark.parametrize("name,config", test_data) +def test_bfloat16_index(gen_data_with_type, faiss_ans, recall, error, name, config): + print(name, config) + version = knowhere.GetCurrentVersion() + idx = knowhere.CreateIndex(name, version, bfloat16) + xb, xq = gen_data_with_type(10000, 100, 256, bfloat16) + + idx.Build( + knowhere.ArrayToDataSet(xb), + json.dumps(config), + ) + + ans, _ = idx.Search( + knowhere.ArrayToDataSet(xq), + json.dumps(config), + knowhere.GetNullBitSetView() + ) + k_dis, k_ids = knowhere.DataSetToArray(ans) + f_dis, f_ids = faiss_ans(xb, xq, config["metric_type"], config["k"]) + if (name != "IVFSQ"): + assert recall(f_ids, k_ids) >= 0.99 + else: + assert recall(f_ids, k_ids) >= 0.70 + assert error(f_dis, f_dis) <= 0.01 + + bitset = knowhere.CreateBitSet(xb.shape[0]) + for id in k_ids[:10,:1].ravel(): + if id < 0: + continue + bitset.SetBit(int(id)) + ans, _ = idx.Search( + knowhere.ArrayToDataSet(xq), + json.dumps(config), + bitset.GetBitSetView() + ) + + k_dis, k_ids = knowhere.DataSetToArray(ans) + if (name != "IVFSQ"): + assert recall(f_ids, k_ids) >= 0.7 + else: + assert recall(f_ids, k_ids) >= 0.5 + assert error(f_dis, f_dis) <= 0.01 diff --git a/tests/ut/test_bruteforce.cc b/tests/ut/test_bruteforce.cc index fd5c16f78..4cdd80143 100644 --- a/tests/ut/test_bruteforce.cc +++ b/tests/ut/test_bruteforce.cc @@ -38,7 +38,7 @@ TEST_CASE("Test Brute Force", "[float vector]") { }; SECTION("Test Search") { - auto res = knowhere::BruteForce::Search(train_ds, query_ds, conf, nullptr); + auto res = knowhere::BruteForce::Search(train_ds, query_ds, conf, nullptr); REQUIRE(res.has_value()); auto ids = res.value()->GetIds(); auto dist = res.value()->GetDistance(); @@ -55,7 +55,7 @@ TEST_CASE("Test Brute Force", "[float vector]") { SECTION("Test Search With Buf") { auto ids = new int64_t[nq * k]; auto dist = new float[nq * k]; - auto res = knowhere::BruteForce::SearchWithBuf(train_ds, query_ds, ids, dist, conf, nullptr); + auto res = knowhere::BruteForce::SearchWithBuf(train_ds, query_ds, ids, dist, conf, nullptr); REQUIRE(res == knowhere::Status::success); for (int64_t i = 0; i < nq; i++) { REQUIRE(ids[i * k] == i); @@ -70,7 +70,7 @@ TEST_CASE("Test Brute Force", "[float vector]") { } SECTION("Test Range Search") { - auto res = knowhere::BruteForce::RangeSearch(train_ds, query_ds, conf, nullptr); + auto res = knowhere::BruteForce::RangeSearch(train_ds, query_ds, conf, nullptr); REQUIRE(res.has_value()); auto ids = res.value()->GetIds(); auto dist = res.value()->GetDistance(); @@ -112,7 +112,7 @@ TEST_CASE("Test Brute Force", "[binary vector]") { }; SECTION("Test Search") { - auto res = knowhere::BruteForce::Search(train_ds, query_ds, conf, nullptr); + auto res = knowhere::BruteForce::Search(train_ds, query_ds, conf, nullptr); REQUIRE(res.has_value()); auto ids = res.value()->GetIds(); auto dist = res.value()->GetDistance(); @@ -125,7 +125,7 @@ TEST_CASE("Test Brute Force", "[binary vector]") { SECTION("Test Search With Buf") { auto ids = new int64_t[nq * k]; auto dist = new float[nq * k]; - auto res = knowhere::BruteForce::SearchWithBuf(train_ds, query_ds, ids, dist, conf, nullptr); + auto res = knowhere::BruteForce::SearchWithBuf(train_ds, query_ds, ids, dist, conf, nullptr); REQUIRE(res == knowhere::Status::success); for (int64_t i = 0; i < nq; i++) { REQUIRE(ids[i * k] == i); @@ -144,7 +144,7 @@ TEST_CASE("Test Brute Force", "[binary vector]") { auto cfg = conf; cfg[knowhere::meta::RADIUS] = radius_map[metric]; - auto res = knowhere::BruteForce::RangeSearch(train_ds, query_ds, cfg, nullptr); + auto res = knowhere::BruteForce::RangeSearch(train_ds, query_ds, cfg, nullptr); REQUIRE(res.has_value()); auto ids = res.value()->GetIds(); auto dist = res.value()->GetDistance(); diff --git a/tests/ut/test_diskann.cc b/tests/ut/test_diskann.cc index 70f0f914d..5f8006f04 100644 --- a/tests/ut/test_diskann.cc +++ b/tests/ut/test_diskann.cc @@ -14,12 +14,12 @@ #include "catch2/catch_approx.hpp" #include "catch2/catch_test_macros.hpp" #include "catch2/generators/catch_generators.hpp" -#include "index/diskann/diskann.cc" #include "index/diskann/diskann_config.h" #include "knowhere/comp/brute_force.h" #include "knowhere/comp/local_file_manager.h" #include "knowhere/expected.h" #include "knowhere/factory.h" +#include "knowhere/utils.h" #include "utils.h" #if __has_include() #include @@ -52,12 +52,13 @@ constexpr float kL2RangeAp = 0.9; constexpr float kIpRangeAp = 0.9; constexpr float kCosineRangeAp = 0.9; +template void -WriteRawDataToDisk(const std::string data_path, const float* raw_data, const uint32_t num, const uint32_t dim) { +WriteRawDataToDisk(const std::string data_path, const T* raw_data, const uint32_t num, const uint32_t dim) { std::ofstream writer(data_path.c_str(), std::ios::binary); writer.write((char*)&num, sizeof(uint32_t)); writer.write((char*)&dim, sizeof(uint32_t)); - writer.write((char*)raw_data, sizeof(float) * num * dim); + writer.write((char*)raw_data, sizeof(T) * num * dim); writer.close(); } @@ -92,11 +93,11 @@ TEST_CASE("Invalid diskann params test", "[diskann]") { auto diskann_index_pack = knowhere::Pack(file_manager); auto base_ds = GenDataSet(rows_num, kDim, 30); auto base_ptr = static_cast(base_ds->GetTensor()); - WriteRawDataToDisk(kRawDataPath, base_ptr, rows_num, kDim); + WriteRawDataToDisk(kRawDataPath, base_ptr, rows_num, kDim); // build process SECTION("Invalid build params test") { knowhere::DataSet* ds_ptr = nullptr; - auto diskann = knowhere::IndexFactory::Instance().Create("DISKANN", version, diskann_index_pack); + auto diskann = knowhere::IndexFactory::Instance().Create("DISKANN", version, diskann_index_pack); knowhere::Json test_json; knowhere::Status test_stat; // invalid metric type @@ -114,7 +115,7 @@ TEST_CASE("Invalid diskann params test", "[diskann]") { SECTION("Invalid search params test") { knowhere::DataSet* ds_ptr = nullptr; auto binarySet = knowhere::BinarySet(); - auto diskann = knowhere::IndexFactory::Instance().Create("DISKANN", version, diskann_index_pack); + auto diskann = knowhere::IndexFactory::Instance().Create("DISKANN", version, diskann_index_pack); diskann.Build(*ds_ptr, test_gen()); diskann.Serialize(binarySet); diskann.Deserialize(binarySet, test_gen()); @@ -144,7 +145,10 @@ TEST_CASE("Invalid diskann params test", "[diskann]") { fs::remove(kDir); } -TEST_CASE("Test DiskANNIndexNode.", "[diskann]") { +template +inline void +base_search() { + std::cout << "test data type" << typeid(data_type).name() << std::endl; fs::remove_all(kDir); fs::remove(kDir); REQUIRE_NOTHROW(fs::create_directory(kDir)); @@ -220,20 +224,26 @@ TEST_CASE("Test DiskANNIndexNode.", "[diskann]") { return json; }; - auto query_ds = GenDataSet(kNumQueries, kDim, 42); + auto fp32_query_ds = GenDataSet(kNumQueries, kDim, 42); knowhere::DataSetPtr knn_gt_ptr = nullptr; knowhere::DataSetPtr range_search_gt_ptr = nullptr; - auto base_ds = GenDataSet(kNumRows, kDim, 30); + auto fp32_base_ds = GenDataSet(kNumRows, kDim, 30); + knowhere::DataSetPtr base_ds(fp32_base_ds); + knowhere::DataSetPtr query_ds(fp32_query_ds); + if (!std::is_same_v) { + base_ds = knowhere::data_type_conversion(*fp32_base_ds); + query_ds = knowhere::data_type_conversion(*fp32_query_ds); + } { - auto base_ptr = static_cast(base_ds->GetTensor()); - WriteRawDataToDisk(kRawDataPath, base_ptr, kNumRows, kDim); + auto base_ptr = static_cast(base_ds->GetTensor()); + WriteRawDataToDisk(kRawDataPath, base_ptr, kNumRows, kDim); // generate the gt of knn search and range search auto base_json = base_gen(); - auto result_knn = knowhere::BruteForce::Search(base_ds, query_ds, base_json, nullptr); + auto result_knn = knowhere::BruteForce::Search(base_ds, query_ds, base_json, nullptr); knn_gt_ptr = result_knn.value(); - auto result_range = knowhere::BruteForce::RangeSearch(base_ds, query_ds, base_json, nullptr); + auto result_range = knowhere::BruteForce::RangeSearch(base_ds, query_ds, base_json, nullptr); range_search_gt_ptr = result_range.value(); } @@ -245,7 +255,7 @@ TEST_CASE("Test DiskANNIndexNode.", "[diskann]") { // build process { knowhere::DataSet* ds_ptr = nullptr; - auto diskann = knowhere::IndexFactory::Instance().Create("DISKANN", version, diskann_index_pack); + auto diskann = knowhere::IndexFactory::Instance().Create("DISKANN", version, diskann_index_pack); auto build_json = build_gen().dump(); knowhere::Json json = knowhere::Json::parse(build_json); diskann.Build(*ds_ptr, json); @@ -253,7 +263,7 @@ TEST_CASE("Test DiskANNIndexNode.", "[diskann]") { } { // knn search - auto diskann = knowhere::IndexFactory::Instance().Create("DISKANN", version, diskann_index_pack); + auto diskann = knowhere::IndexFactory::Instance().Create("DISKANN", version, diskann_index_pack); diskann.Deserialize(binset, deserialize_json); auto knn_search_json = knn_search_gen().dump(); @@ -270,7 +280,8 @@ TEST_CASE("Test DiskANNIndexNode.", "[diskann]") { if (fs::exists(cached_nodes_file_path)) { fs::remove(cached_nodes_file_path); } - auto diskann_tmp = knowhere::IndexFactory::Instance().Create("DISKANN", version, diskann_index_pack); + auto diskann_tmp = + knowhere::IndexFactory::Instance().Create("DISKANN", version, diskann_index_pack); diskann_tmp.Deserialize(binset, deserialize_json); auto knn_search_json = knn_search_gen().dump(); knowhere::Json knn_json = knowhere::Json::parse(knn_search_json); @@ -291,7 +302,7 @@ TEST_CASE("Test DiskANNIndexNode.", "[diskann]") { auto bitset_data = gen_func(kNumRows, percentage * kNumRows); knowhere::BitsetView bitset(bitset_data.data(), kNumRows); auto results = diskann.Search(*query_ds, knn_json, bitset); - auto gt = knowhere::BruteForce::Search(base_ds, query_ds, knn_json, bitset); + auto gt = knowhere::BruteForce::Search(base_ds, query_ds, knn_json, bitset); float recall = GetKNNRecall(*gt.value(), *results.value()); if (percentage == 0.98f) { REQUIRE(recall >= 0.9f); @@ -316,6 +327,10 @@ TEST_CASE("Test DiskANNIndexNode.", "[diskann]") { fs::remove(kDir); } +TEST_CASE("Test DiskANNIndexNode.", "[diskann]") { + base_search(); +} + // This test case only check L2 TEST_CASE("Test DiskANN GetVectorByIds", "[diskann]") { auto version = GenTestVersionList(); @@ -347,13 +362,13 @@ TEST_CASE("Test DiskANN GetVectorByIds", "[diskann]") { auto query_ds = GenDataSet(kNumQueries, dim, 42); auto base_ds = GenDataSet(kNumRows, dim, 30); auto base_ptr = static_cast(base_ds->GetTensor()); - WriteRawDataToDisk(kRawDataPath, base_ptr, kNumRows, dim); + WriteRawDataToDisk(kRawDataPath, base_ptr, kNumRows, dim); std::shared_ptr file_manager = std::make_shared(); auto diskann_index_pack = knowhere::Pack(file_manager); knowhere::DataSet* ds_ptr = nullptr; - auto diskann = knowhere::IndexFactory::Instance().Create("DISKANN", version, diskann_index_pack); + auto diskann = knowhere::IndexFactory::Instance().Create("DISKANN", version, diskann_index_pack); auto build_json = build_gen().dump(); knowhere::Json json = knowhere::Json::parse(build_json); diskann.Build(*ds_ptr, json); @@ -369,7 +384,7 @@ TEST_CASE("Test DiskANN GetVectorByIds", "[diskann]") { return json; }; knowhere::Json deserialize_json = knowhere::Json::parse(deserialize_gen().dump()); - auto index = knowhere::IndexFactory::Instance().Create("DISKANN", version, diskann_index_pack); + auto index = knowhere::IndexFactory::Instance().Create("DISKANN", version, diskann_index_pack); auto ret = index.Deserialize(binset, deserialize_json); REQUIRE(ret == knowhere::Status::success); std::vector ids_sizes = {1, kNumRows * 0.2, kNumRows * 0.7, kNumRows}; diff --git a/tests/ut/test_feder.cc b/tests/ut/test_feder.cc index 905597edc..131dadb3f 100644 --- a/tests/ut/test_feder.cc +++ b/tests/ut/test_feder.cc @@ -173,11 +173,11 @@ TEST_CASE("Test Feder", "[feder]") { const auto query_ds = GenDataSet(nq, dim, seed); const knowhere::Json conf = base_gen(); - auto gt = knowhere::BruteForce::Search(train_ds, query_ds, conf, nullptr); + auto gt = knowhere::BruteForce::Search(train_ds, query_ds, conf, nullptr); SECTION("Test HNSW Feder") { auto name = knowhere::IndexEnum::INDEX_HNSW; - auto idx = knowhere::IndexFactory::Instance().Create(name, version); + auto idx = knowhere::IndexFactory::Instance().Create(name, version); REQUIRE(idx.Type() == name); auto json = hnsw_gen(); @@ -201,7 +201,7 @@ TEST_CASE("Test Feder", "[feder]") { SECTION("Test IVF_FLAT Feder") { auto name = knowhere::IndexEnum::INDEX_FAISS_IVFFLAT; - auto idx = knowhere::IndexFactory::Instance().Create(name, version); + auto idx = knowhere::IndexFactory::Instance().Create(name, version); REQUIRE(idx.Type() == name); auto json = ivfflat_gen(); diff --git a/tests/ut/test_get_vector.cc b/tests/ut/test_get_vector.cc index 182be9702..585061eed 100644 --- a/tests/ut/test_get_vector.cc +++ b/tests/ut/test_get_vector.cc @@ -60,7 +60,7 @@ TEST_CASE("Test Binary Get Vector By Ids", "[Binary GetVectorByIds]") { make_tuple(knowhere::IndexEnum::INDEX_FAISS_BIN_IVFFLAT, bin_ivfflat_gen), make_tuple(knowhere::IndexEnum::INDEX_HNSW, bin_hnsw_gen), })); - auto idx = knowhere::IndexFactory::Instance().Create(name, version); + auto idx = knowhere::IndexFactory::Instance().Create(name, version); auto cfg_json = gen().dump(); CAPTURE(name, cfg_json); knowhere::Json json = knowhere::Json::parse(cfg_json); @@ -75,7 +75,7 @@ TEST_CASE("Test Binary Get Vector By Ids", "[Binary GetVectorByIds]") { knowhere::BinarySet bs; idx.Serialize(bs); - auto idx_new = knowhere::IndexFactory::Instance().Create(name, version); + auto idx_new = knowhere::IndexFactory::Instance().Create(name, version); idx_new.Deserialize(bs); auto retrieve_task = [&]() { @@ -173,7 +173,7 @@ TEST_CASE("Test Float Get Vector By Ids", "[Float GetVectorByIds]") { make_tuple(knowhere::IndexEnum::INDEX_FAISS_SCANN, scann_gen2), make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen), })); - auto idx = knowhere::IndexFactory::Instance().Create(name, version); + auto idx = knowhere::IndexFactory::Instance().Create(name, version); auto cfg_json = gen().dump(); CAPTURE(name, cfg_json); knowhere::Json json = knowhere::Json::parse(cfg_json); @@ -189,7 +189,7 @@ TEST_CASE("Test Float Get Vector By Ids", "[Float GetVectorByIds]") { knowhere::BinarySet bs; idx.Serialize(bs); - auto idx_new = knowhere::IndexFactory::Instance().Create(name, version); + auto idx_new = knowhere::IndexFactory::Instance().Create(name, version); idx_new.Deserialize(bs); auto retrieve_task = [&]() { diff --git a/tests/ut/test_gpu_search.cc b/tests/ut/test_gpu_search.cc index 6c4e7cb56..2614b5385 100644 --- a/tests/ut/test_gpu_search.cc +++ b/tests/ut/test_gpu_search.cc @@ -85,7 +85,7 @@ TEST_CASE("Test All GPU Index", "[search]") { make_tuple(knowhere::IndexEnum::INDEX_RAFT_IVFPQ, refined_gen(ivfpq_gen)), make_tuple(knowhere::IndexEnum::INDEX_RAFT_CAGRA, cagra_gen), })); - auto idx = knowhere::IndexFactory::Instance().Create(name, version); + auto idx = knowhere::IndexFactory::Instance().Create(name, version); auto cfg_json = gen().dump(); CAPTURE(name, cfg_json); knowhere::Json json = knowhere::Json::parse(cfg_json); @@ -111,7 +111,7 @@ TEST_CASE("Test All GPU Index", "[search]") { make_tuple(knowhere::IndexEnum::INDEX_RAFT_IVFPQ, refined_gen(ivfpq_gen)), make_tuple(knowhere::IndexEnum::INDEX_RAFT_CAGRA, cagra_gen), })); - auto idx = knowhere::IndexFactory::Instance().Create(name, version); + auto idx = knowhere::IndexFactory::Instance().Create(name, version); auto cfg_json = gen().dump(); CAPTURE(name, cfg_json); knowhere::Json json = knowhere::Json::parse(cfg_json); @@ -130,7 +130,7 @@ TEST_CASE("Test All GPU Index", "[search]") { knowhere::BitsetView bitset(bitset_data.data(), nb); auto results = idx.Search(*query_ds, json, bitset); REQUIRE(results.has_value()); - auto gt = knowhere::BruteForce::Search(train_ds, query_ds, json, bitset); + auto gt = knowhere::BruteForce::Search(train_ds, query_ds, json, bitset); float recall = GetKNNRecall(*gt.value(), *results.value()); if (percentage == 0.98f) { REQUIRE(recall > 0.4f); @@ -150,7 +150,7 @@ TEST_CASE("Test All GPU Index", "[search]") { make_tuple(knowhere::IndexEnum::INDEX_RAFT_IVFPQ, refined_gen(ivfpq_gen)), make_tuple(knowhere::IndexEnum::INDEX_RAFT_CAGRA, cagra_gen), })); - auto idx = knowhere::IndexFactory::Instance().Create(name, version); + auto idx = knowhere::IndexFactory::Instance().Create(name, version); auto cfg_json = gen().dump(); CAPTURE(name, cfg_json); knowhere::Json json = knowhere::Json::parse(cfg_json); @@ -166,7 +166,7 @@ TEST_CASE("Test All GPU Index", "[search]") { json[knowhere::meta::TOPK] = std::get<0>(topKTuple); auto results = idx.Search(*query_ds, json, nullptr); REQUIRE(results.has_value()); - auto gt = knowhere::BruteForce::Search(train_ds, query_ds, json, nullptr); + auto gt = knowhere::BruteForce::Search(train_ds, query_ds, json, nullptr); float recall = GetKNNRecall(*gt.value(), *results.value()); REQUIRE(recall >= std::get<1>(topKTuple)); } @@ -182,7 +182,7 @@ TEST_CASE("Test All GPU Index", "[search]") { make_tuple(knowhere::IndexEnum::INDEX_RAFT_CAGRA, cagra_gen), })); - auto idx = knowhere::IndexFactory::Instance().Create(name, version); + auto idx = knowhere::IndexFactory::Instance().Create(name, version); auto cfg_json = gen().dump(); CAPTURE(name, cfg_json); knowhere::Json json = knowhere::Json::parse(cfg_json); @@ -193,7 +193,7 @@ TEST_CASE("Test All GPU Index", "[search]") { REQUIRE(res == knowhere::Status::success); knowhere::BinarySet bs; idx.Serialize(bs); - auto idx_ = knowhere::IndexFactory::Instance().Create(name, version); + auto idx_ = knowhere::IndexFactory::Instance().Create(name, version); idx_.Deserialize(bs); auto results = idx_.Search(*query_ds, json, nullptr); REQUIRE(results.has_value()); @@ -213,7 +213,7 @@ TEST_CASE("Test All GPU Index", "[search]") { make_tuple(knowhere::IndexEnum::INDEX_RAFT_CAGRA, cagra_gen), })); auto rows = 16; - auto idx = knowhere::IndexFactory::Instance().Create(name, version); + auto idx = knowhere::IndexFactory::Instance().Create(name, version); auto cfg_json = gen().dump(); CAPTURE(name, cfg_json); knowhere::Json json = knowhere::Json::parse(cfg_json); @@ -228,7 +228,7 @@ TEST_CASE("Test All GPU Index", "[search]") { knowhere::BitsetView bitset(bitset_data.data(), rows); auto results = idx.Search(*train_ds, json, bitset); REQUIRE(results.has_value()); - auto gt = knowhere::BruteForce::Search(train_ds, train_ds, json, bitset); + auto gt = knowhere::BruteForce::Search(train_ds, train_ds, json, bitset); float recall = GetKNNRecall(*gt.value(), *results.value()); REQUIRE(recall == 1.0f); } diff --git a/tests/ut/test_half_presicion.cc b/tests/ut/test_half_presicion.cc new file mode 100644 index 000000000..23c5bab88 --- /dev/null +++ b/tests/ut/test_half_presicion.cc @@ -0,0 +1,258 @@ +// Copyright (C) 2019-2023 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include "catch2/catch_approx.hpp" +#include "catch2/catch_test_macros.hpp" +#include "catch2/generators/catch_generators.hpp" +#include "faiss/utils/binary_distances.h" +#include "hnswlib/hnswalg.h" +#include "knowhere/bitsetview.h" +#include "knowhere/comp/brute_force.h" +#include "knowhere/comp/index_param.h" +#include "knowhere/comp/knowhere_config.h" +#include "knowhere/factory.h" +#include "knowhere/log.h" +#include "knowhere/utils.h" +#include "utils.h" + +namespace { +constexpr float kKnnRecallThreshold = 0.6f; +constexpr float kBruteForceRecallThreshold = 0.99f; +} // namespace +template +void +BaseSearchTest() { + using Catch::Approx; + + const int64_t nb = 1000, nq = 10; + const int64_t dim = 128; + + auto metric = GENERATE(as{}, knowhere::metric::L2, knowhere::metric::COSINE); + auto topk = GENERATE(as{}, 5, 120); + auto version = GenTestVersionList(); + + auto base_gen = [=]() { + knowhere::Json json; + json[knowhere::meta::DIM] = dim; + json[knowhere::meta::METRIC_TYPE] = metric; + json[knowhere::meta::TOPK] = topk; + json[knowhere::meta::RADIUS] = knowhere::IsMetricType(metric, knowhere::metric::L2) ? 10.0 : 0.99; + json[knowhere::meta::RANGE_FILTER] = knowhere::IsMetricType(metric, knowhere::metric::L2) ? 0.0 : 1.01; + return json; + }; + + auto ivfflat_gen = [base_gen]() { + knowhere::Json json = base_gen(); + json[knowhere::indexparam::NLIST] = 16; + json[knowhere::indexparam::NPROBE] = 8; + return json; + }; + + auto ivfflatcc_gen = [ivfflat_gen]() { + knowhere::Json json = ivfflat_gen(); + json[knowhere::indexparam::SSIZE] = 48; + return json; + }; + + auto ivfsq_gen = ivfflat_gen; + + auto flat_gen = base_gen; + + auto ivfpq_gen = [ivfflat_gen]() { + knowhere::Json json = ivfflat_gen(); + json[knowhere::indexparam::M] = 4; + json[knowhere::indexparam::NBITS] = 8; + return json; + }; + + auto scann_gen = [ivfflat_gen]() { + knowhere::Json json = ivfflat_gen(); + json[knowhere::indexparam::NPROBE] = 14; + json[knowhere::indexparam::REORDER_K] = 500; + json[knowhere::indexparam::WITH_RAW_DATA] = true; + return json; + }; + + auto scann_gen2 = [scann_gen]() { + knowhere::Json json = scann_gen(); + json[knowhere::indexparam::WITH_RAW_DATA] = false; + return json; + }; + + auto hnsw_gen = [base_gen]() { + knowhere::Json json = base_gen(); + json[knowhere::indexparam::HNSW_M] = 128; + json[knowhere::indexparam::EFCONSTRUCTION] = 200; + json[knowhere::indexparam::EF] = 200; + return json; + }; + + const auto fp32_train_ds = GenDataSet(nb, dim); + const auto fp32_query_ds = GenDataSet(nq, dim); + auto train_ds = knowhere::data_type_conversion(*fp32_train_ds); + auto query_ds = knowhere::data_type_conversion(*fp32_query_ds); + + const knowhere::Json conf = { + {knowhere::meta::METRIC_TYPE, metric}, + {knowhere::meta::TOPK, topk}, + }; + auto gt = knowhere::BruteForce::Search(train_ds, query_ds, conf, nullptr); + + SECTION("Test half-float Search") { + using std::make_tuple; + auto [name, gen] = GENERATE_REF(table>({ + make_tuple(knowhere::IndexEnum::INDEX_FAISS_IDMAP, flat_gen), + make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT, ivfflat_gen), + make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT_CC, ivfflatcc_gen), + make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFSQ8, ivfsq_gen), + make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFPQ, ivfpq_gen), + make_tuple(knowhere::IndexEnum::INDEX_FAISS_SCANN, scann_gen), + make_tuple(knowhere::IndexEnum::INDEX_FAISS_SCANN, scann_gen2), + make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen), + })); + auto idx = knowhere::IndexFactory::Instance().Create(name, version); + auto cfg_json = gen().dump(); + CAPTURE(name, cfg_json); + knowhere::Json json = knowhere::Json::parse(cfg_json); + REQUIRE(idx.Type() == name); + REQUIRE(idx.Build(*train_ds, json) == knowhere::Status::success); + REQUIRE(idx.Size() > 0); + REQUIRE(idx.Count() == nb); + + knowhere::BinarySet bs; + REQUIRE(idx.Serialize(bs) == knowhere::Status::success); + REQUIRE(idx.Deserialize(bs, json) == knowhere::Status::success); + + auto results = idx.Search(*query_ds, json, nullptr); + REQUIRE(results.has_value()); + float recall = GetKNNRecall(*gt.value(), *results.value()); + bool scann_without_raw_data = + (name == knowhere::IndexEnum::INDEX_FAISS_SCANN && scann_gen2().dump() == cfg_json); + if (name != knowhere::IndexEnum::INDEX_FAISS_IVFPQ && !scann_without_raw_data) { + REQUIRE(recall > kKnnRecallThreshold); + } + + if (metric == knowhere::metric::COSINE) { + if (name != knowhere::IndexEnum::INDEX_FAISS_IVFSQ8 && name != knowhere::IndexEnum::INDEX_FAISS_IVFPQ && + !scann_without_raw_data) { + REQUIRE(CheckDistanceInScope(*results.value(), topk, -1.00001, 1.00001)); + } + } + } + + SECTION("Test half-float Range Search") { + using std::make_tuple; + auto [name, gen] = GENERATE_REF(table>({ + make_tuple(knowhere::IndexEnum::INDEX_FAISS_IDMAP, flat_gen), + make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT, ivfflat_gen), + make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT_CC, ivfflatcc_gen), + make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFSQ8, ivfsq_gen), + make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFPQ, ivfpq_gen), + make_tuple(knowhere::IndexEnum::INDEX_FAISS_SCANN, scann_gen), + make_tuple(knowhere::IndexEnum::INDEX_FAISS_SCANN, scann_gen2), + make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen), + })); + auto idx = knowhere::IndexFactory::Instance().Create(name, version); + auto cfg_json = gen().dump(); + CAPTURE(name, cfg_json); + knowhere::Json json = knowhere::Json::parse(cfg_json); + REQUIRE(idx.Type() == name); + REQUIRE(idx.Build(*train_ds, json) == knowhere::Status::success); + + knowhere::BinarySet bs; + REQUIRE(idx.Serialize(bs) == knowhere::Status::success); + REQUIRE(idx.Deserialize(bs, json) == knowhere::Status::success); + + auto results = idx.RangeSearch(*query_ds, json, nullptr); + REQUIRE(results.has_value()); + auto ids = results.value()->GetIds(); + auto lims = results.value()->GetLims(); + auto dis = results.value()->GetDistance(); + bool scann_without_raw_data = + (name == knowhere::IndexEnum::INDEX_FAISS_SCANN && scann_gen2().dump() == cfg_json); + if (name != knowhere::IndexEnum::INDEX_FAISS_IVFPQ && name != knowhere::IndexEnum::INDEX_FAISS_SCANN) { + for (int i = 0; i < nq; ++i) { + CHECK(ids[lims[i]] == i); + } + } + + if (metric == knowhere::metric::COSINE) { + if (name != knowhere::IndexEnum::INDEX_FAISS_IVFSQ8 && name != knowhere::IndexEnum::INDEX_FAISS_IVFPQ && + !scann_without_raw_data) { + REQUIRE(CheckDistanceInScope(*results.value(), -1.00001, 1.00001)); + } + } + } + + SECTION("Test half-float Search with Bitset") { + using std::make_tuple; + auto [name, gen, threshold] = GENERATE_REF(table, float>({ + make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen, hnswlib::kHnswSearchKnnBFFilterThreshold), + })); + auto idx = knowhere::IndexFactory::Instance().Create(name, version); + auto cfg_json = gen().dump(); + CAPTURE(name, cfg_json); + knowhere::Json json = knowhere::Json::parse(cfg_json); + REQUIRE(idx.Type() == name); + REQUIRE(idx.Build(*train_ds, json) == knowhere::Status::success); + + std::vector(size_t, size_t)>> gen_bitset_funcs = { + GenerateBitsetWithFirstTbitsSet, GenerateBitsetWithRandomTbitsSet}; + const auto bitset_percentages = {0.4f, 0.98f}; + for (const float percentage : bitset_percentages) { + for (const auto& gen_func : gen_bitset_funcs) { + auto bitset_data = gen_func(nb, percentage * nb); + knowhere::BitsetView bitset(bitset_data.data(), nb); + auto results = idx.Search(*query_ds, json, bitset); + auto gt = knowhere::BruteForce::Search(train_ds, query_ds, json, bitset); + float recall = GetKNNRecall(*gt.value(), *results.value()); + if (percentage > threshold) { + REQUIRE(recall > kBruteForceRecallThreshold); + } else { + REQUIRE(recall > kKnnRecallThreshold); + } + } + } + } + + SECTION("Test Serialize/Deserialize") { + using std::make_tuple; + auto [name, gen] = GENERATE_REF(table>({ + make_tuple(knowhere::IndexEnum::INDEX_FAISS_IDMAP, flat_gen), + make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT, ivfflat_gen), + make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT_CC, ivfflatcc_gen), + make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFSQ8, ivfsq_gen), + make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFPQ, ivfpq_gen), + make_tuple(knowhere::IndexEnum::INDEX_FAISS_SCANN, scann_gen), + make_tuple(knowhere::IndexEnum::INDEX_FAISS_SCANN, scann_gen2), + make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen), + })); + + auto idx = knowhere::IndexFactory::Instance().Create(name, version); + auto cfg_json = gen().dump(); + CAPTURE(name, cfg_json); + knowhere::Json json = knowhere::Json::parse(cfg_json); + REQUIRE(idx.Type() == name); + REQUIRE(idx.Build(*train_ds, json) == knowhere::Status::success); + knowhere::BinarySet bs; + idx.Serialize(bs); + + auto idx_ = knowhere::IndexFactory::Instance().Create(name, version); + idx_.Deserialize(bs); + auto results = idx_.Search(*query_ds, json, nullptr); + REQUIRE(results.has_value()); + } +} + +TEST_CASE("Test Mem Index With fp16/bf16 Vector", "[float metrics]") { + BaseSearchTest(); + BaseSearchTest(); +} diff --git a/tests/ut/test_iterator.cc b/tests/ut/test_iterator.cc index 7a92e7e0a..5cb07f317 100644 --- a/tests/ut/test_iterator.cc +++ b/tests/ut/test_iterator.cc @@ -83,14 +83,14 @@ TEST_CASE("Test Iterator Mem Index With Float Vector", "[float metrics]") { {knowhere::meta::METRIC_TYPE, metric}, {knowhere::meta::TOPK, topk}, }; - auto gt = knowhere::BruteForce::Search(train_ds, query_ds, conf, nullptr); + auto gt = knowhere::BruteForce::Search(train_ds, query_ds, conf, nullptr); SECTION("Test Search using iterator") { using std::make_tuple; auto [name, gen] = GENERATE_REF(table>({ make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen), })); - auto idx = knowhere::IndexFactory::Instance().Create(name, version); + auto idx = knowhere::IndexFactory::Instance().Create(name, version); auto cfg_json = gen().dump(); CAPTURE(name, cfg_json); knowhere::Json json = knowhere::Json::parse(cfg_json); @@ -114,7 +114,7 @@ TEST_CASE("Test Iterator Mem Index With Float Vector", "[float metrics]") { auto [name, gen, threshold] = GENERATE_REF(table, float>({ make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen, hnswlib::kHnswSearchKnnBFFilterThreshold), })); - auto idx = knowhere::IndexFactory::Instance().Create(name, version); + auto idx = knowhere::IndexFactory::Instance().Create(name, version); auto cfg_json = gen().dump(); CAPTURE(name, cfg_json); knowhere::Json json = knowhere::Json::parse(cfg_json); @@ -132,7 +132,7 @@ TEST_CASE("Test Iterator Mem Index With Float Vector", "[float metrics]") { auto its = idx.AnnIterator(*query_ds, json, bitset); REQUIRE(its.has_value()); auto results = GetKNNResult(its.value(), topk, &bitset); - auto gt = knowhere::BruteForce::Search(train_ds, query_ds, json, bitset); + auto gt = knowhere::BruteForce::Search(train_ds, query_ds, json, bitset); float recall = GetKNNRecall(*gt.value(), *results); REQUIRE(recall > kKnnRecallThreshold); } @@ -144,7 +144,7 @@ TEST_CASE("Test Iterator Mem Index With Float Vector", "[float metrics]") { auto [name, gen, threshold] = GENERATE_REF(table, float>({ make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen, hnswlib::kHnswSearchKnnBFFilterThreshold), })); - auto idx = knowhere::IndexFactory::Instance().Create(name, version); + auto idx = knowhere::IndexFactory::Instance().Create(name, version); auto cfg_json = gen().dump(); CAPTURE(name, cfg_json); knowhere::Json json = knowhere::Json::parse(cfg_json); @@ -165,7 +165,7 @@ TEST_CASE("Test Iterator Mem Index With Float Vector", "[float metrics]") { for (const auto& it : its.value()) { REQUIRE(!it->HasNext()); } - auto gt = knowhere::BruteForce::Search(train_ds, query_ds, json, bitset); + auto gt = knowhere::BruteForce::Search(train_ds, query_ds, json, bitset); float recall = GetKNNRecall(*gt.value(), *results); REQUIRE(recall > kKnnRecallThreshold); } @@ -205,13 +205,13 @@ TEST_CASE("Test Iterator Mem Index With Binary Vector", "[float metrics]") { {knowhere::meta::TOPK, topk}, }; - auto gt = knowhere::BruteForce::Search(train_ds, query_ds, conf, nullptr); + auto gt = knowhere::BruteForce::Search(train_ds, query_ds, conf, nullptr); SECTION("Test Search using iterator") { using std::make_tuple; auto [name, gen] = GENERATE_REF(table>({ make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen), })); - auto idx = knowhere::IndexFactory::Instance().Create(name, version); + auto idx = knowhere::IndexFactory::Instance().Create(name, version); auto cfg_json = gen().dump(); CAPTURE(name, cfg_json); knowhere::Json json = knowhere::Json::parse(cfg_json); diff --git a/tests/ut/test_ivfflat_cc.cc b/tests/ut/test_ivfflat_cc.cc index 442dd1383..c21c57a13 100644 --- a/tests/ut/test_ivfflat_cc.cc +++ b/tests/ut/test_ivfflat_cc.cc @@ -141,7 +141,7 @@ TEST_CASE("Test Build Search Concurrency", "[Concurrency]") { auto [name, gen] = GENERATE_REF(table>({ make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT_CC, ivfflatcc_gen), })); - auto idx = knowhere::IndexFactory::Instance().Create(name, version); + auto idx = knowhere::IndexFactory::Instance().Create(name, version); auto cfg_json = gen().dump(); CAPTURE(name, cfg_json); knowhere::Json json = knowhere::Json::parse(cfg_json); @@ -183,9 +183,10 @@ TEST_CASE("Test Build Search Concurrency", "[Concurrency]") { SECTION("Test Build & Search Correctness") { using std::make_tuple; - auto ivf_flat = knowhere::IndexFactory::Instance().Create(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT, version); + auto ivf_flat = + knowhere::IndexFactory::Instance().Create(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT, version); auto ivf_flat_cc = - knowhere::IndexFactory::Instance().Create(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT_CC, version); + knowhere::IndexFactory::Instance().Create(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT_CC, version); knowhere::Json ivf_flat_json = knowhere::Json::parse(ivfflat_gen().dump()); knowhere::Json ivf_flat_cc_json = knowhere::Json::parse(ivfflatcc_gen().dump()); @@ -242,7 +243,7 @@ TEST_CASE("Test Build Search Concurrency", "[Concurrency]") { auto [name, gen] = GENERATE_REF(table>({ make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT_CC, ivfflatcc_gen), })); - auto idx = knowhere::IndexFactory::Instance().Create(name, version); + auto idx = knowhere::IndexFactory::Instance().Create(name, version); auto cfg_json = gen().dump(); CAPTURE(name, cfg_json); knowhere::Json json = knowhere::Json::parse(cfg_json); diff --git a/tests/ut/test_mmap.cc b/tests/ut/test_mmap.cc index 8e9d5bb08..4a4eee2db 100644 --- a/tests/ut/test_mmap.cc +++ b/tests/ut/test_mmap.cc @@ -116,7 +116,7 @@ TEST_CASE("Search mmap", "[float metrics]") { {knowhere::meta::METRIC_TYPE, metric}, {knowhere::meta::TOPK, topk}, }; - auto gt = knowhere::BruteForce::Search(train_ds, query_ds, conf, nullptr); + auto gt = knowhere::BruteForce::Search(train_ds, query_ds, conf, nullptr); SECTION("Test Search") { using std::make_tuple; @@ -127,7 +127,7 @@ TEST_CASE("Search mmap", "[float metrics]") { make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFPQ, ivfpq_gen), make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen), })); - auto idx = knowhere::IndexFactory::Instance().Create(name, version); + auto idx = knowhere::IndexFactory::Instance().Create(name, version); auto cfg_json = gen().dump(); CAPTURE(name, cfg_json); knowhere::Json json = knowhere::Json::parse(cfg_json); @@ -154,7 +154,7 @@ TEST_CASE("Search mmap", "[float metrics]") { make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFPQ, ivfpq_gen), make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen), })); - auto idx = knowhere::IndexFactory::Instance().Create(name, version); + auto idx = knowhere::IndexFactory::Instance().Create(name, version); auto cfg_json = gen().dump(); CAPTURE(name, cfg_json); knowhere::Json json = knowhere::Json::parse(cfg_json); @@ -178,7 +178,7 @@ TEST_CASE("Search mmap", "[float metrics]") { auto [name, gen, threshold] = GENERATE_REF(table, float>({ make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen, hnswlib::kHnswSearchKnnBFFilterThreshold), })); - auto idx = knowhere::IndexFactory::Instance().Create(name, version); + auto idx = knowhere::IndexFactory::Instance().Create(name, version); auto cfg_json = gen().dump(); CAPTURE(name, cfg_json); knowhere::Json json = knowhere::Json::parse(cfg_json); @@ -194,7 +194,7 @@ TEST_CASE("Search mmap", "[float metrics]") { auto bitset_data = gen_func(nb, percentage * nb); knowhere::BitsetView bitset(bitset_data.data(), nb); auto results = idx.Search(*query_ds, json, bitset); - auto gt = knowhere::BruteForce::Search(train_ds, query_ds, json, bitset); + auto gt = knowhere::BruteForce::Search(train_ds, query_ds, json, bitset); float recall = GetKNNRecall(*gt.value(), *results.value()); if (percentage > threshold) { REQUIRE(recall > kBruteForceRecallThreshold); @@ -262,7 +262,7 @@ TEST_CASE("Search binary mmap", "[float metrics]") { REQUIRE(index.DeserializeFromFile(path, conf) == knowhere::Status::success); }; - auto gt = knowhere::BruteForce::Search(train_ds, query_ds, conf, nullptr); + auto gt = knowhere::BruteForce::Search(train_ds, query_ds, conf, nullptr); SECTION("Test Search") { using std::make_tuple; auto [name, gen] = GENERATE_REF(table>({ @@ -270,7 +270,7 @@ TEST_CASE("Search binary mmap", "[float metrics]") { make_tuple(knowhere::IndexEnum::INDEX_FAISS_BIN_IVFFLAT, ivfflat_gen), make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen), })); - auto idx = knowhere::IndexFactory::Instance().Create(name, version); + auto idx = knowhere::IndexFactory::Instance().Create(name, version); auto cfg_json = gen().dump(); CAPTURE(name, cfg_json); knowhere::Json json = knowhere::Json::parse(cfg_json); @@ -292,7 +292,7 @@ TEST_CASE("Search binary mmap", "[float metrics]") { make_tuple(knowhere::IndexEnum::INDEX_FAISS_BIN_IDMAP, flat_gen), make_tuple(knowhere::IndexEnum::INDEX_FAISS_BIN_IVFFLAT, ivfflat_gen), })); - auto idx = knowhere::IndexFactory::Instance().Create(name, version); + auto idx = knowhere::IndexFactory::Instance().Create(name, version); auto cfg_json = gen().dump(); CAPTURE(name, cfg_json); knowhere::Json json = knowhere::Json::parse(cfg_json); @@ -374,7 +374,7 @@ TEST_CASE("Search binary mmap", "[bool metrics]") { std::make_tuple(knowhere::IndexEnum::INDEX_FAISS_BIN_IDMAP, flat_gen), std::make_tuple(knowhere::IndexEnum::INDEX_FAISS_BIN_IVFFLAT, ivfflat_gen), })); - auto idx = knowhere::IndexFactory::Instance().Create(name, version); + auto idx = knowhere::IndexFactory::Instance().Create(name, version); auto cfg_json = gen().dump(); CAPTURE(name, cfg_json); knowhere::Json json = knowhere::Json::parse(cfg_json); @@ -417,7 +417,7 @@ TEST_CASE("Search binary mmap", "[bool metrics]") { make_tuple(knowhere::IndexEnum::INDEX_FAISS_BIN_IDMAP, flat_gen), make_tuple(knowhere::IndexEnum::INDEX_FAISS_BIN_IVFFLAT, ivfflat_gen), })); - auto idx = knowhere::IndexFactory::Instance().Create(name); + auto idx = knowhere::IndexFactory::Instance().Create(name); auto cfg_json = gen().dump(); CAPTURE(name, cfg_json); knowhere::Json json = knowhere::Json::parse(cfg_json); diff --git a/tests/ut/test_search.cc b/tests/ut/test_search.cc index 4d177376d..6a25c403b 100644 --- a/tests/ut/test_search.cc +++ b/tests/ut/test_search.cc @@ -100,7 +100,7 @@ TEST_CASE("Test Mem Index With Float Vector", "[float metrics]") { {knowhere::meta::METRIC_TYPE, metric}, {knowhere::meta::TOPK, topk}, }; - auto gt = knowhere::BruteForce::Search(train_ds, query_ds, conf, nullptr); + auto gt = knowhere::BruteForce::Search(train_ds, query_ds, conf, nullptr); SECTION("Test Search") { using std::make_tuple; @@ -114,7 +114,7 @@ TEST_CASE("Test Mem Index With Float Vector", "[float metrics]") { make_tuple(knowhere::IndexEnum::INDEX_FAISS_SCANN, scann_gen2), make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen), })); - auto idx = knowhere::IndexFactory::Instance().Create(name, version); + auto idx = knowhere::IndexFactory::Instance().Create(name, version); auto cfg_json = gen().dump(); CAPTURE(name, cfg_json); knowhere::Json json = knowhere::Json::parse(cfg_json); @@ -156,7 +156,7 @@ TEST_CASE("Test Mem Index With Float Vector", "[float metrics]") { make_tuple(knowhere::IndexEnum::INDEX_FAISS_SCANN, scann_gen2), make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen), })); - auto idx = knowhere::IndexFactory::Instance().Create(name, version); + auto idx = knowhere::IndexFactory::Instance().Create(name, version); auto cfg_json = gen().dump(); CAPTURE(name, cfg_json); knowhere::Json json = knowhere::Json::parse(cfg_json); @@ -199,7 +199,7 @@ TEST_CASE("Test Mem Index With Float Vector", "[float metrics]") { auto [name, gen] = GENERATE_REF(table>({ make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen_), })); - auto idx = knowhere::IndexFactory::Instance().Create(name, version); + auto idx = knowhere::IndexFactory::Instance().Create(name, version); auto cfg_json = gen().dump(); CAPTURE(name, cfg_json); knowhere::Json json = knowhere::Json::parse(cfg_json); @@ -207,7 +207,7 @@ TEST_CASE("Test Mem Index With Float Vector", "[float metrics]") { REQUIRE(idx.Build(*train_ds, json) == knowhere::Status::success); auto results = idx.Search(*query_ds, json, nullptr); - auto gt = knowhere::BruteForce::Search(train_ds, query_ds, json, nullptr); + auto gt = knowhere::BruteForce::Search(train_ds, query_ds, json, nullptr); float recall = GetKNNRecall(*gt.value(), *results.value()); REQUIRE(recall > kBruteForceRecallThreshold); } @@ -217,7 +217,7 @@ TEST_CASE("Test Mem Index With Float Vector", "[float metrics]") { auto [name, gen, threshold] = GENERATE_REF(table, float>({ make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen, hnswlib::kHnswSearchKnnBFFilterThreshold), })); - auto idx = knowhere::IndexFactory::Instance().Create(name, version); + auto idx = knowhere::IndexFactory::Instance().Create(name, version); auto cfg_json = gen().dump(); CAPTURE(name, cfg_json); knowhere::Json json = knowhere::Json::parse(cfg_json); @@ -232,7 +232,7 @@ TEST_CASE("Test Mem Index With Float Vector", "[float metrics]") { auto bitset_data = gen_func(nb, percentage * nb); knowhere::BitsetView bitset(bitset_data.data(), nb); auto results = idx.Search(*query_ds, json, bitset); - auto gt = knowhere::BruteForce::Search(train_ds, query_ds, json, bitset); + auto gt = knowhere::BruteForce::Search(train_ds, query_ds, json, bitset); float recall = GetKNNRecall(*gt.value(), *results.value()); if (percentage > threshold || json[knowhere::meta::TOPK] > (1 - percentage) * nb * hnswlib::kHnswSearchBFTopkThreshold) { @@ -257,7 +257,7 @@ TEST_CASE("Test Mem Index With Float Vector", "[float metrics]") { make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen), })); - auto idx = knowhere::IndexFactory::Instance().Create(name, version); + auto idx = knowhere::IndexFactory::Instance().Create(name, version); auto cfg_json = gen().dump(); CAPTURE(name, cfg_json); knowhere::Json json = knowhere::Json::parse(cfg_json); @@ -266,14 +266,14 @@ TEST_CASE("Test Mem Index With Float Vector", "[float metrics]") { knowhere::BinarySet bs; idx.Serialize(bs); - auto idx_ = knowhere::IndexFactory::Instance().Create(name, version); + auto idx_ = knowhere::IndexFactory::Instance().Create(name, version); idx_.Deserialize(bs); auto results = idx_.Search(*query_ds, json, nullptr); REQUIRE(results.has_value()); } SECTION("Test IVFPQ with invalid params") { - auto idx = knowhere::IndexFactory::Instance().Create(knowhere::IndexEnum::INDEX_FAISS_IVFPQ, version); + auto idx = knowhere::IndexFactory::Instance().Create(knowhere::IndexEnum::INDEX_FAISS_IVFPQ, version); uint32_t nb = 1000; uint32_t dim = 128; auto ivf_pq_gen = [&]() { @@ -333,7 +333,7 @@ TEST_CASE("Test Mem Index With Binary Vector", "[float metrics]") { {knowhere::meta::TOPK, topk}, }; - auto gt = knowhere::BruteForce::Search(train_ds, query_ds, conf, nullptr); + auto gt = knowhere::BruteForce::Search(train_ds, query_ds, conf, nullptr); SECTION("Test Search") { using std::make_tuple; auto [name, gen] = GENERATE_REF(table>({ @@ -341,7 +341,7 @@ TEST_CASE("Test Mem Index With Binary Vector", "[float metrics]") { make_tuple(knowhere::IndexEnum::INDEX_FAISS_BIN_IVFFLAT, ivfflat_gen), make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen), })); - auto idx = knowhere::IndexFactory::Instance().Create(name, version); + auto idx = knowhere::IndexFactory::Instance().Create(name, version); auto cfg_json = gen().dump(); CAPTURE(name, cfg_json); knowhere::Json json = knowhere::Json::parse(cfg_json); @@ -361,7 +361,7 @@ TEST_CASE("Test Mem Index With Binary Vector", "[float metrics]") { make_tuple(knowhere::IndexEnum::INDEX_FAISS_BIN_IDMAP, flat_gen), make_tuple(knowhere::IndexEnum::INDEX_FAISS_BIN_IVFFLAT, ivfflat_gen), })); - auto idx = knowhere::IndexFactory::Instance().Create(name, version); + auto idx = knowhere::IndexFactory::Instance().Create(name, version); auto cfg_json = gen().dump(); CAPTURE(name, cfg_json); knowhere::Json json = knowhere::Json::parse(cfg_json); @@ -383,7 +383,7 @@ TEST_CASE("Test Mem Index With Binary Vector", "[float metrics]") { make_tuple(knowhere::IndexEnum::INDEX_FAISS_BIN_IVFFLAT, ivfflat_gen), })); - auto idx = knowhere::IndexFactory::Instance().Create(name, version); + auto idx = knowhere::IndexFactory::Instance().Create(name, version); auto cfg_json = gen().dump(); CAPTURE(name, cfg_json); knowhere::Json json = knowhere::Json::parse(cfg_json); @@ -392,7 +392,7 @@ TEST_CASE("Test Mem Index With Binary Vector", "[float metrics]") { knowhere::BinarySet bs; idx.Serialize(bs); - auto idx_ = knowhere::IndexFactory::Instance().Create(name, version); + auto idx_ = knowhere::IndexFactory::Instance().Create(name, version); idx_.Deserialize(bs); auto results = idx_.Search(*query_ds, json, nullptr); REQUIRE(results.has_value()); @@ -434,13 +434,13 @@ TEST_CASE("Test Mem Index With Binary Vector", "[float metrics][special case 1]" {knowhere::meta::TOPK, topk}, }; - auto gt = knowhere::BruteForce::Search(train_ds, query_ds, conf, nullptr); + auto gt = knowhere::BruteForce::Search(train_ds, query_ds, conf, nullptr); SECTION("Test Search") { using std::make_tuple; auto [name, gen] = GENERATE_REF(table>({ make_tuple(knowhere::IndexEnum::INDEX_FAISS_BIN_IVFFLAT, ivfflat_gen), })); - auto idx = knowhere::IndexFactory::Instance().Create(name, version); + auto idx = knowhere::IndexFactory::Instance().Create(name, version); auto cfg_json = gen().dump(); CAPTURE(name, cfg_json); knowhere::Json json = knowhere::Json::parse(cfg_json); @@ -504,7 +504,7 @@ TEST_CASE("Test Mem Index With Binary Vector", "[bool metrics]") { std::make_tuple(knowhere::IndexEnum::INDEX_FAISS_BIN_IDMAP, flat_gen), std::make_tuple(knowhere::IndexEnum::INDEX_FAISS_BIN_IVFFLAT, ivfflat_gen), })); - auto idx = knowhere::IndexFactory::Instance().Create(name, version); + auto idx = knowhere::IndexFactory::Instance().Create(name, version); auto cfg_json = gen().dump(); CAPTURE(name, cfg_json); knowhere::Json json = knowhere::Json::parse(cfg_json); @@ -551,7 +551,7 @@ TEST_CASE("Test Mem Index With Binary Vector", "[bool metrics]") { make_tuple(knowhere::IndexEnum::INDEX_FAISS_BIN_IDMAP, flat_gen), make_tuple(knowhere::IndexEnum::INDEX_FAISS_BIN_IVFFLAT, ivfflat_gen), })); - auto idx = knowhere::IndexFactory::Instance().Create(name); + auto idx = knowhere::IndexFactory::Instance().Create(name); auto cfg_json = gen().dump(); CAPTURE(name, cfg_json); knowhere::Json json = knowhere::Json::parse(cfg_json); diff --git a/tests/ut/test_simd.cc b/tests/ut/test_simd.cc index fba56c32c..cbce8f37e 100644 --- a/tests/ut/test_simd.cc +++ b/tests/ut/test_simd.cc @@ -38,7 +38,7 @@ TEST_CASE("Test BruteForce Search SIMD", "[bf]") { auto test_search_with_simd = [&](knowhere::KnowhereConfig::SimdType simd_type) { knowhere::KnowhereConfig::SetSimdType(simd_type); - auto gt = knowhere::BruteForce::Search(train_ds, query_ds, conf, nullptr); + auto gt = knowhere::BruteForce::Search(train_ds, query_ds, conf, nullptr); REQUIRE(gt.has_value()); auto gt_ids = gt.value()->GetIds(); auto gt_dist = gt.value()->GetDistance(); @@ -83,7 +83,7 @@ TEST_CASE("Test PQ Search SIMD", "[pq]") { conf[knowhere::indexparam::M] = m; knowhere::KnowhereConfig::SetSimdType(simd_type); - auto gt = knowhere::BruteForce::Search(train_ds, query_ds, conf, nullptr); + auto gt = knowhere::BruteForce::Search(train_ds, query_ds, conf, nullptr); REQUIRE(gt.has_value()); auto gt_ids = gt.value()->GetIds(); auto gt_dist = gt.value()->GetDistance(); @@ -97,7 +97,7 @@ TEST_CASE("Test PQ Search SIMD", "[pq]") { } } - auto idx = knowhere::IndexFactory::Instance().Create(knowhere::IndexEnum::INDEX_FAISS_IVFPQ, version); + auto idx = knowhere::IndexFactory::Instance().Create(knowhere::IndexEnum::INDEX_FAISS_IVFPQ, version); REQUIRE(idx.Build(*train_ds, conf) == knowhere::Status::success); auto res = idx.Search(*query_ds, conf, nullptr); REQUIRE(res.has_value()); diff --git a/thirdparty/DiskANN/include/diskann/utils.h b/thirdparty/DiskANN/include/diskann/utils.h index 50aaaf17d..17056bcdc 100644 --- a/thirdparty/DiskANN/include/diskann/utils.h +++ b/thirdparty/DiskANN/include/diskann/utils.h @@ -36,6 +36,7 @@ typedef int FileHandle; #include "ann_exception.h" #include "common_includes.h" #include "knowhere/comp/thread_pool.h" +#include "knowhere/operands.h" // taken from // https://github.com/Microsoft/BLAS-on-flash/blob/master/include/utils.h @@ -565,7 +566,6 @@ namespace diskann { template void convert_types(const InType* srcmat, OutType* destmat, size_t npts, size_t dim) { -#pragma omp parallel for schedule(static, 65536) for (int64_t i = 0; i < (_s64) npts; i++) { for (uint64_t j = 0; j < dim; j++) { destmat[i * dim + j] = (OutType) srcmat[i * dim + j]; @@ -836,6 +836,10 @@ namespace diskann { // NOTE: Implementation in utils.cpp. void block_convert(std::ofstream& writr, std::ifstream& readr, float* read_buf, _u64 npts, _u64 ndims); + + template + void convert_types_in_file(const std::string& inFileName, + const std::string& outFileName); void normalize_data_file(const std::string& inFileName, const std::string& outFileName); diff --git a/thirdparty/DiskANN/src/utils.cpp b/thirdparty/DiskANN/src/utils.cpp index 989b8ca83..0f50ae8bb 100644 --- a/thirdparty/DiskANN/src/utils.cpp +++ b/thirdparty/DiskANN/src/utils.cpp @@ -58,4 +58,39 @@ namespace diskann { LOG_KNOWHERE_DEBUG_ << "Wrote normalized points to file: " << outFileName; } + + template + void convert_types_in_file(const std::string& inFileName, + const std::string& outFileName) { + std::ifstream readr(inFileName, std::ios::binary); + std::ofstream writr(outFileName, std::ios::binary); + + int npts_s32, ndims_s32; + readr.read((char*) &npts_s32, sizeof(_s32)); + readr.read((char*) &ndims_s32, sizeof(_s32)); + + writr.write((char*) &npts_s32, sizeof(_s32)); + writr.write((char*) &ndims_s32, sizeof(_s32)); + + _u64 npts = (_u64) npts_s32, ndims = (_u64) ndims_s32; + + _u64 blk_size = 131072; + _u64 nblks = ROUND_UP(npts, blk_size) / blk_size; + LOG_KNOWHERE_DEBUG_ << "# blks: " << nblks; + + IN* read_buf = new IN[npts * ndims]; + OUT* write_buf = new OUT[npts * ndims]; + for (_u64 i = 0; i < nblks; i++) { + _u64 cblk_size = std::min(npts - i * blk_size, blk_size); + readr.read((char*) read_buf, npts * ndims * sizeof(IN)); + convert_types(read_buf, write_buf, cblk_size, ndims); + writr.write((char*) write_buf, npts * ndims * sizeof(OUT)); + } + + delete[] read_buf; + delete[] write_buf; + } + + template void convert_types_in_file(const std::string& , const std::string&); + template void convert_types_in_file(const std::string& , const std::string&); } // namespace diskann