From 7bcfa6a286968bebd0cdbd5772c211945972310a Mon Sep 17 00:00:00 2001 From: chasingegg Date: Fri, 3 Jan 2025 12:04:00 +0800 Subject: [PATCH] Support MV only for HNSW Signed-off-by: chasingegg --- include/knowhere/bitsetview.h | 32 + include/knowhere/bitsetview_idselector.h | 11 +- include/knowhere/comp/index_param.h | 2 + src/index/hnsw/faiss_hnsw.cc | 1302 +++++++++--- .../hnsw/impl/IndexConditionalWrapper.cc | 8 +- tests/ut/test_faiss_hnsw.cc | 1839 ++++++++++------- tests/ut/test_index_check.cc | 14 +- tests/ut/utils.h | 44 + thirdparty/faiss/faiss/impl/index_read.cpp | 23 + thirdparty/faiss/faiss/impl/index_write.cpp | 14 + thirdparty/faiss/faiss/index_io.h | 9 + 11 files changed, 2171 insertions(+), 1127 deletions(-) diff --git a/include/knowhere/bitsetview.h b/include/knowhere/bitsetview.h index 464bf774b..0a3e9b979 100644 --- a/include/knowhere/bitsetview.h +++ b/include/knowhere/bitsetview.h @@ -95,6 +95,38 @@ class BitsetView { return ret; } + size_t + get_first_valid_index() const { + size_t ret = 0; + auto len_uint8 = byte_size(); + auto len_uint64 = len_uint8 >> 3; + + uint64_t* p_uint64 = (uint64_t*)bits_; + for (size_t i = 0; i < len_uint64; i++) { + uint64_t value = (~(*p_uint64)); + if (value == 0) { + p_uint64++; + continue; + } + ret = __builtin_ctzll(value); + return i * 64 + ret; + } + + // calculate remainder + uint8_t* p_uint8 = (uint8_t*)bits_ + (len_uint64 << 3); + for (size_t i = (len_uint64 << 3); i < len_uint8; i++) { + uint8_t value = (~(*p_uint8)); + if (value == 0) { + p_uint8++; + continue; + } + ret = __builtin_ctz(value); + return len_uint64 * 64 + i * 8 + ret; + } + + return num_bits_; + } + std::string to_string(size_t from, size_t to) const { if (empty()) { diff --git a/include/knowhere/bitsetview_idselector.h b/include/knowhere/bitsetview_idselector.h index 39f6ff1a8..c9ece2d4a 100644 --- a/include/knowhere/bitsetview_idselector.h +++ b/include/knowhere/bitsetview_idselector.h @@ -20,15 +20,20 @@ namespace knowhere { struct BitsetViewIDSelector final : faiss::IDSelector { const BitsetView bitset_view; const size_t id_offset; + const uint32_t* out_id_mapping; - inline BitsetViewIDSelector(BitsetView bitset_view, const size_t offset = 0) - : bitset_view{bitset_view}, id_offset(offset) { + inline BitsetViewIDSelector(BitsetView bitset_view, const size_t offset = 0, + const uint32_t* out_id_mapping = nullptr) + : bitset_view{bitset_view}, id_offset(offset), out_id_mapping(out_id_mapping) { } inline bool is_member(faiss::idx_t id) const override final { // it is by design that bitset_view.empty() is not tested here - return (!bitset_view.test(id + id_offset)); + if (out_id_mapping == nullptr) { + return (!bitset_view.test(id + id_offset)); + } + return (!bitset_view.test(out_id_mapping[id + id_offset])); } }; diff --git a/include/knowhere/comp/index_param.h b/include/knowhere/comp/index_param.h index d389522fd..4fe8b7cf5 100644 --- a/include/knowhere/comp/index_param.h +++ b/include/knowhere/comp/index_param.h @@ -95,6 +95,8 @@ constexpr const char* JSON_ID_SET = "json_id_set"; constexpr const char* TRACE_ID = "trace_id"; constexpr const char* SPAN_ID = "span_id"; constexpr const char* TRACE_FLAGS = "trace_flags"; +constexpr const char* SCALAR_INFO = "scalar_info"; +constexpr const char* MV_ONLY_ENABLED = "mv_only_enabled"; constexpr const char* MATERIALIZED_VIEW_SEARCH_INFO = "materialized_view_search_info"; constexpr const char* MATERIALIZED_VIEW_OPT_FIELDS_PATH = "opt_fields_path"; constexpr const char* MAX_EMPTY_RESULT_BUCKETS = "max_empty_result_buckets"; diff --git a/src/index/hnsw/faiss_hnsw.cc b/src/index/hnsw/faiss_hnsw.cc index 2f933bfdd..4efd7fff6 100644 --- a/src/index/hnsw/faiss_hnsw.cc +++ b/src/index/hnsw/faiss_hnsw.cc @@ -30,6 +30,7 @@ #include "faiss/IndexHNSW.h" #include "faiss/IndexRefine.h" #include "faiss/impl/ScalarQuantizer.h" +#include "faiss/impl/mapped_io.h" #include "faiss/index_io.h" #include "index/hnsw/faiss_hnsw_config.h" #include "index/hnsw/hnsw.h" @@ -66,6 +67,11 @@ class BaseFaissIndexNode : public IndexNode { search_pool = ThreadPool::GetGlobalSearchThreadPool(); } + bool + IsAdditionalScalarSupported() const override { + return true; + } + // Status Train(const DataSetPtr dataset, std::shared_ptr cfg) override { @@ -159,21 +165,36 @@ is_faiss_fourcc_error(const char* what) { class BaseFaissRegularIndexNode : public BaseFaissIndexNode { public: BaseFaissRegularIndexNode(const int32_t& version, const Object& object) - : BaseFaissIndexNode(version, object), index{nullptr} { + : BaseFaissIndexNode(version, object), indexes(1, nullptr) { } Status Serialize(BinarySet& binset) const override { - if (index == nullptr) { + if (indexes.empty()) { return Status::empty_index; } + for (auto& index : indexes) { + if (index == nullptr) { + return Status::empty_index; + } + } try { MemoryIOWriter writer; - faiss::write_index(index.get(), &writer); + if (indexes.size() > 1) { + faiss::write_mv(&writer); + writeHeader(&writer); + for (auto& index : indexes) { + faiss::write_index(index.get(), &writer); + } - std::shared_ptr data(writer.data()); - binset.Append(Type(), data, writer.tellg()); + std::shared_ptr data(writer.data()); + binset.Append(Type(), data, writer.tellg()); + } else { + faiss::write_index(indexes[0].get(), &writer); + std::shared_ptr data(writer.data()); + binset.Append(Type(), data, writer.tellg()); + } } catch (const std::exception& e) { LOG_KNOWHERE_WARNING_ << "faiss inner error: " << e.what(); return Status::faiss_inner_error; @@ -192,8 +213,21 @@ class BaseFaissRegularIndexNode : public BaseFaissIndexNode { MemoryIOReader reader(binary->data.get(), binary->size); try { - auto read_index = std::unique_ptr(faiss::read_index(&reader)); - index.reset(read_index.release()); + bool is_mv = faiss::read_is_mv(&reader); + if (is_mv) { + LOG_KNOWHERE_INFO_ << "start to load index by mv"; + uint32_t v = readHeader(&reader); + indexes.resize(v); + LOG_KNOWHERE_INFO_ << "read " << v << " mvs"; + for (auto i = 0; i < v; ++i) { + auto read_index = std::unique_ptr(faiss::read_index(&reader)); + indexes[i].reset(read_index.release()); + } + } else { + reader.reset(); + auto read_index = std::unique_ptr(faiss::read_index(&reader)); + indexes[0].reset(read_index.release()); + } } catch (const std::exception& e) { if (is_faiss_fourcc_error(e.what())) { LOG_KNOWHERE_WARNING_ << "faiss does not recognize the input index: " << e.what(); @@ -217,8 +251,32 @@ class BaseFaissRegularIndexNode : public BaseFaissIndexNode { } try { - auto read_index = std::unique_ptr(faiss::read_index(filename.data(), io_flags)); - index.reset(read_index.release()); + bool is_mv = faiss::read_is_mv(filename.data()); + if (is_mv) { + auto read_index = [&](faiss::IOReader* r) { + LOG_KNOWHERE_INFO_ << "start to load index by mv"; + read_is_mv(r); + uint32_t v = readHeader(r); + LOG_KNOWHERE_INFO_ << "read " << v << " mvs"; + indexes.resize(v); + for (auto i = 0; i < v; ++i) { + auto read_index = std::unique_ptr(faiss::read_index(r, io_flags)); + indexes[i].reset(read_index.release()); + } + }; + if ((io_flags & faiss::IO_FLAG_MMAP_IFC) == faiss::IO_FLAG_MMAP_IFC) { + // enable mmap-supporting IOReader + auto owner = std::make_shared(filename.data()); + faiss::MappedFileIOReader reader(owner); + read_index(&reader); + } else { + faiss::FileIOReader reader(filename.data()); + read_index(&reader); + } + } else { + auto read_index = std::unique_ptr(faiss::read_index(filename.data(), io_flags)); + indexes[0].reset(read_index.release()); + } } catch (const std::exception& e) { if (is_faiss_fourcc_error(e.what())) { LOG_KNOWHERE_WARNING_ << "faiss does not recognize the input index: " << e.what(); @@ -235,32 +293,44 @@ class BaseFaissRegularIndexNode : public BaseFaissIndexNode { // int64_t Dim() const override { - if (index == nullptr) { + if (indexes.empty()) { + return -1; + } + if (indexes[0] == nullptr) { return -1; } - return index->d; + return indexes[0]->d; } int64_t Count() const override { - if (index == nullptr) { + if (indexes.empty()) { return -1; } + int64_t count = 0; + for (auto& index : indexes) { + count += index->ntotal; + } // total number of indexed vectors - return index->ntotal; + return count; } int64_t Size() const override { - if (index == nullptr) { - return 0; + if (indexes.empty()) { + return -1; + } + if (indexes[0] == nullptr) { + return -1; } // a temporary yet expensive workaround faiss::cppcontrib::knowhere::CountSizeIOWriter writer; - faiss::write_index(index.get(), &writer); + for (auto& index : indexes) { + faiss::write_index(index.get(), &writer); + } // todo return writer.total_size; @@ -269,25 +339,64 @@ class BaseFaissRegularIndexNode : public BaseFaissIndexNode { protected: // it is std::shared_ptr, not std::unique_ptr, because it can be // shared with FaissHnswIterator - std::shared_ptr index; - - Status - AddInternal(const DataSetPtr dataset, const Config&) override { - if (this->index == nullptr) { - LOG_KNOWHERE_ERROR_ << "Can not add data to an empty index."; - return Status::empty_index; + // std::shared_ptr index; + std::vector> indexes; + // each index's out ids(label), can be shared with FaissHnswIterator + std::vector>> labels; + + // index rows, help to locate index id by offset + std::vector index_rows_sum; + // label to locate internal offset + std::vector label_to_internal_offset; + + int + getIndexToSearchByScalarInfo(const FaissHnswConfig& config, const BitsetView& bitset) const { + if (indexes.size() == 1) { + return 0; } - - auto data = dataset->GetTensor(); - auto rows = dataset->GetRows(); - try { - this->index->add(rows, reinterpret_cast(data)); - } catch (const std::exception& e) { - LOG_KNOWHERE_WARNING_ << "faiss inner error: " << e.what(); - return Status::faiss_inner_error; + if (bitset.empty()) { + LOG_KNOWHERE_WARNING_ << "partition key value not correctly set"; + return -1; + } + // all data is filtered, just pick the first one + // this will not happen combined with milvus, which will not call knowhere and just return + if (bitset.count() == bitset.size()) { + return 0; } + size_t first_valid_index = bitset.get_first_valid_index(); + auto it = std::lower_bound(index_rows_sum.begin(), index_rows_sum.end(), + label_to_internal_offset[first_valid_index] + 1); + return std::distance(index_rows_sum.begin(), it) - 1; + } - return Status::success; + void + writeHeader(faiss::IOWriter* f) const { + uint32_t version = 0; + faiss::write_value(version, f); + uint32_t size = indexes.size(); + faiss::write_value(size, f); + uint32_t cluster_size = labels.size(); + faiss::write_value(cluster_size, f); + for (const auto& label : labels) { + faiss::write_vector(*label, f); + } + faiss::write_vector(index_rows_sum, f); + faiss::write_vector(label_to_internal_offset, f); + } + + uint32_t + readHeader(faiss::IOReader* f) { + [[maybe_unused]] uint32_t version = faiss::read_value(f); + uint32_t size = faiss::read_value(f); + uint32_t cluster_size = faiss::read_value(f); + labels.resize(cluster_size); + for (auto j = 0; j < cluster_size; ++j) { + labels[j] = std::make_shared>(); + faiss::read_vector(*labels[j], f); + } + faiss::read_vector(index_rows_sum, f); + faiss::read_vector(label_to_internal_offset, f); + return size; } }; @@ -319,6 +428,48 @@ static constexpr DataFormatEnum datatype_v = DataType2EnumHelper::value; namespace { +bool +convert_rows_to_fp32(const void* const __restrict src_in, float* const __restrict dst, + const DataFormatEnum src_data_format, const uint32_t* offsets, const size_t nrows, + const size_t dim) { + if (src_data_format == DataFormatEnum::fp16) { + const knowhere::fp16* const src = reinterpret_cast(src_in); + for (size_t i = 0; i < nrows; i++) { + for (size_t j = 0; j < dim; ++j) { + dst[i * dim + j] = (float)(src[offsets[i] * dim + j]); + } + } + return true; + } else if (src_data_format == DataFormatEnum::bf16) { + const knowhere::bf16* const src = reinterpret_cast(src_in); + for (size_t i = 0; i < nrows; i++) { + for (size_t j = 0; j < dim; ++j) { + dst[i * dim + j] = (float)(src[offsets[i] * dim + j]); + } + } + return true; + } else if (src_data_format == DataFormatEnum::fp32) { + const knowhere::fp32* const src = reinterpret_cast(src_in); + for (size_t i = 0; i < nrows; i++) { + for (size_t j = 0; j < dim; ++j) { + dst[i * dim + j] = (float)(src[offsets[i] * dim + j]); + } + } + return true; + } else if (src_data_format == DataFormatEnum::int8) { + const knowhere::int8* const src = reinterpret_cast(src_in); + for (size_t i = 0; i < nrows; i++) { + for (size_t j = 0; j < dim; ++j) { + dst[i * dim + j] = (float)(src[offsets[i] * dim + j]); + } + } + return true; + } else { + // unknown + return false; + } +} + bool convert_rows_to_fp32(const void* const __restrict src_in, float* const __restrict dst, const DataFormatEnum src_data_format, const size_t start_row, const size_t nrows, @@ -437,6 +588,39 @@ add_to_index(faiss::Index* const __restrict index, const DataSetPtr& dataset, co return Status::success; } +Status +add_partial_dataset_to_index(faiss::Index* const __restrict index, const DataSetPtr& dataset, + const DataFormatEnum data_format, const std::vector& ids) { + const auto* data = dataset->GetTensor(); + + if (ids.size() > dataset->GetRows()) { + LOG_KNOWHERE_ERROR_ << "partial ids size larger than whole dataset size"; + return Status::invalid_args; + } + const int64_t rows = ids.size(); + const auto dim = dataset->GetDim(); + + // convert data into float in pieces and add to the index + constexpr int64_t n_tmp_rows = 4096; + std::unique_ptr tmp = std::make_unique(n_tmp_rows * dim); + + for (int64_t irow = 0; irow < rows; irow += n_tmp_rows) { + const int64_t start_row = irow; + const int64_t end_row = std::min(rows, start_row + n_tmp_rows); + const int64_t count_rows = end_row - start_row; + + if (!convert_rows_to_fp32(data, tmp.get(), data_format, ids.data() + start_row, count_rows, dim)) { + LOG_KNOWHERE_ERROR_ << "Unsupported data format"; + return Status::invalid_args; + } + + // add + index->add(count_rows, tmp.get()); + } + + return Status::success; +} + // IndexFlat and IndexFlatCosine contain raw fp32 data // IndexScalarQuantizer and IndexScalarQuantizerCosine may contain rar bf16 and fp16 data // @@ -484,6 +668,44 @@ storage_distance_computer(const faiss::Index* storage) { } } +// there are chances that each partition split by scalar distribution is too small that we could not evn train pq on it +// bcz 256 points are needed for a 8-bit pq training in faiss +// combine some small partitions to get a bigger one +std::vector> +combine_partitions(const std::vector>& scalar_info, const int64_t base_rows) { + auto scalar_size = scalar_info.size(); + std::vector indices(scalar_size); + std::iota(indices.begin(), indices.end(), 0); + std::vector sizes; + sizes.reserve(scalar_size); + for (auto& id_list : scalar_info) { + sizes.emplace_back(id_list.size()); + } + std::sort(indices.begin(), indices.end(), [&sizes](size_t i1, size_t i2) { return sizes[i1] < sizes[i2]; }); + std::vector> res; + std::vector cur; + int64_t cur_size = 0; + for (auto i = 0; i < sizes.size(); ++i) { + cur_size += sizes[i]; + cur.push_back(indices[i]); + if (cur_size >= base_rows) { + res.push_back(cur); + cur.clear(); + cur_size = 0; + } + } + // tail + if (!cur.empty()) { + if (res.empty()) { + res.push_back(cur); + return res; + } else { + res[res.size() - 1].insert(res[res.size() - 1].end(), cur.begin(), cur.end()); + } + } + return res; +} + } // namespace // Contains an iterator state @@ -532,10 +754,11 @@ struct FaissHnswIteratorWorkspace { // Contains an iterator logic class FaissHnswIterator : public IndexIterator { public: - FaissHnswIterator(const std::shared_ptr& index_in, std::unique_ptr&& query_in, + FaissHnswIterator(const std::shared_ptr& index_in, + const std::shared_ptr>& labels_in, std::unique_ptr&& query_in, const BitsetView& bitset_in, const int32_t ef_in, bool larger_is_closer, const float refine_ratio = 0.5f, bool use_knowhere_search_pool = true) - : IndexIterator(larger_is_closer, use_knowhere_search_pool, refine_ratio), index{index_in} { + : IndexIterator(larger_is_closer, use_knowhere_search_pool, refine_ratio), index{index_in}, labels{labels_in} { workspace.accumulated_alpha = (bitset_in.count() >= (index->ntotal * HnswSearchThresholds::kHnswSearchKnnBFFilterThreshold)) ? std::numeric_limits::max() @@ -740,6 +963,12 @@ class FaissHnswIterator : public IndexIterator { } } + if (labels != nullptr) { + for (auto& p : workspace.dists) { + p.id = labels->operator[](p.id); + } + } + // pass back to the handler batch_handler(workspace.dists); @@ -756,7 +985,7 @@ class FaissHnswIterator : public IndexIterator { next_batch(batch_handler, sel); } else { using filter_type = knowhere::BitsetViewIDSelector; - filter_type sel(workspace.bitset); + filter_type sel(workspace.bitset, 0, labels == nullptr ? nullptr : labels->data()); next_batch(batch_handler, sel); } @@ -770,6 +999,7 @@ class FaissHnswIterator : public IndexIterator { private: std::shared_ptr index; + std::shared_ptr> labels; FaissHnswIteratorWorkspace workspace; }; @@ -783,31 +1013,42 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode { bool HasRawData(const std::string& metric_type) const override { - if (this->index == nullptr) { + if (indexes.empty()) { return false; } // check whether there is an index to reconstruct a raw data from - return (GetIndexToReconstructRawDataFrom() != nullptr); + // only check one is enough + return (GetIndexToReconstructRawDataFrom(0) != nullptr); } expected GetVectorByIds(const DataSetPtr dataset) const override { - if (index == nullptr) { + if (indexes.empty()) { return expected::Err(Status::empty_index, "index not loaded"); } - if (!index->is_trained) { - return expected::Err(Status::index_not_trained, "index not trained"); + for (auto& index : indexes) { + if (index == nullptr) { + return expected::Err(Status::empty_index, "index not loaded"); + } + if (!index->is_trained) { + return expected::Err(Status::index_not_trained, "index not trained"); + } } // an index that is used for reconstruction - const faiss::Index* index_to_reconstruct_from = GetIndexToReconstructRawDataFrom(); - - // check whether raw data is available - if (index_to_reconstruct_from == nullptr) { - return expected::Err( - Status::invalid_index_error, - "The index does not contain a raw data, cannot proceed with GetVectorByIds"); + // const faiss::Index* index_to_reconstruct_from = GetIndexToReconstructRawDataFrom(); + std::vector indexes_to_reconstruct_from(indexes.size()); + for (auto i = 0; i < indexes.size(); ++i) { + const faiss::Index* index_to_reconstruct_from = GetIndexToReconstructRawDataFrom(i); + + // check whether raw data is available + if (index_to_reconstruct_from == nullptr) { + return expected::Err( + Status::invalid_index_error, + "The index does not contain a raw data, cannot proceed with GetVectorByIds"); + } + indexes_to_reconstruct_from[i] = index_to_reconstruct_from; } // perform reconstruction @@ -815,6 +1056,18 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode { auto rows = dataset->GetRows(); auto ids = dataset->GetIds(); + auto get_vector = [&](int64_t id, float* result) { + if (indexes.size() == 1) { + indexes_to_reconstruct_from[0]->reconstruct(id, result); + } else { + auto it = + std::lower_bound(index_rows_sum.begin(), index_rows_sum.end(), label_to_internal_offset[id] + 1); + auto index_id = std::distance(index_rows_sum.begin(), it) - 1; + indexes_to_reconstruct_from[index_id]->reconstruct( + label_to_internal_offset[id] - index_rows_sum[index_id], result); + } + }; + try { if (data_format == DataFormatEnum::fp32) { // perform a direct reconstruction for fp32 data @@ -822,7 +1075,7 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode { for (int64_t i = 0; i < rows; i++) { const int64_t id = ids[i]; assert(id >= 0 && id < index->ntotal); - index_to_reconstruct_from->reconstruct(id, data.get() + i * dim); + get_vector(id, data.get() + i * dim); } return GenResultDataSet(rows, dim, std::move(data)); } else if (data_format == DataFormatEnum::fp16) { @@ -832,8 +1085,8 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode { auto tmp = std::make_unique(dim); for (int64_t i = 0; i < rows; i++) { const int64_t id = ids[i]; - assert(id >= 0 && id < index->ntotal); - index_to_reconstruct_from->reconstruct(id, tmp.get()); + assert(id >= 0 && id < Count()); + get_vector(id, tmp.get()); if (!convert_rows_from_fp32(tmp.get(), data.get(), data_format, i, 1, dim)) { return expected::Err(Status::invalid_args, "Unsupported data format"); } @@ -846,8 +1099,8 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode { auto tmp = std::make_unique(dim); for (int64_t i = 0; i < rows; i++) { const int64_t id = ids[i]; - assert(id >= 0 && id < index->ntotal); - index_to_reconstruct_from->reconstruct(id, tmp.get()); + assert(id >= 0 && id < Count()); + get_vector(id, tmp.get()); if (!convert_rows_from_fp32(tmp.get(), data.get(), data_format, i, 1, dim)) { return expected::Err(Status::invalid_args, "Unsupported data format"); } @@ -860,8 +1113,8 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode { auto tmp = std::make_unique(dim); for (int64_t i = 0; i < rows; i++) { const int64_t id = ids[i]; - assert(id >= 0 && id < index->ntotal); - index_to_reconstruct_from->reconstruct(id, tmp.get()); + assert(id >= 0 && id < Count()); + get_vector(id, tmp.get()); if (!convert_rows_from_fp32(tmp.get(), data.get(), data_format, i, 1, dim)) { return expected::Err(Status::invalid_args, "Unsupported data format"); } @@ -878,11 +1131,16 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode { expected Search(const DataSetPtr dataset, std::unique_ptr cfg, const BitsetView& bitset) const override { - if (this->index == nullptr) { + if (this->indexes.empty()) { return expected::Err(Status::empty_index, "index not loaded"); } - if (!this->index->is_trained) { - return expected::Err(Status::index_not_trained, "index not trained"); + for (auto& index : indexes) { + if (index == nullptr) { + return expected::Err(Status::empty_index, "index not loaded"); + } + if (!index->is_trained) { + return expected::Err(Status::index_not_trained, "index not trained"); + } } const auto dim = dataset->GetDim(); @@ -891,7 +1149,10 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode { const auto hnsw_cfg = static_cast(*cfg); const auto k = hnsw_cfg.k.value(); - + auto index_id = getIndexToSearchByScalarInfo(hnsw_cfg, bitset); + if (index_id < 0) { + return expected::Err(Status::invalid_args, "partition key value not correctly set"); + } feder::hnsw::FederResultUniq feder_result; if (hnsw_cfg.trace_visit.value()) { if (rows != 1) { @@ -901,7 +1162,7 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode { } // check for brute-force search - auto whether_bf_search = WhetherPerformBruteForceSearch(index.get(), hnsw_cfg, bitset); + auto whether_bf_search = WhetherPerformBruteForceSearch(indexes[index_id].get(), hnsw_cfg, bitset); if (!whether_bf_search.has_value()) { return expected::Err(Status::invalid_args, "k parameter is missing"); @@ -912,7 +1173,7 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode { // set up an index wrapper auto [index_wrapper, is_refined] = create_conditional_hnsw_wrapper( - index.get(), hnsw_cfg, whether_bf_search.value_or(false), whether_to_enable_refine); + indexes[index_id].get(), hnsw_cfg, whether_bf_search.value_or(false), whether_to_enable_refine); if (index_wrapper == nullptr) { return expected::Err(Status::invalid_args, "an input index seems to be unrelated to HNSW"); @@ -934,7 +1195,7 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode { hnsw_search_params.kAlpha = bitset.filter_ratio() * 0.7f; // set up a selector - BitsetViewIDSelector bw_idselector(bitset); + BitsetViewIDSelector bw_idselector(bitset, 0, labels.empty() ? nullptr : labels[index_id].get()->data()); faiss::IDSelector* id_selector = (bitset.empty()) ? nullptr : &bw_idselector; hnsw_search_params.sel = id_selector; @@ -948,39 +1209,45 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode { futs.reserve(rows); for (int64_t i = 0; i < rows; ++i) { - futs.emplace_back( - search_pool->push([&, idx = i, is_refined = is_refined, index_wrapper_ptr = index_wrapper_ptr] { - // 1 thread per element - ThreadPool::ScopedSearchOmpSetter setter(1); - - // set up a query - const float* cur_query = nullptr; - - std::vector cur_query_tmp(dim); - if (data_format == DataFormatEnum::fp32) { - cur_query = (const float*)data + idx * dim; - } else { - convert_rows_to_fp32(data, cur_query_tmp.data(), data_format, idx, 1, dim); - cur_query = cur_query_tmp.data(); - } + futs.emplace_back(search_pool->push([&, idx = i, is_refined = is_refined, + index_wrapper_ptr = index_wrapper_ptr] { + // 1 thread per element + ThreadPool::ScopedSearchOmpSetter setter(1); + + // set up a query + const float* cur_query = nullptr; + + std::vector cur_query_tmp(dim); + if (data_format == DataFormatEnum::fp32) { + cur_query = (const float*)data + idx * dim; + } else { + convert_rows_to_fp32(data, cur_query_tmp.data(), data_format, idx, 1, dim); + cur_query = cur_query_tmp.data(); + } + + // set up local results + faiss::idx_t* const __restrict local_ids = ids.get() + k * idx; + float* const __restrict local_distances = distances.get() + k * idx; - // set up local results - faiss::idx_t* const __restrict local_ids = ids.get() + k * idx; - float* const __restrict local_distances = distances.get() + k * idx; - - // perform the search - if (is_refined) { - faiss::IndexRefineSearchParameters refine_params; - refine_params.k_factor = hnsw_cfg.refine_k.value_or(1); - // a refine procedure itself does not need to care about filtering - refine_params.sel = nullptr; - refine_params.base_index_params = &hnsw_search_params; - - index_wrapper_ptr->search(1, cur_query, k, local_distances, local_ids, &refine_params); - } else { - index_wrapper_ptr->search(1, cur_query, k, local_distances, local_ids, &hnsw_search_params); + // perform the search + if (is_refined) { + faiss::IndexRefineSearchParameters refine_params; + refine_params.k_factor = hnsw_cfg.refine_k.value_or(1); + // a refine procedure itself does not need to care about filtering + refine_params.sel = nullptr; + refine_params.base_index_params = &hnsw_search_params; + + index_wrapper_ptr->search(1, cur_query, k, local_distances, local_ids, &refine_params); + } else { + index_wrapper_ptr->search(1, cur_query, k, local_distances, local_ids, &hnsw_search_params); + } + + if (!labels.empty()) { + for (auto i = 0; i < k; ++i) { + local_ids[i] = local_ids[i] < 0 ? local_ids[i] : labels[index_id]->operator[](local_ids[i]); } - })); + } + })); } // wait for the completion @@ -1006,11 +1273,16 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode { expected RangeSearch(const DataSetPtr dataset, std::unique_ptr cfg, const BitsetView& bitset) const override { - if (this->index == nullptr) { + if (this->indexes.empty()) { return expected::Err(Status::empty_index, "index not loaded"); } - if (!this->index->is_trained) { - return expected::Err(Status::index_not_trained, "index not trained"); + for (auto& index : indexes) { + if (index == nullptr) { + return expected::Err(Status::empty_index, "index not loaded"); + } + if (!index->is_trained) { + return expected::Err(Status::index_not_trained, "index not trained"); + } } const auto dim = dataset->GetDim(); @@ -1018,8 +1290,12 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode { const auto* data = dataset->GetTensor(); const auto hnsw_cfg = static_cast(*cfg); + auto index_id = getIndexToSearchByScalarInfo(hnsw_cfg, bitset); + if (index_id < 0) { + return expected::Err(Status::invalid_args, "partition key value not correctly set"); + } - const bool is_similarity_metric = faiss::is_similarity_metric(index->metric_type); + const bool is_similarity_metric = faiss::is_similarity_metric(indexes[index_id]->metric_type); const float radius = hnsw_cfg.radius.value(); const float range_filter = hnsw_cfg.range_filter.value(); @@ -1033,7 +1309,7 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode { } // check for brute-force search - auto whether_bf_search = WhetherPerformBruteForceRangeSearch(index.get(), hnsw_cfg, bitset); + auto whether_bf_search = WhetherPerformBruteForceRangeSearch(indexes[index_id].get(), hnsw_cfg, bitset); if (!whether_bf_search.has_value()) { return expected::Err(Status::invalid_args, "ef parameter is missing"); @@ -1044,7 +1320,7 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode { // set up an index wrapper auto [index_wrapper, is_refined] = create_conditional_hnsw_wrapper( - index.get(), hnsw_cfg, whether_bf_search.value_or(false), whether_to_enable_refine); + indexes[index_id].get(), hnsw_cfg, whether_bf_search.value_or(false), whether_to_enable_refine); if (index_wrapper == nullptr) { return expected::Err(Status::invalid_args, "an input index seems to be unrelated to HNSW"); @@ -1067,7 +1343,7 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode { hnsw_search_params.kAlpha = bitset.filter_ratio() * 0.7f; // set up a selector - BitsetViewIDSelector bw_idselector(bitset); + BitsetViewIDSelector bw_idselector(bitset, 0, labels.empty() ? nullptr : labels[index_id].get()->data()); faiss::IDSelector* id_selector = (bitset.empty()) ? nullptr : &bw_idselector; hnsw_search_params.sel = id_selector; @@ -1122,9 +1398,16 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode { result_dist_array[idx].resize(elem_cnt); result_id_array[idx].resize(elem_cnt); - for (size_t j = 0; j < elem_cnt; j++) { - result_dist_array[idx][j] = res.distances[j]; - result_id_array[idx][j] = res.labels[j]; + if (labels.empty()) { + for (size_t j = 0; j < elem_cnt; j++) { + result_dist_array[idx][j] = res.distances[j]; + result_id_array[idx][j] = res.labels[j]; + } + } else { + for (size_t j = 0; j < elem_cnt; j++) { + result_dist_array[idx][j] = res.distances[j]; + result_id_array[idx][j] = labels[index_id]->operator[](res.labels[j]); + } } if (hnsw_cfg.range_filter.value() != defaultRangeFilter) { @@ -1147,22 +1430,51 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode { protected: DataFormatEnum data_format; + std::vector> tmp_combined_scalar_ids; + Status AddInternal(const DataSetPtr dataset, const Config&) override { - if (index == nullptr) { + if (indexes.empty()) { LOG_KNOWHERE_ERROR_ << "Can not add data to an empty index."; return Status::empty_index; } + for (auto& index : indexes) { + if (index == nullptr) { + LOG_KNOWHERE_ERROR_ << "Can not add data to an empty index."; + return Status::empty_index; + } + } auto rows = dataset->GetRows(); - try { - LOG_KNOWHERE_INFO_ << "Adding " << rows << " to HNSW Index"; - auto status = add_to_index(index.get(), dataset, data_format); - if (status != Status::success) { + const std::vector>& scalar_info = + dataset->Get>>(meta::SCALAR_INFO); + if (scalar_info.empty()) { + try { + LOG_KNOWHERE_INFO_ << "Adding " << rows << " rows to HNSW Index"; + + auto status = add_to_index(indexes[0].get(), dataset, data_format); return status; + } catch (const std::exception& e) { + LOG_KNOWHERE_WARNING_ << "faiss inner error: " << e.what(); + return Status::faiss_inner_error; } + } + LOG_KNOWHERE_INFO_ << "add to hnsw index with scalar info"; + try { + for (auto i = 0; i < tmp_combined_scalar_ids.size(); ++i) { + for (auto j = 0; j < tmp_combined_scalar_ids[i].size(); ++j) { + auto id = tmp_combined_scalar_ids[i][j]; + // hnsw + LOG_KNOWHERE_INFO_ << "Adding " << scalar_info[id].size() << " to HNSW Index"; + + auto status = add_partial_dataset_to_index(indexes[i].get(), dataset, data_format, scalar_info[id]); + if (status != Status::success) { + return status; + } + } + } } catch (const std::exception& e) { LOG_KNOWHERE_WARNING_ << "faiss inner error: " << e.what(); return Status::faiss_inner_error; @@ -1172,8 +1484,11 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode { } const faiss::Index* - GetIndexToReconstructRawDataFrom() const { - if (index == nullptr) { + GetIndexToReconstructRawDataFrom(int i) const { + if (indexes.size() <= i) { + return nullptr; + } + if (indexes[i] == nullptr) { return nullptr; } @@ -1181,12 +1496,12 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode { const faiss::Index* index_to_reconstruct_from = nullptr; // check whether our index uses refine - auto index_refine = dynamic_cast(index.get()); + auto index_refine = dynamic_cast(indexes[i].get()); if (index_refine == nullptr) { // non-refined index // cast as IndexHNSW - auto index_hnsw = dynamic_cast(index.get()); + auto index_hnsw = dynamic_cast(indexes[i].get()); if (index_hnsw == nullptr) { // this is unexpected, we expect IndexHNSW return nullptr; @@ -1220,10 +1535,16 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode { expected> AnnIterator(const DataSetPtr dataset, std::unique_ptr cfg, const BitsetView& bitset, bool use_knowhere_search_pool) const override { - if (index == nullptr) { - LOG_KNOWHERE_WARNING_ << "creating iterator on empty index"; + if (this->indexes.empty()) { + LOG_KNOWHERE_ERROR_ << "creating iterator on empty index"; return expected>::Err(Status::empty_index, "index not loaded"); } + for (auto& index : indexes) { + if (index == nullptr) { + LOG_KNOWHERE_ERROR_ << "creating iterator on empty index"; + return expected>::Err(Status::empty_index, "index not loaded"); + } + } if (data_format != DataFormatEnum::fp32 && data_format != DataFormatEnum::fp16 && data_format != DataFormatEnum::bf16) { @@ -1239,6 +1560,11 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode { auto vec = std::vector(n_queries, nullptr); const FaissHnswConfig& hnsw_cfg = static_cast(*cfg); + int index_id = getIndexToSearchByScalarInfo(hnsw_cfg, bitset); + if (index_id < 0) { + return expected>::Err(Status::invalid_args, + "partition key value not correctly set"); + } const bool is_cosine = IsMetricType(hnsw_cfg.metric_type.value(), knowhere::metric::COSINE); const bool larger_is_closer = (IsMetricType(hnsw_cfg.metric_type.value(), knowhere::metric::IP) || is_cosine); @@ -1263,7 +1589,8 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode { throw; } - const bool should_use_refine = (dynamic_cast(index.get()) != nullptr); + const bool should_use_refine = + (dynamic_cast(indexes[index_id].get()) != nullptr); const float iterator_refine_ratio = should_use_refine ? hnsw_cfg.iterator_refine_ratio.value_or(0.5) : 0; @@ -1271,8 +1598,9 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode { // create an iterator and initialize it // refine is not needed for flat // hnsw_cfg.iterator_refine_ratio.value_or(0.5f) - auto it = std::make_shared(index, std::move(cur_query), bitset, ef, larger_is_closer, - iterator_refine_ratio, use_knowhere_search_pool); + auto it = std::make_shared( + indexes[index_id], labels.empty() ? nullptr : labels[index_id], std::move(cur_query), bitset, ef, + larger_is_closer, iterator_refine_ratio, use_knowhere_search_pool); // store vec[i] = it; } @@ -1329,54 +1657,100 @@ class BaseFaissRegularIndexHNSWFlatNode : public BaseFaissRegularIndexHNSWNode { const bool is_cosine = IsMetricType(hnsw_cfg.metric_type.value(), metric::COSINE); std::unique_ptr hnsw_index; - if (is_cosine) { - if (data_format == DataFormatEnum::fp32) { - hnsw_index = std::make_unique(dim, hnsw_cfg.M.value()); - } else if (data_format == DataFormatEnum::fp16) { - hnsw_index = std::make_unique(dim, faiss::ScalarQuantizer::QT_fp16, - hnsw_cfg.M.value()); - } else if (data_format == DataFormatEnum::bf16) { - hnsw_index = std::make_unique(dim, faiss::ScalarQuantizer::QT_bf16, - hnsw_cfg.M.value()); - } else if (data_format == DataFormatEnum::int8) { - hnsw_index = std::make_unique( - dim, faiss::ScalarQuantizer::QT_8bit_direct_signed, hnsw_cfg.M.value()); + auto create_index = [&](const float* data, const int i, const int64_t rows) { + if (is_cosine) { + if (data_format == DataFormatEnum::fp32) { + hnsw_index = std::make_unique(dim, hnsw_cfg.M.value()); + } else if (data_format == DataFormatEnum::fp16) { + hnsw_index = std::make_unique(dim, faiss::ScalarQuantizer::QT_fp16, + hnsw_cfg.M.value()); + } else if (data_format == DataFormatEnum::bf16) { + hnsw_index = std::make_unique(dim, faiss::ScalarQuantizer::QT_bf16, + hnsw_cfg.M.value()); + } else if (data_format == DataFormatEnum::int8) { + hnsw_index = std::make_unique( + dim, faiss::ScalarQuantizer::QT_8bit_direct_signed, hnsw_cfg.M.value()); + } else { + LOG_KNOWHERE_ERROR_ << "Unsupported metric type: " << hnsw_cfg.metric_type.value(); + return Status::invalid_metric_type; + } } else { - LOG_KNOWHERE_ERROR_ << "Unsupported metric type: " << hnsw_cfg.metric_type.value(); - return Status::invalid_metric_type; + if (data_format == DataFormatEnum::fp32) { + hnsw_index = std::make_unique(dim, hnsw_cfg.M.value(), metric.value()); + } else if (data_format == DataFormatEnum::fp16) { + hnsw_index = std::make_unique(dim, faiss::ScalarQuantizer::QT_fp16, + hnsw_cfg.M.value(), metric.value()); + } else if (data_format == DataFormatEnum::bf16) { + hnsw_index = std::make_unique(dim, faiss::ScalarQuantizer::QT_bf16, + hnsw_cfg.M.value(), metric.value()); + } else if (data_format == DataFormatEnum::int8) { + hnsw_index = std::make_unique( + dim, faiss::ScalarQuantizer::QT_8bit_direct_signed, hnsw_cfg.M.value(), metric.value()); + } else { + LOG_KNOWHERE_ERROR_ << "Unsupported metric type: " << hnsw_cfg.metric_type.value(); + return Status::invalid_metric_type; + } } - } else { - if (data_format == DataFormatEnum::fp32) { - hnsw_index = std::make_unique(dim, hnsw_cfg.M.value(), metric.value()); - } else if (data_format == DataFormatEnum::fp16) { - hnsw_index = std::make_unique(dim, faiss::ScalarQuantizer::QT_fp16, - hnsw_cfg.M.value(), metric.value()); - } else if (data_format == DataFormatEnum::bf16) { - hnsw_index = std::make_unique(dim, faiss::ScalarQuantizer::QT_bf16, - hnsw_cfg.M.value(), metric.value()); - } else if (data_format == DataFormatEnum::int8) { - hnsw_index = std::make_unique(dim, faiss::ScalarQuantizer::QT_8bit_direct_signed, - hnsw_cfg.M.value(), metric.value()); - } else { - LOG_KNOWHERE_ERROR_ << "Unsupported metric type: " << hnsw_cfg.metric_type.value(); - return Status::invalid_metric_type; + hnsw_index->hnsw.efConstruction = hnsw_cfg.efConstruction.value(); + // train + LOG_KNOWHERE_INFO_ << "Training HNSW Index"; + // this function does nothing for the given parameters and indices. + // as a result, I'm just keeping it to have is_trained set to true. + // WARNING: this may cause problems if ->train() performs some action + // based on the data in the future. Otherwise, data needs to be + // converted into float*. + hnsw_index->train(rows, data); + + // done + indexes[i] = std::move(hnsw_index); + return Status::success; + }; + + const std::vector>& scalar_info = + dataset->Get>>(meta::SCALAR_INFO); + tmp_combined_scalar_ids = + scalar_info.size() > 1 ? combine_partitions(scalar_info, 128) : std::vector>(); + // no scalar info or just one partition(after possible combination), build index on whole data + if (scalar_info.size() <= 1 || tmp_combined_scalar_ids.size() <= 1) { + return create_index((const float*)(data), 0, rows); + } + + LOG_KNOWHERE_INFO_ << "train hnsw index with scalar info"; + + label_to_internal_offset.resize(rows); + index_rows_sum.resize(tmp_combined_scalar_ids.size() + 1); + labels.resize(tmp_combined_scalar_ids.size()); + indexes.resize(tmp_combined_scalar_ids.size()); + for (auto i = 0; i < tmp_combined_scalar_ids.size(); ++i) { + size_t partition_size = 0; + for (int j : tmp_combined_scalar_ids[i]) { + partition_size += scalar_info[j].size(); + } + std::unique_ptr tmp_data = std::make_unique(dim * partition_size); + labels[i] = std::make_shared>(partition_size); + index_rows_sum[i + 1] = index_rows_sum[i] + partition_size; + size_t cur_size = 0; + + for (size_t j = 0; j < tmp_combined_scalar_ids[i].size(); ++j) { + auto scalar_id = tmp_combined_scalar_ids[i][j]; + if (!convert_rows_to_fp32(data, tmp_data.get() + dim * cur_size, data_format, + scalar_info[scalar_id].data(), scalar_info[scalar_id].size(), dim)) { + LOG_KNOWHERE_ERROR_ << "Unsupported data format"; + return Status::invalid_args; + } + for (size_t m = 0; m < scalar_info[scalar_id].size(); ++m) { + labels[i]->operator[](cur_size + m) = scalar_info[scalar_id][m]; + label_to_internal_offset[scalar_info[scalar_id][m]] = index_rows_sum[i] + cur_size + m; + } + cur_size += scalar_info[scalar_id].size(); } - } - - hnsw_index->hnsw.efConstruction = hnsw_cfg.efConstruction.value(); - - // train - LOG_KNOWHERE_INFO_ << "Training HNSW Index"; - // this function does nothing for the given parameters and indices. - // as a result, I'm just keeping it to have is_trained set to true. - // WARNING: this may cause problems if ->train() performs some action - // based on the data in the future. Otherwise, data needs to be - // converted into float*. - hnsw_index->train(rows, (const float*)data); + Status s = create_index((const float*)(tmp_data.get()), i, partition_size); + if (s != Status::success) { + return s; + } + } - // done - index = std::move(hnsw_index); return Status::success; } }; @@ -1409,6 +1783,15 @@ class HNSWIndexNodeWithFallback : public IndexNode { } } + bool + IsAdditionalScalarSupported() const override { + if (use_base_index) { + return base_index->IsAdditionalScalarSupported(); + } else { + return fallback_search_index->IsAdditionalScalarSupported(); + } + } + Status Train(const DataSetPtr dataset, std::shared_ptr cfg) override { if (use_base_index) { @@ -1825,47 +2208,96 @@ class BaseFaissRegularIndexHNSWSQNode : public BaseFaissRegularIndexHNSWNode { // create an index const bool is_cosine = IsMetricType(hnsw_cfg.metric_type.value(), metric::COSINE); - std::unique_ptr hnsw_index; - if (is_cosine) { - hnsw_index = std::make_unique(dim, sq_type.value(), hnsw_cfg.M.value()); - } else { - hnsw_index = std::make_unique(dim, sq_type.value(), hnsw_cfg.M.value(), metric.value()); - } - - hnsw_index->hnsw.efConstruction = hnsw_cfg.efConstruction.value(); - // should refine be used? std::unique_ptr final_index; - if (hnsw_cfg.refine.value_or(false) && hnsw_cfg.refine_type.has_value()) { - // yes - auto final_index_cnd = pick_refine_index(data_format, hnsw_cfg.refine_type, std::move(hnsw_index)); - if (!final_index_cnd.has_value()) { - return Status::invalid_args; + + auto create_index = [&](const float* data, const int i, const int64_t rows) { + std::unique_ptr hnsw_index; + if (is_cosine) { + hnsw_index = std::make_unique(dim, sq_type.value(), hnsw_cfg.M.value()); + } else { + hnsw_index = + std::make_unique(dim, sq_type.value(), hnsw_cfg.M.value(), metric.value()); } - // assign - final_index = std::move(final_index_cnd.value()); - } else { - // no refine + hnsw_index->hnsw.efConstruction = hnsw_cfg.efConstruction.value(); - // assign - final_index = std::move(hnsw_index); - } + if (hnsw_cfg.refine.value_or(false) && hnsw_cfg.refine_type.has_value()) { + // yes + auto final_index_cnd = pick_refine_index(data_format, hnsw_cfg.refine_type, std::move(hnsw_index)); + if (!final_index_cnd.has_value()) { + return Status::invalid_args; + } - // we have to convert the data to float, unfortunately, which costs extra RAM - auto float_ds_ptr = convert_ds_to_float(dataset, data_format); - if (float_ds_ptr == nullptr) { - LOG_KNOWHERE_ERROR_ << "Unsupported data format"; - return Status::invalid_args; - } + // assign + final_index = std::move(final_index_cnd.value()); + } else { + // no refine - // train - LOG_KNOWHERE_INFO_ << "Training HNSW Index"; + // assign + final_index = std::move(hnsw_index); + } - final_index->train(rows, reinterpret_cast(float_ds_ptr->GetTensor())); + // train + LOG_KNOWHERE_INFO_ << "Training HNSW Index"; + + final_index->train(rows, data); + + // done + indexes[i] = std::move(final_index); + return Status::success; + }; + + const std::vector>& scalar_info = + dataset->Get>>(meta::SCALAR_INFO); + tmp_combined_scalar_ids = + scalar_info.size() > 1 ? combine_partitions(scalar_info, 128) : std::vector>(); + // no scalar info or just one partition(after possible combination), build index on whole data + if (scalar_info.size() <= 1 || tmp_combined_scalar_ids.size() <= 1) { + // we have to convert the data to float, unfortunately, which costs extra RAM + auto float_ds_ptr = convert_ds_to_float(dataset, data_format); + if (float_ds_ptr == nullptr) { + LOG_KNOWHERE_ERROR_ << "Unsupported data format"; + return Status::invalid_args; + } + return create_index(reinterpret_cast(float_ds_ptr->GetTensor()), 0, rows); + } + LOG_KNOWHERE_INFO_ << "train hnsw index with scalar info"; + + label_to_internal_offset.resize(rows); + index_rows_sum.resize(tmp_combined_scalar_ids.size() + 1); + labels.resize(tmp_combined_scalar_ids.size()); + indexes.resize(tmp_combined_scalar_ids.size()); + const void* data = dataset->GetTensor(); + for (auto i = 0; i < tmp_combined_scalar_ids.size(); ++i) { + size_t partition_size = 0; + for (int j : tmp_combined_scalar_ids[i]) { + partition_size += scalar_info[j].size(); + } + std::unique_ptr tmp_data = std::make_unique(dim * partition_size); + labels[i] = std::make_shared>(partition_size); + index_rows_sum[i + 1] = index_rows_sum[i] + partition_size; + size_t cur_size = 0; + + for (size_t j = 0; j < tmp_combined_scalar_ids[i].size(); ++j) { + auto scalar_id = tmp_combined_scalar_ids[i][j]; + if (!convert_rows_to_fp32(data, tmp_data.get() + dim * cur_size, data_format, + scalar_info[scalar_id].data(), scalar_info[scalar_id].size(), dim)) { + LOG_KNOWHERE_ERROR_ << "Unsupported data format"; + return Status::invalid_args; + } + for (size_t m = 0; m < scalar_info[scalar_id].size(); ++m) { + labels[i]->operator[](cur_size + m) = scalar_info[scalar_id][m]; + label_to_internal_offset[scalar_info[scalar_id][m]] = index_rows_sum[i] + cur_size + m; + } + cur_size += scalar_info[scalar_id].size(); + } - // done - index = std::move(final_index); + Status s = create_index((const float*)(tmp_data.get()), i, partition_size); + if (s != Status::success) { + return s; + } + } return Status::success; } @@ -1914,7 +2346,7 @@ class BaseFaissRegularIndexHNSWPQNode : public BaseFaissRegularIndexHNSWNode { } protected: - std::unique_ptr tmp_index_pq; + std::vector> tmp_index_pq; Status TrainInternal(const DataSetPtr dataset, const Config& cfg) override { @@ -1926,6 +2358,12 @@ class BaseFaissRegularIndexHNSWPQNode : public BaseFaissRegularIndexHNSWNode { // config auto hnsw_cfg = static_cast(cfg); + if (rows < (1 << hnsw_cfg.nbits.value())) { + LOG_KNOWHERE_ERROR_ << rows << " rows not enough, needs at least " << (1 << hnsw_cfg.nbits.value()) + << " rows"; + return Status::faiss_inner_error; + } + auto metric = Str2FaissMetricType(hnsw_cfg.metric_type.value()); if (!metric.has_value()) { LOG_KNOWHERE_ERROR_ << "Invalid metric type: " << hnsw_cfg.metric_type.value(); @@ -1937,105 +2375,149 @@ class BaseFaissRegularIndexHNSWPQNode : public BaseFaissRegularIndexHNSWNode { // HNSW + PQ index yields BAD recall somewhy. // Let's build HNSW+FLAT index, then replace FLAT with PQ + auto create_index = [&](const float* data, const int i, const int64_t rows) { + std::unique_ptr hnsw_index; + if (is_cosine) { + hnsw_index = std::make_unique(dim, hnsw_cfg.M.value()); + } else { + hnsw_index = std::make_unique(dim, hnsw_cfg.M.value(), metric.value()); + } - std::unique_ptr hnsw_index; - if (is_cosine) { - hnsw_index = std::make_unique(dim, hnsw_cfg.M.value()); - } else { - hnsw_index = std::make_unique(dim, hnsw_cfg.M.value(), metric.value()); - } + hnsw_index->hnsw.efConstruction = hnsw_cfg.efConstruction.value(); - hnsw_index->hnsw.efConstruction = hnsw_cfg.efConstruction.value(); + // pq + std::unique_ptr pq_index; + if (is_cosine) { + pq_index = std::make_unique(dim, hnsw_cfg.m.value(), hnsw_cfg.nbits.value()); + } else { + pq_index = + std::make_unique(dim, hnsw_cfg.m.value(), hnsw_cfg.nbits.value(), metric.value()); + } - // pq - std::unique_ptr pq_index; - if (is_cosine) { - pq_index = std::make_unique(dim, hnsw_cfg.m.value(), hnsw_cfg.nbits.value()); - } else { - pq_index = - std::make_unique(dim, hnsw_cfg.m.value(), hnsw_cfg.nbits.value(), metric.value()); - } + // should refine be used? + std::unique_ptr final_index; + if (hnsw_cfg.refine.value_or(false) && hnsw_cfg.refine_type.has_value()) { + // yes + auto final_index_cnd = pick_refine_index(data_format, hnsw_cfg.refine_type, std::move(hnsw_index)); + if (!final_index_cnd.has_value()) { + return Status::invalid_args; + } - // should refine be used? - std::unique_ptr final_index; - if (hnsw_cfg.refine.value_or(false) && hnsw_cfg.refine_type.has_value()) { - // yes - auto final_index_cnd = pick_refine_index(data_format, hnsw_cfg.refine_type, std::move(hnsw_index)); - if (!final_index_cnd.has_value()) { - return Status::invalid_args; + // assign + final_index = std::move(final_index_cnd.value()); + } else { + // no refine + + // assign + final_index = std::move(hnsw_index); } - // assign - final_index = std::move(final_index_cnd.value()); - } else { - // no refine + // train hnswflat + LOG_KNOWHERE_INFO_ << "Training HNSW Index"; - // assign - final_index = std::move(hnsw_index); - } + final_index->train(rows, data); - // we have to convert the data to float, unfortunately, which costs extra RAM - auto float_ds_ptr = convert_ds_to_float(dataset, data_format); - if (float_ds_ptr == nullptr) { - LOG_KNOWHERE_ERROR_ << "Unsupported data format"; - return Status::invalid_args; + // train pq + LOG_KNOWHERE_INFO_ << "Training PQ Index"; + + pq_index->train(rows, data); + pq_index->pq.compute_sdc_table(); + + // done + indexes[i] = std::move(final_index); + tmp_index_pq[i] = std::move(pq_index); + return Status::success; + }; + + const std::vector>& scalar_info = + dataset->Get>>(meta::SCALAR_INFO); + tmp_combined_scalar_ids = scalar_info.size() > 1 + ? combine_partitions(scalar_info, (1 << hnsw_cfg.nbits.value())) + : std::vector>(); + + // no scalar info or just one partition(after possible combination), build index on whole data + if (scalar_info.size() <= 1 || tmp_combined_scalar_ids.size() <= 1) { + tmp_index_pq.resize(1); + // we have to convert the data to float, unfortunately, which costs extra RAM + auto float_ds_ptr = convert_ds_to_float(dataset, data_format); + if (float_ds_ptr == nullptr) { + LOG_KNOWHERE_ERROR_ << "Unsupported data format"; + return Status::invalid_args; + } + return create_index((const float*)(float_ds_ptr->GetTensor()), 0, rows); } - // train hnswflat - LOG_KNOWHERE_INFO_ << "Training HNSW Index"; + LOG_KNOWHERE_INFO_ << "train hnsw index with scalar info"; - final_index->train(rows, reinterpret_cast(float_ds_ptr->GetTensor())); + label_to_internal_offset.resize(rows); + index_rows_sum.resize(tmp_combined_scalar_ids.size() + 1); - // train pq - LOG_KNOWHERE_INFO_ << "Training PQ Index"; + labels.resize(tmp_combined_scalar_ids.size()); + tmp_index_pq.resize(tmp_combined_scalar_ids.size()); - pq_index->train(rows, reinterpret_cast(float_ds_ptr->GetTensor())); - pq_index->pq.compute_sdc_table(); + const void* data = dataset->GetTensor(); + indexes.resize(tmp_combined_scalar_ids.size()); + for (auto i = 0; i < tmp_combined_scalar_ids.size(); ++i) { + size_t partition_size = 0; + for (int j : tmp_combined_scalar_ids[i]) { + partition_size += scalar_info[j].size(); + } + std::unique_ptr tmp_data = std::make_unique(dim * partition_size); + labels[i] = std::make_shared>(partition_size); + index_rows_sum[i + 1] = index_rows_sum[i] + partition_size; + size_t cur_size = 0; + + for (size_t j = 0; j < tmp_combined_scalar_ids[i].size(); ++j) { + auto scalar_id = tmp_combined_scalar_ids[i][j]; + if (!convert_rows_to_fp32(data, tmp_data.get() + dim * cur_size, data_format, + scalar_info[scalar_id].data(), scalar_info[scalar_id].size(), dim)) { + LOG_KNOWHERE_ERROR_ << "Unsupported data format"; + return Status::invalid_args; + } + for (size_t m = 0; m < scalar_info[scalar_id].size(); ++m) { + labels[i]->operator[](cur_size + m) = scalar_info[scalar_id][m]; + label_to_internal_offset[scalar_info[scalar_id][m]] = index_rows_sum[i] + cur_size + m; + } + cur_size += scalar_info[scalar_id].size(); + } - // done - index = std::move(final_index); - tmp_index_pq = std::move(pq_index); + Status s = create_index((const float*)(tmp_data.get()), i, partition_size); + if (s != Status::success) { + return s; + } + } return Status::success; } Status AddInternal(const DataSetPtr dataset, const Config&) override { - if (this->index == nullptr) { + if (this->indexes.empty()) { LOG_KNOWHERE_ERROR_ << "Can not add data to an empty index."; return Status::empty_index; } - - auto rows = dataset->GetRows(); - try { - // hnsw - LOG_KNOWHERE_INFO_ << "Adding " << rows << " to HNSW Index"; - - auto status_reg = add_to_index(index.get(), dataset, data_format); - if (status_reg != Status::success) { - return status_reg; + for (auto& index : indexes) { + if (index == nullptr) { + LOG_KNOWHERE_ERROR_ << "Can not add data to an empty index."; + return Status::empty_index; } + } - // pq - LOG_KNOWHERE_INFO_ << "Adding " << rows << " to PQ Index"; - - auto status_pq = add_to_index(tmp_index_pq.get(), dataset, data_format); - if (status_pq != Status::success) { - return status_pq; - } + auto rows = dataset->GetRows(); + auto finalize_index = [&](int i) { // we're done. // throw away flat and replace it with pq // check if we have a refine available. faiss::IndexHNSW* index_hnsw = nullptr; - faiss::IndexRefine* const index_refine = dynamic_cast(index.get()); + faiss::IndexRefine* const index_refine = dynamic_cast(indexes[i].get()); if (index_refine != nullptr) { index_hnsw = dynamic_cast(index_refine->base_index); } else { - index_hnsw = dynamic_cast(index.get()); + index_hnsw = dynamic_cast(indexes[i].get()); } // recreate hnswpq @@ -2057,14 +2539,64 @@ class BaseFaissRegularIndexHNSWPQNode : public BaseFaissRegularIndexHNSWNode { index_hnsw_pq->storage = nullptr; // replace storage - index_hnsw_pq->storage = tmp_index_pq.release(); + index_hnsw_pq->storage = tmp_index_pq[i].release(); // replace if refine if (index_refine != nullptr) { delete index_refine->base_index; index_refine->base_index = index_hnsw_pq.release(); } else { - index = std::move(index_hnsw_pq); + indexes[i] = std::move(index_hnsw_pq); + } + return Status::success; + }; + try { + const std::vector>& scalar_info = + dataset->Get>>(meta::SCALAR_INFO); + + if (scalar_info.size() <= 1 || tmp_combined_scalar_ids.size() <= 1) { + // hnsw + LOG_KNOWHERE_INFO_ << "Adding " << rows << " to HNSW Index"; + + auto status_reg = add_to_index(indexes[0].get(), dataset, data_format); + if (status_reg != Status::success) { + return status_reg; + } + + // pq + LOG_KNOWHERE_INFO_ << "Adding " << rows << " to PQ Index"; + + auto status_pq = add_to_index(tmp_index_pq[0].get(), dataset, data_format); + if (status_pq != Status::success) { + return status_pq; + } + return finalize_index(0); + } + LOG_KNOWHERE_INFO_ << "add to hnsw index with scalar info"; + + for (auto i = 0; i < tmp_combined_scalar_ids.size(); ++i) { + for (auto j = 0; j < tmp_combined_scalar_ids[i].size(); ++j) { + auto id = tmp_combined_scalar_ids[i][j]; + // hnsw + LOG_KNOWHERE_INFO_ << "Adding " << scalar_info[id].size() << " to HNSW Index"; + + auto status_reg = + add_partial_dataset_to_index(indexes[i].get(), dataset, data_format, scalar_info[id]); + if (status_reg != Status::success) { + return status_reg; + } + + // pq + LOG_KNOWHERE_INFO_ << "Adding " << scalar_info[id].size() << " to PQ Index"; + + auto status_pq = + add_partial_dataset_to_index(tmp_index_pq[i].get(), dataset, data_format, scalar_info[id]); + + if (status_pq != Status::success) { + return status_pq; + } + } + finalize_index(i); } } catch (const std::exception& e) { @@ -2113,7 +2645,7 @@ class BaseFaissRegularIndexHNSWPRQNode : public BaseFaissRegularIndexHNSWNode { } protected: - std::unique_ptr tmp_index_prq; + std::vector> tmp_index_prq; Status TrainInternal(const DataSetPtr dataset, const Config& cfg) override { @@ -2125,6 +2657,12 @@ class BaseFaissRegularIndexHNSWPRQNode : public BaseFaissRegularIndexHNSWNode { // config auto hnsw_cfg = static_cast(cfg); + if (rows < (1 << hnsw_cfg.nbits.value())) { + LOG_KNOWHERE_ERROR_ << rows << " rows not enough, needs at least " << (1 << hnsw_cfg.nbits.value()) + << " rows"; + return Status::faiss_inner_error; + } + auto metric = Str2FaissMetricType(hnsw_cfg.metric_type.value()); if (!metric.has_value()) { LOG_KNOWHERE_ERROR_ << "Invalid metric type: " << hnsw_cfg.metric_type.value(); @@ -2136,110 +2674,156 @@ class BaseFaissRegularIndexHNSWPRQNode : public BaseFaissRegularIndexHNSWNode { // HNSW + PRQ index yields BAD recall somewhy. // Let's build HNSW+FLAT index, then replace FLAT with PRQ + auto create_index = [&](const float* data, const int i, const int64_t rows) { + std::unique_ptr hnsw_index; + if (is_cosine) { + hnsw_index = std::make_unique(dim, hnsw_cfg.M.value()); + } else { + hnsw_index = std::make_unique(dim, hnsw_cfg.M.value(), metric.value()); + } - std::unique_ptr hnsw_index; - if (is_cosine) { - hnsw_index = std::make_unique(dim, hnsw_cfg.M.value()); - } else { - hnsw_index = std::make_unique(dim, hnsw_cfg.M.value(), metric.value()); - } + hnsw_index->hnsw.efConstruction = hnsw_cfg.efConstruction.value(); - hnsw_index->hnsw.efConstruction = hnsw_cfg.efConstruction.value(); + // prq + faiss::AdditiveQuantizer::Search_type_t prq_search_type = + (metric.value() == faiss::MetricType::METRIC_INNER_PRODUCT) + ? faiss::AdditiveQuantizer::Search_type_t::ST_LUT_nonorm + : faiss::AdditiveQuantizer::Search_type_t::ST_norm_float; + + std::unique_ptr prq_index; + if (is_cosine) { + prq_index = std::make_unique( + dim, hnsw_cfg.m.value(), hnsw_cfg.nrq.value(), hnsw_cfg.nbits.value(), prq_search_type); + } else { + prq_index = std::make_unique( + dim, hnsw_cfg.m.value(), hnsw_cfg.nrq.value(), hnsw_cfg.nbits.value(), metric.value(), + prq_search_type); + } - // prq - faiss::AdditiveQuantizer::Search_type_t prq_search_type = - (metric.value() == faiss::MetricType::METRIC_INNER_PRODUCT) - ? faiss::AdditiveQuantizer::Search_type_t::ST_LUT_nonorm - : faiss::AdditiveQuantizer::Search_type_t::ST_norm_float; + // should refine be used? + std::unique_ptr final_index; + if (hnsw_cfg.refine.value_or(false) && hnsw_cfg.refine_type.has_value()) { + // yes + auto final_index_cnd = pick_refine_index(data_format, hnsw_cfg.refine_type, std::move(hnsw_index)); + if (!final_index_cnd.has_value()) { + return Status::invalid_args; + } - std::unique_ptr prq_index; - if (is_cosine) { - prq_index = std::make_unique( - dim, hnsw_cfg.m.value(), hnsw_cfg.nrq.value(), hnsw_cfg.nbits.value(), prq_search_type); - } else { - prq_index = std::make_unique( - dim, hnsw_cfg.m.value(), hnsw_cfg.nrq.value(), hnsw_cfg.nbits.value(), metric.value(), prq_search_type); - } + // assign + final_index = std::move(final_index_cnd.value()); + } else { + // no refine - // should refine be used? - std::unique_ptr final_index; - if (hnsw_cfg.refine.value_or(false) && hnsw_cfg.refine_type.has_value()) { - // yes - auto final_index_cnd = pick_refine_index(data_format, hnsw_cfg.refine_type, std::move(hnsw_index)); - if (!final_index_cnd.has_value()) { - return Status::invalid_args; + // assign + final_index = std::move(hnsw_index); } - // assign - final_index = std::move(final_index_cnd.value()); - } else { - // no refine + // train hnswflat + LOG_KNOWHERE_INFO_ << "Training HNSW Index"; - // assign - final_index = std::move(hnsw_index); - } + final_index->train(rows, data); - // we have to convert the data to float, unfortunately, which costs extra RAM - auto float_ds_ptr = convert_ds_to_float(dataset, data_format); - if (float_ds_ptr == nullptr) { - LOG_KNOWHERE_ERROR_ << "Unsupported data format"; - return Status::invalid_args; + // train prq + LOG_KNOWHERE_INFO_ << "Training ProductResidualQuantizer Index"; + + prq_index->train(rows, data); + + // done + indexes[i] = std::move(final_index); + tmp_index_prq[i] = std::move(prq_index); + + return Status::success; + }; + const std::vector>& scalar_info = + dataset->Get>>(meta::SCALAR_INFO); + tmp_combined_scalar_ids = scalar_info.size() > 1 + ? combine_partitions(scalar_info, (1 << hnsw_cfg.nbits.value())) + : std::vector>(); + + // no scalar info or just one partition(after possible combination), build index on whole data + if (scalar_info.size() <= 1 || tmp_combined_scalar_ids.size() <= 1) { + tmp_index_prq.resize(1); + // we have to convert the data to float, unfortunately, which costs extra RAM + auto float_ds_ptr = convert_ds_to_float(dataset, data_format); + + if (float_ds_ptr == nullptr) { + LOG_KNOWHERE_ERROR_ << "Unsupported data format"; + return Status::invalid_args; + } + return create_index((const float*)(float_ds_ptr->GetTensor()), 0, rows); } - // train hnswflat - LOG_KNOWHERE_INFO_ << "Training HNSW Index"; + LOG_KNOWHERE_INFO_ << "train hnsw index with scalar info"; - final_index->train(rows, reinterpret_cast(float_ds_ptr->GetTensor())); + label_to_internal_offset.resize(rows); + index_rows_sum.resize(tmp_combined_scalar_ids.size() + 1); - // train prq - LOG_KNOWHERE_INFO_ << "Training ProductResidualQuantizer Index"; + labels.resize(tmp_combined_scalar_ids.size()); + tmp_index_prq.resize(tmp_combined_scalar_ids.size()); - prq_index->train(rows, reinterpret_cast(float_ds_ptr->GetTensor())); + const void* data = dataset->GetTensor(); + indexes.resize(tmp_combined_scalar_ids.size()); + for (auto i = 0; i < tmp_combined_scalar_ids.size(); ++i) { + size_t partition_size = 0; + for (int j : tmp_combined_scalar_ids[i]) { + partition_size += scalar_info[j].size(); + } + std::unique_ptr tmp_data = std::make_unique(dim * partition_size); + labels[i] = std::make_shared>(partition_size); + index_rows_sum[i + 1] = index_rows_sum[i] + partition_size; + size_t cur_size = 0; + + for (size_t j = 0; j < tmp_combined_scalar_ids[i].size(); ++j) { + auto scalar_id = tmp_combined_scalar_ids[i][j]; + if (!convert_rows_to_fp32(data, tmp_data.get() + dim * cur_size, data_format, + scalar_info[scalar_id].data(), scalar_info[scalar_id].size(), dim)) { + LOG_KNOWHERE_ERROR_ << "Unsupported data format"; + return Status::invalid_args; + } + for (size_t m = 0; m < scalar_info[scalar_id].size(); ++m) { + labels[i]->operator[](cur_size + m) = scalar_info[scalar_id][m]; + label_to_internal_offset[scalar_info[scalar_id][m]] = index_rows_sum[i] + cur_size + m; + } + cur_size += scalar_info[scalar_id].size(); + } - // done - index = std::move(final_index); - tmp_index_prq = std::move(prq_index); + Status s = create_index((const float*)(tmp_data.get()), i, partition_size); + if (s != Status::success) { + return s; + } + } return Status::success; } Status AddInternal(const DataSetPtr dataset, const Config&) override { - if (this->index == nullptr) { + if (indexes.empty()) { LOG_KNOWHERE_ERROR_ << "Can not add data to an empty index."; return Status::empty_index; } - - auto rows = dataset->GetRows(); - try { - // hnsw - LOG_KNOWHERE_INFO_ << "Adding " << rows << " to HNSW Index"; - - auto status_reg = add_to_index(index.get(), dataset, data_format); - if (status_reg != Status::success) { - return status_reg; + for (auto& index : indexes) { + if (index == nullptr) { + LOG_KNOWHERE_ERROR_ << "Can not add data to an empty index."; + return Status::empty_index; } + } - // prq - LOG_KNOWHERE_INFO_ << "Adding " << rows << " to ProductResidualQuantizer Index"; - - auto status_prq = add_to_index(tmp_index_prq.get(), dataset, data_format); - if (status_prq != Status::success) { - return status_prq; - } + auto rows = dataset->GetRows(); + auto finalize_index = [&](int i) { // we're done. // throw away flat and replace it with prq // check if we have a refine available. faiss::IndexHNSW* index_hnsw = nullptr; - faiss::IndexRefine* const index_refine = dynamic_cast(index.get()); + faiss::IndexRefine* const index_refine = dynamic_cast(indexes[i].get()); if (index_refine != nullptr) { index_hnsw = dynamic_cast(index_refine->base_index); } else { - index_hnsw = dynamic_cast(index.get()); + index_hnsw = dynamic_cast(indexes[i].get()); } // recreate hnswprq @@ -2261,14 +2845,62 @@ class BaseFaissRegularIndexHNSWPRQNode : public BaseFaissRegularIndexHNSWNode { index_hnsw_prq->storage = nullptr; // replace storage - index_hnsw_prq->storage = tmp_index_prq.release(); + index_hnsw_prq->storage = tmp_index_prq[i].release(); // replace if refine if (index_refine != nullptr) { delete index_refine->base_index; index_refine->base_index = index_hnsw_prq.release(); } else { - index = std::move(index_hnsw_prq); + indexes[i] = std::move(index_hnsw_prq); + } + return Status::success; + }; + try { + const std::vector>& scalar_info = + dataset->Get>>(meta::SCALAR_INFO); + if (scalar_info.size() <= 1 || tmp_combined_scalar_ids.size() <= 1) { + // hnsw + LOG_KNOWHERE_INFO_ << "Adding " << rows << " to HNSW Index"; + + auto status_reg = add_to_index(indexes[0].get(), dataset, data_format); + if (status_reg != Status::success) { + return status_reg; + } + + // prq + LOG_KNOWHERE_INFO_ << "Adding " << rows << " to ProductResidualQuantizer Index"; + + auto status_prq = add_to_index(tmp_index_prq[0].get(), dataset, data_format); + if (status_prq != Status::success) { + return status_prq; + } + return finalize_index(0); + } + LOG_KNOWHERE_INFO_ << "add to hnsw index with scalar info"; + for (auto i = 0; i < tmp_combined_scalar_ids.size(); ++i) { + for (auto j = 0; j < tmp_combined_scalar_ids[i].size(); ++j) { + auto id = tmp_combined_scalar_ids[i][j]; + // hnsw + LOG_KNOWHERE_INFO_ << "Adding " << scalar_info[id].size() << " to HNSW Index"; + + auto status_reg = + add_partial_dataset_to_index(indexes[i].get(), dataset, data_format, scalar_info[id]); + if (status_reg != Status::success) { + return status_reg; + } + + // prq + LOG_KNOWHERE_INFO_ << "Adding " << scalar_info[id].size() << " to PQ Index"; + + auto status_prq = + add_partial_dataset_to_index(tmp_index_prq[i].get(), dataset, data_format, scalar_info[id]); + + if (status_prq != Status::success) { + return status_prq; + } + } + finalize_index(i); } } catch (const std::exception& e) { @@ -2294,7 +2926,6 @@ class BaseFaissRegularIndexHNSWPRQNodeTemplate : public BaseFaissRegularIndexHNS } }; -// MV is only for compatibility #ifdef KNOWHERE_WITH_CARDINAL KNOWHERE_SIMPLE_REGISTER_DENSE_FLOAT_ALL_GLOBAL(HNSW_DEPRECATED, BaseFaissRegularIndexHNSWFlatNodeTemplateWithSearchFallback, @@ -2307,13 +2938,16 @@ KNOWHERE_SIMPLE_REGISTER_DENSE_INT8_GLOBAL(HNSW, BaseFaissRegularIndexHNSWFlatNo #endif KNOWHERE_SIMPLE_REGISTER_DENSE_FLOAT_ALL_GLOBAL(HNSW_SQ, BaseFaissRegularIndexHNSWSQNodeTemplate, - knowhere::feature::MMAP) -KNOWHERE_SIMPLE_REGISTER_DENSE_INT8_GLOBAL(HNSW_SQ, BaseFaissRegularIndexHNSWSQNodeTemplate, knowhere::feature::MMAP) + knowhere::feature::MMAP | knowhere::feature::MV) +KNOWHERE_SIMPLE_REGISTER_DENSE_INT8_GLOBAL(HNSW_SQ, BaseFaissRegularIndexHNSWSQNodeTemplate, + knowhere::feature::MMAP | knowhere::feature::MV) KNOWHERE_SIMPLE_REGISTER_DENSE_FLOAT_ALL_GLOBAL(HNSW_PQ, BaseFaissRegularIndexHNSWPQNodeTemplate, - knowhere::feature::MMAP) -KNOWHERE_SIMPLE_REGISTER_DENSE_INT8_GLOBAL(HNSW_PQ, BaseFaissRegularIndexHNSWPQNodeTemplate, knowhere::feature::MMAP) + knowhere::feature::MMAP | knowhere::feature::MV) +KNOWHERE_SIMPLE_REGISTER_DENSE_INT8_GLOBAL(HNSW_PQ, BaseFaissRegularIndexHNSWPQNodeTemplate, + knowhere::feature::MMAP | knowhere::feature::MV) KNOWHERE_SIMPLE_REGISTER_DENSE_FLOAT_ALL_GLOBAL(HNSW_PRQ, BaseFaissRegularIndexHNSWPRQNodeTemplate, - knowhere::feature::MMAP) -KNOWHERE_SIMPLE_REGISTER_DENSE_INT8_GLOBAL(HNSW_PRQ, BaseFaissRegularIndexHNSWPRQNodeTemplate, knowhere::feature::MMAP) + knowhere::feature::MMAP | knowhere::feature::MV) +KNOWHERE_SIMPLE_REGISTER_DENSE_INT8_GLOBAL(HNSW_PRQ, BaseFaissRegularIndexHNSWPRQNodeTemplate, + knowhere::feature::MMAP | knowhere::feature::MV) } // namespace knowhere diff --git a/src/index/hnsw/impl/IndexConditionalWrapper.cc b/src/index/hnsw/impl/IndexConditionalWrapper.cc index e3ab59e82..69410cbdf 100644 --- a/src/index/hnsw/impl/IndexConditionalWrapper.cc +++ b/src/index/hnsw/impl/IndexConditionalWrapper.cc @@ -51,8 +51,8 @@ WhetherPerformBruteForceSearch(const faiss::Index* index, const BaseConfig& cfg, double ratio = ((double)filtered_out_num) / bitset.size(); knowhere::knowhere_hnsw_bitset_ratio.Observe(ratio); #endif - if (filtered_out_num >= (index->ntotal * HnswSearchThresholds::kHnswSearchKnnBFFilterThreshold) || - k >= (index->ntotal - filtered_out_num) * HnswSearchThresholds::kHnswSearchBFTopkThreshold) { + if (filtered_out_num >= (bitset.size() * HnswSearchThresholds::kHnswSearchKnnBFFilterThreshold) || + k >= (bitset.size() - filtered_out_num) * HnswSearchThresholds::kHnswSearchBFTopkThreshold) { return true; } } @@ -84,8 +84,8 @@ WhetherPerformBruteForceRangeSearch(const faiss::Index* index, const FaissHnswCo double ratio = ((double)filtered_out_num) / bitset.size(); knowhere::knowhere_hnsw_bitset_ratio.Observe(ratio); #endif - if (filtered_out_num >= (index->ntotal * HnswSearchThresholds::kHnswSearchRangeBFFilterThreshold) || - ef >= (index->ntotal - filtered_out_num) * HnswSearchThresholds::kHnswSearchRangeBFFilterThreshold) { + if (filtered_out_num >= (bitset.size() * HnswSearchThresholds::kHnswSearchRangeBFFilterThreshold) || + ef >= (bitset.size() - filtered_out_num) * HnswSearchThresholds::kHnswSearchRangeBFFilterThreshold) { return true; } } diff --git a/tests/ut/test_faiss_hnsw.cc b/tests/ut/test_faiss_hnsw.cc index 33c2d9474..565579aa8 100644 --- a/tests/ut/test_faiss_hnsw.cc +++ b/tests/ut/test_faiss_hnsw.cc @@ -208,6 +208,11 @@ create_index(const std::string& index_type, const std::string& index_file_name, auto base = knowhere::ConvertToDataTypeIfNeeded(default_ds_ptr); + if (conf[knowhere::meta::MV_ONLY_ENABLED]) { + base->Set(knowhere::meta::SCALAR_INFO, + default_ds_ptr->Get>>(knowhere::meta::SCALAR_INFO)); + } + StopWatch sw; index.value().Build(base, conf); double elapsed = sw.elapsed(); @@ -252,7 +257,7 @@ index_support_int8(const knowhere::Json& conf) { // template -void +std::string test_hnsw(const knowhere::DataSetPtr& default_ds_ptr, const knowhere::DataSetPtr& query_ds_ptr, const knowhere::DataSetPtr& golden_result, const std::vector& index_params, const knowhere::Json& conf, const knowhere::BitsetView bitset_view) { @@ -313,11 +318,12 @@ test_hnsw(const knowhere::DataSetPtr& default_ds_ptr, const knowhere::DataSetPtr match_datasets(default_t_ds_ptr, vectors.value(), ids); } + return index_file_name; } // template -void +std::string test_hnsw_range(const knowhere::DataSetPtr& default_ds_ptr, const knowhere::DataSetPtr& query_ds_ptr, const knowhere::DataSetPtr& golden_result, const std::vector& index_params, const knowhere::Json& conf, const knowhere::BitsetView bitset_view) { @@ -381,6 +387,7 @@ test_hnsw_range(const knowhere::DataSetPtr& default_ds_ptr, const knowhere::Data match_datasets(default_t_ds_ptr, vectors.value(), ids); } + return index_file_name; } } // namespace @@ -403,6 +410,8 @@ TEST_CASE("Search for FAISS HNSW Indices", "Benchmark and validation") { const int32_t NQ = 16; const int32_t TOPK = 16; + const std::vector MV_ONLYs = {false, true}; + const std::vector SQ_TYPES = {"SQ6", "SQ8", "BF16", "FP16"}; // random bitset rates @@ -444,6 +453,7 @@ TEST_CASE("Search for FAISS HNSW Indices", "Benchmark and validation") { default_conf[knowhere::indexparam::EFCONSTRUCTION] = 96; default_conf[knowhere::indexparam::EF] = 64; default_conf[knowhere::meta::TOPK] = TOPK; + default_conf[knowhere::meta::MV_ONLY_ENABLED] = false; // create golden indices for search { @@ -506,53 +516,93 @@ TEST_CASE("Search for FAISS HNSW Indices", "Benchmark and validation") { auto golden_index = create_index(golden_index_type, golden_index_file_name, default_ds_ptr, conf, "golden "); - // test various bitset rates - for (const float bitset_rate : BITSET_RATES) { - const int32_t nbits_set = nb * bitset_rate; - const std::vector bitset_data = GenerateBitsetWithRandomTbitsSet(nb, nbits_set); + std::vector> scalar_info = GenerateScalarInfo(nb); + auto partition_size = scalar_info[0].size(); // will be masked by partition key value - // initialize bitset_view. - // provide a default one if nbits_set == 0 - knowhere::BitsetView bitset_view = nullptr; - if (nbits_set != 0) { - bitset_view = knowhere::BitsetView(bitset_data.data(), nb, nb - nbits_set); + for (const bool mv_only_enable : MV_ONLYs) { +#ifdef KNOWHERE_WITH_CARDINAL + if (mv_only_enable) { + continue; + } +#endif + printf("with mv only enabled : %d\n", mv_only_enable); + if (mv_only_enable) { + default_ds_ptr->Set(knowhere::meta::SCALAR_INFO, scalar_info); } - // get a golden result - auto golden_result = golden_index.Search(query_ds_ptr, conf, bitset_view); + std::vector index_files; + std::string index_file; + + // test various bitset rates + for (const float bitset_rate : BITSET_RATES) { + const int32_t nbits_set = mv_only_enable + ? partition_size + (nb - partition_size) * bitset_rate + : nb * bitset_rate; + const std::vector bitset_data = + mv_only_enable ? GenerateBitsetByScalarInfoAndFirstTBits(scalar_info[0], nb, 0) + : GenerateBitsetWithRandomTbitsSet(nb, nbits_set); + + // initialize bitset_view. + // provide a default one if nbits_set == 0 + knowhere::BitsetView bitset_view = nullptr; + if (nbits_set != 0) { + bitset_view = knowhere::BitsetView(bitset_data.data(), nb, nb - nbits_set); + } - // fp32 candidate - printf( - "\nProcessing HNSW,Flat fp32 for %s distance, dim=%d, nrows=%d, %d%% points filtered out\n", - DISTANCE_TYPES[distance_type].c_str(), dim, nb, int(bitset_rate * 100)); + // get a golden result + auto golden_result = golden_index.Search(query_ds_ptr, conf, bitset_view); - test_hnsw(default_ds_ptr, query_ds_ptr, golden_result.value(), params, conf, - bitset_view); + // fp32 candidate + printf( + "\nProcessing HNSW,Flat fp32 for %s distance, dim=%d, nrows=%d, %d%% points filtered " + "out\n", + DISTANCE_TYPES[distance_type].c_str(), dim, nb, int(bitset_rate * 100)); - // fp16 candidate - printf( - "\nProcessing HNSW,Flat fp16 for %s distance, dim=%d, nrows=%d, %d%% points filtered out\n", - DISTANCE_TYPES[distance_type].c_str(), dim, nb, int(bitset_rate * 100)); + index_file = test_hnsw(default_ds_ptr, query_ds_ptr, golden_result.value(), + params, conf, bitset_view); + index_files.emplace_back(index_file); - test_hnsw(default_ds_ptr, query_ds_ptr, golden_result.value(), params, conf, - bitset_view); + // fp16 candidate + printf( + "\nProcessing HNSW,Flat fp16 for %s distance, dim=%d, nrows=%d, %d%% points filtered " + "out\n", + DISTANCE_TYPES[distance_type].c_str(), dim, nb, int(bitset_rate * 100)); - // bf16 candidate - printf( - "\nProcessing HNSW,Flat bf16 for %s distance, dim=%d, nrows=%d, %d%% points filtered out\n", - DISTANCE_TYPES[distance_type].c_str(), dim, nb, int(bitset_rate * 100)); + index_file = test_hnsw(default_ds_ptr, query_ds_ptr, golden_result.value(), + params, conf, bitset_view); + index_files.emplace_back(index_file); - test_hnsw(default_ds_ptr, query_ds_ptr, golden_result.value(), params, conf, - bitset_view); - if (index_support_int8(conf)) { - // int8 candidate + // bf16 candidate printf( - "\nProcessing HNSW,Flat int8 for %s distance, dim=%d, nrows=%d, %d%% points filtered " + "\nProcessing HNSW,Flat bf16 for %s distance, dim=%d, nrows=%d, %d%% points filtered " "out\n", DISTANCE_TYPES[distance_type].c_str(), dim, nb, int(bitset_rate * 100)); - test_hnsw(default_ds_ptr, query_ds_ptr, golden_result.value(), params, conf, - bitset_view); + index_file = test_hnsw(default_ds_ptr, query_ds_ptr, golden_result.value(), + params, conf, bitset_view); + index_files.emplace_back(index_file); + if (index_support_int8(conf)) { + // int8 candidate + printf( + "\nProcessing HNSW,Flat int8 for %s distance, dim=%d, nrows=%d, %d%% points " + "filtered " + "out\n", + DISTANCE_TYPES[distance_type].c_str(), dim, nb, int(bitset_rate * 100)); + + index_file = test_hnsw( + default_ds_ptr, query_ds_ptr, golden_result.value(), params, conf, bitset_view); + index_files.emplace_back(index_file); + } + + std::remove(get_index_name(ann_test_name_, index_type, params).c_str()); + std::remove(get_index_name(ann_test_name_, index_type, params).c_str()); + std::remove(get_index_name(ann_test_name_, index_type, params).c_str()); + if (index_support_int8(conf)) { + std::remove(get_index_name(ann_test_name_, index_type, params).c_str()); + } + } + for (auto index : index_files) { + std::remove(index.c_str()); } } } @@ -590,153 +640,186 @@ TEST_CASE("Search for FAISS HNSW Indices", "Benchmark and validation") { auto golden_index = create_index(golden_index_type, golden_index_file_name, default_ds_ptr, conf_golden, "golden "); - // test various bitset rates - for (const float bitset_rate : BITSET_RATES) { - const int32_t nbits_set = nb * bitset_rate; - const std::vector bitset_data = GenerateBitsetWithRandomTbitsSet(nb, nbits_set); + std::vector> scalar_info = GenerateScalarInfo(nb); + auto partition_size = scalar_info[0].size(); // will be masked by partition key value - // initialize bitset_view. - // provide a default one if nbits_set == 0 - knowhere::BitsetView bitset_view = nullptr; - if (nbits_set != 0) { - bitset_view = knowhere::BitsetView(bitset_data.data(), nb, nb - nbits_set); + for (const bool mv_only_enable : MV_ONLYs) { + printf("with mv only enabled : %d\n", mv_only_enable); + if (mv_only_enable) { + default_ds_ptr->Set(knowhere::meta::SCALAR_INFO, scalar_info); } + std::vector index_files; + std::string index_file; - // get a golden result - auto golden_result = golden_index.Search(query_ds_ptr, conf_golden, bitset_view); - - // go SQ - for (size_t i_sq_type = 0; i_sq_type < SQ_TYPES.size(); i_sq_type++) { - knowhere::Json conf = conf_golden; - conf[knowhere::meta::INDEX_TYPE] = index_type; + // test various bitset rates + for (const float bitset_rate : BITSET_RATES) { + const int32_t nbits_set = mv_only_enable + ? partition_size + (nb - partition_size) * bitset_rate + : nb * bitset_rate; + const std::vector bitset_data = + mv_only_enable ? GenerateBitsetByScalarInfoAndFirstTBits(scalar_info[0], nb, 0) + : GenerateBitsetWithRandomTbitsSet(nb, nbits_set); - const std::string sq_type = SQ_TYPES[i_sq_type]; - conf[knowhere::indexparam::SQ_TYPE] = sq_type; + // initialize bitset_view. + // provide a default one if nbits_set == 0 + knowhere::BitsetView bitset_view = nullptr; + if (nbits_set != 0) { + bitset_view = knowhere::BitsetView(bitset_data.data(), nb, nb - nbits_set); + } - std::vector params = {(int)distance_type, dim, nb, (int)i_sq_type}; + // get a golden result + auto golden_result = golden_index.Search(query_ds_ptr, conf_golden, bitset_view); - // fp32 candidate - printf( - "\nProcessing HNSW,SQ(%s) fp32 for %s distance, dim=%d, nrows=%d, %d%% points filtered " - "out\n", - sq_type.c_str(), DISTANCE_TYPES[distance_type].c_str(), dim, nb, - int(bitset_rate * 100)); + // go SQ + for (size_t i_sq_type = 0; i_sq_type < SQ_TYPES.size(); i_sq_type++) { + knowhere::Json conf = conf_golden; + conf[knowhere::meta::INDEX_TYPE] = index_type; - test_hnsw(default_ds_ptr, query_ds_ptr, golden_result.value(), params, conf, - bitset_view); + const std::string sq_type = SQ_TYPES[i_sq_type]; + conf[knowhere::indexparam::SQ_TYPE] = sq_type; - // fp16 candidate - printf( - "\nProcessing HNSW,SQ(%s) fp16 for %s distance, dim=%d, nrows=%d, %d%% points filtered " - "out\n", - sq_type.c_str(), DISTANCE_TYPES[distance_type].c_str(), dim, nb, - int(bitset_rate * 100)); + std::vector params = {(int)distance_type, dim, nb, (int)i_sq_type}; - test_hnsw(default_ds_ptr, query_ds_ptr, golden_result.value(), params, conf, - bitset_view); + // fp32 candidate + printf( + "\nProcessing HNSW,SQ(%s) fp32 for %s distance, dim=%d, nrows=%d, %d%% points " + "filtered " + "out\n", + sq_type.c_str(), DISTANCE_TYPES[distance_type].c_str(), dim, nb, + int(bitset_rate * 100)); - // bf16 candidate - printf( - "\nProcessing HNSW,SQ(%s) bf16 for %s distance, dim=%d, nrows=%d, %d%% points filtered " - "out\n", - sq_type.c_str(), DISTANCE_TYPES[distance_type].c_str(), dim, nb, - int(bitset_rate * 100)); + index_file = test_hnsw( + default_ds_ptr, query_ds_ptr, golden_result.value(), params, conf, bitset_view); + index_files.emplace_back(index_file); - test_hnsw(default_ds_ptr, query_ds_ptr, golden_result.value(), params, conf, - bitset_view); - if (index_support_int8(conf)) { - // int8 candidate + // fp16 candidate printf( - "\nProcessing HNSW,SQ(%s) int8 for %s distance, dim=%d, nrows=%d, %d%% points " + "\nProcessing HNSW,SQ(%s) fp16 for %s distance, dim=%d, nrows=%d, %d%% points " "filtered " "out\n", sq_type.c_str(), DISTANCE_TYPES[distance_type].c_str(), dim, nb, int(bitset_rate * 100)); - test_hnsw(default_ds_ptr, query_ds_ptr, golden_result.value(), params, - conf, bitset_view); - } - // test refines for FP32 - { - const auto& allowed_refs = SQ_ALLOWED_REFINES_FP32[sq_type]; - for (size_t allowed_ref_idx = 0; allowed_ref_idx < allowed_refs.size(); - allowed_ref_idx++) { - auto conf_refine = conf; - conf_refine["refine"] = true; - conf_refine["refine_k"] = 1.5; + index_file = test_hnsw( + default_ds_ptr, query_ds_ptr, golden_result.value(), params, conf, bitset_view); + index_files.emplace_back(index_file); - const std::string allowed_ref = allowed_refs[allowed_ref_idx]; - conf_refine["refine_type"] = allowed_ref; + // bf16 candidate + printf( + "\nProcessing HNSW,SQ(%s) bf16 for %s distance, dim=%d, nrows=%d, %d%% points " + "filtered " + "out\n", + sq_type.c_str(), DISTANCE_TYPES[distance_type].c_str(), dim, nb, + int(bitset_rate * 100)); - std::vector params_refine = {(int)distance_type, dim, nb, (int)i_sq_type, - (int)allowed_ref_idx}; + index_file = test_hnsw( + default_ds_ptr, query_ds_ptr, golden_result.value(), params, conf, bitset_view); + index_files.emplace_back(index_file); - // fp32 candidate + if (index_support_int8(conf)) { + // int8 candidate printf( - "\nProcessing HNSW,SQ(%s) with %s refine, fp32 for %s distance, dim=%d, " - "nrows=%d, %d%% points filtered out\n", - sq_type.c_str(), allowed_ref.c_str(), DISTANCE_TYPES[distance_type].c_str(), - dim, nb, int(bitset_rate * 100)); + "\nProcessing HNSW,SQ(%s) int8 for %s distance, dim=%d, nrows=%d, %d%% points " + "filtered " + "out\n", + sq_type.c_str(), DISTANCE_TYPES[distance_type].c_str(), dim, nb, + int(bitset_rate * 100)); - test_hnsw(default_ds_ptr, query_ds_ptr, golden_result.value(), - params_refine, conf_refine, bitset_view); + index_file = test_hnsw( + default_ds_ptr, query_ds_ptr, golden_result.value(), params, conf, bitset_view); + index_files.emplace_back(index_file); } - } + // test refines for FP32 + { + const auto& allowed_refs = SQ_ALLOWED_REFINES_FP32[sq_type]; + for (size_t allowed_ref_idx = 0; allowed_ref_idx < allowed_refs.size(); + allowed_ref_idx++) { + auto conf_refine = conf; + conf_refine["refine"] = true; + conf_refine["refine_k"] = 1.5; - // test refines for FP16 - { - const auto& allowed_refs = SQ_ALLOWED_REFINES_FP16[sq_type]; - for (size_t allowed_ref_idx = 0; allowed_ref_idx < allowed_refs.size(); - allowed_ref_idx++) { - auto conf_refine = conf; - conf_refine["refine"] = true; - conf_refine["refine_k"] = 1.5; + const std::string allowed_ref = allowed_refs[allowed_ref_idx]; + conf_refine["refine_type"] = allowed_ref; - const std::string allowed_ref = allowed_refs[allowed_ref_idx]; - conf_refine["refine_type"] = allowed_ref; + std::vector params_refine = {(int)distance_type, dim, nb, + (int)i_sq_type, (int)allowed_ref_idx}; + + // fp32 candidate + printf( + "\nProcessing HNSW,SQ(%s) with %s refine, fp32 for %s distance, dim=%d, " + "nrows=%d, %d%% points filtered out\n", + sq_type.c_str(), allowed_ref.c_str(), DISTANCE_TYPES[distance_type].c_str(), + dim, nb, int(bitset_rate * 100)); - std::vector params_refine = {(int)distance_type, dim, nb, (int)i_sq_type, - (int)allowed_ref_idx}; + index_file = test_hnsw(default_ds_ptr, query_ds_ptr, + golden_result.value(), params_refine, + conf_refine, bitset_view); + index_files.emplace_back(index_file); + } + } - // fp16 candidate - printf( - "\nProcessing HNSW,SQ(%s) with %s refine, fp16 for %s distance, dim=%d, " - "nrows=%d, %d%% points filtered out\n", - sq_type.c_str(), allowed_ref.c_str(), DISTANCE_TYPES[distance_type].c_str(), - dim, nb, int(bitset_rate * 100)); + // test refines for FP16 + { + const auto& allowed_refs = SQ_ALLOWED_REFINES_FP16[sq_type]; + for (size_t allowed_ref_idx = 0; allowed_ref_idx < allowed_refs.size(); + allowed_ref_idx++) { + auto conf_refine = conf; + conf_refine["refine"] = true; + conf_refine["refine_k"] = 1.5; - test_hnsw(default_ds_ptr, query_ds_ptr, golden_result.value(), - params_refine, conf_refine, bitset_view); + const std::string allowed_ref = allowed_refs[allowed_ref_idx]; + conf_refine["refine_type"] = allowed_ref; + + std::vector params_refine = {(int)distance_type, dim, nb, + (int)i_sq_type, (int)allowed_ref_idx}; + + // fp16 candidate + printf( + "\nProcessing HNSW,SQ(%s) with %s refine, fp16 for %s distance, dim=%d, " + "nrows=%d, %d%% points filtered out\n", + sq_type.c_str(), allowed_ref.c_str(), DISTANCE_TYPES[distance_type].c_str(), + dim, nb, int(bitset_rate * 100)); + + index_file = test_hnsw(default_ds_ptr, query_ds_ptr, + golden_result.value(), params_refine, + conf_refine, bitset_view); + index_files.emplace_back(index_file); + } } - } - // test refines for BF16 - { - const auto& allowed_refs = SQ_ALLOWED_REFINES_BF16[sq_type]; - for (size_t allowed_ref_idx = 0; allowed_ref_idx < allowed_refs.size(); - allowed_ref_idx++) { - auto conf_refine = conf; - conf_refine["refine"] = true; - conf_refine["refine_k"] = 1.5; + // test refines for BF16 + { + const auto& allowed_refs = SQ_ALLOWED_REFINES_BF16[sq_type]; + for (size_t allowed_ref_idx = 0; allowed_ref_idx < allowed_refs.size(); + allowed_ref_idx++) { + auto conf_refine = conf; + conf_refine["refine"] = true; + conf_refine["refine_k"] = 1.5; - const std::string allowed_ref = allowed_refs[allowed_ref_idx]; - conf_refine["refine_type"] = allowed_ref; + const std::string allowed_ref = allowed_refs[allowed_ref_idx]; + conf_refine["refine_type"] = allowed_ref; - std::vector params_refine = {(int)distance_type, dim, nb, (int)i_sq_type, - (int)allowed_ref_idx}; + std::vector params_refine = {(int)distance_type, dim, nb, + (int)i_sq_type, (int)allowed_ref_idx}; - // bf16 candidate - printf( - "\nProcessing HNSW,SQ(%s) with %s refine, bf16 for %s distance, dim=%d, " - "nrows=%d, %d%% points filtered out\n", - sq_type.c_str(), allowed_ref.c_str(), DISTANCE_TYPES[distance_type].c_str(), - dim, nb, int(bitset_rate * 100)); + // bf16 candidate + printf( + "\nProcessing HNSW,SQ(%s) with %s refine, bf16 for %s distance, dim=%d, " + "nrows=%d, %d%% points filtered out\n", + sq_type.c_str(), allowed_ref.c_str(), DISTANCE_TYPES[distance_type].c_str(), + dim, nb, int(bitset_rate * 100)); - test_hnsw(default_ds_ptr, query_ds_ptr, golden_result.value(), - params_refine, conf_refine, bitset_view); + index_file = test_hnsw(default_ds_ptr, query_ds_ptr, + golden_result.value(), params_refine, + conf_refine, bitset_view); + index_files.emplace_back(index_file); + } } } } + for (auto index : index_files) { + std::remove(index.c_str()); + } } } } @@ -773,144 +856,176 @@ TEST_CASE("Search for FAISS HNSW Indices", "Benchmark and validation") { auto golden_index = create_index(golden_index_type, golden_index_file_name, default_ds_ptr, conf_golden, "golden "); - // test various bitset rates - for (const float bitset_rate : BITSET_RATES) { - const int32_t nbits_set = nb * bitset_rate; - const std::vector bitset_data = GenerateBitsetWithRandomTbitsSet(nb, nbits_set); + std::vector> scalar_info = GenerateScalarInfo(nb); + auto partition_size = scalar_info[0].size(); // will be masked by partition key value - // initialize bitset_view. - // provide a default one if nbits_set == 0 - knowhere::BitsetView bitset_view = nullptr; - if (nbits_set != 0) { - bitset_view = knowhere::BitsetView(bitset_data.data(), nb, nb - nbits_set); + for (const bool mv_only_enable : MV_ONLYs) { + printf("with mv only enabled : %d\n", mv_only_enable); + if (mv_only_enable) { + default_ds_ptr->Set(knowhere::meta::SCALAR_INFO, scalar_info); } + std::vector index_files; + std::string index_file; - // get a golden result - auto golden_result = golden_index.Search(query_ds_ptr, conf_golden, bitset_view); + // test various bitset rates + for (const float bitset_rate : BITSET_RATES) { + const int32_t nbits_set = mv_only_enable + ? partition_size + (nb - partition_size) * bitset_rate + : nb * bitset_rate; + const std::vector bitset_data = + mv_only_enable ? GenerateBitsetByScalarInfoAndFirstTBits(scalar_info[0], nb, 0) + : GenerateBitsetWithRandomTbitsSet(nb, nbits_set); - // go PQ - for (size_t nbits_type = 0; nbits_type < NBITS.size(); nbits_type++) { - const int pq_m = 8; + // initialize bitset_view. + // provide a default one if nbits_set == 0 + knowhere::BitsetView bitset_view = nullptr; + if (nbits_set != 0) { + bitset_view = knowhere::BitsetView(bitset_data.data(), nb, nb - nbits_set); + } - knowhere::Json conf = conf_golden; - conf[knowhere::meta::INDEX_TYPE] = index_type; - conf[knowhere::indexparam::NBITS] = NBITS[nbits_type]; - conf[knowhere::indexparam::M] = pq_m; + // get a golden result + auto golden_result = golden_index.Search(query_ds_ptr, conf_golden, bitset_view); - std::vector params = {(int)distance_type, dim, nb, pq_m, (int)nbits_type}; + // go PQ + for (size_t nbits_type = 0; nbits_type < NBITS.size(); nbits_type++) { + const int pq_m = 8; - // test fp32 candidate - printf( - "\nProcessing HNSW,PQ%dx%d fp32 for %s distance, dim=%d, nrows=%d, %d%% points " - "filtered out\n", - pq_m, NBITS[nbits_type], DISTANCE_TYPES[distance_type].c_str(), dim, nb, - int(bitset_rate * 100)); + knowhere::Json conf = conf_golden; + conf[knowhere::meta::INDEX_TYPE] = index_type; + conf[knowhere::indexparam::NBITS] = NBITS[nbits_type]; + conf[knowhere::indexparam::M] = pq_m; - test_hnsw(default_ds_ptr, query_ds_ptr, golden_result.value(), params, conf, - bitset_view); + std::vector params = {(int)distance_type, dim, nb, pq_m, (int)nbits_type}; - // test fp16 candidate - printf( - "\nProcessing HNSW,PQ%dx%d fp16 for %s distance, dim=%d, nrows=%d, %d%% points " - "filtered out\n", - pq_m, NBITS[nbits_type], DISTANCE_TYPES[distance_type].c_str(), dim, nb, - int(bitset_rate * 100)); + // test fp32 candidate + printf( + "\nProcessing HNSW,PQ%dx%d fp32 for %s distance, dim=%d, nrows=%d, %d%% points " + "filtered out\n", + pq_m, NBITS[nbits_type], DISTANCE_TYPES[distance_type].c_str(), dim, nb, + int(bitset_rate * 100)); - test_hnsw(default_ds_ptr, query_ds_ptr, golden_result.value(), params, conf, - bitset_view); + index_file = test_hnsw( + default_ds_ptr, query_ds_ptr, golden_result.value(), params, conf, bitset_view); + index_files.emplace_back(index_file); - // test bf16 candidate - printf( - "\nProcessing HNSW,PQ%dx%d bf16 for %s distance, dim=%d, nrows=%d, %d%% points " - "filtered out\n", - pq_m, NBITS[nbits_type], DISTANCE_TYPES[distance_type].c_str(), dim, nb, - int(bitset_rate * 100)); + // test fp16 candidate + printf( + "\nProcessing HNSW,PQ%dx%d fp16 for %s distance, dim=%d, nrows=%d, %d%% points " + "filtered out\n", + pq_m, NBITS[nbits_type], DISTANCE_TYPES[distance_type].c_str(), dim, nb, + int(bitset_rate * 100)); - test_hnsw(default_ds_ptr, query_ds_ptr, golden_result.value(), params, conf, - bitset_view); - if (index_support_int8(conf)) { - // test int8 candidate + index_file = test_hnsw( + default_ds_ptr, query_ds_ptr, golden_result.value(), params, conf, bitset_view); + index_files.emplace_back(index_file); + + // test bf16 candidate printf( - "\nProcessing HNSW,PQ%dx%d int8 for %s distance, dim=%d, nrows=%d, %d%% points " + "\nProcessing HNSW,PQ%dx%d bf16 for %s distance, dim=%d, nrows=%d, %d%% points " "filtered out\n", pq_m, NBITS[nbits_type], DISTANCE_TYPES[distance_type].c_str(), dim, nb, int(bitset_rate * 100)); - test_hnsw(default_ds_ptr, query_ds_ptr, golden_result.value(), params, - conf, bitset_view); - } - // test refines for fp32 - for (size_t allowed_ref_idx = 0; allowed_ref_idx < PQ_ALLOWED_REFINES_FP32.size(); - allowed_ref_idx++) { - auto conf_refine = conf; - conf_refine["refine"] = true; - conf_refine["refine_k"] = 1.5; + index_file = test_hnsw( + default_ds_ptr, query_ds_ptr, golden_result.value(), params, conf, bitset_view); + index_files.emplace_back(index_file); + if (index_support_int8(conf)) { + // test int8 candidate + printf( + "\nProcessing HNSW,PQ%dx%d int8 for %s distance, dim=%d, nrows=%d, %d%% points " + "filtered out\n", + pq_m, NBITS[nbits_type], DISTANCE_TYPES[distance_type].c_str(), dim, nb, + int(bitset_rate * 100)); - const std::string allowed_ref = PQ_ALLOWED_REFINES_FP32[allowed_ref_idx]; - conf_refine["refine_type"] = allowed_ref; + index_file = test_hnsw( + default_ds_ptr, query_ds_ptr, golden_result.value(), params, conf, bitset_view); + index_files.emplace_back(index_file); + } + // test refines for fp32 + for (size_t allowed_ref_idx = 0; allowed_ref_idx < PQ_ALLOWED_REFINES_FP32.size(); + allowed_ref_idx++) { + auto conf_refine = conf; + conf_refine["refine"] = true; + conf_refine["refine_k"] = 1.5; - std::vector params_refine = { - (int)distance_type, dim, nb, pq_m, (int)nbits_type, (int)allowed_ref_idx}; + const std::string allowed_ref = PQ_ALLOWED_REFINES_FP32[allowed_ref_idx]; + conf_refine["refine_type"] = allowed_ref; - // test fp32 candidate - printf( - "\nProcessing HNSW,PQ%dx%d with %s refine, fp32 for %s distance, dim=%d, nrows=%d, " - "%d%% points filtered out\n", - pq_m, NBITS[nbits_type], allowed_ref.c_str(), DISTANCE_TYPES[distance_type].c_str(), - dim, nb, int(bitset_rate * 100)); + std::vector params_refine = { + (int)distance_type, dim, nb, pq_m, (int)nbits_type, (int)allowed_ref_idx}; - test_hnsw(default_ds_ptr, query_ds_ptr, golden_result.value(), - params_refine, conf_refine, bitset_view); - } + // test fp32 candidate + printf( + "\nProcessing HNSW,PQ%dx%d with %s refine, fp32 for %s distance, dim=%d, " + "nrows=%d, " + "%d%% points filtered out\n", + pq_m, NBITS[nbits_type], allowed_ref.c_str(), + DISTANCE_TYPES[distance_type].c_str(), dim, nb, int(bitset_rate * 100)); - // test refines for fp16 - for (size_t allowed_ref_idx = 0; allowed_ref_idx < PQ_ALLOWED_REFINES_FP16.size(); - allowed_ref_idx++) { - auto conf_refine = conf; - conf_refine["refine"] = true; - conf_refine["refine_k"] = 1.5; + index_file = + test_hnsw(default_ds_ptr, query_ds_ptr, golden_result.value(), + params_refine, conf_refine, bitset_view); + index_files.emplace_back(index_file); + } - const std::string allowed_ref = PQ_ALLOWED_REFINES_FP16[allowed_ref_idx]; - conf_refine["refine_type"] = allowed_ref; + // test refines for fp16 + for (size_t allowed_ref_idx = 0; allowed_ref_idx < PQ_ALLOWED_REFINES_FP16.size(); + allowed_ref_idx++) { + auto conf_refine = conf; + conf_refine["refine"] = true; + conf_refine["refine_k"] = 1.5; - std::vector params_refine = { - (int)distance_type, dim, nb, pq_m, (int)nbits_type, (int)allowed_ref_idx}; + const std::string allowed_ref = PQ_ALLOWED_REFINES_FP16[allowed_ref_idx]; + conf_refine["refine_type"] = allowed_ref; - // test fp16 candidate - printf( - "\nProcessing HNSW,PQ%dx%d with %s refine, fp16 for %s distance, dim=%d, nrows=%d, " - "%d%% points filtered out\n", - pq_m, NBITS[nbits_type], allowed_ref.c_str(), DISTANCE_TYPES[distance_type].c_str(), - dim, nb, int(bitset_rate * 100)); + std::vector params_refine = { + (int)distance_type, dim, nb, pq_m, (int)nbits_type, (int)allowed_ref_idx}; - test_hnsw(default_ds_ptr, query_ds_ptr, golden_result.value(), - params_refine, conf_refine, bitset_view); - } + // test fp16 candidate + printf( + "\nProcessing HNSW,PQ%dx%d with %s refine, fp16 for %s distance, dim=%d, " + "nrows=%d, " + "%d%% points filtered out\n", + pq_m, NBITS[nbits_type], allowed_ref.c_str(), + DISTANCE_TYPES[distance_type].c_str(), dim, nb, int(bitset_rate * 100)); - // test refines for bf16 - for (size_t allowed_ref_idx = 0; allowed_ref_idx < PQ_ALLOWED_REFINES_BF16.size(); - allowed_ref_idx++) { - auto conf_refine = conf; - conf_refine["refine"] = true; - conf_refine["refine_k"] = 1.5; + index_file = + test_hnsw(default_ds_ptr, query_ds_ptr, golden_result.value(), + params_refine, conf_refine, bitset_view); + index_files.emplace_back(index_file); + } - const std::string allowed_ref = PQ_ALLOWED_REFINES_BF16[allowed_ref_idx]; - conf_refine["refine_type"] = allowed_ref; + // test refines for bf16 + for (size_t allowed_ref_idx = 0; allowed_ref_idx < PQ_ALLOWED_REFINES_BF16.size(); + allowed_ref_idx++) { + auto conf_refine = conf; + conf_refine["refine"] = true; + conf_refine["refine_k"] = 1.5; - std::vector params_refine = { - (int)distance_type, dim, nb, pq_m, (int)nbits_type, (int)allowed_ref_idx}; + const std::string allowed_ref = PQ_ALLOWED_REFINES_BF16[allowed_ref_idx]; + conf_refine["refine_type"] = allowed_ref; - // test bf16 candidate - printf( - "\nProcessing HNSW,PQ%dx%d with %s refine, bf16 for %s distance, dim=%d, nrows=%d, " - "%d%% points filtered out\n", - pq_m, NBITS[nbits_type], allowed_ref.c_str(), DISTANCE_TYPES[distance_type].c_str(), - dim, nb, int(bitset_rate * 100)); + std::vector params_refine = { + (int)distance_type, dim, nb, pq_m, (int)nbits_type, (int)allowed_ref_idx}; - test_hnsw(default_ds_ptr, query_ds_ptr, golden_result.value(), - params_refine, conf_refine, bitset_view); - } - } + // test bf16 candidate + printf( + "\nProcessing HNSW,PQ%dx%d with %s refine, bf16 for %s distance, dim=%d, " + "nrows=%d, " + "%d%% points filtered out\n", + pq_m, NBITS[nbits_type], allowed_ref.c_str(), + DISTANCE_TYPES[distance_type].c_str(), dim, nb, int(bitset_rate * 100)); + + index_file = + test_hnsw(default_ds_ptr, query_ds_ptr, golden_result.value(), + params_refine, conf_refine, bitset_view); + index_files.emplace_back(index_file); + } + } + } + for (auto index : index_files) { + std::remove(index.c_str()); + } } } } @@ -946,151 +1061,184 @@ TEST_CASE("Search for FAISS HNSW Indices", "Benchmark and validation") { auto golden_index = create_index(golden_index_type, golden_index_file_name, default_ds_ptr, conf_golden, "golden "); - // test various bitset rates - for (const float bitset_rate : BITSET_RATES) { - const int32_t nbits_set = nb * bitset_rate; - const std::vector bitset_data = GenerateBitsetWithRandomTbitsSet(nb, nbits_set); - - // initialize bitset_view. - // provide a default one if nbits_set == 0 - knowhere::BitsetView bitset_view = nullptr; - if (nbits_set != 0) { - bitset_view = knowhere::BitsetView(bitset_data.data(), nb, nb - nbits_set); + std::vector> scalar_info = GenerateScalarInfo(nb); + auto partition_size = scalar_info[0].size(); // will be masked by partition key value + + for (const bool mv_only_enable : MV_ONLYs) { + printf("with mv only enabled : %d\n", mv_only_enable); + if (mv_only_enable) { + default_ds_ptr->Set(knowhere::meta::SCALAR_INFO, scalar_info); } + std::vector index_files; + std::string index_file; - // get a golden result - auto golden_result = golden_index.Search(query_ds_ptr, conf_golden, bitset_view); + // test various bitset rates + for (const float bitset_rate : BITSET_RATES) { + const int32_t nbits_set = mv_only_enable + ? partition_size + (nb - partition_size) * bitset_rate + : nb * bitset_rate; + const std::vector bitset_data = + mv_only_enable ? GenerateBitsetByScalarInfoAndFirstTBits(scalar_info[0], nb, 0) + : GenerateBitsetWithRandomTbitsSet(nb, nbits_set); - // go PRQ - for (size_t nbits_type = 0; nbits_type < NBITS.size(); nbits_type++) { - const int prq_m = 4; - const int prq_num = 2; + // initialize bitset_view. + // provide a default one if nbits_set == 0 + knowhere::BitsetView bitset_view = nullptr; + if (nbits_set != 0) { + bitset_view = knowhere::BitsetView(bitset_data.data(), nb, nb - nbits_set); + } - knowhere::Json conf = conf_golden; - conf[knowhere::meta::INDEX_TYPE] = index_type; - conf[knowhere::indexparam::NBITS] = NBITS[nbits_type]; - conf[knowhere::indexparam::M] = prq_m; - conf[knowhere::indexparam::PRQ_NUM] = prq_num; + // get a golden result + auto golden_result = golden_index.Search(query_ds_ptr, conf_golden, bitset_view); - std::vector params = {(int)distance_type, dim, nb, prq_m, prq_num, - (int)nbits_type}; + // go PRQ + for (size_t nbits_type = 0; nbits_type < NBITS.size(); nbits_type++) { + const int prq_m = 4; + const int prq_num = 2; - // test fp32 candidate - printf( - "\nProcessing HNSW,PRQ%dx%dx%d fp32 for %s distance, dim=%d, nrows=%d, %d%% points " - "filtered out\n", - prq_num, prq_m, NBITS[nbits_type], DISTANCE_TYPES[distance_type].c_str(), dim, nb, - int(bitset_rate * 100)); + knowhere::Json conf = conf_golden; + conf[knowhere::meta::INDEX_TYPE] = index_type; + conf[knowhere::indexparam::NBITS] = NBITS[nbits_type]; + conf[knowhere::indexparam::M] = prq_m; + conf[knowhere::indexparam::PRQ_NUM] = prq_num; - test_hnsw(default_ds_ptr, query_ds_ptr, golden_result.value(), params, conf, - bitset_view); + std::vector params = {(int)distance_type, dim, nb, prq_m, prq_num, + (int)nbits_type}; - // test fp16 candidate - printf( - "\nProcessing HNSW,PRQ%dx%dx%d fp16 for %s distance, dim=%d, nrows=%d, %d%% points " - "filtered out\n", - prq_num, prq_m, NBITS[nbits_type], DISTANCE_TYPES[distance_type].c_str(), dim, nb, - int(bitset_rate * 100)); + // test fp32 candidate + printf( + "\nProcessing HNSW,PRQ%dx%dx%d fp32 for %s distance, dim=%d, nrows=%d, %d%% points " + "filtered out\n", + prq_num, prq_m, NBITS[nbits_type], DISTANCE_TYPES[distance_type].c_str(), dim, nb, + int(bitset_rate * 100)); - test_hnsw(default_ds_ptr, query_ds_ptr, golden_result.value(), params, conf, - bitset_view); + index_file = test_hnsw( + default_ds_ptr, query_ds_ptr, golden_result.value(), params, conf, bitset_view); + index_files.emplace_back(index_file); - // test bf16 candidate - printf( - "\nProcessing HNSW,PRQ%dx%dx%d bf16 for %s distance, dim=%d, nrows=%d, %d%% points " - "filtered out\n", - prq_num, prq_m, NBITS[nbits_type], DISTANCE_TYPES[distance_type].c_str(), dim, nb, - int(bitset_rate * 100)); + // test fp16 candidate + printf( + "\nProcessing HNSW,PRQ%dx%dx%d fp16 for %s distance, dim=%d, nrows=%d, %d%% points " + "filtered out\n", + prq_num, prq_m, NBITS[nbits_type], DISTANCE_TYPES[distance_type].c_str(), dim, nb, + int(bitset_rate * 100)); - test_hnsw(default_ds_ptr, query_ds_ptr, golden_result.value(), params, conf, - bitset_view); + index_file = test_hnsw( + default_ds_ptr, query_ds_ptr, golden_result.value(), params, conf, bitset_view); + index_files.emplace_back(index_file); - if (index_support_int8(conf)) { - // test int8 candidate + // test bf16 candidate printf( - "\nProcessing HNSW,PRQ%dx%dx%d int8 for %s distance, dim=%d, nrows=%d, %d%% points " + "\nProcessing HNSW,PRQ%dx%dx%d bf16 for %s distance, dim=%d, nrows=%d, %d%% points " "filtered out\n", prq_num, prq_m, NBITS[nbits_type], DISTANCE_TYPES[distance_type].c_str(), dim, nb, int(bitset_rate * 100)); - test_hnsw(default_ds_ptr, query_ds_ptr, golden_result.value(), params, - conf, bitset_view); - } - // test fp32 refines - for (size_t allowed_ref_idx = 0; allowed_ref_idx < PQ_ALLOWED_REFINES_FP32.size(); - allowed_ref_idx++) { - auto conf_refine = conf; - conf_refine["refine"] = true; - conf_refine["refine_k"] = 1.5; + index_file = test_hnsw( + default_ds_ptr, query_ds_ptr, golden_result.value(), params, conf, bitset_view); + index_files.emplace_back(index_file); - const std::string allowed_ref = PQ_ALLOWED_REFINES_FP32[allowed_ref_idx]; - conf_refine["refine_type"] = allowed_ref; + if (index_support_int8(conf)) { + // test int8 candidate + printf( + "\nProcessing HNSW,PRQ%dx%dx%d int8 for %s distance, dim=%d, nrows=%d, %d%% " + "points " + "filtered out\n", + prq_num, prq_m, NBITS[nbits_type], DISTANCE_TYPES[distance_type].c_str(), dim, + nb, int(bitset_rate * 100)); - std::vector params_refine = { - (int)distance_type, dim, nb, prq_m, prq_num, (int)nbits_type, (int)allowed_ref_idx}; + index_file = test_hnsw( + default_ds_ptr, query_ds_ptr, golden_result.value(), params, conf, bitset_view); + index_files.emplace_back(index_file); + } + // test fp32 refines + for (size_t allowed_ref_idx = 0; allowed_ref_idx < PQ_ALLOWED_REFINES_FP32.size(); + allowed_ref_idx++) { + auto conf_refine = conf; + conf_refine["refine"] = true; + conf_refine["refine_k"] = 1.5; - // - printf( - "\nProcessing HNSW,PRQ%dx%dx%d with %s refine, fp32 for %s distance, dim=%d, " - "nrows=%d, %d%% points filtered out\n", - prq_num, prq_m, NBITS[nbits_type], allowed_ref.c_str(), - DISTANCE_TYPES[distance_type].c_str(), dim, nb, int(bitset_rate * 100)); + const std::string allowed_ref = PQ_ALLOWED_REFINES_FP32[allowed_ref_idx]; + conf_refine["refine_type"] = allowed_ref; - // test a candidate - test_hnsw(default_ds_ptr, query_ds_ptr, golden_result.value(), - params_refine, conf_refine, bitset_view); - } + std::vector params_refine = { + (int)distance_type, dim, nb, prq_m, prq_num, (int)nbits_type, + (int)allowed_ref_idx}; - // test fp16 refines - for (size_t allowed_ref_idx = 0; allowed_ref_idx < PQ_ALLOWED_REFINES_FP16.size(); - allowed_ref_idx++) { - auto conf_refine = conf; - conf_refine["refine"] = true; - conf_refine["refine_k"] = 1.5; + // + printf( + "\nProcessing HNSW,PRQ%dx%dx%d with %s refine, fp32 for %s distance, dim=%d, " + "nrows=%d, %d%% points filtered out\n", + prq_num, prq_m, NBITS[nbits_type], allowed_ref.c_str(), + DISTANCE_TYPES[distance_type].c_str(), dim, nb, int(bitset_rate * 100)); - const std::string allowed_ref = PQ_ALLOWED_REFINES_FP16[allowed_ref_idx]; - conf_refine["refine_type"] = allowed_ref; + // test a candidate + index_file = + test_hnsw(default_ds_ptr, query_ds_ptr, golden_result.value(), + params_refine, conf_refine, bitset_view); + index_files.emplace_back(index_file); + } - std::vector params_refine = { - (int)distance_type, dim, nb, prq_m, prq_num, (int)nbits_type, (int)allowed_ref_idx}; + // test fp16 refines + for (size_t allowed_ref_idx = 0; allowed_ref_idx < PQ_ALLOWED_REFINES_FP16.size(); + allowed_ref_idx++) { + auto conf_refine = conf; + conf_refine["refine"] = true; + conf_refine["refine_k"] = 1.5; - // - printf( - "\nProcessing HNSW,PRQ%dx%dx%d with %s refine, fp16 for %s distance, dim=%d, " - "nrows=%d, %d%% points filtered out\n", - prq_num, prq_m, NBITS[nbits_type], allowed_ref.c_str(), - DISTANCE_TYPES[distance_type].c_str(), dim, nb, int(bitset_rate * 100)); + const std::string allowed_ref = PQ_ALLOWED_REFINES_FP16[allowed_ref_idx]; + conf_refine["refine_type"] = allowed_ref; - // test a candidate - test_hnsw(default_ds_ptr, query_ds_ptr, golden_result.value(), - params_refine, conf_refine, bitset_view); - } + std::vector params_refine = { + (int)distance_type, dim, nb, prq_m, prq_num, (int)nbits_type, + (int)allowed_ref_idx}; + + // + printf( + "\nProcessing HNSW,PRQ%dx%dx%d with %s refine, fp16 for %s distance, dim=%d, " + "nrows=%d, %d%% points filtered out\n", + prq_num, prq_m, NBITS[nbits_type], allowed_ref.c_str(), + DISTANCE_TYPES[distance_type].c_str(), dim, nb, int(bitset_rate * 100)); + + // test a candidate + index_file = + test_hnsw(default_ds_ptr, query_ds_ptr, golden_result.value(), + params_refine, conf_refine, bitset_view); + index_files.emplace_back(index_file); + } - // test bf16 refines - for (size_t allowed_ref_idx = 0; allowed_ref_idx < PQ_ALLOWED_REFINES_BF16.size(); - allowed_ref_idx++) { - auto conf_refine = conf; - conf_refine["refine"] = true; - conf_refine["refine_k"] = 1.5; + // test bf16 refines + for (size_t allowed_ref_idx = 0; allowed_ref_idx < PQ_ALLOWED_REFINES_BF16.size(); + allowed_ref_idx++) { + auto conf_refine = conf; + conf_refine["refine"] = true; + conf_refine["refine_k"] = 1.5; - const std::string allowed_ref = PQ_ALLOWED_REFINES_BF16[allowed_ref_idx]; - conf_refine["refine_type"] = allowed_ref; + const std::string allowed_ref = PQ_ALLOWED_REFINES_BF16[allowed_ref_idx]; + conf_refine["refine_type"] = allowed_ref; - std::vector params_refine = { - (int)distance_type, dim, nb, prq_m, prq_num, (int)nbits_type, (int)allowed_ref_idx}; + std::vector params_refine = { + (int)distance_type, dim, nb, prq_m, prq_num, (int)nbits_type, + (int)allowed_ref_idx}; - // - printf( - "\nProcessing HNSW,PRQ%dx%dx%d with %s refine, bf16 for %s distance, dim=%d, " - "nrows=%d, %d%% points filtered out\n", - prq_num, prq_m, NBITS[nbits_type], allowed_ref.c_str(), - DISTANCE_TYPES[distance_type].c_str(), dim, nb, int(bitset_rate * 100)); + // + printf( + "\nProcessing HNSW,PRQ%dx%dx%d with %s refine, bf16 for %s distance, dim=%d, " + "nrows=%d, %d%% points filtered out\n", + prq_num, prq_m, NBITS[nbits_type], allowed_ref.c_str(), + DISTANCE_TYPES[distance_type].c_str(), dim, nb, int(bitset_rate * 100)); - // test a candidate - test_hnsw(default_ds_ptr, query_ds_ptr, golden_result.value(), - params_refine, conf_refine, bitset_view); + // test a candidate + index_file = + test_hnsw(default_ds_ptr, query_ds_ptr, golden_result.value(), + params_refine, conf_refine, bitset_view); + index_files.emplace_back(index_file); + } } } + for (auto index : index_files) { + std::remove(index.c_str()); + } } } } @@ -1123,6 +1271,8 @@ TEST_CASE("RangeSearch for FAISS HNSW Indices", "Benchmark and validation for Ra const std::vector NBS = {256}; const int32_t NQ = 16; + const std::vector MV_ONLYs = {false, true}; + const std::vector SQ_TYPES = {"SQ6", "SQ8", "BF16", "FP16"}; // random bitset rates @@ -1163,6 +1313,7 @@ TEST_CASE("RangeSearch for FAISS HNSW Indices", "Benchmark and validation for Ra default_conf[knowhere::indexparam::HNSW_M] = 16; default_conf[knowhere::indexparam::EFCONSTRUCTION] = 96; default_conf[knowhere::indexparam::EF] = 64; + default_conf[knowhere::meta::MV_ONLY_ENABLED] = false; // create golden indices for search { @@ -1192,11 +1343,10 @@ TEST_CASE("RangeSearch for FAISS HNSW Indices", "Benchmark and validation for Ra } // I'd like to have a sequential process here, because every item in the loop - // is parallelized on its own + // is parallelized on its own SECTION("FLAT") { const std::string& index_type = knowhere::IndexEnum::INDEX_HNSW; - // const std::string& index_type = knowhere::IndexEnum::INDEX_HNSW; const std::string& golden_index_type = knowhere::IndexEnum::INDEX_FAISS_IDMAP; for (size_t distance_type = 0; distance_type < DISTANCE_TYPES.size(); distance_type++) { @@ -1239,60 +1389,88 @@ TEST_CASE("RangeSearch for FAISS HNSW Indices", "Benchmark and validation for Ra auto golden_index = create_index(golden_index_type, golden_index_file_name, default_ds_ptr, conf, "golden "); - // test various bitset rates - for (const float bitset_rate : BITSET_RATES) { - const int32_t nbits_set = nb * bitset_rate; - const std::vector bitset_data = GenerateBitsetWithRandomTbitsSet(nb, nbits_set); + std::vector> scalar_info = GenerateScalarInfo(nb); + auto partition_size = scalar_info[0].size(); // will be masked by partition key value - // initialize bitset_view. - // provide a default one if nbits_set == 0 - knowhere::BitsetView bitset_view = nullptr; - if (nbits_set != 0) { - bitset_view = knowhere::BitsetView(bitset_data.data(), nb, nb - nbits_set); + for (const bool mv_only_enable : MV_ONLYs) { +#ifdef KNOWHERE_WITH_CARDINAL + if (mv_only_enable) { + continue; } +#endif + printf("with mv only enabled : %d\n", mv_only_enable); + if (mv_only_enable) { + default_ds_ptr->Set(knowhere::meta::SCALAR_INFO, scalar_info); + } + std::vector index_files; + std::string index_file; + + // test various bitset rates + for (const float bitset_rate : BITSET_RATES) { + const int32_t nbits_set = mv_only_enable + ? partition_size + (nb - partition_size) * bitset_rate + : nb * bitset_rate; + const std::vector bitset_data = + mv_only_enable ? GenerateBitsetByScalarInfoAndFirstTBits(scalar_info[0], nb, 0) + : GenerateBitsetWithRandomTbitsSet(nb, nbits_set); + + // initialize bitset_view. + // provide a default one if nbits_set == 0 + knowhere::BitsetView bitset_view = nullptr; + if (nbits_set != 0) { + bitset_view = knowhere::BitsetView(bitset_data.data(), nb, nb - nbits_set); + } - // get a golden result - auto golden_result = golden_index.RangeSearch(query_ds_ptr, conf, bitset_view); - - // fp32 candidate - printf( - "\nProcessing HNSW,Flat fp32 for %s distance, dim=%d, nrows=%d, radius=%f, " - "range_filter=%f, %d%% points filtered out\n", - DISTANCE_TYPES[distance_type].c_str(), dim, nb, radius, range_filter, - int(bitset_rate * 100)); + // get a golden result + auto golden_result = golden_index.RangeSearch(query_ds_ptr, conf, bitset_view); - test_hnsw_range(default_ds_ptr, query_ds_ptr, golden_result.value(), params, - conf, bitset_view); + // fp32 candidate + printf( + "\nProcessing HNSW,Flat fp32 for %s distance, dim=%d, nrows=%d, radius=%f, " + "range_filter=%f, %d%% points filtered out\n", + DISTANCE_TYPES[distance_type].c_str(), dim, nb, radius, range_filter, + int(bitset_rate * 100)); - // fp16 candidate - printf( - "\nProcessing HNSW,Flat fp16 for %s distance, dim=%d, nrows=%d, radius=%f, " - "range_filter=%f, %d%% points filtered out\n", - DISTANCE_TYPES[distance_type].c_str(), dim, nb, radius, range_filter, - int(bitset_rate * 100)); + index_file = test_hnsw_range( + default_ds_ptr, query_ds_ptr, golden_result.value(), params, conf, bitset_view); + index_files.emplace_back(index_file); - test_hnsw_range(default_ds_ptr, query_ds_ptr, golden_result.value(), params, - conf, bitset_view); + // fp16 candidate + printf( + "\nProcessing HNSW,Flat fp16 for %s distance, dim=%d, nrows=%d, radius=%f, " + "range_filter=%f, %d%% points filtered out\n", + DISTANCE_TYPES[distance_type].c_str(), dim, nb, radius, range_filter, + int(bitset_rate * 100)); - // bf16 candidate - printf( - "\nProcessing HNSW,Flat bf16 for %s distance, dim=%d, nrows=%d, radius=%f, " - "range_filter=%f, %d%% points filtered out\n", - DISTANCE_TYPES[distance_type].c_str(), dim, nb, radius, range_filter, - int(bitset_rate * 100)); + index_file = test_hnsw_range( + default_ds_ptr, query_ds_ptr, golden_result.value(), params, conf, bitset_view); + index_files.emplace_back(index_file); - test_hnsw_range(default_ds_ptr, query_ds_ptr, golden_result.value(), params, - conf, bitset_view); - if (index_support_int8(conf)) { - // int8 candidate + // bf16 candidate printf( - "\nProcessing HNSW,Flat int8 for %s distance, dim=%d, nrows=%d, radius=%f, " + "\nProcessing HNSW,Flat bf16 for %s distance, dim=%d, nrows=%d, radius=%f, " "range_filter=%f, %d%% points filtered out\n", DISTANCE_TYPES[distance_type].c_str(), dim, nb, radius, range_filter, int(bitset_rate * 100)); - test_hnsw_range(default_ds_ptr, query_ds_ptr, golden_result.value(), - params, conf, bitset_view); + index_file = test_hnsw_range( + default_ds_ptr, query_ds_ptr, golden_result.value(), params, conf, bitset_view); + index_files.emplace_back(index_file); + if (index_support_int8(conf)) { + // int8 candidate + printf( + "\nProcessing HNSW,Flat int8 for %s distance, dim=%d, nrows=%d, radius=%f, " + "range_filter=%f, %d%% points filtered out\n", + DISTANCE_TYPES[distance_type].c_str(), dim, nb, radius, range_filter, + int(bitset_rate * 100)); + + index_file = test_hnsw_range( + default_ds_ptr, query_ds_ptr, golden_result.value(), params, conf, bitset_view); + index_files.emplace_back(index_file); + } + } + for (auto index : index_files) { + std::remove(index.c_str()); } } } @@ -1344,175 +1522,210 @@ TEST_CASE("RangeSearch for FAISS HNSW Indices", "Benchmark and validation for Ra auto golden_index = create_index(golden_index_type, golden_index_file_name, default_ds_ptr, conf_golden, "golden "); - // test various bitset rates - for (const float bitset_rate : BITSET_RATES) { - const int32_t nbits_set = nb * bitset_rate; - const std::vector bitset_data = GenerateBitsetWithRandomTbitsSet(nb, nbits_set); + std::vector> scalar_info = GenerateScalarInfo(nb); + auto partition_size = scalar_info[0].size(); // will be masked by partition key value - // initialize bitset_view. - // provide a default one if nbits_set == 0 - knowhere::BitsetView bitset_view = nullptr; - if (nbits_set != 0) { - bitset_view = knowhere::BitsetView(bitset_data.data(), nb, nb - nbits_set); + for (const bool mv_only_enable : MV_ONLYs) { + printf("with mv only enabled : %d\n", mv_only_enable); + if (mv_only_enable) { + default_ds_ptr->Set(knowhere::meta::SCALAR_INFO, scalar_info); } + std::vector index_files; + std::string index_file; + + // test various bitset rates + for (const float bitset_rate : BITSET_RATES) { + const int32_t nbits_set = mv_only_enable + ? partition_size + (nb - partition_size) * bitset_rate + : nb * bitset_rate; + const std::vector bitset_data = + mv_only_enable ? GenerateBitsetByScalarInfoAndFirstTBits(scalar_info[0], nb, 0) + : GenerateBitsetWithRandomTbitsSet(nb, nbits_set); + + // initialize bitset_view. + // provide a default one if nbits_set == 0 + knowhere::BitsetView bitset_view = nullptr; + if (nbits_set != 0) { + bitset_view = knowhere::BitsetView(bitset_data.data(), nb, nb - nbits_set); + } - // get a golden result - auto golden_result = golden_index.RangeSearch(query_ds_ptr, conf_golden, bitset_view); - - // go SQ - for (size_t i_sq_type = 0; i_sq_type < SQ_TYPES.size(); i_sq_type++) { - knowhere::Json conf = conf_golden; - conf[knowhere::meta::INDEX_TYPE] = index_type; - - const std::string sq_type = SQ_TYPES[i_sq_type]; - conf[knowhere::indexparam::SQ_TYPE] = sq_type; + // get a golden result + auto golden_result = golden_index.RangeSearch(query_ds_ptr, conf_golden, bitset_view); - std::vector params = {(int)distance_type, dim, nb, (int)i_sq_type}; + // go SQ + for (size_t i_sq_type = 0; i_sq_type < SQ_TYPES.size(); i_sq_type++) { + knowhere::Json conf = conf_golden; + conf[knowhere::meta::INDEX_TYPE] = index_type; - // fp32 candidate - printf( - "\nProcessing HNSW,SQ(%s) fp32 for %s distance, dim=%d, nrows=%d, radius=%f, " - "range_filter=%f, %d%% points filtered out\n", - sq_type.c_str(), DISTANCE_TYPES[distance_type].c_str(), dim, nb, radius, - range_filter, int(bitset_rate * 100)); + const std::string sq_type = SQ_TYPES[i_sq_type]; + conf[knowhere::indexparam::SQ_TYPE] = sq_type; - test_hnsw_range(default_ds_ptr, query_ds_ptr, golden_result.value(), - params, conf, bitset_view); + std::vector params = {(int)distance_type, dim, nb, (int)i_sq_type}; - // fp16 candidate - printf( - "\nProcessing HNSW,SQ(%s) fp16 for %s distance, dim=%d, nrows=%d, radius=%f, " - "range_filter=%f, %d%% points filtered out\n", - sq_type.c_str(), DISTANCE_TYPES[distance_type].c_str(), dim, nb, radius, - range_filter, int(bitset_rate * 100)); + // fp32 candidate + printf( + "\nProcessing HNSW,SQ(%s) fp32 for %s distance, dim=%d, nrows=%d, radius=%f, " + "range_filter=%f, %d%% points filtered out\n", + sq_type.c_str(), DISTANCE_TYPES[distance_type].c_str(), dim, nb, radius, + range_filter, int(bitset_rate * 100)); - test_hnsw_range(default_ds_ptr, query_ds_ptr, golden_result.value(), - params, conf, bitset_view); + index_file = test_hnsw_range( + default_ds_ptr, query_ds_ptr, golden_result.value(), params, conf, bitset_view); + index_files.emplace_back(index_file); - // bf16 candidate - printf( - "\nProcessing HNSW,SQ(%s) bf16 for %s distance, dim=%d, nrows=%d, radius=%f, " - "range_filter=%f, %d%% points filtered out\n", - sq_type.c_str(), DISTANCE_TYPES[distance_type].c_str(), dim, nb, radius, - range_filter, int(bitset_rate * 100)); + // fp16 candidate + printf( + "\nProcessing HNSW,SQ(%s) fp16 for %s distance, dim=%d, nrows=%d, radius=%f, " + "range_filter=%f, %d%% points filtered out\n", + sq_type.c_str(), DISTANCE_TYPES[distance_type].c_str(), dim, nb, radius, + range_filter, int(bitset_rate * 100)); - test_hnsw_range(default_ds_ptr, query_ds_ptr, golden_result.value(), - params, conf, bitset_view); + index_file = test_hnsw_range( + default_ds_ptr, query_ds_ptr, golden_result.value(), params, conf, bitset_view); + index_files.emplace_back(index_file); - if (index_support_int8(conf)) { - // int8 candidate + // bf16 candidate printf( - "\nProcessing HNSW,SQ(%s) int8 for %s distance, dim=%d, nrows=%d, radius=%f, " + "\nProcessing HNSW,SQ(%s) bf16 for %s distance, dim=%d, nrows=%d, radius=%f, " "range_filter=%f, %d%% points filtered out\n", sq_type.c_str(), DISTANCE_TYPES[distance_type].c_str(), dim, nb, radius, range_filter, int(bitset_rate * 100)); - test_hnsw_range(default_ds_ptr, query_ds_ptr, golden_result.value(), - params, conf, bitset_view); - } - - // test refines for FP32 - { - const auto& allowed_refs = SQ_ALLOWED_REFINES_FP32[sq_type]; - for (size_t allowed_ref_idx = 0; allowed_ref_idx < allowed_refs.size(); - allowed_ref_idx++) { - auto conf_refine = conf; - conf_refine["refine"] = true; - conf_refine["refine_k"] = 1.5; + index_file = test_hnsw_range( + default_ds_ptr, query_ds_ptr, golden_result.value(), params, conf, bitset_view); + index_files.emplace_back(index_file); - const std::string allowed_ref = allowed_refs[allowed_ref_idx]; - conf_refine["refine_type"] = allowed_ref; + if (index_support_int8(conf)) { + // int8 candidate + printf( + "\nProcessing HNSW,SQ(%s) int8 for %s distance, dim=%d, nrows=%d, " + "radius=%f, " + "range_filter=%f, %d%% points filtered out\n", + sq_type.c_str(), DISTANCE_TYPES[distance_type].c_str(), dim, nb, radius, + range_filter, int(bitset_rate * 100)); + + index_file = test_hnsw_range(default_ds_ptr, query_ds_ptr, + golden_result.value(), params, + conf, bitset_view); + index_files.emplace_back(index_file); + } - std::vector params_refine = {(int)distance_type, dim, nb, - (int)i_sq_type, (int)allowed_ref_idx}; + // test refines for FP32 + { + const auto& allowed_refs = SQ_ALLOWED_REFINES_FP32[sq_type]; + for (size_t allowed_ref_idx = 0; allowed_ref_idx < allowed_refs.size(); + allowed_ref_idx++) { + auto conf_refine = conf; + conf_refine["refine"] = true; + conf_refine["refine_k"] = 1.5; + + const std::string allowed_ref = allowed_refs[allowed_ref_idx]; + conf_refine["refine_type"] = allowed_ref; + + std::vector params_refine = {(int)distance_type, dim, nb, + (int)i_sq_type, (int)allowed_ref_idx}; + + // fp32 candidate + printf( + "\nProcessing HNSW,SQ(%s) with %s refine, fp32 for %s distance, " + "dim=%d, " + "nrows=%d, radius=%f, range_filter=%f, %d%% points filtered out\n", + sq_type.c_str(), allowed_ref.c_str(), + DISTANCE_TYPES[distance_type].c_str(), dim, nb, radius, range_filter, + int(bitset_rate * 100)); + + index_file = test_hnsw_range( + default_ds_ptr, query_ds_ptr, golden_result.value(), params_refine, + conf_refine, bitset_view); + index_files.emplace_back(index_file); + } + } - // fp32 candidate - printf( - "\nProcessing HNSW,SQ(%s) with %s refine, fp32 for %s distance, dim=%d, " - "nrows=%d, radius=%f, range_filter=%f, %d%% points filtered out\n", - sq_type.c_str(), allowed_ref.c_str(), DISTANCE_TYPES[distance_type].c_str(), - dim, nb, radius, range_filter, int(bitset_rate * 100)); + // test refines for FP16 + { + const auto& allowed_refs = SQ_ALLOWED_REFINES_FP16[sq_type]; + for (size_t allowed_ref_idx = 0; allowed_ref_idx < allowed_refs.size(); + allowed_ref_idx++) { + auto conf_refine = conf; + conf_refine["refine"] = true; + conf_refine["refine_k"] = 1.5; + + const std::string allowed_ref = allowed_refs[allowed_ref_idx]; + conf_refine["refine_type"] = allowed_ref; + + std::vector params_refine = {(int)distance_type, dim, nb, + (int)i_sq_type, (int)allowed_ref_idx}; + + // fp16 candidate + printf( + "\nProcessing HNSW,SQ(%s) with %s refine, fp16 for %s distance, " + "dim=%d, " + "nrows=%d, radius=%f, range_filter=%f, %d%% points filtered out\n", + sq_type.c_str(), allowed_ref.c_str(), + DISTANCE_TYPES[distance_type].c_str(), dim, nb, radius, range_filter, + int(bitset_rate * 100)); + + index_file = test_hnsw_range( + default_ds_ptr, query_ds_ptr, golden_result.value(), params_refine, + conf_refine, bitset_view); + index_files.emplace_back(index_file); + } + } - test_hnsw_range(default_ds_ptr, query_ds_ptr, - golden_result.value(), params_refine, - conf_refine, bitset_view); + // test refines for BF16 + { + const auto& allowed_refs = SQ_ALLOWED_REFINES_BF16[sq_type]; + for (size_t allowed_ref_idx = 0; allowed_ref_idx < allowed_refs.size(); + allowed_ref_idx++) { + auto conf_refine = conf; + conf_refine["refine"] = true; + conf_refine["refine_k"] = 1.5; + + const std::string allowed_ref = allowed_refs[allowed_ref_idx]; + conf_refine["refine_type"] = allowed_ref; + + std::vector params_refine = {(int)distance_type, dim, nb, + (int)i_sq_type, (int)allowed_ref_idx}; + + // bf16 candidate + printf( + "\nProcessing HNSW,SQ(%s) with %s refine, bf16 for %s distance, " + "dim=%d, " + "nrows=%d, radius=%f, range_filter=%f, %d%% points filtered out\n", + sq_type.c_str(), allowed_ref.c_str(), + DISTANCE_TYPES[distance_type].c_str(), dim, nb, radius, range_filter, + int(bitset_rate * 100)); + + index_file = test_hnsw_range( + default_ds_ptr, query_ds_ptr, golden_result.value(), params_refine, + conf_refine, bitset_view); + index_files.emplace_back(index_file); + } } } + } - // test refines for FP16 - { - const auto& allowed_refs = SQ_ALLOWED_REFINES_FP16[sq_type]; - for (size_t allowed_ref_idx = 0; allowed_ref_idx < allowed_refs.size(); - allowed_ref_idx++) { - auto conf_refine = conf; - conf_refine["refine"] = true; - conf_refine["refine_k"] = 1.5; - - const std::string allowed_ref = allowed_refs[allowed_ref_idx]; - conf_refine["refine_type"] = allowed_ref; + for (auto index : index_files) { + std::remove(index.c_str()); + } + } + } + } + } + } + } - std::vector params_refine = {(int)distance_type, dim, nb, - (int)i_sq_type, (int)allowed_ref_idx}; + SECTION("PQ") { + const std::string& index_type = knowhere::IndexEnum::INDEX_HNSW_PQ; + const std::string& golden_index_type = knowhere::IndexEnum::INDEX_FAISS_IDMAP; - // fp16 candidate - printf( - "\nProcessing HNSW,SQ(%s) with %s refine, fp16 for %s distance, dim=%d, " - "nrows=%d, radius=%f, range_filter=%f, %d%% points filtered out\n", - sq_type.c_str(), allowed_ref.c_str(), DISTANCE_TYPES[distance_type].c_str(), - dim, nb, radius, range_filter, int(bitset_rate * 100)); + for (size_t distance_type = 0; distance_type < DISTANCE_TYPES.size(); distance_type++) { + const ranges_vec_type& ranges_vec = ranges_dict[DISTANCE_TYPES[distance_type]]; - test_hnsw_range(default_ds_ptr, query_ds_ptr, - golden_result.value(), params_refine, - conf_refine, bitset_view); - } - } - - // test refines for BF16 - { - const auto& allowed_refs = SQ_ALLOWED_REFINES_BF16[sq_type]; - for (size_t allowed_ref_idx = 0; allowed_ref_idx < allowed_refs.size(); - allowed_ref_idx++) { - auto conf_refine = conf; - conf_refine["refine"] = true; - conf_refine["refine_k"] = 1.5; - - const std::string allowed_ref = allowed_refs[allowed_ref_idx]; - conf_refine["refine_type"] = allowed_ref; - - std::vector params_refine = {(int)distance_type, dim, nb, - (int)i_sq_type, (int)allowed_ref_idx}; - - // bf16 candidate - printf( - "\nProcessing HNSW,SQ(%s) with %s refine, bf16 for %s distance, dim=%d, " - "nrows=%d, radius=%f, range_filter=%f, %d%% points filtered out\n", - sq_type.c_str(), allowed_ref.c_str(), DISTANCE_TYPES[distance_type].c_str(), - dim, nb, radius, range_filter, int(bitset_rate * 100)); - - test_hnsw_range(default_ds_ptr, query_ds_ptr, - golden_result.value(), params_refine, - conf_refine, bitset_view); - } - } - } - } - } - } - } - } - } - - SECTION("PQ") { - const std::string& index_type = knowhere::IndexEnum::INDEX_HNSW_PQ; - const std::string& golden_index_type = knowhere::IndexEnum::INDEX_FAISS_IDMAP; - - for (size_t distance_type = 0; distance_type < DISTANCE_TYPES.size(); distance_type++) { - const ranges_vec_type& ranges_vec = ranges_dict[DISTANCE_TYPES[distance_type]]; - - for (const int32_t dim : {16}) { - // generate a query - const uint64_t query_rng_seed = get_params_hash({(int)distance_type, dim}); - auto query_ds_ptr = GenDataSet(NQ, dim, query_rng_seed); + for (const int32_t dim : {16}) { + // generate a query + const uint64_t query_rng_seed = get_params_hash({(int)distance_type, dim}); + auto query_ds_ptr = GenDataSet(NQ, dim, query_rng_seed); for (const int32_t nb : NBS) { for (const auto& [radius_in, range_filter_in] : ranges_vec) { @@ -1545,149 +1758,179 @@ TEST_CASE("RangeSearch for FAISS HNSW Indices", "Benchmark and validation for Ra auto golden_index = create_index(golden_index_type, golden_index_file_name, default_ds_ptr, conf_golden, "golden "); - // test various bitset rates - for (const float bitset_rate : BITSET_RATES) { - const int32_t nbits_set = nb * bitset_rate; - const std::vector bitset_data = GenerateBitsetWithRandomTbitsSet(nb, nbits_set); + std::vector> scalar_info = GenerateScalarInfo(nb); + auto partition_size = scalar_info[0].size(); // will be masked by partition key value - // initialize bitset_view. - // provide a default one if nbits_set == 0 - knowhere::BitsetView bitset_view = nullptr; - if (nbits_set != 0) { - bitset_view = knowhere::BitsetView(bitset_data.data(), nb, nb - nbits_set); + for (const bool mv_only_enable : MV_ONLYs) { + printf("with mv only enabled : %d\n", mv_only_enable); + if (mv_only_enable) { + default_ds_ptr->Set(knowhere::meta::SCALAR_INFO, scalar_info); } + std::vector index_files; + std::string index_file; + + // test various bitset rates + for (const float bitset_rate : BITSET_RATES) { + const int32_t nbits_set = mv_only_enable + ? partition_size + (nb - partition_size) * bitset_rate + : nb * bitset_rate; + const std::vector bitset_data = + mv_only_enable ? GenerateBitsetByScalarInfoAndFirstTBits(scalar_info[0], nb, 0) + : GenerateBitsetWithRandomTbitsSet(nb, nbits_set); + // initialize bitset_view. + // provide a default one if nbits_set == 0 + knowhere::BitsetView bitset_view = nullptr; + if (nbits_set != 0) { + bitset_view = knowhere::BitsetView(bitset_data.data(), nb, nb - nbits_set); + } - // get a golden result - auto golden_result = golden_index.RangeSearch(query_ds_ptr, conf_golden, bitset_view); - - // go PQ - for (size_t nbits_type = 0; nbits_type < NBITS.size(); nbits_type++) { - const int pq_m = 8; - - knowhere::Json conf = conf_golden; - conf[knowhere::meta::INDEX_TYPE] = index_type; - conf[knowhere::indexparam::NBITS] = NBITS[nbits_type]; - conf[knowhere::indexparam::M] = pq_m; - - std::vector params = {(int)distance_type, dim, nb, pq_m, (int)nbits_type}; - - // test fp32 candidate - printf( - "\nProcessing HNSW,PQ%dx%d fp32 for %s distance, dim=%d, nrows=%d, radius=%f, " - "range_filter=%f, %d%% points filtered out\n", - pq_m, NBITS[nbits_type], DISTANCE_TYPES[distance_type].c_str(), dim, nb, radius, - range_filter, int(bitset_rate * 100)); + // get a golden result + auto golden_result = golden_index.RangeSearch(query_ds_ptr, conf_golden, bitset_view); - test_hnsw_range(default_ds_ptr, query_ds_ptr, golden_result.value(), - params, conf, bitset_view); + // go PQ + for (size_t nbits_type = 0; nbits_type < NBITS.size(); nbits_type++) { + const int pq_m = 8; - // test fp16 candidate - printf( - "\nProcessing HNSW,PQ%dx%d fp16 for %s distance, dim=%d, nrows=%d, radius=%f, " - "range_filter=%f, %d%% points filtered out\n", - pq_m, NBITS[nbits_type], DISTANCE_TYPES[distance_type].c_str(), dim, nb, radius, - range_filter, int(bitset_rate * 100)); + knowhere::Json conf = conf_golden; + conf[knowhere::meta::INDEX_TYPE] = index_type; + conf[knowhere::indexparam::NBITS] = NBITS[nbits_type]; + conf[knowhere::indexparam::M] = pq_m; - test_hnsw_range(default_ds_ptr, query_ds_ptr, golden_result.value(), - params, conf, bitset_view); + std::vector params = {(int)distance_type, dim, nb, pq_m, (int)nbits_type}; - // test bf16 candidate - printf( - "\nProcessing HNSW,PQ%dx%d bf16 for %s distance, dim=%d, nrows=%d, radius=%f, " - "range_filter=%f, %d%% points filtered out\n", - pq_m, NBITS[nbits_type], DISTANCE_TYPES[distance_type].c_str(), dim, nb, radius, - range_filter, int(bitset_rate * 100)); + // test fp32 candidate + printf( + "\nProcessing HNSW,PQ%dx%d fp32 for %s distance, dim=%d, nrows=%d, radius=%f, " + "range_filter=%f, %d%% points filtered out\n", + pq_m, NBITS[nbits_type], DISTANCE_TYPES[distance_type].c_str(), dim, nb, radius, + range_filter, int(bitset_rate * 100)); - test_hnsw_range(default_ds_ptr, query_ds_ptr, golden_result.value(), - params, conf, bitset_view); + index_file = test_hnsw_range( + default_ds_ptr, query_ds_ptr, golden_result.value(), params, conf, bitset_view); + index_files.emplace_back(index_file); - if (index_support_int8(conf)) { - // test int8 candidate + // test fp16 candidate printf( - "\nProcessing HNSW,PQ%dx%d int8 for %s distance, dim=%d, nrows=%d, radius=%f, " + "\nProcessing HNSW,PQ%dx%d fp16 for %s distance, dim=%d, nrows=%d, radius=%f, " "range_filter=%f, %d%% points filtered out\n", pq_m, NBITS[nbits_type], DISTANCE_TYPES[distance_type].c_str(), dim, nb, radius, range_filter, int(bitset_rate * 100)); - test_hnsw_range(default_ds_ptr, query_ds_ptr, golden_result.value(), - params, conf, bitset_view); - } + index_file = test_hnsw_range( + default_ds_ptr, query_ds_ptr, golden_result.value(), params, conf, bitset_view); + index_files.emplace_back(index_file); - // test refines for fp32 - for (size_t allowed_ref_idx = 0; allowed_ref_idx < PQ_ALLOWED_REFINES_FP32.size(); - allowed_ref_idx++) { - auto conf_refine = conf; - conf_refine["refine"] = true; - conf_refine["refine_k"] = 1.5; + // test bf16 candidate + printf( + "\nProcessing HNSW,PQ%dx%d bf16 for %s distance, dim=%d, nrows=%d, radius=%f, " + "range_filter=%f, %d%% points filtered out\n", + pq_m, NBITS[nbits_type], DISTANCE_TYPES[distance_type].c_str(), dim, nb, radius, + range_filter, int(bitset_rate * 100)); - const std::string allowed_ref = PQ_ALLOWED_REFINES_FP32[allowed_ref_idx]; - conf_refine["refine_type"] = allowed_ref; + index_file = test_hnsw_range( + default_ds_ptr, query_ds_ptr, golden_result.value(), params, conf, bitset_view); + index_files.emplace_back(index_file); - std::vector params_refine = { - (int)distance_type, dim, nb, pq_m, (int)nbits_type, (int)allowed_ref_idx}; + if (index_support_int8(conf)) { + // test int8 candidate + printf( + "\nProcessing HNSW,PQ%dx%d int8 for %s distance, dim=%d, nrows=%d, " + "radius=%f, " + "range_filter=%f, %d%% points filtered out\n", + pq_m, NBITS[nbits_type], DISTANCE_TYPES[distance_type].c_str(), dim, nb, + radius, range_filter, int(bitset_rate * 100)); + + index_file = test_hnsw_range(default_ds_ptr, query_ds_ptr, + golden_result.value(), params, + conf, bitset_view); + index_files.emplace_back(index_file); + } - // test fp32 candidate - printf( - "\nProcessing HNSW,PQ%dx%d with %s refine, fp32 for %s distance, dim=%d, " - "nrows=%d, radius=%f, range_filter=%f, %d%% points filtered out\n", - pq_m, NBITS[nbits_type], allowed_ref.c_str(), - DISTANCE_TYPES[distance_type].c_str(), dim, nb, radius, range_filter, - int(bitset_rate * 100)); + // test refines for fp32 + for (size_t allowed_ref_idx = 0; allowed_ref_idx < PQ_ALLOWED_REFINES_FP32.size(); + allowed_ref_idx++) { + auto conf_refine = conf; + conf_refine["refine"] = true; + conf_refine["refine_k"] = 1.5; - test_hnsw_range(default_ds_ptr, query_ds_ptr, golden_result.value(), - params_refine, conf_refine, bitset_view); - } + const std::string allowed_ref = PQ_ALLOWED_REFINES_FP32[allowed_ref_idx]; + conf_refine["refine_type"] = allowed_ref; - // test refines for fp16 - for (size_t allowed_ref_idx = 0; allowed_ref_idx < PQ_ALLOWED_REFINES_FP16.size(); - allowed_ref_idx++) { - auto conf_refine = conf; - conf_refine["refine"] = true; - conf_refine["refine_k"] = 1.5; + std::vector params_refine = { + (int)distance_type, dim, nb, pq_m, (int)nbits_type, (int)allowed_ref_idx}; - const std::string allowed_ref = PQ_ALLOWED_REFINES_FP16[allowed_ref_idx]; - conf_refine["refine_type"] = allowed_ref; + // test fp32 candidate + printf( + "\nProcessing HNSW,PQ%dx%d with %s refine, fp32 for %s distance, dim=%d, " + "nrows=%d, radius=%f, range_filter=%f, %d%% points filtered out\n", + pq_m, NBITS[nbits_type], allowed_ref.c_str(), + DISTANCE_TYPES[distance_type].c_str(), dim, nb, radius, range_filter, + int(bitset_rate * 100)); + + index_file = test_hnsw_range( + default_ds_ptr, query_ds_ptr, golden_result.value(), params_refine, + conf_refine, bitset_view); + index_files.emplace_back(index_file); + } - std::vector params_refine = { - (int)distance_type, dim, nb, pq_m, (int)nbits_type, (int)allowed_ref_idx}; + // test refines for fp16 + for (size_t allowed_ref_idx = 0; allowed_ref_idx < PQ_ALLOWED_REFINES_FP16.size(); + allowed_ref_idx++) { + auto conf_refine = conf; + conf_refine["refine"] = true; + conf_refine["refine_k"] = 1.5; - // test fp16 candidate - printf( - "\nProcessing HNSW,PQ%dx%d with %s refine, fp16 for %s distance, dim=%d, " - "nrows=%d, radius=%f, range_filter=%f, %d%% points filtered out\n", - pq_m, NBITS[nbits_type], allowed_ref.c_str(), - DISTANCE_TYPES[distance_type].c_str(), dim, nb, radius, range_filter, - int(bitset_rate * 100)); + const std::string allowed_ref = PQ_ALLOWED_REFINES_FP16[allowed_ref_idx]; + conf_refine["refine_type"] = allowed_ref; - test_hnsw_range(default_ds_ptr, query_ds_ptr, golden_result.value(), - params_refine, conf_refine, bitset_view); - } + std::vector params_refine = { + (int)distance_type, dim, nb, pq_m, (int)nbits_type, (int)allowed_ref_idx}; - // test refines for bf16 - for (size_t allowed_ref_idx = 0; allowed_ref_idx < PQ_ALLOWED_REFINES_BF16.size(); - allowed_ref_idx++) { - auto conf_refine = conf; - conf_refine["refine"] = true; - conf_refine["refine_k"] = 1.5; + // test fp16 candidate + printf( + "\nProcessing HNSW,PQ%dx%d with %s refine, fp16 for %s distance, dim=%d, " + "nrows=%d, radius=%f, range_filter=%f, %d%% points filtered out\n", + pq_m, NBITS[nbits_type], allowed_ref.c_str(), + DISTANCE_TYPES[distance_type].c_str(), dim, nb, radius, range_filter, + int(bitset_rate * 100)); + + index_file = test_hnsw_range( + default_ds_ptr, query_ds_ptr, golden_result.value(), params_refine, + conf_refine, bitset_view); + index_files.emplace_back(index_file); + } - const std::string allowed_ref = PQ_ALLOWED_REFINES_BF16[allowed_ref_idx]; - conf_refine["refine_type"] = allowed_ref; + // test refines for bf16 + for (size_t allowed_ref_idx = 0; allowed_ref_idx < PQ_ALLOWED_REFINES_BF16.size(); + allowed_ref_idx++) { + auto conf_refine = conf; + conf_refine["refine"] = true; + conf_refine["refine_k"] = 1.5; - std::vector params_refine = { - (int)distance_type, dim, nb, pq_m, (int)nbits_type, (int)allowed_ref_idx}; + const std::string allowed_ref = PQ_ALLOWED_REFINES_BF16[allowed_ref_idx]; + conf_refine["refine_type"] = allowed_ref; - // test bf16 candidate - printf( - "\nProcessing HNSW,PQ%dx%d with %s refine, bf16 for %s distance, dim=%d, " - "nrows=%d, radius=%f, range_filter=%f, %d%% points filtered out\n", - pq_m, NBITS[nbits_type], allowed_ref.c_str(), - DISTANCE_TYPES[distance_type].c_str(), dim, nb, radius, range_filter, - int(bitset_rate * 100)); + std::vector params_refine = { + (int)distance_type, dim, nb, pq_m, (int)nbits_type, (int)allowed_ref_idx}; - test_hnsw_range(default_ds_ptr, query_ds_ptr, golden_result.value(), - params_refine, conf_refine, bitset_view); + // test bf16 candidate + printf( + "\nProcessing HNSW,PQ%dx%d with %s refine, bf16 for %s distance, dim=%d, " + "nrows=%d, radius=%f, range_filter=%f, %d%% points filtered out\n", + pq_m, NBITS[nbits_type], allowed_ref.c_str(), + DISTANCE_TYPES[distance_type].c_str(), dim, nb, radius, range_filter, + int(bitset_rate * 100)); + + index_file = test_hnsw_range( + default_ds_ptr, query_ds_ptr, golden_result.value(), params_refine, + conf_refine, bitset_view); + index_files.emplace_back(index_file); + } } } + for (auto index : index_files) { + std::remove(index.c_str()); + } } } } @@ -1738,156 +1981,193 @@ TEST_CASE("RangeSearch for FAISS HNSW Indices", "Benchmark and validation for Ra auto golden_index = create_index(golden_index_type, golden_index_file_name, default_ds_ptr, conf_golden, "golden "); - // test various bitset rates - for (const float bitset_rate : BITSET_RATES) { - const int32_t nbits_set = nb * bitset_rate; - const std::vector bitset_data = GenerateBitsetWithRandomTbitsSet(nb, nbits_set); + std::vector> scalar_info = GenerateScalarInfo(nb); + auto partition_size = scalar_info[0].size(); // will be masked by partition key value - // initialize bitset_view. - // provide a default one if nbits_set == 0 - knowhere::BitsetView bitset_view = nullptr; - if (nbits_set != 0) { - bitset_view = knowhere::BitsetView(bitset_data.data(), nb, nb - nbits_set); + for (const bool mv_only_enable : MV_ONLYs) { + printf("with mv only enabled : %d\n", mv_only_enable); + if (mv_only_enable) { + default_ds_ptr->Set(knowhere::meta::SCALAR_INFO, scalar_info); } - // get a golden result - auto golden_result = golden_index.RangeSearch(query_ds_ptr, conf_golden, bitset_view); - - // go PRQ - for (size_t nbits_type = 0; nbits_type < NBITS.size(); nbits_type++) { - const int prq_m = 4; - const int prq_num = 2; + std::vector index_files; + std::string index_file; + + // test various bitset rates + for (const float bitset_rate : BITSET_RATES) { + const int32_t nbits_set = mv_only_enable + ? partition_size + (nb - partition_size) * bitset_rate + : nb * bitset_rate; + const std::vector bitset_data = + mv_only_enable ? GenerateBitsetByScalarInfoAndFirstTBits(scalar_info[0], nb, 0) + : GenerateBitsetWithRandomTbitsSet(nb, nbits_set); + // initialize bitset_view. + // provide a default one if nbits_set == 0 + knowhere::BitsetView bitset_view = nullptr; + if (nbits_set != 0) { + bitset_view = knowhere::BitsetView(bitset_data.data(), nb, nb - nbits_set); + } - knowhere::Json conf = conf_golden; - conf[knowhere::meta::INDEX_TYPE] = index_type; - conf[knowhere::indexparam::NBITS] = NBITS[nbits_type]; - conf[knowhere::indexparam::M] = prq_m; - conf[knowhere::indexparam::PRQ_NUM] = prq_num; + // get a golden result + auto golden_result = golden_index.RangeSearch(query_ds_ptr, conf_golden, bitset_view); - std::vector params = {(int)distance_type, dim, nb, prq_m, prq_num, - (int)nbits_type}; + // go PRQ + for (size_t nbits_type = 0; nbits_type < NBITS.size(); nbits_type++) { + const int prq_m = 4; + const int prq_num = 2; - // test fp32 candidate - printf( - "\nProcessing HNSW,PRQ%dx%dx%d fp32 for %s distance, dim=%d, nrows=%d, radius=%f, " - "range_filter=%f, %d%% points filtered out\n", - prq_num, prq_m, NBITS[nbits_type], DISTANCE_TYPES[distance_type].c_str(), dim, nb, - radius, range_filter, int(bitset_rate * 100)); + knowhere::Json conf = conf_golden; + conf[knowhere::meta::INDEX_TYPE] = index_type; + conf[knowhere::indexparam::NBITS] = NBITS[nbits_type]; + conf[knowhere::indexparam::M] = prq_m; + conf[knowhere::indexparam::PRQ_NUM] = prq_num; - test_hnsw_range(default_ds_ptr, query_ds_ptr, golden_result.value(), - params, conf, bitset_view); + std::vector params = {(int)distance_type, dim, nb, prq_m, prq_num, + (int)nbits_type}; - // test fp16 candidate - printf( - "\nProcessing HNSW,PRQ%dx%dx%d fp16 for %s distance, dim=%d, nrows=%d, radius=%f, " - "range_filter=%f, %d%% points filtered out\n", - prq_num, prq_m, NBITS[nbits_type], DISTANCE_TYPES[distance_type].c_str(), dim, nb, - radius, range_filter, int(bitset_rate * 100)); + // test fp32 candidate + printf( + "\nProcessing HNSW,PRQ%dx%dx%d fp32 for %s distance, dim=%d, nrows=%d, " + "radius=%f, " + "range_filter=%f, %d%% points filtered out\n", + prq_num, prq_m, NBITS[nbits_type], DISTANCE_TYPES[distance_type].c_str(), dim, + nb, radius, range_filter, int(bitset_rate * 100)); - test_hnsw_range(default_ds_ptr, query_ds_ptr, golden_result.value(), - params, conf, bitset_view); + index_file = test_hnsw_range( + default_ds_ptr, query_ds_ptr, golden_result.value(), params, conf, bitset_view); + index_files.emplace_back(index_file); - // test bf16 candidate - printf( - "\nProcessing HNSW,PRQ%dx%dx%d bf16 for %s distance, dim=%d, nrows=%d, radius=%f, " - "range_filter=%f, %d%% points filtered out\n", - prq_num, prq_m, NBITS[nbits_type], DISTANCE_TYPES[distance_type].c_str(), dim, nb, - radius, range_filter, int(bitset_rate * 100)); + // test fp16 candidate + printf( + "\nProcessing HNSW,PRQ%dx%dx%d fp16 for %s distance, dim=%d, nrows=%d, " + "radius=%f, " + "range_filter=%f, %d%% points filtered out\n", + prq_num, prq_m, NBITS[nbits_type], DISTANCE_TYPES[distance_type].c_str(), dim, + nb, radius, range_filter, int(bitset_rate * 100)); - test_hnsw_range(default_ds_ptr, query_ds_ptr, golden_result.value(), - params, conf, bitset_view); + index_file = test_hnsw_range( + default_ds_ptr, query_ds_ptr, golden_result.value(), params, conf, bitset_view); + index_files.emplace_back(index_file); - if (index_support_int8(conf)) { - // test int8 candidate + // test bf16 candidate printf( - "\nProcessing HNSW,PRQ%dx%dx%d int8 for %s distance, dim=%d, nrows=%d, " + "\nProcessing HNSW,PRQ%dx%dx%d bf16 for %s distance, dim=%d, nrows=%d, " "radius=%f, " "range_filter=%f, %d%% points filtered out\n", prq_num, prq_m, NBITS[nbits_type], DISTANCE_TYPES[distance_type].c_str(), dim, nb, radius, range_filter, int(bitset_rate * 100)); - test_hnsw_range(default_ds_ptr, query_ds_ptr, golden_result.value(), - params, conf, bitset_view); - } + index_file = test_hnsw_range( + default_ds_ptr, query_ds_ptr, golden_result.value(), params, conf, bitset_view); + index_files.emplace_back(index_file); - // test fp32 refines - for (size_t allowed_ref_idx = 0; allowed_ref_idx < PQ_ALLOWED_REFINES_FP32.size(); - allowed_ref_idx++) { - auto conf_refine = conf; - conf_refine["refine"] = true; - conf_refine["refine_k"] = 1.5; + if (index_support_int8(conf)) { + // test int8 candidate + printf( + "\nProcessing HNSW,PRQ%dx%dx%d int8 for %s distance, dim=%d, nrows=%d, " + "radius=%f, " + "range_filter=%f, %d%% points filtered out\n", + prq_num, prq_m, NBITS[nbits_type], DISTANCE_TYPES[distance_type].c_str(), + dim, nb, radius, range_filter, int(bitset_rate * 100)); - const std::string allowed_ref = PQ_ALLOWED_REFINES_FP32[allowed_ref_idx]; - conf_refine["refine_type"] = allowed_ref; + index_file = test_hnsw_range(default_ds_ptr, query_ds_ptr, + golden_result.value(), params, + conf, bitset_view); + index_files.emplace_back(index_file); + } - std::vector params_refine = { - (int)distance_type, dim, nb, prq_m, prq_num, (int)nbits_type, - (int)allowed_ref_idx}; + // test fp32 refines + for (size_t allowed_ref_idx = 0; allowed_ref_idx < PQ_ALLOWED_REFINES_FP32.size(); + allowed_ref_idx++) { + auto conf_refine = conf; + conf_refine["refine"] = true; + conf_refine["refine_k"] = 1.5; - printf( - "\nProcessing HNSW,PRQ%dx%dx%d with %s refine, fp32 for %s distance, dim=%d, " - "nrows=%d, radius=%f, range_filter=%f, %d%% points filtered out\n", - prq_num, prq_m, NBITS[nbits_type], allowed_ref.c_str(), - DISTANCE_TYPES[distance_type].c_str(), dim, nb, radius, range_filter, - int(bitset_rate * 100)); + const std::string allowed_ref = PQ_ALLOWED_REFINES_FP32[allowed_ref_idx]; + conf_refine["refine_type"] = allowed_ref; - // test a candidate - test_hnsw_range(default_ds_ptr, query_ds_ptr, golden_result.value(), - params_refine, conf_refine, bitset_view); - } + std::vector params_refine = { + (int)distance_type, dim, nb, prq_m, prq_num, (int)nbits_type, + (int)allowed_ref_idx}; - // test fp16 refines - for (size_t allowed_ref_idx = 0; allowed_ref_idx < PQ_ALLOWED_REFINES_FP16.size(); - allowed_ref_idx++) { - auto conf_refine = conf; - conf_refine["refine"] = true; - conf_refine["refine_k"] = 1.5; + printf( + "\nProcessing HNSW,PRQ%dx%dx%d with %s refine, fp32 for %s distance, " + "dim=%d, " + "nrows=%d, radius=%f, range_filter=%f, %d%% points filtered out\n", + prq_num, prq_m, NBITS[nbits_type], allowed_ref.c_str(), + DISTANCE_TYPES[distance_type].c_str(), dim, nb, radius, range_filter, + int(bitset_rate * 100)); + + // test a candidate + index_file = test_hnsw_range( + default_ds_ptr, query_ds_ptr, golden_result.value(), params_refine, + conf_refine, bitset_view); + index_files.emplace_back(index_file); + } - const std::string allowed_ref = PQ_ALLOWED_REFINES_FP16[allowed_ref_idx]; - conf_refine["refine_type"] = allowed_ref; + // test fp16 refines + for (size_t allowed_ref_idx = 0; allowed_ref_idx < PQ_ALLOWED_REFINES_FP16.size(); + allowed_ref_idx++) { + auto conf_refine = conf; + conf_refine["refine"] = true; + conf_refine["refine_k"] = 1.5; - std::vector params_refine = { - (int)distance_type, dim, nb, prq_m, prq_num, (int)nbits_type, - (int)allowed_ref_idx}; + const std::string allowed_ref = PQ_ALLOWED_REFINES_FP16[allowed_ref_idx]; + conf_refine["refine_type"] = allowed_ref; - printf( - "\nProcessing HNSW,PRQ%dx%dx%d with %s refine, fp16 for %s distance, dim=%d, " - "nrows=%d, radius=%f, range_filter=%f, %d%% points filtered out\n", - prq_num, prq_m, NBITS[nbits_type], allowed_ref.c_str(), - DISTANCE_TYPES[distance_type].c_str(), dim, nb, radius, range_filter, - int(bitset_rate * 100)); + std::vector params_refine = { + (int)distance_type, dim, nb, prq_m, prq_num, (int)nbits_type, + (int)allowed_ref_idx}; - // test a candidate - test_hnsw_range(default_ds_ptr, query_ds_ptr, golden_result.value(), - params_refine, conf_refine, bitset_view); - } + printf( + "\nProcessing HNSW,PRQ%dx%dx%d with %s refine, fp16 for %s distance, " + "dim=%d, " + "nrows=%d, radius=%f, range_filter=%f, %d%% points filtered out\n", + prq_num, prq_m, NBITS[nbits_type], allowed_ref.c_str(), + DISTANCE_TYPES[distance_type].c_str(), dim, nb, radius, range_filter, + int(bitset_rate * 100)); - // test bf16 refines - for (size_t allowed_ref_idx = 0; allowed_ref_idx < PQ_ALLOWED_REFINES_BF16.size(); - allowed_ref_idx++) { - auto conf_refine = conf; - conf_refine["refine"] = true; - conf_refine["refine_k"] = 1.5; + // test a candidate + index_file = test_hnsw_range( + default_ds_ptr, query_ds_ptr, golden_result.value(), params_refine, + conf_refine, bitset_view); - const std::string allowed_ref = PQ_ALLOWED_REFINES_BF16[allowed_ref_idx]; - conf_refine["refine_type"] = allowed_ref; + index_files.emplace_back(index_file); + } - std::vector params_refine = { - (int)distance_type, dim, nb, prq_m, prq_num, (int)nbits_type, - (int)allowed_ref_idx}; + // test bf16 refines + for (size_t allowed_ref_idx = 0; allowed_ref_idx < PQ_ALLOWED_REFINES_BF16.size(); + allowed_ref_idx++) { + auto conf_refine = conf; + conf_refine["refine"] = true; + conf_refine["refine_k"] = 1.5; - printf( - "\nProcessing HNSW,PRQ%dx%dx%d with %s refine, bf16 for %s distance, dim=%d, " - "nrows=%d, radius=%f, range_filter=%f, %d%% points filtered out\n", - prq_num, prq_m, NBITS[nbits_type], allowed_ref.c_str(), - DISTANCE_TYPES[distance_type].c_str(), dim, nb, radius, range_filter, - int(bitset_rate * 100)); + const std::string allowed_ref = PQ_ALLOWED_REFINES_BF16[allowed_ref_idx]; + conf_refine["refine_type"] = allowed_ref; - // test a candidate - test_hnsw_range(default_ds_ptr, query_ds_ptr, golden_result.value(), - params_refine, conf_refine, bitset_view); + std::vector params_refine = { + (int)distance_type, dim, nb, prq_m, prq_num, (int)nbits_type, + (int)allowed_ref_idx}; + + printf( + "\nProcessing HNSW,PRQ%dx%dx%d with %s refine, bf16 for %s distance, " + "dim=%d, " + "nrows=%d, radius=%f, range_filter=%f, %d%% points filtered out\n", + prq_num, prq_m, NBITS[nbits_type], allowed_ref.c_str(), + DISTANCE_TYPES[distance_type].c_str(), dim, nb, radius, range_filter, + int(bitset_rate * 100)); + + // test a candidate + index_file = test_hnsw_range( + default_ds_ptr, query_ds_ptr, golden_result.value(), params_refine, + conf_refine, bitset_view); + index_files.emplace_back(index_file); + } } } + for (auto index : index_files) { + std::remove(index.c_str()); + } } } } @@ -1915,6 +2195,7 @@ TEST_CASE("hnswlib to FAISS HNSW for HNSW_FLAT", "Check search fallback") { default_conf[knowhere::indexparam::EFCONSTRUCTION] = 96; default_conf[knowhere::indexparam::EF] = 64; default_conf[knowhere::meta::TOPK] = TOPK; + default_conf[knowhere::meta::MV_ONLY_ENABLED] = false; // const std::string hnswlib_index_type = knowhere::IndexEnum::INDEX_HNSW; @@ -1947,7 +2228,7 @@ TEST_CASE("hnswlib to FAISS HNSW for HNSW_FLAT", "Check search fallback") { get_index_name(ann_test_name_, hnswlib_index_type, hnswlib_params); auto hnswlib_index = - create_index(hnswlib_index_type, hnswlib_index_file_name, default_ds_ptr, conf, "hnswlib "); + create_index(hnswlib_index_type, hnswlib_index_file_name, default_ds_ptr, conf, "hnswlib"); // perform an hnswlib search auto hnswlib_result = hnswlib_index.Search(query_ds_ptr, conf, nullptr); diff --git a/tests/ut/test_index_check.cc b/tests/ut/test_index_check.cc index 276051b88..5a7971676 100644 --- a/tests/ut/test_index_check.cc +++ b/tests/ut/test_index_check.cc @@ -472,8 +472,13 @@ TEST_CASE("Test index feature check", "[IndexFeatureCheck]") { } SECTION("Check MV") { - // Only HNSW supports Materialized View + // Only HNSW family supports Materialized View REQUIRE(IndexFactory::Instance().FeatureCheck(IndexEnum::INDEX_HNSW, knowhere::feature::MV)); + REQUIRE(IndexFactory::Instance().FeatureCheck(IndexEnum::INDEX_HNSW_SQ, knowhere::feature::MV)); + REQUIRE(IndexFactory::Instance().FeatureCheck(IndexEnum::INDEX_HNSW_SQ, knowhere::feature::MV)); + REQUIRE(IndexFactory::Instance().FeatureCheck(IndexEnum::INDEX_HNSW_SQ, knowhere::feature::MV)); + REQUIRE(IndexFactory::Instance().FeatureCheck(IndexEnum::INDEX_HNSW_PQ, knowhere::feature::MV)); + REQUIRE(IndexFactory::Instance().FeatureCheck(IndexEnum::INDEX_HNSW_PRQ, knowhere::feature::MV)); // All other indexes do not support MV REQUIRE_FALSE(IndexFactory::Instance().FeatureCheck(IndexEnum::INDEX_FAISS_IDMAP, knowhere::feature::MV)); @@ -483,12 +488,7 @@ TEST_CASE("Test index feature check", "[IndexFeatureCheck]") { REQUIRE_FALSE(IndexFactory::Instance().FeatureCheck(IndexEnum::INDEX_FAISS_SCANN, knowhere::feature::MV)); REQUIRE_FALSE(IndexFactory::Instance().FeatureCheck(IndexEnum::INDEX_FAISS_BIN_IDMAP, knowhere::feature::MV)); REQUIRE_FALSE(IndexFactory::Instance().FeatureCheck(IndexEnum::INDEX_FAISS_BIN_IVFFLAT, knowhere::feature::MV)); - REQUIRE_FALSE(IndexFactory::Instance().FeatureCheck(IndexEnum::INDEX_HNSW_SQ, knowhere::feature::MV)); - REQUIRE_FALSE(IndexFactory::Instance().FeatureCheck(IndexEnum::INDEX_HNSW_SQ, knowhere::feature::MV)); - REQUIRE(IndexFactory::Instance().FeatureCheck(IndexEnum::INDEX_HNSW, knowhere::feature::MV)); - REQUIRE_FALSE(IndexFactory::Instance().FeatureCheck(IndexEnum::INDEX_HNSW_SQ, knowhere::feature::MV)); - REQUIRE_FALSE(IndexFactory::Instance().FeatureCheck(IndexEnum::INDEX_HNSW_PQ, knowhere::feature::MV)); - REQUIRE_FALSE(IndexFactory::Instance().FeatureCheck(IndexEnum::INDEX_HNSW_PRQ, knowhere::feature::MV)); + REQUIRE_FALSE( IndexFactory::Instance().FeatureCheck(IndexEnum::INDEX_SPARSE_INVERTED_INDEX, knowhere::feature::MV)); REQUIRE_FALSE(IndexFactory::Instance().FeatureCheck(IndexEnum::INDEX_SPARSE_WAND, knowhere::feature::MV)); diff --git a/tests/ut/utils.h b/tests/ut/utils.h index 3b62ba7ef..c3979f95d 100644 --- a/tests/ut/utils.h +++ b/tests/ut/utils.h @@ -283,6 +283,50 @@ CheckDistanceInScope(const knowhere::DataSet& result, float low_bound, float hig return true; } +inline std::vector> +GenerateScalarInfo(size_t n) { + std::vector> scalar_info; + scalar_info.reserve(2); + std::vector scalar1; + scalar1.reserve(n / 2); + std::vector scalar2; + scalar2.reserve(n - n / 2); + for (size_t i = 0; i < n; ++i) { + if (i % 2 == 0) { + scalar2.emplace_back(i); + } else { + scalar1.emplace_back(i); + } + } + scalar_info.emplace_back(std::move(scalar1)); + scalar_info.emplace_back(std::move(scalar2)); + return scalar_info; +} + +inline std::vector +GenerateBitsetByScalarInfoAndFirstTBits(const std::vector& scalar, size_t n, size_t t) { + assert(scalar.size() <= n); + assert(t >= 0 && t <= n - scalar.size()); + std::vector data((n + 8 - 1) / 8, 0); + // set bits by scalar info + for (size_t i = 0; i < scalar.size(); ++i) { + data[scalar[i] >> 3] |= (0x1 << (scalar[i] & 0x7)); + } + size_t count = 0; + for (size_t i = 0; i < n; ++i) { + if (count == t) { + break; + } + // already set, skip + if (data[i >> 3] & (0x1 << (i & 0x7))) { + continue; + } + data[i >> 3] |= (0x1 << (i & 0x7)); + ++count; + } + return data; +} + // Return a n-bits bitset data with first t bits set to true inline std::vector GenerateBitsetWithFirstTbitsSet(size_t n, size_t t) { diff --git a/thirdparty/faiss/faiss/impl/index_read.cpp b/thirdparty/faiss/faiss/impl/index_read.cpp index c2bf84b0c..7504114e8 100644 --- a/thirdparty/faiss/faiss/impl/index_read.cpp +++ b/thirdparty/faiss/faiss/impl/index_read.cpp @@ -65,6 +65,29 @@ namespace faiss { +uint32_t read_value(IOReader* f) { + uint32_t h; + READ1(h) + return h; +} + +void read_vector(std::vector& v, IOReader* f) { + READVECTOR(v); +} + +// "IHMV" is a special header for faiss hnsw to indicate whether mv or not +bool read_is_mv(IOReader* f) { + uint32_t h; + READ1(h); + return h == fourcc("IHMV"); +} + +bool read_is_mv(const char* fname) { + FileIOReader f(fname); + return read_is_mv(&f); +} + + template void read_vector(VectorT& target, IOReader* f) { // is it a mmap-enabled reader? diff --git a/thirdparty/faiss/faiss/impl/index_write.cpp b/thirdparty/faiss/faiss/impl/index_write.cpp index 6a73729ed..c67a5674a 100644 --- a/thirdparty/faiss/faiss/impl/index_write.cpp +++ b/thirdparty/faiss/faiss/impl/index_write.cpp @@ -1102,6 +1102,20 @@ void write_index(const Index* idx, const char* fname, int io_flags) { write_index(idx, &writer, io_flags); } +void write_value(int v, IOWriter* f) { + WRITE1(v); +} + +void write_vector(const std::vector& v, IOWriter* f) { + WRITEVECTOR(v); +} + +// "IHMV" is a special header for faiss hnsw to indicate whether mv or not +void write_mv(IOWriter* f) { + uint32_t h = fourcc("IHMV"); + WRITE1(h); +} + // write index for offset-only index void write_index_nm(const Index *idx, IOWriter *f) { if(const IndexIVFFlat * ivfl = diff --git a/thirdparty/faiss/faiss/index_io.h b/thirdparty/faiss/faiss/index_io.h index 8e08746ca..166efdff0 100644 --- a/thirdparty/faiss/faiss/index_io.h +++ b/thirdparty/faiss/faiss/index_io.h @@ -97,6 +97,15 @@ InvertedLists* read_InvertedLists(IOReader* reader, int io_flags = 0); // for backward compatibility Index *read_index_nm(IOReader *f, int io_flags = 0); void write_index_nm(const Index* idx, IOWriter* writer); + +// additional helper function +bool read_is_mv(IOReader* reader); +bool read_is_mv(const char* fname); +void write_vector(const std::vector& v, IOWriter* writer); +void read_vector(std::vector& v, IOReader* f); +uint32_t read_value(IOReader *f); +void write_value(int v, IOWriter* writer); +void write_mv(IOWriter* writer); } // namespace faiss #endif