Skip to content

Commit

Permalink
Fix issue #146 - Check Bounds for avx512 reduction mod_2
Browse files Browse the repository at this point in the history
  • Loading branch information
joserochh committed Oct 26, 2024
1 parent 8463983 commit b21c380
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 3 deletions.
4 changes: 2 additions & 2 deletions hexl/eltwise/eltwise-reduce-mod-avx512.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ void EltwiseReduceModAVX512(uint64_t* result, const uint64_t* operand,
__m512i v_op = _mm512_loadu_si512(v_operand);
v_op = _mm512_hexl_barrett_reduce64<BitShift, 2>(
v_op, v_modulus, v_bf, v_bf_52, prod_right_shift, v_neg_mod);
HEXL_CHECK_BOUNDS(ExtractValues(v_op).data(), 8, modulus,
"v_op exceeds bound " << modulus);
HEXL_CHECK_BOUNDS(ExtractValues(v_op).data(), 8, twice_mod,
"v_op exceeds bound " << twice_mod);
_mm512_storeu_si512(v_result, v_op);
++v_operand;
++v_result;
Expand Down
2 changes: 1 addition & 1 deletion hexl/eltwise/eltwise-reduce-mod.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ void EltwiseReduceModNative(uint64_t* result, const uint64_t* operand,
if (input_mod_factor == modulus) {
if (output_mod_factor == 2) {
for (size_t i = 0; i < n; ++i) {
if (operand[i] >= modulus) {
if (operand[i] >= twice_modulus) {
result[i] = BarrettReduce64<2>(operand[i], modulus, barrett_factor);
} else {
result[i] = operand[i];
Expand Down
47 changes: 47 additions & 0 deletions test/test-eltwise-reduce-mod-avx512.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,30 @@ TEST(EltwiseReduceMod, avx512_64_mod_1) {
CheckEqual(result, exp_out);
}


TEST(EltwiseReduceMod, avx512_64_mod_2) {
if (!has_avx512dq) {
GTEST_SKIP();
}

std::vector<uint64_t> op{18399319504785536384ULL, 17772833711639413686ULL,
12597119745262224203ULL, 1504294004559805751ULL,
11357185129558358846ULL, 15524763729212309524ULL,
15578066193709346988ULL, 9262080163435001663ULL};

std::vector<uint64_t> exp_out{1282985348605, 701667589612, 1154521301334,
519986957540, 1153052298859, 914113932554,
1255706689604, 1229762981307};
std::vector<uint64_t> result{0, 0, 0, 0, 0, 0, 0, 0};

uint64_t modulus = 1099511590913;
const uint64_t input_mod_factor = modulus;
const uint64_t output_mod_factor = 2;
EltwiseReduceModAVX512<64>(result.data(), op.data(), op.size(), modulus,
input_mod_factor, output_mod_factor);
CheckEqual(result, exp_out);
}

TEST(EltwiseReduceModMontInOut, avx512_64_mod_1) {
if (!has_avx512dq) {
GTEST_SKIP();
Expand Down Expand Up @@ -83,6 +107,29 @@ TEST(EltwiseReduceMod, avx512_52_mod_1) {
CheckEqual(result, exp_out);
}

TEST(EltwiseReduceMod, avx512_52_mod_2) {
if (!has_avx512ifma) {
GTEST_SKIP();
}

std::vector<uint64_t> op{18399319504785536384ULL, 17772833711639413686ULL,
12597119745262224203ULL, 1504294004559805751ULL,
11357185129558358846ULL, 15524763729212309524ULL,
15578066193709346988ULL, 9262080163435001663ULL};

std::vector<uint64_t> exp_out{183473757692, 701667589612, 1154521301334,
519986957540, 1153052298859, 914113932554,
1255706689604, 130251390394};
std::vector<uint64_t> result{0, 0, 0, 0, 0, 0, 0, 0};

uint64_t modulus = 1099511590913;
const uint64_t input_mod_factor = modulus;
const uint64_t output_mod_factor = 2;
EltwiseReduceModAVX512<52>(result.data(), op.data(), op.size(), modulus,
input_mod_factor, output_mod_factor);
CheckEqual(result, exp_out);
}

TEST(EltwiseReduceMod, avx512_52_Big_mod_1) {
if (!has_avx512ifma) {
GTEST_SKIP();
Expand Down

0 comments on commit b21c380

Please sign in to comment.