Skip to content

Commit

Permalink
faiss_hnsw support INT8 (#991)
Browse files Browse the repository at this point in the history
Signed-off-by: Cai Yudong <[email protected]>
  • Loading branch information
cydrain authored Dec 19, 2024
1 parent c90443c commit ca4ba32
Show file tree
Hide file tree
Showing 12 changed files with 200 additions and 42 deletions.
1 change: 1 addition & 0 deletions include/knowhere/comp/index_param.h
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ enum VecType {
VECTOR_FLOAT16 = 102,
VECTOR_BFLOAT16 = 103,
VECTOR_SPARSE_FLOAT = 104,
VECTOR_INT8 = 105,
}; // keep the same value as milvus proto define

} // namespace knowhere
6 changes: 4 additions & 2 deletions include/knowhere/feature.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ constexpr uint64_t FP16 = 1UL << 2;
constexpr uint64_t BF16 = 1UL << 3;
// vector datatype support : sparse_float32
constexpr uint64_t SPARSE_FLOAT32 = 1UL << 4;
// vector datatype support : int8
constexpr uint64_t INT8 = 1UL << 5;

// This flag indicates that there is no need to create any index structure (build stage can be skipped)
constexpr uint64_t NO_TRAIN = 1UL << 16;
Expand All @@ -45,8 +47,8 @@ constexpr uint64_t DISK = 1UL << 21;

constexpr uint64_t NONE = 0UL;

constexpr uint64_t ALL_TYPE = BINARY | FLOAT32 | FP16 | BF16 | SPARSE_FLOAT32;
constexpr uint64_t ALL_DENSE_TYPE = BINARY | FLOAT32 | FP16 | BF16;
constexpr uint64_t ALL_TYPE = BINARY | FLOAT32 | FP16 | BF16 | SPARSE_FLOAT32 | INT8;
constexpr uint64_t ALL_DENSE_TYPE = BINARY | FLOAT32 | FP16 | BF16 | INT8;
constexpr uint64_t ALL_DENSE_FLOAT_TYPE = FLOAT32 | FP16 | BF16;

