Skip to content

Commit

Permalink
introduce new train method for sq
Browse files Browse the repository at this point in the history
Signed-off-by: LHT129 <[email protected]>
  • Loading branch information
LHT129 committed Dec 18, 2024
1 parent ccc0108 commit b94ca10
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 13 deletions.
46 changes: 41 additions & 5 deletions src/quantization/scalar_quantization_trainer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@

#include "scalar_quantization_trainer.h"

#include <queue>
#include <random>
#include <unordered_set>

#include "simd/normalize.h"

Expand All @@ -32,11 +32,13 @@ ScalarQuantizationTrainer::Train(const float* data,
float* upper_bound,
float* lower_bound,
bool need_normalize,
ScalarQuantizationTrainer::SQTrainMode mode) {
SQTrainMode mode) {
std::vector<float> sample_datas;
auto sample_count = this->sample_train_data(data, count, sample_datas, need_normalize);
if (mode == CLASSIC) {
this->classic_train(sample_datas.data(), sample_count, upper_bound, lower_bound);
} else if (mode == TRUNC_BOUND) {
this->trunc_bound_train(sample_datas.data(), sample_count, upper_bound, lower_bound);
}
}

Expand All @@ -46,7 +48,7 @@ ScalarQuantizationTrainer::TrainUniform(const float* data,
float& upper_bound,
float& lower_bound,
bool need_normalize,
ScalarQuantizationTrainer::SQTrainMode mode) {
SQTrainMode mode) {
std::vector<float> sample_datas;
auto sample_count = this->sample_train_data(data, count, sample_datas, need_normalize);
std::vector<float> upper(dim_);
Expand All @@ -62,7 +64,7 @@ void
ScalarQuantizationTrainer::classic_train(const float* data,
uint64_t count,
float* upper_bound,
float* lower_bound) {
float* lower_bound) const {
for (uint64_t i = 0; i < dim_; ++i) {
upper_bound[i] = std::numeric_limits<float>::lowest();
lower_bound[i] = std::numeric_limits<float>::max();
Expand All @@ -74,11 +76,45 @@ ScalarQuantizationTrainer::classic_train(const float* data,
}
}

void
ScalarQuantizationTrainer::trunc_bound_train(const float* data,
uint64_t count,
float* upper_bound,
float* lower_bound) const {
auto ignore_count = static_cast<uint64_t>(static_cast<float>(count - 1) * 0.001);

for (uint64_t i = 0; i < dim_; ++i) {
upper_bound[i] = std::numeric_limits<float>::lowest();
lower_bound[i] = std::numeric_limits<float>::max();
std::priority_queue<float, std::vector<float>, std::greater<>> heap_max;
std::priority_queue<float, std::vector<float>, std::less<>> heap_min;
heap_max.emplace(upper_bound[i]);
heap_min.emplace(lower_bound[i]);
for (uint64_t j = 0; j < count; ++j) {
auto value = data[j * dim_ + i];
if (value > heap_max.top() || heap_max.size() < ignore_count) {
heap_max.emplace(value);
}
if (heap_max.size() > ignore_count) {
heap_max.pop();
}
if (value < heap_min.top() || heap_min.size() < ignore_count) {
heap_min.emplace(value);
}
if (heap_min.size() > ignore_count) {
heap_min.pop();
}
}
upper_bound[i] = heap_max.top();
lower_bound[i] = heap_min.top();
}
}

uint64_t
ScalarQuantizationTrainer::sample_train_data(const float* data,
uint64_t count,
std::vector<float>& sample_datas,
bool need_normalize) {
bool need_normalize) const {
uint64_t step = 2147483647UL % count;
auto sample_count = max_sample_count_;
if (count <= max_sample_count_) {
Expand Down
22 changes: 14 additions & 8 deletions src/quantization/scalar_quantization_trainer.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,16 @@

namespace vsag {

enum SQTrainMode {
CLASSIC = 1,
K_MEANS = 2,
TRUNC_BOUND = 3,
};

class ScalarQuantizationTrainer {
public:
explicit ScalarQuantizationTrainer(int32_t dim, int bits = 8);

enum SQTrainMode {
CLASSIC = 1,
K_MEANS = 2,
TRUNC_BOUND = 3,
};

void
Train(const float* data,
uint64_t count,
Expand All @@ -57,13 +57,19 @@ class ScalarQuantizationTrainer {

private:
void
classic_train(const float* data, uint64_t count, float* upper_bound, float* lower_bound);
classic_train(const float* data, uint64_t count, float* upper_bound, float* lower_bound) const;

void
trunc_bound_train(const float* data,
uint64_t count,
float* upper_bound,
float* lower_bound) const;

uint64_t
sample_train_data(const float* data,
uint64_t count,
std::vector<float>& sample_datas,
bool need_normalize = false);
bool need_normalize = false) const;

private:
int dim_{0};
Expand Down

0 comments on commit b94ca10

Please sign in to comment.