Skip to content

Commit

Permalink
Instead of having many APIs for different PARAM_TYPE, Knowhere's Co…
Browse files Browse the repository at this point in the history
…nfig operations(Load/Check) are using a single API with `PARAM_TYPE` as a parameter. However, the post loading APIs (CheckAndAdjustForXXX) didn't follow this rule. This PR synced this style.

Signed-off-by: Li Liu <[email protected]>
  • Loading branch information
liliu-z committed Nov 21, 2023
1 parent 66b1656 commit 928bbdb
Show file tree
Hide file tree
Showing 6 changed files with 96 additions and 117 deletions.
32 changes: 11 additions & 21 deletions include/knowhere/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,11 @@ class Config {
}
}

return Status::success;
if (!err_msg) {
std::string tem_msg;
return cfg.CheckAndAdjust(type, &tem_msg);
}
return cfg.CheckAndAdjust(type, err_msg);
}

virtual ~Config() {
Expand All @@ -485,6 +489,12 @@ class Config {
using VarEntry =
std::variant<Entry<CFG_STRING>, Entry<CFG_FLOAT>, Entry<CFG_INT>, Entry<CFG_LIST>, Entry<CFG_BOOL>>;
std::unordered_map<std::string, VarEntry> __DICT__;

protected:
inline virtual Status
CheckAndAdjust(PARAM_TYPE param_type, std::string* const err_msg) {
return Status::success;
}
};

#define KNOHWERE_DECLARE_CONFIG(CONFIG) CONFIG()
Expand Down Expand Up @@ -557,26 +567,6 @@ class BaseConfig : public Config {
.for_deserialize_from_file();
KNOWHERE_CONFIG_DECLARE_FIELD(for_tuning).set_default(false).description("for tuning").for_search();
}

virtual Status
CheckAndAdjustForSearch(std::string* err_msg) {
return Status::success;
}

virtual Status
CheckAndAdjustForRangeSearch(std::string* err_msg) {
return Status::success;
}

virtual Status
CheckAndAdjustForIterator() {
return Status::success;
}

virtual inline Status
CheckAndAdjustForBuild() {
return Status::success;
}
};
} // namespace knowhere

