Skip to content

Commit

Permalink
fix bug on generate vector (#241)
Browse files Browse the repository at this point in the history
fix bug on generate vector 

Signed-off-by: zhongxiaoyao.zxy <[email protected]>
  • Loading branch information
ShawnShawnYou authored Dec 24, 2024
1 parent f691e4c commit 2adb25d
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 5 deletions.
3 changes: 2 additions & 1 deletion src/quantization/quantizer_test.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ TestQuantizerEncodeDecodeSame(Quantizer<T>& quant,
int code_max = 15,
float error = 1e-5,
bool retrain = true) {
auto data_uint8 = fixtures::GenerateVectors<uint8_t>(count, dim, 0, 16);
int seed = 47;
auto data_uint8 = fixtures::GenerateVectors<uint8_t>(count, dim, seed, 0, 16);
std::vector<float> data(dim * count);
for (uint64_t i = 0; i < dim * count; ++i) {
data[i] = static_cast<float>(data_uint8[i]);
Expand Down
2 changes: 1 addition & 1 deletion tests/fixtures/fixtures.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ get_common_used_dims(uint64_t count, int seed) {

std::vector<float>
generate_vectors(uint64_t count, uint32_t dim, bool need_normalize, int seed) {
return std::move(GenerateVectors<float>(count, dim, need_normalize, seed));
return std::move(GenerateVectors<float>(count, dim, seed, need_normalize));
}

std::vector<int8_t>
Expand Down
6 changes: 3 additions & 3 deletions tests/fixtures/fixtures.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ template <typename T, typename RT = typename std::enable_if<std::is_integral_v<T
std::vector<RT>
GenerateVectors(uint64_t count,
uint32_t dim,
int seed = 47,
T min = std::numeric_limits<T>::lowest(),
T max = std::numeric_limits<T>::max(),
int seed = 47) {
T max = std::numeric_limits<T>::max()) {
std::mt19937 rng(seed);
std::uniform_int_distribution<T> distrib_real(min, max);
std::vector<T> vectors(dim * count);
Expand All @@ -47,7 +47,7 @@ GenerateVectors(uint64_t count,

template <typename T, typename RT = typename std::enable_if<std::is_floating_point_v<T>, T>::type>
std::vector<RT>
GenerateVectors(uint64_t count, uint32_t dim, bool need_normalize = true, int seed = 47) {
GenerateVectors(uint64_t count, uint32_t dim, int seed = 47, bool need_normalize = true) {
std::mt19937 rng(seed);
std::uniform_real_distribution<T> distrib_real;
std::vector<T> vectors(dim * count);
Expand Down

0 comments on commit 2adb25d

Please sign in to comment.