constexpr uint64_t NO_TRAIN_INDEX = NO_TRAIN;
Expand Down
3 changes: 3 additions & 0 deletions include/knowhere/index/index_factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,9 @@ class IndexFactory {
// register vector index supporting binary data type
#define KNOWHERE_SIMPLE_REGISTER_DENSE_BIN_GLOBAL(name, index_node, features, ...) \
KNOWHERE_SIMPLE_REGISTER_GLOBAL(name, index_node, bin1, (features | knowhere::feature::BINARY), ##__VA_ARGS__);
// register vector index supporting int8 data type
#define KNOWHERE_SIMPLE_REGISTER_DENSE_INT8_GLOBAL(name, index_node, features, ...) \
KNOWHERE_SIMPLE_REGISTER_GLOBAL(name, index_node, int8, (features | knowhere::feature::INT8), ##__VA_ARGS__);

// register vector index supporting ALL_DENSE_FLOAT_TYPE(float32, bf16, fp16) data types
#define KNOWHERE_SIMPLE_REGISTER_DENSE_FLOAT_ALL_GLOBAL(name, index_node, features, ...) \
Expand Down
4 changes: 4 additions & 0 deletions include/knowhere/index/index_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,18 +61,22 @@ static std::set<std::pair<std::string, VecType>> legal_knowhere_index = {
{IndexEnum::INDEX_HNSW, VecType::VECTOR_FLOAT},
{IndexEnum::INDEX_HNSW, VecType::VECTOR_FLOAT16},
{IndexEnum::INDEX_HNSW, VecType::VECTOR_BFLOAT16},
{IndexEnum::INDEX_HNSW, VecType::VECTOR_INT8},

{IndexEnum::INDEX_HNSW_SQ, VecType::VECTOR_FLOAT},
{IndexEnum::INDEX_HNSW_SQ, VecType::VECTOR_FLOAT16},
{IndexEnum::INDEX_HNSW_SQ, VecType::VECTOR_BFLOAT16},
{IndexEnum::INDEX_HNSW_SQ, VecType::VECTOR_INT8},

{IndexEnum::INDEX_HNSW_PQ, VecType::VECTOR_FLOAT},
{IndexEnum::INDEX_HNSW_PQ, VecType::VECTOR_FLOAT16},
{IndexEnum::INDEX_HNSW_PQ, VecType::VECTOR_BFLOAT16},
{IndexEnum::INDEX_HNSW_PQ, VecType::VECTOR_INT8},

{IndexEnum::INDEX_HNSW_PRQ, VecType::VECTOR_FLOAT},
{IndexEnum::INDEX_HNSW_PRQ, VecType::VECTOR_FLOAT16},
{IndexEnum::INDEX_HNSW_PRQ, VecType::VECTOR_BFLOAT16},
{IndexEnum::INDEX_HNSW_PRQ, VecType::VECTOR_INT8},

// diskann
{IndexEnum::INDEX_DISKANN, VecType::VECTOR_FLOAT},
Expand Down
11 changes: 10 additions & 1 deletion include/knowhere/operands.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ fp32_to_bits(const float& f) {

namespace knowhere {
using fp32 = float;
using int8 = int8_t;
using bin1 = uint8_t;

struct fp16 {
Expand Down Expand Up @@ -161,13 +162,16 @@ typeCheck(uint64_t features) {
if constexpr (std::is_same_v<T, fp32>) {
return (features & knowhere::feature::FLOAT32) || (features & knowhere::feature::SPARSE_FLOAT32);
}
if constexpr (std::is_same_v<T, int8>) {
return features & knowhere::feature::INT8;
}
return false;
}

template <typename InType, typename... Types>
using TypeMatch = std::bool_constant<(... | std::is_same_v<InType, Types>)>;
template <typename InType>
using KnowhereDataTypeCheck = TypeMatch<InType, bin1, fp16, fp32, bf16>;
using KnowhereDataTypeCheck = TypeMatch<InType, bin1, fp16, fp32, bf16, int8>;
template <typename InType>
using KnowhereFloatTypeCheck = TypeMatch<InType, fp16, fp32, bf16>;
template <typename InType>
Expand All @@ -187,5 +191,10 @@ template <>
struct MockData<knowhere::bf16> {
using type = knowhere::fp32;
};

template <>
struct MockData<knowhere::int8> {
using type = knowhere::fp32;
};
} // namespace knowhere
#endif /* OPERANDS_H */
2 changes: 2 additions & 0 deletions include/knowhere/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ GetKey(const std::string& name) {
return name + std::string("_bf16");
} else if (std::is_same_v<DataType, bin1>) {
return name + std::string("_bin1");
} else if (std::is_same_v<DataType, int8>) {
return name + std::string("_int8");
}
}

Expand Down
100 changes: 69 additions & 31 deletions src/index/hnsw/faiss_hnsw.cc
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ class BaseFaissRegularIndexNode : public BaseFaissIndexNode {
};

//
enum class DataFormatEnum { fp32, fp16, bf16 };
enum class DataFormatEnum { fp32, fp16, bf16, int8 };

template <typename T>
struct DataType2EnumHelper {};
Expand All @@ -309,14 +309,16 @@ template <>
struct DataType2EnumHelper<knowhere::bf16> {
static constexpr DataFormatEnum value = DataFormatEnum::bf16;
};
template <>
struct DataType2EnumHelper<knowhere::int8> {
static constexpr DataFormatEnum value = DataFormatEnum::int8;
};

template <typename T>
static constexpr DataFormatEnum datatype_v = DataType2EnumHelper<T>::value;

//
namespace {

//
bool
convert_rows_to_fp32(const void* const __restrict src_in, float* const __restrict dst,
const DataFormatEnum src_data_format, const size_t start_row, const size_t nrows,
Expand All @@ -326,21 +328,24 @@ convert_rows_to_fp32(const void* const __restrict src_in, float* const __restric
for (size_t i = 0; i < nrows * dim; i++) {
dst[i] = (float)(src[i + start_row * dim]);
}

return true;
} else if (src_data_format == DataFormatEnum::bf16) {
const knowhere::bf16* const src = reinterpret_cast<const knowhere::bf16*>(src_in);
for (size_t i = 0; i < nrows * dim; i++) {
dst[i] = (float)(src[i + start_row * dim]);
}

return true;
} else if (src_data_format == DataFormatEnum::fp32) {
const knowhere::fp32* const src = reinterpret_cast<const knowhere::fp32*>(src_in);
for (size_t i = 0; i < nrows * dim; i++) {
dst[i] = src[i + start_row * dim];
}

return true;
} else if (src_data_format == DataFormatEnum::int8) {
const knowhere::int8* const src = reinterpret_cast<const knowhere::int8*>(src_in);
for (size_t i = 0; i < nrows * dim; i++) {
dst[i] = (float)(src[i + start_row * dim]);
}
return true;
} else {
// unknown
Expand All @@ -357,21 +362,27 @@ convert_rows_from_fp32(const float* const __restrict src, void* const __restrict
for (size_t i = 0; i < nrows * dim; i++) {
dst[i + start_row * dim] = (knowhere::fp16)src[i];
}

return true;
} else if (dst_data_format == DataFormatEnum::bf16) {
knowhere::bf16* const dst = reinterpret_cast<knowhere::bf16*>(dst_in);
for (size_t i = 0; i < nrows * dim; i++) {
dst[i + start_row * dim] = (knowhere::bf16)src[i];
}

return true;
} else if (dst_data_format == DataFormatEnum::fp32) {
knowhere::fp32* const dst = reinterpret_cast<knowhere::fp32*>(dst_in);
for (size_t i = 0; i < nrows * dim; i++) {
dst[i + start_row * dim] = src[i];
}

return true;
} else if (dst_data_format == DataFormatEnum::int8) {
knowhere::int8* const dst = reinterpret_cast<knowhere::int8*>(dst_in);
for (size_t i = 0; i < nrows * dim; i++) {
KNOWHERE_THROW_IF_NOT_MSG(src[i] >= std::numeric_limits<knowhere::int8>::min() &&
src[i] <= std::numeric_limits<knowhere::int8>::max(),
"convert float to int8_t overflow");
dst[i + start_row * dim] = (knowhere::int8)src[i];
}
return true;
} else {
// unknown
Expand All @@ -388,8 +399,9 @@ convert_ds_to_float(const DataSetPtr& src, DataFormatEnum data_format) {
return ConvertFromDataTypeIfNeeded<knowhere::fp16>(src);
} else if (data_format == DataFormatEnum::bf16) {
return ConvertFromDataTypeIfNeeded<knowhere::bf16>(src);
} else if (data_format == DataFormatEnum::int8) {
return ConvertFromDataTypeIfNeeded<knowhere::int8>(src);
}

return nullptr;
}

Expand Down Expand Up @@ -451,6 +463,8 @@ get_index_data_format(const faiss::Index* index) {
return DataFormatEnum::bf16;
} else if (index_sq->sq.qtype == faiss::ScalarQuantizer::QT_fp16) {
return DataFormatEnum::fp16;
} else if (index_sq->sq.qtype == faiss::ScalarQuantizer::QT_8bit_direct_signed) {
return DataFormatEnum::int8;
} else {
return std::nullopt;
}
Expand Down Expand Up @@ -806,49 +820,53 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode {
if (data_format == DataFormatEnum::fp32) {
// perform a direct reconstruction for fp32 data
auto data = std::make_unique<float[]>(dim * rows);

for (int64_t i = 0; i < rows; i++) {
const int64_t id = ids[i];
assert(id >= 0 && id < index->ntotal);
index_to_reconstruct_from->reconstruct(id, data.get() + i * dim);
}

return GenResultDataSet(rows, dim, std::move(data));
} else if (data_format == DataFormatEnum::fp16) {
auto data = std::make_unique<knowhere::fp16[]>(dim * rows);

// faiss produces fp32 data format, we need some other format.
// Let's create a temporary fp32 buffer for this.
auto tmp = std::make_unique<float[]>(dim);

for (int64_t i = 0; i < rows; i++) {
const int64_t id = ids[i];
assert(id >= 0 && id < index->ntotal);
index_to_reconstruct_from->reconstruct(id, tmp.get());

if (!convert_rows_from_fp32(tmp.get(), data.get(), data_format, i, 1, dim)) {
return expected<DataSetPtr>::Err(Status::invalid_args, "Unsupported data format");
}
}

return GenResultDataSet(rows, dim, std::move(data));
} else if (data_format == DataFormatEnum::bf16) {
auto data = std::make_unique<knowhere::bf16[]>(dim * rows);

// faiss produces fp32 data format, we need some other format.
// Let's create a temporary fp32 buffer for this.
auto tmp = std::make_unique<float[]>(dim);

for (int64_t i = 0; i < rows; i++) {
const int64_t id = ids[i];
assert(id >= 0 && id < index->ntotal);
index_to_reconstruct_from->reconstruct(id, tmp.get());

if (!convert_rows_from_fp32(tmp.get(), data.get(), data_format, i, 1, dim)) {
return expected<DataSetPtr>::Err(Status::invalid_args, "Unsupported data format");
}
}

return GenResultDataSet(rows, dim, std::move(data));
} else if (data_format == DataFormatEnum::int8) {
auto data = std::make_unique<knowhere::int8[]>(dim * rows);
// faiss produces fp32 data format, we need some other format.
// Let's create a temporary fp32 buffer for this.
auto tmp = std::make_unique<float[]>(dim);
for (int64_t i = 0; i < rows; i++) {
const int64_t id = ids[i];
assert(id >= 0 && id < index->ntotal);
index_to_reconstruct_from->reconstruct(id, tmp.get());
if (!convert_rows_from_fp32(tmp.get(), data.get(), data_format, i, 1, dim)) {
return expected<DataSetPtr>::Err(Status::invalid_args, "Unsupported data format");
}
}
return GenResultDataSet(rows, dim, std::move(data));
} else {
return expected<DataSetPtr>::Err(Status::invalid_args, "Unsupported data format");
Expand Down Expand Up @@ -1234,13 +1252,18 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode {
// The query data is always cloned
std::unique_ptr<float[]> cur_query = std::make_unique<float[]>(dim);

if (data_format == DataFormatEnum::fp32) {
std::copy_n(reinterpret_cast<const float*>(data) + idx * dim, dim, cur_query.get());
} else if (data_format == DataFormatEnum::fp16 || data_format == DataFormatEnum::bf16) {
convert_rows_to_fp32(data, cur_query.get(), data_format, idx, 1, dim);
} else {
// invalid one. Should not be triggered, bcz input parameters are validated
throw;
switch (data_format) {
case DataFormatEnum::fp32:
std::copy_n(reinterpret_cast<const float*>(data) + idx * dim, dim, cur_query.get());
break;
case DataFormatEnum::fp16:
case DataFormatEnum::bf16:
case DataFormatEnum::int8:
convert_rows_to_fp32(data, cur_query.get(), data_format, idx, 1, dim);
break;
default:
// invalid one. Should not be triggered, bcz input parameters are validated
throw;
}

const bool should_use_refine = (dynamic_cast<const faiss::IndexRefine*>(index.get()) != nullptr);
Expand Down Expand Up @@ -1327,6 +1350,9 @@ class BaseFaissRegularIndexHNSWFlatNode : public BaseFaissRegularIndexHNSWNode {
} else if (data_format == DataFormatEnum::bf16) {
hnsw_index = std::make_unique<faiss::IndexHNSWSQCosine>(dim, faiss::ScalarQuantizer::QT_bf16,
hnsw_cfg.M.value());
} else if (data_format == DataFormatEnum::int8) {
hnsw_index = std::make_unique<faiss::IndexHNSWSQCosine>(
dim, faiss::ScalarQuantizer::QT_8bit_direct_signed, hnsw_cfg.M.value());
} else {
LOG_KNOWHERE_ERROR_ << "Unsupported metric type: " << hnsw_cfg.metric_type.value();
return Status::invalid_metric_type;
Expand All @@ -1340,6 +1366,9 @@ class BaseFaissRegularIndexHNSWFlatNode : public BaseFaissRegularIndexHNSWNode {
} else if (data_format == DataFormatEnum::bf16) {
hnsw_index = std::make_unique<faiss::IndexHNSWSQ>(dim, faiss::ScalarQuantizer::QT_bf16,
hnsw_cfg.M.value(), metric.value());
} else if (data_format == DataFormatEnum::int8) {
hnsw_index = std::make_unique<faiss::IndexHNSWSQ>(dim, faiss::ScalarQuantizer::QT_8bit_direct_signed,
hnsw_cfg.M.value(), metric.value());
} else {
LOG_KNOWHERE_ERROR_ << "Unsupported metric type: " << hnsw_cfg.metric_type.value();
return Status::invalid_metric_type;
Expand Down Expand Up @@ -1564,10 +1593,12 @@ namespace {
// a supporting function
expected<faiss::ScalarQuantizer::QuantizerType>
get_sq_quantizer_type(const std::string& sq_type) {
std::map<std::string, faiss::ScalarQuantizer::QuantizerType> sq_types = {{"sq6", faiss::ScalarQuantizer::QT_6bit},
{"sq8", faiss::ScalarQuantizer::QT_8bit},
{"fp16", faiss::ScalarQuantizer::QT_fp16},
{"bf16", faiss::ScalarQuantizer::QT_bf16}};
std::map<std::string, faiss::ScalarQuantizer::QuantizerType> sq_types = {
{"sq6", faiss::ScalarQuantizer::QT_6bit},
{"sq8", faiss::ScalarQuantizer::QT_8bit},
{"fp16", faiss::ScalarQuantizer::QT_fp16},
{"bf16", faiss::ScalarQuantizer::QT_bf16},
{"int8", faiss::ScalarQuantizer::QT_8bit_direct_signed}};

// todo: tolower
auto sq_type_tolower = str_to_lower(sq_type);
Expand Down Expand Up @@ -1653,6 +1684,8 @@ has_lossless_quant(const expected<faiss::ScalarQuantizer::QuantizerType>& quant_
return quant == faiss::ScalarQuantizer::QuantizerType::QT_fp16;
case DataFormatEnum::bf16:
return quant == faiss::ScalarQuantizer::QuantizerType::QT_bf16;
case DataFormatEnum::int8:
return quant == faiss::ScalarQuantizer::QuantizerType::QT_8bit_direct_signed;
default:
return false;
}
Expand Down Expand Up @@ -2280,13 +2313,18 @@ KNOWHERE_SIMPLE_REGISTER_DENSE_FLOAT_ALL_GLOBAL(HNSW_DEPRECATED,
#else
KNOWHERE_SIMPLE_REGISTER_DENSE_FLOAT_ALL_GLOBAL(HNSW, BaseFaissRegularIndexHNSWFlatNodeTemplateWithSearchFallback,
knowhere::feature::MMAP | knowhere::feature::MV)
KNOWHERE_SIMPLE_REGISTER_DENSE_INT8_GLOBAL(HNSW, BaseFaissRegularIndexHNSWFlatNodeTemplate,
knowhere::feature::MMAP | knowhere::feature::MV)
#endif

KNOWHERE_SIMPLE_REGISTER_DENSE_FLOAT_ALL_GLOBAL(HNSW_SQ, BaseFaissRegularIndexHNSWSQNodeTemplate,
knowhere::feature::MMAP)
KNOWHERE_SIMPLE_REGISTER_DENSE_INT8_GLOBAL(HNSW_SQ, BaseFaissRegularIndexHNSWSQNodeTemplate, knowhere::feature::MMAP)
KNOWHERE_SIMPLE_REGISTER_DENSE_FLOAT_ALL_GLOBAL(HNSW_PQ, BaseFaissRegularIndexHNSWPQNodeTemplate,
knowhere::feature::MMAP)
KNOWHERE_SIMPLE_REGISTER_DENSE_INT8_GLOBAL(HNSW_PQ, BaseFaissRegularIndexHNSWPQNodeTemplate, knowhere::feature::MMAP)
KNOWHERE_SIMPLE_REGISTER_DENSE_FLOAT_ALL_GLOBAL(HNSW_PRQ, BaseFaissRegularIndexHNSWPRQNodeTemplate,
knowhere::feature::MMAP)
KNOWHERE_SIMPLE_REGISTER_DENSE_INT8_GLOBAL(HNSW_PRQ, BaseFaissRegularIndexHNSWPRQNodeTemplate, knowhere::feature::MMAP)

} // namespace knowhere
6 changes: 6 additions & 0 deletions src/index/index_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,8 @@ template knowhere::expected<knowhere::Index<knowhere::IndexNode>>
knowhere::IndexFactory::Create<knowhere::fp16>(const std::string&, const int32_t&, const Object&);
template knowhere::expected<knowhere::Index<knowhere::IndexNode>>
knowhere::IndexFactory::Create<knowhere::bf16>(const std::string&, const int32_t&, const Object&);
template knowhere::expected<knowhere::Index<knowhere::IndexNode>>
knowhere::IndexFactory::Create<knowhere::int8>(const std::string&, const int32_t&, const Object&);
template const knowhere::IndexFactory&
knowhere::IndexFactory::Register<knowhere::fp32>(
const std::string&, std::function<knowhere::Index<knowhere::IndexNode>(const int32_t&, const Object&)>,
Expand All @@ -153,3 +155,7 @@ template const knowhere::IndexFactory&
knowhere::IndexFactory::Register<knowhere::bf16>(
const std::string&, std::function<knowhere::Index<knowhere::IndexNode>(const int32_t&, const Object&)>,
const uint64_t);
template const knowhere::IndexFactory&
knowhere::IndexFactory::Register<knowhere::int8>(
const std::string&, std::function<knowhere::Index<knowhere::IndexNode>(const int32_t&, const Object&)>,
const uint64_t);
1 change: 1 addition & 0 deletions src/index/index_static.cc
Original file line number Diff line number Diff line change
Expand Up @@ -152,5 +152,6 @@ template class IndexStaticFaced<knowhere::fp32>;
template class IndexStaticFaced<knowhere::fp16>;
template class IndexStaticFaced<knowhere::bf16>;
template class IndexStaticFaced<knowhere::bin1>;
template class IndexStaticFaced<knowhere::int8>;

} // namespace knowhere
Loading

0 comments on commit ca4ba32

Please sign in to comment.