diff --git a/src/index/flat/flat.cc b/src/index/flat/flat.cc index e7ef2af77..3ccb61c47 100644 --- a/src/index/flat/flat.cc +++ b/src/index/flat/flat.cc @@ -304,6 +304,7 @@ class FlatIndexNode : public IndexNode { if constexpr (std::is_same<T, faiss::IndexBinaryFlat>::value) { faiss::write_index_binary(index_.get(), &writer); } + writer.addtailer(Type(), this->version_); std::shared_ptr<uint8_t[]> data(writer.data()); binset.Append(Type(), data, writer.tellg()); return Status::success; diff --git a/src/index/hnsw/hnsw.cc b/src/index/hnsw/hnsw.cc index 9dc1d8255..615b0e038 100644 --- a/src/index/hnsw/hnsw.cc +++ b/src/index/hnsw/hnsw.cc @@ -389,6 +389,7 @@ class HnswIndexNode : public IndexNode { try { MemoryIOWriter writer; index_->saveIndex(writer); + writer.addtailer(Type(), this->version_); std::shared_ptr<uint8_t[]> data(writer.data()); binset.Append(Type(), data, writer.tellg()); } catch (std::exception& e) { diff --git a/src/index/ivf/ivf.cc b/src/index/ivf/ivf.cc index 53c8276de..a70ec2dd4 100644 --- a/src/index/ivf/ivf.cc +++ b/src/index/ivf/ivf.cc @@ -824,6 +824,7 @@ IvfIndexNode<T>::Serialize(BinarySet& binset) const { } else { faiss::write_index(index_.get(), &writer); } + writer.addtailer(Type(), this->version_); std::shared_ptr<uint8_t[]> data(writer.data()); binset.Append(Type(), data, writer.tellg()); return Status::success; @@ -846,6 +847,7 @@ IvfIndexNode<faiss::IndexIVFFlat>::Serialize(BinarySet& binset) const { faiss::write_index(index_.get(), &writer); LOG_KNOWHERE_INFO_ << "write IVF_FLAT, file size " << writer.tellg(); } + writer.addtailer(Type(), this->version_); std::shared_ptr<uint8_t[]> index_data_ptr(writer.data()); binset.Append(Type(), index_data_ptr, writer.tellg()); diff --git a/src/io/memory_io.cc b/src/io/memory_io.cc index d84acc922..cd65fb377 100644 --- a/src/io/memory_io.cc +++ b/src/io/memory_io.cc @@ -13,6 +13,17 @@ #include <cstring> +namespace { +uint32_t +CalculateCheckSum(const uint8_t* data, int64_t size) { + uint32_t checksum = 0; + for (auto i = 0; i < size; i++) { + checksum ^= data[i]; // xor + } + return checksum; +} +} // namespace + namespace knowhere { // TODO(linxj): Get From Config File @@ -47,6 +58,16 @@ MemoryIOWriter::operator()(const void* ptr, size_t size, size_t nitems) { return nitems; } +void +MemoryIOWriter::addtailer(const std::string& index_name, const Version& version) { + auto tailer_ptr = std::make_unique<Tailer>(); + tailer_ptr->SetIndexBinarySize(rp_); + tailer_ptr->SetCheckSum(CalculateCheckSum(data_, rp_)); + tailer_ptr->SetVersion(version.VersionNumber()); + tailer_ptr->SetIndexName(index_name); + write(tailer_ptr->bytes, KNOWHERE_TAILER_SIZE); +} + size_t MemoryIOReader::operator()(void* ptr, size_t size, size_t nitems) { if (rp_ >= total_) { @@ -61,4 +82,44 @@ MemoryIOReader::operator()(void* ptr, size_t size, size_t nitems) { return nitems; } +bool +MemoryIOReader::isvalid(const std::string& index_name) { + uint64_t bin_size = TAILER_OFFSET(total_); + if (bin_size < 0) { + LOG_KNOWHERE_WARNING_ << "The binary is too small and assume no tailer, pass tailer check."; + return true; + } + + auto tailer_ptr = std::make_unique<Tailer>(); + auto pre_rp = rp_; + rp_ = bin_size; + read(tailer_ptr.get(), KNOWHERE_TAILER_SIZE); + rp_ = pre_rp; + if (!tailer_ptr->TailerValidCheck()) { + LOG_KNOWHERE_WARNING_ << "Tailer not exist in Binary."; + return true; + } + + auto version = Version(tailer_ptr->GetVersion()); + if (!Version::VersionSupport(version)) { + LOG_KNOWHERE_ERROR_ << "Binary version(" << version.VersionNumber() << ") is not supported, pass tailer check."; + return false; + } + + if (tailer_ptr->GetIndexName() != index_name) { + LOG_KNOWHERE_ERROR_ << "Index type or data type is not correct(" << index_name << ")."; + return false; + } + + if (tailer_ptr->GetIndexBinarySize() != bin_size) { + LOG_KNOWHERE_ERROR_ << "The size of index binary is not correct."; + return false; + } + auto bin = this->data(); + if (CalculateCheckSum(bin, bin_size) != tailer_ptr->GetCheckSum()) { + LOG_KNOWHERE_ERROR_ << "Binary checksum check fail."; + return false; + } + return true; +} } // namespace knowhere diff --git a/src/io/memory_io.h b/src/io/memory_io.h index bc74092e7..ab0b07681 100644 --- a/src/io/memory_io.h +++ b/src/io/memory_io.h @@ -10,9 +10,10 @@ // or implied. See the License for the specific language governing permissions and limitations under the License #pragma once - #include <faiss/impl/io.h> +#include "io/tailer.h" + namespace knowhere { #if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ @@ -71,6 +72,7 @@ getSwappedBytes(char C) { #endif +// MemoryIOwriter and MemoryIOreader it not thread safe struct MemoryIOWriter : public faiss::IOWriter { uint8_t* data_ = nullptr; size_t total_ = 0; @@ -100,6 +102,9 @@ struct MemoryIOWriter : public faiss::IOWriter { tellg() const { return rp_; } + + void + addtailer(const std::string& index_name, const Version& version); }; struct MemoryIOReader : public faiss::IOReader { @@ -140,6 +145,14 @@ struct MemoryIOReader : public faiss::IOReader { reset() { rp_ = 0; } + + void + seekg(const size_t offset) { + rp_ = offset; + } + + bool + isvalid(const std::string& index_name); }; } // namespace knowhere diff --git a/src/io/tailer.h b/src/io/tailer.h new file mode 100644 index 000000000..aa805b263 --- /dev/null +++ b/src/io/tailer.h @@ -0,0 +1,77 @@ +// Copyright (C) 2019-2023 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License + +#pragma once +#include "knowhere/version.h" +#include "memory_io.h" + +#define KNOWHERE_TAILER_SIZE 512 // bytes +#define MAX_INDEX_NAME_SIZE 100 // bytes +#define TAILER_VALID_FLAG UINT32_C(0x0F0F0F0F) +#define TAILER_OFFSET(size) size - KNOWHERE_TAILER_SIZE + +namespace knowhere { +union Tailer { + uint8_t bytes[KNOWHERE_TAILER_SIZE]; + struct Meta { + uint32_t flag; + uint64_t bin_size; + IndexVersion version; + uint32_t checksum; + char index_name[MAX_INDEX_NAME_SIZE + 1]; + } meta; + Tailer() { + meta.flag = TAILER_VALID_FLAG; + }; + bool + TailerValidCheck() { + return meta.flag == TAILER_VALID_FLAG; + } + std::string + GetIndexName() { + return std::string(meta.index_name); + } + int32_t + GetVersion() { + return meta.version; + } + uint64_t + GetIndexBinarySize() { + return meta.bin_size; + } + uint32_t + GetCheckSum() { + return meta.checksum; + } + void + SetIndexName(std::string index_name) { + if (index_name.size() > MAX_INDEX_NAME_SIZE) { + LOG_KNOWHERE_ERROR_ << "the size of index name larger than " << MAX_INDEX_NAME_SIZE; + } else { + memcpy((char*)meta.index_name, index_name.data(), index_name.size()); + meta.index_name[index_name.size()] = '\0'; + } + } + void + SetCheckSum(uint32_t value) { + meta.checksum = value; + } + void + SetIndexBinarySize(uint64_t s) { + meta.bin_size = s; + } + void + SetVersion(int32_t v) { + meta.version = v; + } +}; +using TailerPtr = std::unique_ptr<Tailer>; +} // namespace knowhere