From b21c380884be09bde0fb6b0498251a4572d7246d Mon Sep 17 00:00:00 2001 From: "Rojas Chaves, Jose" Date: Sat, 26 Oct 2024 00:32:16 -0400 Subject: [PATCH] Fix issue #146 - Check Bounds for avx512 reduction mod_2 --- hexl/eltwise/eltwise-reduce-mod-avx512.hpp | 4 +- hexl/eltwise/eltwise-reduce-mod.cpp | 2 +- test/test-eltwise-reduce-mod-avx512.cpp | 47 ++++++++++++++++++++++ 3 files changed, 50 insertions(+), 3 deletions(-) diff --git a/hexl/eltwise/eltwise-reduce-mod-avx512.hpp b/hexl/eltwise/eltwise-reduce-mod-avx512.hpp index 5374c9c8..87ad2c6d 100644 --- a/hexl/eltwise/eltwise-reduce-mod-avx512.hpp +++ b/hexl/eltwise/eltwise-reduce-mod-avx512.hpp @@ -81,8 +81,8 @@ void EltwiseReduceModAVX512(uint64_t* result, const uint64_t* operand, __m512i v_op = _mm512_loadu_si512(v_operand); v_op = _mm512_hexl_barrett_reduce64( v_op, v_modulus, v_bf, v_bf_52, prod_right_shift, v_neg_mod); - HEXL_CHECK_BOUNDS(ExtractValues(v_op).data(), 8, modulus, - "v_op exceeds bound " << modulus); + HEXL_CHECK_BOUNDS(ExtractValues(v_op).data(), 8, twice_mod, + "v_op exceeds bound " << twice_mod); _mm512_storeu_si512(v_result, v_op); ++v_operand; ++v_result; diff --git a/hexl/eltwise/eltwise-reduce-mod.cpp b/hexl/eltwise/eltwise-reduce-mod.cpp index accfe938..e7b9dd06 100644 --- a/hexl/eltwise/eltwise-reduce-mod.cpp +++ b/hexl/eltwise/eltwise-reduce-mod.cpp @@ -35,7 +35,7 @@ void EltwiseReduceModNative(uint64_t* result, const uint64_t* operand, if (input_mod_factor == modulus) { if (output_mod_factor == 2) { for (size_t i = 0; i < n; ++i) { - if (operand[i] >= modulus) { + if (operand[i] >= twice_modulus) { result[i] = BarrettReduce64<2>(operand[i], modulus, barrett_factor); } else { result[i] = operand[i]; diff --git a/test/test-eltwise-reduce-mod-avx512.cpp b/test/test-eltwise-reduce-mod-avx512.cpp index c2bb18c7..dbb2a162 100644 --- a/test/test-eltwise-reduce-mod-avx512.cpp +++ b/test/test-eltwise-reduce-mod-avx512.cpp @@ -35,6 +35,30 @@ TEST(EltwiseReduceMod, avx512_64_mod_1) { CheckEqual(result, exp_out); } + +TEST(EltwiseReduceMod, avx512_64_mod_2) { + if (!has_avx512dq) { + GTEST_SKIP(); + } + + std::vector op{18399319504785536384ULL, 17772833711639413686ULL, + 12597119745262224203ULL, 1504294004559805751ULL, + 11357185129558358846ULL, 15524763729212309524ULL, + 15578066193709346988ULL, 9262080163435001663ULL}; + + std::vector exp_out{1282985348605, 701667589612, 1154521301334, + 519986957540, 1153052298859, 914113932554, + 1255706689604, 1229762981307}; + std::vector result{0, 0, 0, 0, 0, 0, 0, 0}; + + uint64_t modulus = 1099511590913; + const uint64_t input_mod_factor = modulus; + const uint64_t output_mod_factor = 2; + EltwiseReduceModAVX512<64>(result.data(), op.data(), op.size(), modulus, + input_mod_factor, output_mod_factor); + CheckEqual(result, exp_out); +} + TEST(EltwiseReduceModMontInOut, avx512_64_mod_1) { if (!has_avx512dq) { GTEST_SKIP(); @@ -83,6 +107,29 @@ TEST(EltwiseReduceMod, avx512_52_mod_1) { CheckEqual(result, exp_out); } +TEST(EltwiseReduceMod, avx512_52_mod_2) { + if (!has_avx512ifma) { + GTEST_SKIP(); + } + + std::vector op{18399319504785536384ULL, 17772833711639413686ULL, + 12597119745262224203ULL, 1504294004559805751ULL, + 11357185129558358846ULL, 15524763729212309524ULL, + 15578066193709346988ULL, 9262080163435001663ULL}; + + std::vector exp_out{183473757692, 701667589612, 1154521301334, + 519986957540, 1153052298859, 914113932554, + 1255706689604, 130251390394}; + std::vector result{0, 0, 0, 0, 0, 0, 0, 0}; + + uint64_t modulus = 1099511590913; + const uint64_t input_mod_factor = modulus; + const uint64_t output_mod_factor = 2; + EltwiseReduceModAVX512<52>(result.data(), op.data(), op.size(), modulus, + input_mod_factor, output_mod_factor); + CheckEqual(result, exp_out); +} + TEST(EltwiseReduceMod, avx512_52_Big_mod_1) { if (!has_avx512ifma) { GTEST_SKIP();