From 0548a6174d498448a505c7186a291a8c29331947 Mon Sep 17 00:00:00 2001 From: Gao Date: Thu, 19 Oct 2023 08:19:59 +0800 Subject: [PATCH] Add new ser/deser for scann index (#156) Signed-off-by: chasingegg --- src/index/ivf/ivf.cc | 4 +- thirdparty/faiss/faiss/IndexRefine.cpp | 32 ++--------- thirdparty/faiss/faiss/IndexRefine.h | 2 - thirdparty/faiss/faiss/IndexScaNN.cpp | 43 +++++++++++++- thirdparty/faiss/faiss/IndexScaNN.h | 12 +++- thirdparty/faiss/faiss/impl/index_read.cpp | 64 ++++++++------------- thirdparty/faiss/faiss/impl/index_write.cpp | 31 +++++----- 7 files changed, 97 insertions(+), 91 deletions(-) diff --git a/src/index/ivf/ivf.cc b/src/index/ivf/ivf.cc index 5aed5a1b2..2c8916e15 100644 --- a/src/index/ivf/ivf.cc +++ b/src/index/ivf/ivf.cc @@ -78,7 +78,7 @@ class IvfIndexNode : public IndexNode { return false; } if constexpr (std::is_same::value) { - return index_->with_raw_data; + return index_->with_raw_data(); } if constexpr (std::is_same::value) { return false; @@ -607,7 +607,7 @@ IvfIndexNode::GetVectorByIds(const DataSet& dataset) const { } } else if constexpr (std::is_same::value) { // we should never go here since we should call HasRawData() first - if (!index_->with_raw_data) { + if (!index_->with_raw_data()) { return expected::Err(Status::not_implemented, "GetVectorByIds not implemented"); } auto dim = Dim(); diff --git a/thirdparty/faiss/faiss/IndexRefine.cpp b/thirdparty/faiss/faiss/IndexRefine.cpp index ec67798b9..412211460 100644 --- a/thirdparty/faiss/faiss/IndexRefine.cpp +++ b/thirdparty/faiss/faiss/IndexRefine.cpp @@ -43,23 +43,20 @@ IndexRefine::IndexRefine() void IndexRefine::train(idx_t n, const float* x) { base_index->train(n, x); - if (refine_index) - refine_index->train(n, x); + refine_index->train(n, x); is_trained = true; } void IndexRefine::add(idx_t n, const float* x) { FAISS_THROW_IF_NOT(is_trained); base_index->add(n, x); - if (refine_index) - refine_index->add(n, x); + refine_index->add(n, x); ntotal = base_index->ntotal; } void IndexRefine::reset() { base_index->reset(); - if (refine_index) - refine_index->reset(); + refine_index->reset(); ntotal = 0; } @@ -100,9 +97,6 @@ void IndexRefine::search( float* distances, idx_t* labels, const BitsetView bitset) const { - FAISS_THROW_IF_NOT(base_index); - FAISS_THROW_IF_NOT(refine_index); - FAISS_THROW_IF_NOT(k > 0); FAISS_THROW_IF_NOT(is_trained); @@ -159,19 +153,14 @@ void IndexRefine::search( } void IndexRefine::reconstruct(idx_t key, float* recons) const { - FAISS_THROW_IF_NOT(refine_index); refine_index->reconstruct(key, recons); } size_t IndexRefine::sa_code_size() const { - FAISS_THROW_IF_NOT(base_index); - FAISS_THROW_IF_NOT(refine_index); return base_index->sa_code_size() + refine_index->sa_code_size(); } void IndexRefine::sa_encode(idx_t n, const float* x, uint8_t* bytes) const { - FAISS_THROW_IF_NOT(base_index); - FAISS_THROW_IF_NOT(refine_index); size_t cs1 = base_index->sa_code_size(), cs2 = refine_index->sa_code_size(); std::unique_ptr tmp1(new uint8_t[n * cs1]); base_index->sa_encode(n, x, tmp1.get()); @@ -185,8 +174,6 @@ void IndexRefine::sa_encode(idx_t n, const float* x, uint8_t* bytes) const { } void IndexRefine::sa_decode(idx_t n, const uint8_t* bytes, float* x) const { - FAISS_THROW_IF_NOT(base_index); - FAISS_THROW_IF_NOT(refine_index); size_t cs1 = base_index->sa_code_size(), cs2 = refine_index->sa_code_size(); std::unique_ptr tmp2( new uint8_t[n * refine_index->sa_code_size()]); @@ -221,14 +208,10 @@ IndexRefineFlat::IndexRefineFlat(Index* base_index) IndexRefineFlat::IndexRefineFlat(Index* base_index, const float* xb) : IndexRefine(base_index, nullptr) { - is_trained = base_index->is_trained; - if (xb) { - refine_index = new IndexFlat(base_index->d, base_index->metric_type); - with_raw_data = true; - } else { - with_raw_data = false; - } + is_trained = base_index->is_trained; + refine_index = new IndexFlat(base_index->d, base_index->metric_type); own_refine_index = true; + refine_index->add(base_index->ntotal, xb); } IndexRefineFlat::IndexRefineFlat() : IndexRefine() { @@ -242,9 +225,6 @@ void IndexRefineFlat::search( float* distances, idx_t* labels, const BitsetView bitset) const { - FAISS_THROW_IF_NOT(base_index); - FAISS_THROW_IF_NOT(refine_index); - FAISS_THROW_IF_NOT(k > 0); FAISS_THROW_IF_NOT(is_trained); diff --git a/thirdparty/faiss/faiss/IndexRefine.h b/thirdparty/faiss/faiss/IndexRefine.h index 11be3ab9d..218106030 100644 --- a/thirdparty/faiss/faiss/IndexRefine.h +++ b/thirdparty/faiss/faiss/IndexRefine.h @@ -28,8 +28,6 @@ struct IndexRefine : Index { /// the base_index (should be >= 1) float k_factor = 1; - bool with_raw_data; - /// initialize from empty index IndexRefine(Index* base_index, Index* refine_index); diff --git a/thirdparty/faiss/faiss/IndexScaNN.cpp b/thirdparty/faiss/faiss/IndexScaNN.cpp index 5d4ffc54a..f2ec79b5d 100644 --- a/thirdparty/faiss/faiss/IndexScaNN.cpp +++ b/thirdparty/faiss/faiss/IndexScaNN.cpp @@ -14,12 +14,49 @@ namespace faiss { * IndexScaNN ***************************************************/ -IndexScaNN::IndexScaNN(Index* base_index) : IndexRefineFlat(base_index) {} +IndexScaNN::IndexScaNN(Index* base_index) + : IndexRefine( + base_index, + new IndexFlat(base_index->d, base_index->metric_type)) { + is_trained = base_index->is_trained; + own_refine_index = true; + FAISS_THROW_IF_NOT_MSG( + base_index->ntotal == 0, + "base_index should be empty in the beginning"); +} IndexScaNN::IndexScaNN(Index* base_index, const float* xb) - : IndexRefineFlat(base_index, xb) {} + : IndexRefine(base_index, nullptr) { + is_trained = base_index->is_trained; + if (xb) { + refine_index = new IndexFlat(base_index->d, base_index->metric_type); + } + own_refine_index = true; +} + +IndexScaNN::IndexScaNN() : IndexRefine() {} + +void IndexScaNN::train(idx_t n, const float* x) { + base_index->train(n, x); + if (refine_index) + refine_index->train(n, x); + is_trained = true; +} -IndexScaNN::IndexScaNN() : IndexRefineFlat() {} +void IndexScaNN::add(idx_t n, const float* x) { + FAISS_THROW_IF_NOT(is_trained); + base_index->add(n, x); + if (refine_index) + refine_index->add(n, x); + ntotal = base_index->ntotal; +} + +void IndexScaNN::reset() { + base_index->reset(); + if (refine_index) + refine_index->reset(); + ntotal = 0; +} namespace { diff --git a/thirdparty/faiss/faiss/IndexScaNN.h b/thirdparty/faiss/faiss/IndexScaNN.h index 2dc6ef909..3727cba61 100644 --- a/thirdparty/faiss/faiss/IndexScaNN.h +++ b/thirdparty/faiss/faiss/IndexScaNN.h @@ -5,12 +5,22 @@ namespace faiss { -struct IndexScaNN : IndexRefineFlat { +struct IndexScaNN : IndexRefine { explicit IndexScaNN(Index* base_index); IndexScaNN(Index* base_index, const float* xb); IndexScaNN(); + void train(idx_t n, const float* x) override; + + void add(idx_t n, const float* x) override; + + void reset() override; + + inline bool with_raw_data() const { + return (refine_index != nullptr); + } + int64_t size(); void search_thread_safe( diff --git a/thirdparty/faiss/faiss/impl/index_read.cpp b/thirdparty/faiss/faiss/impl/index_read.cpp index 7c55ba751..0afd12ac1 100644 --- a/thirdparty/faiss/faiss/impl/index_read.cpp +++ b/thirdparty/faiss/faiss/impl/index_read.cpp @@ -83,25 +83,6 @@ static void read_index_header(Index* idx, IOReader* f) { idx->verbose = false; } -static void read_scann_header(IndexRefine* idx, IOReader* f) { - READ1(idx->d); - READ1(idx->ntotal); - Index::idx_t dummy; - READ1(dummy); - READ1(dummy); - if (dummy == (1 << 20)) { // for compatibility, old scann binary always contains raw data - idx->with_raw_data = true; - } else { - idx->with_raw_data = (dummy == 1); - } - READ1(idx->is_trained); - READ1(idx->metric_type); - if (idx->metric_type > 1) { - READ1(idx->metric_arg); - } - idx->verbose = false; -} - VectorTransform* read_VectorTransform(IOReader* f) { uint32_t h; READ1(h); @@ -879,31 +860,34 @@ Index* read_index(IOReader* f, int io_flags) { read_index_header(imiq, f); read_ProductQuantizer(&imiq->pq, f); idx = imiq; + } else if (h == fourcc("IxSC")) { + IndexScaNN* idxscann = new IndexScaNN(); + read_index_header(idxscann, f); + idxscann->base_index = read_index(f, io_flags); + bool with_raw_data; + READ1(with_raw_data); + if (with_raw_data) { + idxscann->refine_index = read_index(f, io_flags); + } else { + idxscann->refine_index = nullptr; + } + READ1(idxscann->k_factor); + idxscann->own_fields = true; + idxscann->own_refine_index = true; + idx = idxscann; } else if (h == fourcc("IxRF")) { IndexRefine* idxrf = new IndexRefine(); - read_scann_header(idxrf, f); + read_index_header(idxrf, f); idxrf->base_index = read_index(f, io_flags); - if (idxrf->with_raw_data) { - idxrf->refine_index = read_index(f, io_flags); - if (dynamic_cast(idxrf->refine_index)) { - if (dynamic_cast(idxrf->base_index)) { - // this is IndexScaNN - IndexRefine* idxrf_old = idxrf; - idxrf = new IndexScaNN(); - *idxrf = *idxrf_old; - delete idxrf_old; - } else { - // then make a RefineFlat with it - IndexRefine* idxrf_old = idxrf; - idxrf = new IndexRefineFlat(); - *idxrf = *idxrf_old; - delete idxrf_old; - } - } - } else { - idxrf->refine_index = nullptr; - } + idxrf->refine_index = read_index(f, io_flags); READ1(idxrf->k_factor); + if (dynamic_cast(idxrf->refine_index)) { + // then make a RefineFlat with it + IndexRefine* idxrf_old = idxrf; + idxrf = new IndexRefineFlat(); + *idxrf = *idxrf_old; + delete idxrf_old; + } idxrf->own_fields = true; idxrf->own_refine_index = true; idx = idxrf; diff --git a/thirdparty/faiss/faiss/impl/index_write.cpp b/thirdparty/faiss/faiss/impl/index_write.cpp index 5604ffa7e..fb69bdd41 100644 --- a/thirdparty/faiss/faiss/impl/index_write.cpp +++ b/thirdparty/faiss/faiss/impl/index_write.cpp @@ -45,6 +45,7 @@ #include #include #include +#include #include #include @@ -99,20 +100,6 @@ static void write_index_header(const Index* idx, IOWriter* f) { } } -static void write_scann_header(const IndexRefine* idx, IOWriter* f) { - WRITE1(idx->d); - WRITE1(idx->ntotal); - Index::idx_t dummy = 1 << 20; - WRITE1(dummy); - dummy = static_cast(idx->with_raw_data); - WRITE1(dummy); - WRITE1(idx->is_trained); - WRITE1(idx->metric_type); - if (idx->metric_type > 1) { - WRITE1(idx->metric_arg); - } -} - void write_VectorTransform(const VectorTransform* vt, IOWriter* f) { if (const LinearTransform* lt = dynamic_cast(vt)) { if (dynamic_cast(lt)) { @@ -702,14 +689,24 @@ void write_index(const Index* idx, IOWriter* f) { WRITE1(h); write_index_header(imiq, f); write_ProductQuantizer(&imiq->pq, f); + } else if ( + const IndexScaNN* idxscann = dynamic_cast(idx)) { + uint32_t h = fourcc("IxSC"); + WRITE1(h); + write_index_header(idxscann, f); + write_index(idxscann->base_index, f); + bool with_raw_data = idxscann->with_raw_data(); + WRITE1(with_raw_data); + if (with_raw_data) + write_index(idxscann->refine_index, f); + WRITE1(idxscann->k_factor); } else if ( const IndexRefine* idxrf = dynamic_cast(idx)) { uint32_t h = fourcc("IxRF"); WRITE1(h); - write_scann_header(idxrf, f); + write_index_header(idxrf, f); write_index(idxrf->base_index, f); - if (idxrf->with_raw_data) - write_index(idxrf->refine_index, f); + write_index(idxrf->refine_index, f); WRITE1(idxrf->k_factor); } else if ( const IndexIDMap* idxmap = dynamic_cast(idx)) {