diff --git a/python/knowhere/__init__.py b/python/knowhere/__init__.py index 9d3c77d3b..815db8e31 100644 --- a/python/knowhere/__init__.py +++ b/python/knowhere/__init__.py @@ -1,6 +1,7 @@ from . import swigknowhere from .swigknowhere import Status from .swigknowhere import GetBinarySet, GetNullDataSet, GetNullBitSetView +from .swigknowhere import BruteForceSearch, BruteForceRangeSearch import numpy as np @@ -95,3 +96,6 @@ def GetBinaryVectorDataSetToArray(ans): data = np.zeros([rows, dim]).astype(np.int32) swigknowhere.BinaryDataSetTensor2Array(ans, data) return data + +def SetSimdType(type): + swigknowhere.SetSimdType(type) diff --git a/python/knowhere/knowhere.i b/python/knowhere/knowhere.i index 14af92a3c..18a435d9d 100644 --- a/python/knowhere/knowhere.i +++ b/python/knowhere/knowhere.i @@ -30,6 +30,8 @@ typedef uint64_t size_t; #include #include #include +#include +#include #include #include #include @@ -288,19 +290,23 @@ Array2DataSetIds(int* ids, int len){ return ds; }; -int64_t DataSet_Rows(knowhere::DataSetPtr results){ +int64_t +DataSet_Rows(knowhere::DataSetPtr results){ return results->GetRows(); } -int64_t DataSet_Dim(knowhere::DataSetPtr results){ +int64_t +DataSet_Dim(knowhere::DataSetPtr results){ return results->GetDim(); } -knowhere::BinarySetPtr GetBinarySet() { +knowhere::BinarySetPtr +GetBinarySet() { return std::make_shared(); } -knowhere::DataSetPtr GetNullDataSet() { +knowhere::DataSetPtr +GetNullDataSet() { return nullptr; } @@ -418,4 +424,47 @@ Load(knowhere::BinarySetPtr binset, const std::string& file_name) { } } +knowhere::DataSetPtr +BruteForceSearch(knowhere::DataSetPtr base_dataset, knowhere::DataSetPtr query_dataset, const std::string& json, + const knowhere::BitsetView& bitset, knowhere::Status& status) { + GILReleaser rel; + auto res = knowhere::BruteForce::Search(base_dataset, query_dataset, knowhere::Json::parse(json), bitset); + if (res.has_value()) { + status = knowhere::Status::success; + return res.value(); + } else { + status = res.error(); + return nullptr; + } +} + +knowhere::DataSetPtr +BruteForceRangeSearch(knowhere::DataSetPtr base_dataset, knowhere::DataSetPtr query_dataset, const std::string& json, + const knowhere::BitsetView& bitset, knowhere::Status& status) { + GILReleaser rel; + auto res = knowhere::BruteForce::RangeSearch(base_dataset, query_dataset, knowhere::Json::parse(json), bitset); + if (res.has_value()) { + status = knowhere::Status::success; + return res.value(); + } else { + status = res.error(); + return nullptr; + } +} + +void +SetSimdType(const std::string type) { + if (type == "auto") { + knowhere::KnowhereConfig::SetSimdType(knowhere::KnowhereConfig::SimdType::AUTO); + } else if (type == "avx512") { + knowhere::KnowhereConfig::SetSimdType(knowhere::KnowhereConfig::SimdType::AVX512); + } else if (type == "avx2") { + knowhere::KnowhereConfig::SetSimdType(knowhere::KnowhereConfig::SimdType::AVX2); + } else if (type == "avx" || type == "sse4_2") { + knowhere::KnowhereConfig::SetSimdType(knowhere::KnowhereConfig::SimdType::SSE4_2); + } else { + knowhere::KnowhereConfig::SetSimdType(knowhere::KnowhereConfig::SimdType::GENERIC); + } +} + %}