Skip to content

Commit

Permalink
Add new ser/deser for scann index (#156)
Browse files Browse the repository at this point in the history
Signed-off-by: chasingegg <[email protected]>
  • Loading branch information
chasingegg authored Oct 19, 2023
1 parent 84ab6ab commit 0548a61
Show file tree
Hide file tree
Showing 7 changed files with 97 additions and 91 deletions.
4 changes: 2 additions & 2 deletions src/index/ivf/ivf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ class IvfIndexNode : public IndexNode {
return false;
}
if constexpr (std::is_same<faiss::IndexScaNN, T>::value) {
return index_->with_raw_data;
return index_->with_raw_data();
}
if constexpr (std::is_same<faiss::IndexIVFScalarQuantizer, T>::value) {
return false;
Expand Down Expand Up @@ -607,7 +607,7 @@ IvfIndexNode<T>::GetVectorByIds(const DataSet& dataset) const {
}
} else if constexpr (std::is_same<T, faiss::IndexScaNN>::value) {
// we should never go here since we should call HasRawData() first
if (!index_->with_raw_data) {
if (!index_->with_raw_data()) {
return expected<DataSetPtr>::Err(Status::not_implemented, "GetVectorByIds not implemented");
}
auto dim = Dim();
Expand Down
32 changes: 6 additions & 26 deletions thirdparty/faiss/faiss/IndexRefine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,23 +43,20 @@ IndexRefine::IndexRefine()

void IndexRefine::train(idx_t n, const float* x) {
base_index->train(n, x);
if (refine_index)
refine_index->train(n, x);
refine_index->train(n, x);
is_trained = true;
}

void IndexRefine::add(idx_t n, const float* x) {
FAISS_THROW_IF_NOT(is_trained);
base_index->add(n, x);
if (refine_index)
refine_index->add(n, x);
refine_index->add(n, x);
ntotal = base_index->ntotal;
}

void IndexRefine::reset() {
base_index->reset();
if (refine_index)
refine_index->reset();
refine_index->reset();
ntotal = 0;
}

Expand Down Expand Up @@ -100,9 +97,6 @@ void IndexRefine::search(
float* distances,
idx_t* labels,
const BitsetView bitset) const {
FAISS_THROW_IF_NOT(base_index);
FAISS_THROW_IF_NOT(refine_index);

FAISS_THROW_IF_NOT(k > 0);

FAISS_THROW_IF_NOT(is_trained);
Expand Down Expand Up @@ -159,19 +153,14 @@ void IndexRefine::search(
}

void IndexRefine::reconstruct(idx_t key, float* recons) const {
FAISS_THROW_IF_NOT(refine_index);
refine_index->reconstruct(key, recons);
}

size_t IndexRefine::sa_code_size() const {
FAISS_THROW_IF_NOT(base_index);
FAISS_THROW_IF_NOT(refine_index);
return base_index->sa_code_size() + refine_index->sa_code_size();
}

void IndexRefine::sa_encode(idx_t n, const float* x, uint8_t* bytes) const {
FAISS_THROW_IF_NOT(base_index);
FAISS_THROW_IF_NOT(refine_index);
size_t cs1 = base_index->sa_code_size(), cs2 = refine_index->sa_code_size();
std::unique_ptr<uint8_t[]> tmp1(new uint8_t[n * cs1]);
base_index->sa_encode(n, x, tmp1.get());
Expand All @@ -185,8 +174,6 @@ void IndexRefine::sa_encode(idx_t n, const float* x, uint8_t* bytes) const {
}

void IndexRefine::sa_decode(idx_t n, const uint8_t* bytes, float* x) const {
FAISS_THROW_IF_NOT(base_index);
FAISS_THROW_IF_NOT(refine_index);
size_t cs1 = base_index->sa_code_size(), cs2 = refine_index->sa_code_size();
std::unique_ptr<uint8_t[]> tmp2(
new uint8_t[n * refine_index->sa_code_size()]);
Expand Down Expand Up @@ -221,14 +208,10 @@ IndexRefineFlat::IndexRefineFlat(Index* base_index)

IndexRefineFlat::IndexRefineFlat(Index* base_index, const float* xb)
: IndexRefine(base_index, nullptr) {
is_trained = base_index->is_trained;
if (xb) {
refine_index = new IndexFlat(base_index->d, base_index->metric_type);
with_raw_data = true;
} else {
with_raw_data = false;
}
is_trained = base_index->is_trained;
refine_index = new IndexFlat(base_index->d, base_index->metric_type);
own_refine_index = true;
refine_index->add(base_index->ntotal, xb);
}

IndexRefineFlat::IndexRefineFlat() : IndexRefine() {
Expand All @@ -242,9 +225,6 @@ void IndexRefineFlat::search(
float* distances,
idx_t* labels,
const BitsetView bitset) const {
FAISS_THROW_IF_NOT(base_index);
FAISS_THROW_IF_NOT(refine_index);

FAISS_THROW_IF_NOT(k > 0);

FAISS_THROW_IF_NOT(is_trained);
Expand Down
2 changes: 0 additions & 2 deletions thirdparty/faiss/faiss/IndexRefine.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@ struct IndexRefine : Index {
/// the base_index (should be >= 1)
float k_factor = 1;

bool with_raw_data;

/// initialize from empty index
IndexRefine(Index* base_index, Index* refine_index);

Expand Down
43 changes: 40 additions & 3 deletions thirdparty/faiss/faiss/IndexScaNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,49 @@ namespace faiss {
* IndexScaNN
***************************************************/

IndexScaNN::IndexScaNN(Index* base_index) : IndexRefineFlat(base_index) {}
IndexScaNN::IndexScaNN(Index* base_index)
: IndexRefine(
base_index,
new IndexFlat(base_index->d, base_index->metric_type)) {
is_trained = base_index->is_trained;
own_refine_index = true;
FAISS_THROW_IF_NOT_MSG(
base_index->ntotal == 0,
"base_index should be empty in the beginning");
}

IndexScaNN::IndexScaNN(Index* base_index, const float* xb)
: IndexRefineFlat(base_index, xb) {}
: IndexRefine(base_index, nullptr) {
is_trained = base_index->is_trained;
if (xb) {
refine_index = new IndexFlat(base_index->d, base_index->metric_type);
}
own_refine_index = true;
}

IndexScaNN::IndexScaNN() : IndexRefine() {}

void IndexScaNN::train(idx_t n, const float* x) {
base_index->train(n, x);
if (refine_index)
refine_index->train(n, x);
is_trained = true;
}

IndexScaNN::IndexScaNN() : IndexRefineFlat() {}
void IndexScaNN::add(idx_t n, const float* x) {
FAISS_THROW_IF_NOT(is_trained);
base_index->add(n, x);
if (refine_index)
refine_index->add(n, x);
ntotal = base_index->ntotal;
}

void IndexScaNN::reset() {
base_index->reset();
if (refine_index)
refine_index->reset();
ntotal = 0;
}

namespace {

Expand Down
12 changes: 11 additions & 1 deletion thirdparty/faiss/faiss/IndexScaNN.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,22 @@

namespace faiss {

struct IndexScaNN : IndexRefineFlat {
struct IndexScaNN : IndexRefine {
explicit IndexScaNN(Index* base_index);
IndexScaNN(Index* base_index, const float* xb);

IndexScaNN();

void train(idx_t n, const float* x) override;

void add(idx_t n, const float* x) override;

void reset() override;

inline bool with_raw_data() const {
return (refine_index != nullptr);
}

int64_t size();

void search_thread_safe(
Expand Down
64 changes: 24 additions & 40 deletions thirdparty/faiss/faiss/impl/index_read.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,25 +83,6 @@ static void read_index_header(Index* idx, IOReader* f) {
idx->verbose = false;
}

static void read_scann_header(IndexRefine* idx, IOReader* f) {
READ1(idx->d);
READ1(idx->ntotal);
Index::idx_t dummy;
READ1(dummy);
READ1(dummy);
if (dummy == (1 << 20)) { // for compatibility, old scann binary always contains raw data
idx->with_raw_data = true;
} else {
idx->with_raw_data = (dummy == 1);
}
READ1(idx->is_trained);
READ1(idx->metric_type);
if (idx->metric_type > 1) {
READ1(idx->metric_arg);
}
idx->verbose = false;
}

VectorTransform* read_VectorTransform(IOReader* f) {
uint32_t h;
READ1(h);
Expand Down Expand Up @@ -879,31 +860,34 @@ Index* read_index(IOReader* f, int io_flags) {
read_index_header(imiq, f);
read_ProductQuantizer(&imiq->pq, f);
idx = imiq;
} else if (h == fourcc("IxSC")) {
IndexScaNN* idxscann = new IndexScaNN();
read_index_header(idxscann, f);
idxscann->base_index = read_index(f, io_flags);
bool with_raw_data;
READ1(with_raw_data);
if (with_raw_data) {
idxscann->refine_index = read_index(f, io_flags);
} else {
idxscann->refine_index = nullptr;
}
READ1(idxscann->k_factor);
idxscann->own_fields = true;
idxscann->own_refine_index = true;
idx = idxscann;
} else if (h == fourcc("IxRF")) {
IndexRefine* idxrf = new IndexRefine();
read_scann_header(idxrf, f);
read_index_header(idxrf, f);
idxrf->base_index = read_index(f, io_flags);
if (idxrf->with_raw_data) {
idxrf->refine_index = read_index(f, io_flags);
if (dynamic_cast<IndexFlat*>(idxrf->refine_index)) {
if (dynamic_cast<IndexIVFPQFastScan*>(idxrf->base_index)) {
// this is IndexScaNN
IndexRefine* idxrf_old = idxrf;
idxrf = new IndexScaNN();
*idxrf = *idxrf_old;
delete idxrf_old;
} else {
// then make a RefineFlat with it
IndexRefine* idxrf_old = idxrf;
idxrf = new IndexRefineFlat();
*idxrf = *idxrf_old;
delete idxrf_old;
}
}
} else {
idxrf->refine_index = nullptr;
}
idxrf->refine_index = read_index(f, io_flags);
READ1(idxrf->k_factor);
if (dynamic_cast<IndexFlat*>(idxrf->refine_index)) {
// then make a RefineFlat with it
IndexRefine* idxrf_old = idxrf;
idxrf = new IndexRefineFlat();
*idxrf = *idxrf_old;
delete idxrf_old;
}
idxrf->own_fields = true;
idxrf->own_refine_index = true;
idx = idxrf;
Expand Down
31 changes: 14 additions & 17 deletions thirdparty/faiss/faiss/impl/index_write.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
#include <faiss/IndexPreTransform.h>
#include <faiss/IndexRefine.h>
#include <faiss/IndexScalarQuantizer.h>
#include <faiss/IndexScaNN.h>
#include <faiss/MetaIndexes.h>
#include <faiss/VectorTransform.h>

Expand Down Expand Up @@ -99,20 +100,6 @@ static void write_index_header(const Index* idx, IOWriter* f) {
}
}

static void write_scann_header(const IndexRefine* idx, IOWriter* f) {
WRITE1(idx->d);
WRITE1(idx->ntotal);
Index::idx_t dummy = 1 << 20;
WRITE1(dummy);
dummy = static_cast<Index::idx_t>(idx->with_raw_data);
WRITE1(dummy);
WRITE1(idx->is_trained);
WRITE1(idx->metric_type);
if (idx->metric_type > 1) {
WRITE1(idx->metric_arg);
}
}

void write_VectorTransform(const VectorTransform* vt, IOWriter* f) {
if (const LinearTransform* lt = dynamic_cast<const LinearTransform*>(vt)) {
if (dynamic_cast<const RandomRotationMatrix*>(lt)) {
Expand Down Expand Up @@ -702,14 +689,24 @@ void write_index(const Index* idx, IOWriter* f) {
WRITE1(h);
write_index_header(imiq, f);
write_ProductQuantizer(&imiq->pq, f);
} else if (
const IndexScaNN* idxscann = dynamic_cast<const IndexScaNN*>(idx)) {
uint32_t h = fourcc("IxSC");
WRITE1(h);
write_index_header(idxscann, f);
write_index(idxscann->base_index, f);
bool with_raw_data = idxscann->with_raw_data();
WRITE1(with_raw_data);
if (with_raw_data)
write_index(idxscann->refine_index, f);
WRITE1(idxscann->k_factor);
} else if (
const IndexRefine* idxrf = dynamic_cast<const IndexRefine*>(idx)) {
uint32_t h = fourcc("IxRF");
WRITE1(h);
write_scann_header(idxrf, f);
write_index_header(idxrf, f);
write_index(idxrf->base_index, f);
if (idxrf->with_raw_data)
write_index(idxrf->refine_index, f);
write_index(idxrf->refine_index, f);
WRITE1(idxrf->k_factor);
} else if (
const IndexIDMap* idxmap = dynamic_cast<const IndexIDMap*>(idx)) {
Expand Down

0 comments on commit 0548a61

Please sign in to comment.