Expand Down
13 changes: 0 additions & 13 deletions src/common/index.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ inline Status
Index<T>::Build(const DataSet& dataset, const Json& json) {
auto cfg = this->node->CreateConfig();
RETURN_IF_ERROR(LoadConfig(cfg.get(), json, knowhere::TRAIN, "Build"));
RETURN_IF_ERROR(cfg->CheckAndAdjustForBuild());

#ifdef NOT_COMPILE_FOR_SWIG
TimeRecorder rc("Build index", 2);
Expand Down Expand Up @@ -77,10 +76,6 @@ Index<T>::Search(const DataSet& dataset, const Json& json, const BitsetView& bit
if (load_status != Status::success) {
return expected<DataSetPtr>::Err(load_status, msg);
}
const Status search_status = cfg->CheckAndAdjustForSearch(&msg);
if (search_status != Status::success) {
return expected<DataSetPtr>::Err(search_status, msg);
}

#ifdef NOT_COMPILE_FOR_SWIG
TimeRecorder rc("Search");
Expand All @@ -105,10 +100,6 @@ Index<T>::AnnIterator(const DataSet& dataset, const Json& json, const BitsetView
if (status != Status::success) {
return expected<std::vector<std::shared_ptr<IndexNode::iterator>>>::Err(status, msg);
}
status = cfg->CheckAndAdjustForIterator();
if (status != Status::success) {
return expected<std::vector<std::shared_ptr<IndexNode::iterator>>>::Err(status, "invalid params for iterator");
}

#ifdef NOT_COMPILE_FOR_SWIG
// note that this time includes only the initial search phase of iterator.
Expand All @@ -133,10 +124,6 @@ Index<T>::RangeSearch(const DataSet& dataset, const Json& json, const BitsetView
if (status != Status::success) {
return expected<DataSetPtr>::Err(status, std::move(msg));
}
status = cfg->CheckAndAdjustForRangeSearch(&msg);
if (status != Status::success) {
return expected<DataSetPtr>::Err(status, std::move(msg));
}

#ifdef NOT_COMPILE_FOR_SWIG
TimeRecorder rc("Range Search");
Expand Down
40 changes: 22 additions & 18 deletions src/index/diskann/diskann_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -149,24 +149,28 @@ class DiskANNConfig : public BaseConfig {
.for_search();
}

inline Status
CheckAndAdjustForSearch(std::string* err_msg) override {
if (!search_list_size.has_value()) {
search_list_size = std::max(k.value(), kSearchListSizeMinValue);
} else if (k.value() > search_list_size.value()) {
*err_msg = "search_list_size(" + std::to_string(search_list_size.value()) + ") should be larger than k(" +
std::to_string(k.value()) + ")";
LOG_KNOWHERE_ERROR_ << *err_msg;
return Status::out_of_range_in_json;
}

return Status::success;
}

inline Status
CheckAndAdjustForBuild() override {
if (!search_list_size.has_value()) {
search_list_size = kDefaultSearchListSizeForBuild;
Status
CheckAndAdjust(PARAM_TYPE param_type, std::string* err_msg) override {
switch (param_type) {
case PARAM_TYPE::TRAIN: {
if (!search_list_size.has_value()) {
search_list_size = kDefaultSearchListSizeForBuild;
}
break;
}
case PARAM_TYPE::SEARCH: {
if (!search_list_size.has_value()) {
search_list_size = std::max(k.value(), kSearchListSizeMinValue);
} else if (k.value() > search_list_size.value()) {
*err_msg = "search_list_size(" + std::to_string(search_list_size.value()) +
") should be larger than k(" + std::to_string(k.value()) + ")";
LOG_KNOWHERE_ERROR_ << *err_msg;
return Status::out_of_range_in_json;
}
break;
}
default:
break;
}
return Status::success;
}
Expand Down
42 changes: 23 additions & 19 deletions src/index/hnsw/hnsw_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,25 +57,29 @@ class HnswConfig : public BaseConfig {
.for_feder();
}

inline Status
CheckAndAdjustForSearch(std::string* err_msg) override {
if (!ef.has_value()) {
ef = std::max(k.value(), kEfMinValue);
} else if (k.value() > ef.value()) {
*err_msg =
"ef(" + std::to_string(ef.value()) + ") should be larger than k(" + std::to_string(k.value()) + ")";
LOG_KNOWHERE_ERROR_ << *err_msg;
return Status::out_of_range_in_json;
}

return Status::success;
}

inline Status
CheckAndAdjustForRangeSearch(std::string* err_msg) override {
if (!ef.has_value()) {
// if ef is not set by user, set it to default
ef = kDefaultRangeSearchEf;
Status
CheckAndAdjust(PARAM_TYPE param_type, std::string* err_msg) override {
switch (param_type) {
case PARAM_TYPE::SEARCH: {
if (!ef.has_value()) {
ef = std::max(k.value(), kEfMinValue);
} else if (k.value() > ef.value()) {
*err_msg = "ef(" + std::to_string(ef.value()) + ") should be larger than k(" +
std::to_string(k.value()) + ")";
LOG_KNOWHERE_ERROR_ << *err_msg;
return Status::out_of_range_in_json;
}
break;
}
case PARAM_TYPE::RANGE_SEARCH: {
if (!ef.has_value()) {
// if ef is not set by user, set it to default
ef = kDefaultRangeSearchEf;
}
break;
}
default:
break;
}
return Status::success;
}
Expand Down
75 changes: 40 additions & 35 deletions src/index/ivf/ivf_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,41 +79,46 @@ class ScannConfig : public IvfFlatConfig {
.for_train();
}

inline Status
CheckAndAdjustForSearch(std::string* err_msg) override {
if (!faiss::support_pq_fast_scan) {
*err_msg = "SCANN index is not supported on the current CPU model, avx2 support is needed for x86 arch.";
LOG_KNOWHERE_ERROR_ << *err_msg;
return Status::invalid_instruction_set;
}
if (!reorder_k.has_value()) {
reorder_k = k.value();
} else if (reorder_k.value() < k.value()) {
*err_msg = "reorder_k(" + std::to_string(reorder_k.value()) + ") should be larger than k(" +
std::to_string(k.value()) + ")";
LOG_KNOWHERE_ERROR_ << *err_msg;
return Status::out_of_range_in_json;
}

return Status::success;
}

inline Status
CheckAndAdjustForRangeSearch(std::string* err_msg) override {
if (!faiss::support_pq_fast_scan) {
*err_msg = "SCANN index is not supported on the current CPU model, avx2 support is needed for x86 arch.";
LOG_KNOWHERE_ERROR_ << *err_msg;
return Status::invalid_instruction_set;
}
return Status::success;
}

inline Status
CheckAndAdjustForBuild() override {
if (!faiss::support_pq_fast_scan) {
LOG_KNOWHERE_ERROR_
<< "SCANN index is not supported on the current CPU model, avx2 support is needed for x86 arch.";
return Status::invalid_instruction_set;
Status
CheckAndAdjust(PARAM_TYPE param_type, std::string* err_msg) override {
switch (param_type) {
case PARAM_TYPE::TRAIN: {
if (!faiss::support_pq_fast_scan) {
LOG_KNOWHERE_ERROR_ << "SCANN index is not supported on the current CPU model, avx2 support is "
"needed for x86 arch.";
return Status::invalid_instruction_set;
}
break;
}
case PARAM_TYPE::SEARCH: {
if (!faiss::support_pq_fast_scan) {
LOG_KNOWHERE_ERROR_ << "SCANN index is not supported on the current CPU model, avx2 support is "
"needed for x86 arch.";
return Status::invalid_instruction_set;
}
if (!reorder_k.has_value()) {
reorder_k = k.value();
} else if (reorder_k.value() < k.value()) {
if (!err_msg) {
err_msg = new std::string();
}
*err_msg = "reorder_k(" + std::to_string(reorder_k.value()) + ") should be larger than k(" +
std::to_string(k.value()) + ")";
LOG_KNOWHERE_ERROR_ << *err_msg;
return Status::out_of_range_in_json;
}
break;
}
case PARAM_TYPE::RANGE_SEARCH: {
if (!faiss::support_pq_fast_scan) {
LOG_KNOWHERE_ERROR_ << "SCANN index is not supported on the current CPU model, avx2 support is "
"needed for x86 arch.";
return Status::invalid_instruction_set;
}
break;
}
default:
break;
}
return Status::success;
}
Expand Down
11 changes: 0 additions & 11 deletions tests/ut/test_config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,6 @@ TEST_CASE("Test config json parse", "[config]") {
invalid_value_json = json;
invalid_value_json["ef"] = 99;
s = knowhere::Config::Load(wrong_cfg, invalid_value_json, knowhere::SEARCH);
s = wrong_cfg.CheckAndAdjustForSearch(&err_msg);
CHECK(s == knowhere::Status::out_of_range_in_json);
}

Expand All @@ -189,7 +188,6 @@ TEST_CASE("Test config json parse", "[config]") {
{
knowhere::HnswConfig search_cfg;
s = knowhere::Config::Load(search_cfg, json, knowhere::SEARCH);
s = search_cfg.CheckAndAdjustForSearch(&err_msg);
CHECK(s == knowhere::Status::success);
}

Expand All @@ -198,7 +196,6 @@ TEST_CASE("Test config json parse", "[config]") {
auto search_json = json;
search_json.erase("ef");
s = knowhere::Config::Load(search_cfg, search_json, knowhere::SEARCH);
s = search_cfg.CheckAndAdjustForSearch(&err_msg);
CHECK(s == knowhere::Status::success);
CHECK_EQ(100, search_cfg.ef.value());
}
Expand All @@ -209,7 +206,6 @@ TEST_CASE("Test config json parse", "[config]") {
search_json.erase("ef");
search_json["k"] = 10;
s = knowhere::Config::Load(search_cfg, search_json, knowhere::SEARCH);
s = search_cfg.CheckAndAdjustForSearch(&err_msg);
CHECK(s == knowhere::Status::success);
CHECK_EQ(16, search_cfg.ef.value());
}
Expand Down Expand Up @@ -244,8 +240,6 @@ TEST_CASE("Test config json parse", "[config]") {
knowhere::DiskANNConfig train_cfg;
s = knowhere::Config::Load(train_cfg, json, knowhere::TRAIN);
CHECK(s == knowhere::Status::success);
s = train_cfg.CheckAndAdjustForBuild();
CHECK(s == knowhere::Status::success);
CHECK_EQ(128, train_cfg.search_list_size.value());
CHECK_EQ("L2", train_cfg.metric_type.value());
}
Expand All @@ -254,8 +248,6 @@ TEST_CASE("Test config json parse", "[config]") {
knowhere::DiskANNConfig search_cfg;
s = knowhere::Config::Load(search_cfg, json, knowhere::SEARCH);
CHECK(s == knowhere::Status::success);
s = search_cfg.CheckAndAdjustForSearch(&err_msg);
CHECK(s == knowhere::Status::success);
CHECK_EQ("L2", search_cfg.metric_type.value());
CHECK_EQ(100, search_cfg.k.value());
CHECK_EQ(100, search_cfg.search_list_size.value());
Expand All @@ -267,8 +259,6 @@ TEST_CASE("Test config json parse", "[config]") {
search_json["k"] = 2;
s = knowhere::Config::Load(search_cfg, search_json, knowhere::SEARCH);
CHECK(s == knowhere::Status::success);
s = search_cfg.CheckAndAdjustForSearch(&err_msg);
CHECK(s == knowhere::Status::success);
CHECK_EQ(16, search_cfg.search_list_size.value());
}

Expand All @@ -277,7 +267,6 @@ TEST_CASE("Test config json parse", "[config]") {
auto search_json = json;
search_json["search_list_size"] = 99;
s = knowhere::Config::Load(search_cfg, search_json, knowhere::SEARCH);
s = search_cfg.CheckAndAdjustForSearch(&err_msg);
CHECK(s == knowhere::Status::out_of_range_in_json);
}

Expand Down

0 comments on commit 928bbdb

Please sign in to comment.