diff --git a/CMakeLists.txt b/CMakeLists.txt index b0da96f7..85088575 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -53,12 +53,14 @@ set(CMAKE_INSTALL_RPATH "\$ORIGIN") #------------------------------------------------------------------------------ # Compiler options... #------------------------------------------------------------------------------ +#twy +option(MY_OPENMP "Enable multithreading" ON) option(HEXL_BENCHMARK "Enable benchmarking" ON) option(HEXL_COVERAGE "Enables coverage for unit tests" OFF) option(HEXL_DOCS "Enable documentation building" OFF) option(HEXL_EXPERIMENTAL "Enable experimental features" OFF) -option(HEXL_SHARED_LIB "Generate a shared library" OFF) +option(HEXL_SHARED_LIB "Generate a shared library" ON) option(HEXL_TESTING "Enables unit-tests" ON) option(HEXL_TREAT_WARNING_AS_ERROR "Treat all compile-time warnings as errors" OFF) @@ -66,6 +68,9 @@ if (NOT HEXL_FPGA_COMPATIBILITY) set(HEXL_FPGA_COMPATIBILITY "0" CACHE INTERNAL "Set FPGA compatibility mask" FORCE) endif() +#twy +message(STATUS "OpenMP_OPTION: ${MY_OPENMP}") + message(STATUS "CMAKE_BUILD_TYPE: ${CMAKE_BUILD_TYPE}") message(STATUS "CMAKE_C_COMPILER: ${CMAKE_C_COMPILER}") message(STATUS "CMAKE_CXX_COMPILER: ${CMAKE_CXX_COMPILER}") @@ -82,6 +87,11 @@ message(STATUS "HEXL_FPGA_COMPATIBILITY: ${HEXL_FPGA_COMPATIBILITY}") hexl_check_compiler_version() hexl_add_compiler_definition() +#twy +if (OpenMP_OPTION) + add_compile_options(-fopenmp) +endif() + if (HEXL_COVERAGE) if (NOT HEXL_USE_GNU) message(FATAL_ERROR "HEXL_COVERAGE only supported on GCC.") @@ -183,6 +193,12 @@ if (HEXL_DEBUG) endif() endif() +find_package(OpenMP) +if (OpenMP_FOUND) + message(STATUS "OpenMP_CXX_INCLUDE_DIRS: ${OpenMP_CXX_INCLUDE_DIRS}") + message(STATUS "OpenMP_CXX_LIBRARIES: ${OpenMP_CXX_LIBRARIES}") +endif() + #------------------------------------------------------------------------------ # Subfolders... #------------------------------------------------------------------------------ diff --git a/README.md b/README.md index 5113f6f0..b9626f5f 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ [![Build and Test](https://github.com/intel/hexl/actions/workflows/github-ci.yml/badge.svg?branch=main)](https://github.com/intel/hexl/actions/workflows/github-ci.yml) -# Intel Homomorphic Encryption (HE) Acceleration Library +# OpenMP-enabled Intel Homomorphic Encryption (HE) Acceleration Library Intel:registered: HE Acceleration Library is an open-source library which provides efficient implementations of integer arithmetic on Galois fields. Such arithmetic is prevalent in cryptography, particularly in homomorphic encryption @@ -11,6 +11,8 @@ Intel HE Acceleration Library, see our [whitepaper](https://arxiv.org/abs/2103.16400.pdf). For tips on best performance, see [Performance](#performance). +The project extends HEXL's capabilities by incorporating OpenMP for multi-threaded parallelization, in order to reduce computational latency without compromising security. + ## Contents - [Intel Homomorphic Encryption (HE) Acceleration Library](#intel-homomorphic-encryption-he-acceleration-library) - [Contents](#contents) @@ -20,7 +22,7 @@ performance, see [Performance](#performance). - [Compile-time options](#compile-time-options) - [Compiling Intel HE Acceleration Library](#compiling-intel-he-acceleration-library) - [Linux and Mac](#linux-and-mac) - - [Windows](#windows) + - [Build the OpenMP-enabled Version](#build-the-openmp-enabled-version) - [Performance](#performance) - [Testing Intel HE Acceleration Library](#testing-intel-he-acceleration-library) - [Benchmarking Intel HE Acceleration Library](#benchmarking-intel-he-acceleration-library) @@ -98,20 +100,15 @@ from source using CMake. ### Dependencies We have tested Intel HE Acceleration Library on the following operating systems: - Ubuntu 20.04 -- macOS 10.15 Catalina -- Microsoft Windows 10 +- macOS 13.4.1 (22F82) Vertura Intel HE Acceleration Library requires the following dependencies: | Dependency | Version | |-------------|----------------------------------------------| | CMake | >= 3.13 \* | -| Compiler | gcc >= 7.0, clang++ >= 5.0, MSVC >= 2019 | +| Compiler | gcc >= 7.0, clang++ >= 5.0 | -\* For Windows 10, you must check whether the version on CMake you have can -generate the necessary Visual Studio project files. For example, only from -[CMake 3.14 onwards can MSVC 2019 project files be -generated](https://cmake.org/cmake/help/git-stage/generator/Visual%20Studio%2016%202019.html). ### Compile-time options @@ -162,33 +159,151 @@ To install Intel HE Acceleration Library to the installation directory, run ```bash cmake --install build ``` +## Build the OpenMP-enabled Version +1. First start from a clean project in the ‘hexl’ folder +```bash + rm -rf CMakeCache.txt CMakeFiles/ +``` +2. Go to the build folder +```bash + cd build +``` +3. Repeat the first command to clean the cmake cache in the build folder +```bash + rm -rf CMakeCache.txt CMakeFiles/ +``` -#### Windows -To compile Intel HE Acceleration Library on Windows using Visual Studio in -Release mode, configure the build via +4. Now configure the build ```bash -cmake -S . -B build -G "Visual Studio 16 2019" -DCMAKE_BUILD_TYPE=Release + cmake -S . -B build -DCMAKE_INSTALL_PREFIX=/path/to/install/hexl_v2 -DCMAKE_C_COMPILER=gcc -DCMAKE_CXX_COMPILER=g++ + ``` + or +```bash + cmake -S . -B build -DCMAKE_INSTALL_PREFIX=/path/to/installhexl_v1 -DCMAKE_C_COMPILER=gcc -DCMAKE_CXX_COMPILER=g++ +``` + or +```bash + cmake -S . -B build -DCMAKE_INSTALL_PREFIX=/path/to/install/hexl_v3 -DCMAKE_C_COMPILER=gcc -DCMAKE_CXX_COMPILER=g++ ``` -adding the desired compile-time options with a `-D` flag (see [Compile-time -options](#compile-time-options)). For instance, to use a non-standard -installation directory, configure the build with +Notice that, hexl_v0 is the original serial version; hexl_v1 stands for the serial but removed or unroll techniques version; hexl_v2 stands for the OpenMP-enabled multi-threading version; hexl_v3 stands for the OpenMP-enabled multi-threading and allow OpenMP region timing test version. ```bash -cmake -S . -B build -G "Visual Studio 16 2019" -DCMAKE_BUILD_TYPE=Release -DCMAKE_INSTALL_PREFIX=/path/to/install + cmake --build build + cmake --install build ``` -To specify the desired build configuration, pass either `--config Debug` or -`--config Release` to the build step and install steps. For instance, to build -Intel HE Acceleration Library in Release mode, call +## Test the parallel performance of OpenMP-enabled Version +1. Go into the OpenMP enabled example test directory ```bash -cmake --build build --config Release + cd omp_example +``` +2. Define the directory containing HEXLConfig.cmake (save but don’t close) +```bash + vim camke/CMakeLists.txt +``` +```bash + set(HEXL_HINT_DIR "path/to/install/hexl_v2/lib/cmake/hexl-1.2.5") ``` -This will build the Intel HE Acceleration Library library in the -`build/hexl/lib/` or `build/hexl/Release/lib` directory. -To install Intel HE Acceleration Library to the installation directory, run +3. Help to find OpenMP in the CMakeLists.txt +```bash + find_package(OpenMP) + if (OpenMP_FOUND) + message(STATUS "OpenMP_CXX_INCLUDE_DIRS: ${OpenMP_CXX_INCLUDE_DIRS}") + message(STATUS "OpenMP_CXX_LIBRARIES: ${OpenMP_CXX_LIBRARIES}") + endif() +``` + + N.B. Don’t forget to edit the execution file name! + +4. Build the example +```bash + cmake -S cmake -B build + cmake --build build -j +``` + +5. Set the desired thread number and execute + +```bash + ./build/omp_example 1000 4 1234 + time ./build/omp_example 1000 1,2,4,8,16 16384 +``` + or directly run the script + ```bash + cd .. + ./hexl_omp.sh + ``` + +## Test the serial performance of OpenMP-enabled Version +1. Go into the serial example test directory +```bash + cd ser_example +``` +2. Define the directory containing HEXLConfig.cmake (save but don’t close) +```bash + vim camke/CMakeLists.txt +``` +```bash + set(HEXL_HINT_DIR "path/to/install/hexl_v1/lib/cmake/hexl-1.2.5") +``` + N.B. Don’t forget to edit the execution file name! + +3. Build the example +```bash + cmake -S cmake -B build + cmake --build build -j +``` + +4. Set the desired thread number and execute + +```bash + ./build/example 1000 4096,65536 + time ./build/example 1000 4096,65536 +``` + or directly run the script + ```bash + cd .. + ./hexl_serial.sh + ``` + +## Test the parallel performance of OpenMP-enabled Version +1. Go into the OpenMP enabled example test directory +```bash + cd time_example +``` +2. Define the directory containing HEXLConfig.cmake (save but don’t close) +```bash + vim camke/CMakeLists.txt +``` +```bash + set(HEXL_HINT_DIR "path/to/install/hexl_v3/lib/cmake/hexl-1.2.5") +``` + +3. Help to find OpenMP in the CMakeLists.txt +```bash + find_package(OpenMP) + if (OpenMP_FOUND) + message(STATUS "OpenMP_CXX_INCLUDE_DIRS: ${OpenMP_CXX_INCLUDE_DIRS}") + message(STATUS "OpenMP_CXX_LIBRARIES: ${OpenMP_CXX_LIBRARIES}") + endif() +``` + +4. Build the example +```bash + cmake -S cmake -B build + cmake --build build -j +``` + +5. Set the desired thread number and execute + ```bash -cmake --build build --target install --config Release + ./build/time_example 100 4 1024 0 + time ./build/omp_example 100 4 1024 0 ``` + or directly run the script + ```bash + cd .. + ./hexl_time.sh + ``` ## Performance For best performance, we recommend using Intel HE Acceleration Library on a @@ -352,7 +467,7 @@ To cite Intel HE Acceleration Library, please use the following BibTeX entry. ``` # Contributors -The Intel contributors to this project, sorted by last name, are +The Intel contributors to the original HEXL project, sorted by last name, are - [Paky Abu-Alam](https://www.linkedin.com/in/paky-abu-alam-89797710/) - [Flavio Bergamaschi](https://www.linkedin.com/in/flavio-bergamaschi-1634141/) - [Fabian Boemer](https://www.linkedin.com/in/fabian-boemer-5a40a9102/) @@ -370,5 +485,5 @@ The Intel contributors to this project, sorted by last name, are - [Gelila Seifu](https://www.linkedin.com/in/gelila-seifu/) In addition to the Intel contributors listed, we are also grateful to -contributions to this project that are not reflected in the Git history: +contributions to the original HEXL project that are not reflected in the Git history: - [Antonis Papadimitriou](https://www.linkedin.com/in/apapadimitriou/) diff --git a/benchmark/main.cpp b/benchmark/main.cpp index c8592bb8..99aa2371 100644 --- a/benchmark/main.cpp +++ b/benchmark/main.cpp @@ -2,10 +2,14 @@ // SPDX-License-Identifier: Apache-2.0 #include - +#include +#include #include "hexl/logging/logging.hpp" int main(int argc, char** argv) { + int max_threads = omp_get_max_threads(); + std::cout << "Maximum number of threads = " << max_threads << std::endl; + START_EASYLOGGINGPP(argc, argv); benchmark::Initialize(&argc, argv); diff --git a/example/Serial_result b/example/Serial_result new file mode 100644 index 00000000..7aa525ca --- /dev/null +++ b/example/Serial_result @@ -0,0 +1,34 @@ +Method Input_size=65536Input_size=1048576Input_size=16777216Input_size=268435456 +BM_EltwiseCmpAdd 0.017372 0.360454 10.3044 101.545 +BM_EltwiseCmpSubMod 0.293158 5.49905 37.7997 569.832 +BM_EltwiseFMAModAdd 0.403343 3.77922 48.3154 736.565 +BM_EltwiseMultMod 0.225371 2.10433 49.6537 917.409 +BM_EltwiseReduceModInPlace 0.324567 2.90482 45.3023 708.633 +BM_EltwiseVectorScalarAddMod 0.035236 0.67465 16.7397 157.531 +BM_EltwiseVectorVectorAddMod 0.054301 0.941839 20.3437 221.558 +BM_EltwiseVectorVectorSubMod 0.05375 0.946959 21.3859 216.095 +BM_NTTInPlace 0 0 0 0 + +s1820742psd@eidf018-s1820742msc:~/hexl/example$ ./build/example 1 4096,65536,1048576,16777216,268435456 +Method Threads=4096 Threads=65536 Threads=1048576Threads=16777216Threads=268435456 +BM_EltwiseCmpAdd 0.002144 0.01595 0.406509 10.3117 113.003 +BM_EltwiseCmpSubMod 0.039855 0.318033 3.72818 36.4171 644.48 +BM_EltwiseFMAModAdd 0.01605 0.248584 3.99384 57.2828 724.992 +BM_EltwiseMultMod 0.015299 0.168785 3.30193 48.9081 728.924 +BM_EltwiseReduceModInPlace 0.016561 0.234798 3.21797 38.945 659.401 +BM_EltwiseVectorScalarAddMod 0.001934 0.026049 0.538125 15.2527 175.121 +BM_EltwiseVectorVectorAddMod 0.003437 0.039815 0.918195 19.7483 212.407 +BM_EltwiseVectorVectorSubMod 0.001643 0.039985 0.743188 20.7958 196.669 +BM_NTTInPlace 0 0 0 0 0 + +s1820742psd@eidf018-s1820742msc:~/hexl/example$ ./build/example 10 4096,65536,1048576,16777216,268435456 +Method Threads=4096 Threads=65536 Threads=1048576 Threads=16777216 Threads=268435456 +BM_EltwiseCmpAdd 0.0013718 0.0193352 0.208849 10.9892 108.882 +BM_EltwiseCmpSubMod 0.0277949 0.370251 3.56102 56.7082 885.602 +BM_EltwiseFMAModAdd 0.0240648 0.351835 2.94638 54.911 1103.29 +BM_EltwiseMultMod 0.0194895 0.241749 2.1714 61.7039 793.763 +BM_EltwiseReduceModInPlace 0.0193062 0.287436 2.48417 44.0457 660.978 +BM_EltwiseVectorScalarAddMod 0.001593 0.0340157 0.425441 16.1058 172.764 +BM_EltwiseVectorVectorAddMod 0.0021721 0.050867 0.720809 20.2879 224.474 +BM_EltwiseVectorVectorSubMod 0.002193 0.0495797 0.733838 20.3605 215.544 +BM_NTTInPlace 0.0003798 0.0007619 0.0090708 0.171779 3.80078 \ No newline at end of file diff --git a/example/cmake/CMakeLists.txt b/example/cmake/CMakeLists.txt index ec7bf652..43963bcd 100644 --- a/example/cmake/CMakeLists.txt +++ b/example/cmake/CMakeLists.txt @@ -5,6 +5,9 @@ project(hexl_example LANGUAGES C CXX) cmake_minimum_required(VERSION 3.13) set(CMAKE_CXX_STANDARD 17) +# Define the directory containing HEXLConfig.cmake +set(HEXL_HINT_DIR "/home/eidf018/eidf018/s1820742psd/hexl/hexl_v1") + # Example using source find_package(HEXL 1.2.5 HINTS ${HEXL_HINT_DIR} diff --git a/example/example.cpp b/example/example.cpp index 9e6bf659..50876fca 100644 --- a/example/example.cpp +++ b/example/example.cpp @@ -1,157 +1,319 @@ -// Copyright (C) 2020 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 +// #include +#include #include +#include +#include #include +#include +#include +#include +#include #include -#include "hexl/hexl.hpp" +#include "../hexl/include/hexl/hexl.hpp" +// #include "../hexl/include/hexl/util/util.hpp" +#include "../hexl/util/util-internal.hpp" +// #include "../hexl/include/hexl/experimental/fft-like/fft-like.hpp" +// #include "../hexl/include/hexl/experimental/fft-like/fft-like-native.hpp" -bool CheckEqual(const std::vector& x, - const std::vector& y) { - if (x.size() != y.size()) { - std::cout << "Not equal in size\n"; - return false; - } - uint64_t N = x.size(); - bool is_match = true; - for (size_t i = 0; i < N; ++i) { - if (x[i] != y[i]) { - std::cout << "Not equal at index " << i << "\n"; - is_match = false; - } +template +double TimeFunction(Func&& f) { + auto start_time = std::chrono::high_resolution_clock::now(); + f(); + auto end_time = std::chrono::high_resolution_clock::now(); + + std::chrono::duration duration = end_time - start_time; + return duration.count(); +} + +std::vector split(const std::string& s, char delimiter) { + std::vector tokens; + std::string token; + std::istringstream tokenStream(s); + while (std::getline(tokenStream, token, delimiter)) { + tokens.push_back(std::stoi(token)); } - return is_match; + return tokens; } -void ExampleEltwiseVectorVectorAddMod() { - std::cout << "Running ExampleEltwiseVectorVectorAddMod...\n"; +double BM_EltwiseVectorVectorAddMod(size_t input_size) { + uint64_t modulus = 0xffffffffffc0001ULL; - std::vector op1{1, 2, 3, 4, 5, 6, 7, 8}; - std::vector op2{1, 3, 5, 7, 2, 4, 6, 8}; - uint64_t modulus = 10; - std::vector exp_out{2, 5, 8, 1, 7, 0, 3, 6}; + auto input1 = intel::hexl::GenerateInsecureUniformIntRandomValues(input_size, + 0, modulus); + auto input2 = intel::hexl::GenerateInsecureUniformIntRandomValues(input_size, + 0, modulus); + intel::hexl::AlignedVector64 output(input_size, 0); - intel::hexl::EltwiseAddMod(op1.data(), op1.data(), op2.data(), op1.size(), - modulus); + double time_taken = TimeFunction([&]() { + intel::hexl::EltwiseAddMod(output.data(), input1.data(), input2.data(), + input_size, modulus); + }); - CheckEqual(op1, exp_out); - std::cout << "Done running ExampleEltwiseVectorVectorAddMod\n"; + return time_taken; } -void ExampleEltwiseVectorScalarAddMod() { - std::cout << "Running ExampleEltwiseVectorScalarAddMod...\n"; +double BM_EltwiseVectorScalarAddMod(size_t input_size) { + uint64_t modulus = 0xffffffffffc0001ULL; - std::vector op1{1, 2, 3, 4, 5, 6, 7, 8}; - uint64_t op2{3}; - uint64_t modulus = 10; - std::vector exp_out{4, 5, 6, 7, 8, 9, 0, 1}; + auto input1 = intel::hexl::GenerateInsecureUniformIntRandomValues(input_size, + 0, modulus); + uint64_t input2 = + intel::hexl::GenerateInsecureUniformIntRandomValue(0, modulus); + intel::hexl::AlignedVector64 output(input_size, 0); - intel::hexl::EltwiseAddMod(op1.data(), op1.data(), op2, op1.size(), modulus); + double time_taken = TimeFunction([&]() { + intel::hexl::EltwiseAddMod(output.data(), input1.data(), input2, input_size, + modulus); + }); - CheckEqual(op1, exp_out); - std::cout << "Done running ExampleEltwiseVectorScalarAddMod\n"; + return time_taken; } -void ExampleEltwiseCmpAdd() { - std::cout << "Running ExampleEltwiseCmpAdd...\n"; +double BM_EltwiseCmpAdd(size_t input_size, intel::hexl::CMPINT chosenCMP) { + uint64_t modulus = 100; - std::vector op1{1, 2, 3, 4, 5, 6, 7, 8}; - uint64_t cmp = 3; - uint64_t diff = 5; - std::vector exp_out{1, 2, 3, 9, 10, 11, 12, 13}; + uint64_t bound = + intel::hexl::GenerateInsecureUniformIntRandomValue(0, modulus); + uint64_t diff = + intel::hexl::GenerateInsecureUniformIntRandomValue(1, modulus - 1); + auto input1 = intel::hexl::GenerateInsecureUniformIntRandomValues(input_size, + 0, modulus); - intel::hexl::EltwiseCmpAdd(op1.data(), op1.data(), op1.size(), - intel::hexl::CMPINT::NLE, cmp, diff); + double time_taken = TimeFunction([&]() { + intel::hexl::EltwiseCmpAdd(input1.data(), input1.data(), input_size, + chosenCMP, bound, diff); + }); - CheckEqual(op1, exp_out); - std::cout << "Done running ExampleEltwiseCmpAdd\n"; + return time_taken; } -void ExampleEltwiseCmpSubMod() { - std::cout << "Running ExampleEltwiseCmpSubMod...\n"; +double BM_EltwiseCmpSubMod(size_t input_size, intel::hexl::CMPINT chosenCMP) { + uint64_t modulus = 100; - std::vector op1{1, 2, 3, 4, 5, 6, 7}; - uint64_t bound = 4; - uint64_t diff = 5; - std::vector exp_out{1, 2, 3, 4, 0, 1, 2}; + uint64_t bound = + intel::hexl::GenerateInsecureUniformIntRandomValue(1, modulus); + uint64_t diff = + intel::hexl::GenerateInsecureUniformIntRandomValue(1, modulus); + auto input1 = intel::hexl::GenerateInsecureUniformIntRandomValues(input_size, + 0, modulus); - uint64_t modulus = 10; + double time_taken = TimeFunction([&]() { + intel::hexl::EltwiseCmpSubMod(input1.data(), input1.data(), input_size, + modulus, chosenCMP, bound, diff); + }); - intel::hexl::EltwiseCmpSubMod(op1.data(), op1.data(), op1.size(), modulus, - intel::hexl::CMPINT::NLE, bound, diff); - CheckEqual(op1, exp_out); - std::cout << "Done running ExampleEltwiseCmpSubMod\n"; + return time_taken; } -void ExampleEltwiseFMAMod() { - std::cout << "Running ExampleEltwiseFMAMod...\n"; - - std::vector arg1{1, 2, 3, 4, 5, 6, 7, 8, 9}; - uint64_t arg2 = 1; - std::vector exp_out{1, 2, 3, 4, 5, 6, 7, 8, 9}; - uint64_t modulus = 769; +double BM_EltwiseFMAModAdd(size_t input_size, bool add) { + uint64_t modulus = 0xffffffffffc0001ULL; + + auto input1 = intel::hexl::GenerateInsecureUniformIntRandomValues(input_size, + 0, modulus); + uint64_t input2 = + intel::hexl::GenerateInsecureUniformIntRandomValue(0, modulus); + intel::hexl::AlignedVector64 input3 = + intel::hexl::GenerateInsecureUniformIntRandomValues(input_size, 0, + modulus); + uint64_t* arg3 = add ? input3.data() : nullptr; + double time_taken = TimeFunction([&]() { + intel::hexl::EltwiseFMAMod(input1.data(), input1.data(), input2, arg3, + input1.size(), modulus, 1); + }); + return time_taken; +} - intel::hexl::EltwiseFMAMod(arg1.data(), arg1.data(), arg2, nullptr, - arg1.size(), modulus, 1); - CheckEqual(arg1, exp_out); - std::cout << "Done running ExampleEltwiseFMAMod\n"; +double BM_EltwiseMultMod(size_t input_size, size_t bit_width, + size_t input_mod_factor) { + uint64_t modulus = (1ULL << bit_width) + 7; + + auto input1 = intel::hexl::GenerateInsecureUniformIntRandomValues(input_size, + 0, modulus); + auto input2 = intel::hexl::GenerateInsecureUniformIntRandomValues(input_size, + 0, modulus); + intel::hexl::AlignedVector64 output(input_size, 2); + + double time_taken = TimeFunction([&]() { + intel::hexl::EltwiseMultMod(output.data(), input1.data(), input2.data(), + input_size, modulus, input_mod_factor); + }); + return time_taken; } -void ExampleEltwiseMultMod() { - std::cout << "Running ExampleEltwiseMultMod...\n"; +double BM_EltwiseReduceModInPlace(size_t input_size) { + uint64_t modulus = 0xffffffffffc0001ULL; - std::vector op1{2, 4, 3, 2}; - std::vector op2{2, 1, 2, 0}; - std::vector exp_out{4, 4, 6, 0}; + auto input1 = intel::hexl::GenerateInsecureUniformIntRandomValues( + input_size, 0, 100 * modulus); - uint64_t modulus = 769; + const uint64_t input_mod_factor = modulus; + const uint64_t output_mod_factor = 1; - intel::hexl::EltwiseMultMod(op1.data(), op1.data(), op2.data(), op1.size(), - modulus, 1); - CheckEqual(op1, exp_out); - std::cout << "Done running ExampleEltwiseMultMod\n"; + double time_taken = TimeFunction([&]() { + intel::hexl::EltwiseReduceMod(input1.data(), input1.data(), input_size, + modulus, input_mod_factor, output_mod_factor); + }); + return time_taken; } -void ExampleNTT() { - std::cout << "Running ExampleNTT...\n"; +double BM_EltwiseVectorVectorSubMod(size_t input_size) { + uint64_t modulus = 0xffffffffffc0001ULL; - uint64_t N = 8; - uint64_t modulus = 769; - std::vector arg{1, 2, 3, 4, 5, 6, 7, 8}; - auto exp_out = arg; - intel::hexl::NTT ntt(N, modulus); + auto input1 = intel::hexl::GenerateInsecureUniformIntRandomValues(input_size, + 0, modulus); + auto input2 = intel::hexl::GenerateInsecureUniformIntRandomValues(input_size, + 0, modulus); + intel::hexl::AlignedVector64 output(input_size, 0); - ntt.ComputeForward(arg.data(), arg.data(), 1, 1); - ntt.ComputeInverse(arg.data(), arg.data(), 1, 1); + double time_taken = TimeFunction([&]() { + intel::hexl::EltwiseSubMod(output.data(), input1.data(), input2.data(), + input_size, modulus); + }); + return time_taken; +} - CheckEqual(arg, exp_out); - std::cout << "Done running ExampleNTT\n"; +// double BM_FwdFFTLikeRd2InPlaceSmall(const size_t fft_like_size) { +// const size_t bound = 1 << 30; +// const double scale = 10; +// const double scalar = scale / static_cast(fft_like_size); +// intel::hexl::FFTLike fft_like(fft_like_size, nullptr); +// +// intel::hexl::AlignedVector64> input(fft_like_size); +// for (size_t i = 0; i < fft_like_size; i++) { +// input[i] = std::complex( +// intel::hexl::GenerateInsecureUniformRealRandomValue(0, bound), +// intel::hexl::GenerateInsecureUniformRealRandomValue(0, bound)); +// } +// +// intel::hexl::AlignedVector64> root_powers = +// fft_like.GetComplexRootsOfUnity(); +// +// double time_taken = TimeFunction([&]() { +// intel::hexl::Forward_FFTLike_ToBitReverseRadix2( +// input.data(), input.data(), root_powers.data(), fft_like_size, +// &scalar); +// }); +// return time_taken; +// } +double BM_NTTInPlace(size_t ntt_size) { + size_t modulus = intel::hexl::GeneratePrimes(1, 45, true, ntt_size)[0]; + + auto input = + intel::hexl::GenerateInsecureUniformIntRandomValues(ntt_size, 0, modulus); + intel::hexl::NTT ntt(ntt_size, modulus); + + double time_taken = TimeFunction([&]() { + ntt.ComputeForward(input.data(), input.data(), 1, 1); + }) + + TimeFunction([&]() { + ntt.ComputeInverse(input.data(), input.data(), 2, 1); + }); + return time_taken; } -void ExampleReduceMod() { - std::cout << "Running ExampleReduceMod...\n"; +int main(int argc, char** argv) { + if (argc != 3) { + std::cerr << "Usage: " << argv[0] + << " " + << std::endl; + return 1; + } - uint64_t modulus = 5; - std::vector arg{1, 2, 3, 4, 5, 6, 7, 8}; - std::vector exp_out{1, 2, 3, 4, 0, 1, 2, 3}; - std::vector result{0, 0, 0, 0, 0, 0, 0, 0}; - intel::hexl::EltwiseReduceMod(result.data(), arg.data(), arg.size(), modulus, - 2, 1); + int num_iterations = std::stoi(argv[1]); + std::vector input_size = split(argv[2], ','); + + // Using a map to store the results + std::map> results; + + // Initialize the results map + results["BM_EltwiseVectorVectorAddMod"] = + std::vector(input_size.size(), 0.0); + results["BM_EltwiseVectorScalarAddMod"] = + std::vector(input_size.size(), 0.0); + results["BM_EltwiseCmpAdd"] = std::vector(input_size.size(), 0.0); + results["BM_EltwiseCmpSubMod"] = std::vector(input_size.size(), 0.0); + results["BM_EltwiseFMAModAdd"] = std::vector(input_size.size(), 0.0); + results["BM_EltwiseMultMod"] = std::vector(input_size.size(), 0.0); + results["BM_EltwiseReduceModInPlace"] = + std::vector(input_size.size(), 0.0); + results["BM_EltwiseVectorVectorSubMod"] = + std::vector(input_size.size(), 0.0); + results["BM_NTTInPlace"] = std::vector(input_size.size(), 0.0); + + // Execute each method for all thread numbers + for (size_t j = 0; j < input_size.size(); ++j) { + int size = input_size[j]; + // omp_set_num_threads(num_threads); + + results["BM_EltwiseVectorVectorAddMod"][j] = 0; + results["BM_EltwiseVectorScalarAddMod"][j] = 0; + results["BM_EltwiseCmpAdd"][j] = 0; + results["BM_EltwiseCmpSubMod"][j] = 0; + results["BM_EltwiseFMAModAdd"][j] = 0; + results["BM_EltwiseMultMod"][j] = 0; + results["BM_EltwiseReduceModInPlace"][j] = 0; + results["BM_EltwiseVectorVectorSubMod"][j] = 0; + results["BM_NTTInPlace"][j] = 0; + + bool add_choices[] = {false, true}; + int bit_width_choices[] = {48, 60}; + int mod_factor_choices[] = {1, 2, 4}; + + for (int i = 0; i < num_iterations; i++) { + // There are CMPINT possibilities, should be chosen randomly for testing + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution<> dis_cmpint( + 0, 7); // 8 enum values, from 0 to 7 + std::uniform_int_distribution<> dis_add(0, 1); + std::uniform_int_distribution<> dis_factor(0, 2); + + intel::hexl::CMPINT chosenCMP = + static_cast(dis_cmpint(gen)); + + bool add = add_choices[dis_add(gen)]; + + size_t bit_width = bit_width_choices[dis_add(gen)]; + + size_t input_mod_factor = mod_factor_choices[dis_factor(gen)]; + + results["BM_EltwiseVectorVectorAddMod"][j] += + BM_EltwiseVectorVectorAddMod(size); + results["BM_EltwiseVectorScalarAddMod"][j] += + BM_EltwiseVectorScalarAddMod(size); + results["BM_EltwiseCmpAdd"][j] += BM_EltwiseCmpAdd(size, chosenCMP); + results["BM_EltwiseCmpSubMod"][j] += + BM_EltwiseCmpSubMod(size, chosenCMP); + results["BM_EltwiseFMAModAdd"][j] += BM_EltwiseFMAModAdd(size, add); + results["BM_EltwiseMultMod"][j] += + BM_EltwiseMultMod(size, bit_width, input_mod_factor); + results["BM_EltwiseReduceModInPlace"][j] += + BM_EltwiseReduceModInPlace(size); + results["BM_EltwiseVectorVectorSubMod"][j] += + BM_EltwiseVectorVectorSubMod(size); + results["BM_NTTInPlace"][j] += BM_NTTInPlace(size/4096); + } + } - CheckEqual(result, exp_out); - std::cout << "Done running ExampleReduceMod\n"; -} + // Print the table -int main() { - ExampleEltwiseVectorVectorAddMod(); - ExampleEltwiseVectorScalarAddMod(); - ExampleEltwiseCmpAdd(); - ExampleEltwiseCmpSubMod(); - ExampleEltwiseFMAMod(); - ExampleEltwiseMultMod(); - ExampleNTT(); - ExampleReduceMod(); + // Print headers + std::cout << std::left << std::setw(40) << "Method"; + for (int size : input_size) { + std::cout << std::setw(20) << ("Input_size=" + std::to_string(size)); + } + std::cout << std::endl; + + // Print results + for (auto& [method, times] : results) { + std::cout << std::left << std::setw(40) << method; + for (double time : times) { + std::cout << std::setw(20) << time/num_iterations; + } + std::cout << std::endl; + } return 0; } diff --git a/example/o_example.cpp b/example/o_example.cpp new file mode 100644 index 00000000..9e6bf659 --- /dev/null +++ b/example/o_example.cpp @@ -0,0 +1,157 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include +#include +#include + +#include "hexl/hexl.hpp" + +bool CheckEqual(const std::vector& x, + const std::vector& y) { + if (x.size() != y.size()) { + std::cout << "Not equal in size\n"; + return false; + } + uint64_t N = x.size(); + bool is_match = true; + for (size_t i = 0; i < N; ++i) { + if (x[i] != y[i]) { + std::cout << "Not equal at index " << i << "\n"; + is_match = false; + } + } + return is_match; +} + +void ExampleEltwiseVectorVectorAddMod() { + std::cout << "Running ExampleEltwiseVectorVectorAddMod...\n"; + + std::vector op1{1, 2, 3, 4, 5, 6, 7, 8}; + std::vector op2{1, 3, 5, 7, 2, 4, 6, 8}; + uint64_t modulus = 10; + std::vector exp_out{2, 5, 8, 1, 7, 0, 3, 6}; + + intel::hexl::EltwiseAddMod(op1.data(), op1.data(), op2.data(), op1.size(), + modulus); + + CheckEqual(op1, exp_out); + std::cout << "Done running ExampleEltwiseVectorVectorAddMod\n"; +} + +void ExampleEltwiseVectorScalarAddMod() { + std::cout << "Running ExampleEltwiseVectorScalarAddMod...\n"; + + std::vector op1{1, 2, 3, 4, 5, 6, 7, 8}; + uint64_t op2{3}; + uint64_t modulus = 10; + std::vector exp_out{4, 5, 6, 7, 8, 9, 0, 1}; + + intel::hexl::EltwiseAddMod(op1.data(), op1.data(), op2, op1.size(), modulus); + + CheckEqual(op1, exp_out); + std::cout << "Done running ExampleEltwiseVectorScalarAddMod\n"; +} + +void ExampleEltwiseCmpAdd() { + std::cout << "Running ExampleEltwiseCmpAdd...\n"; + + std::vector op1{1, 2, 3, 4, 5, 6, 7, 8}; + uint64_t cmp = 3; + uint64_t diff = 5; + std::vector exp_out{1, 2, 3, 9, 10, 11, 12, 13}; + + intel::hexl::EltwiseCmpAdd(op1.data(), op1.data(), op1.size(), + intel::hexl::CMPINT::NLE, cmp, diff); + + CheckEqual(op1, exp_out); + std::cout << "Done running ExampleEltwiseCmpAdd\n"; +} + +void ExampleEltwiseCmpSubMod() { + std::cout << "Running ExampleEltwiseCmpSubMod...\n"; + + std::vector op1{1, 2, 3, 4, 5, 6, 7}; + uint64_t bound = 4; + uint64_t diff = 5; + std::vector exp_out{1, 2, 3, 4, 0, 1, 2}; + + uint64_t modulus = 10; + + intel::hexl::EltwiseCmpSubMod(op1.data(), op1.data(), op1.size(), modulus, + intel::hexl::CMPINT::NLE, bound, diff); + CheckEqual(op1, exp_out); + std::cout << "Done running ExampleEltwiseCmpSubMod\n"; +} + +void ExampleEltwiseFMAMod() { + std::cout << "Running ExampleEltwiseFMAMod...\n"; + + std::vector arg1{1, 2, 3, 4, 5, 6, 7, 8, 9}; + uint64_t arg2 = 1; + std::vector exp_out{1, 2, 3, 4, 5, 6, 7, 8, 9}; + uint64_t modulus = 769; + + intel::hexl::EltwiseFMAMod(arg1.data(), arg1.data(), arg2, nullptr, + arg1.size(), modulus, 1); + CheckEqual(arg1, exp_out); + std::cout << "Done running ExampleEltwiseFMAMod\n"; +} + +void ExampleEltwiseMultMod() { + std::cout << "Running ExampleEltwiseMultMod...\n"; + + std::vector op1{2, 4, 3, 2}; + std::vector op2{2, 1, 2, 0}; + std::vector exp_out{4, 4, 6, 0}; + + uint64_t modulus = 769; + + intel::hexl::EltwiseMultMod(op1.data(), op1.data(), op2.data(), op1.size(), + modulus, 1); + CheckEqual(op1, exp_out); + std::cout << "Done running ExampleEltwiseMultMod\n"; +} + +void ExampleNTT() { + std::cout << "Running ExampleNTT...\n"; + + uint64_t N = 8; + uint64_t modulus = 769; + std::vector arg{1, 2, 3, 4, 5, 6, 7, 8}; + auto exp_out = arg; + intel::hexl::NTT ntt(N, modulus); + + ntt.ComputeForward(arg.data(), arg.data(), 1, 1); + ntt.ComputeInverse(arg.data(), arg.data(), 1, 1); + + CheckEqual(arg, exp_out); + std::cout << "Done running ExampleNTT\n"; +} + +void ExampleReduceMod() { + std::cout << "Running ExampleReduceMod...\n"; + + uint64_t modulus = 5; + std::vector arg{1, 2, 3, 4, 5, 6, 7, 8}; + std::vector exp_out{1, 2, 3, 4, 0, 1, 2, 3}; + std::vector result{0, 0, 0, 0, 0, 0, 0, 0}; + intel::hexl::EltwiseReduceMod(result.data(), arg.data(), arg.size(), modulus, + 2, 1); + + CheckEqual(result, exp_out); + std::cout << "Done running ExampleReduceMod\n"; +} + +int main() { + ExampleEltwiseVectorVectorAddMod(); + ExampleEltwiseVectorScalarAddMod(); + ExampleEltwiseCmpAdd(); + ExampleEltwiseCmpSubMod(); + ExampleEltwiseFMAMod(); + ExampleEltwiseMultMod(); + ExampleNTT(); + ExampleReduceMod(); + + return 0; +} diff --git a/hexl/CMakeLists.txt b/hexl/CMakeLists.txt index 7c660a02..2c92c389 100644 --- a/hexl/CMakeLists.txt +++ b/hexl/CMakeLists.txt @@ -119,6 +119,12 @@ install(DIRECTORY ${HEXL_INC_ROOT_DIR}/ PATTERN "*.hpp" PATTERN "*.h") +#twy +find_package(OpenMP) +if (OpenMP_CXX_FOUND) + target_link_libraries(hexl PUBLIC OpenMP::OpenMP_CXX) +endif() + if (HEXL_SHARED_LIB) target_link_libraries(hexl PRIVATE cpu_features) if (HEXL_DEBUG) diff --git a/hexl/eltwise/eltwise-add-mod.cpp b/hexl/eltwise/eltwise-add-mod.cpp index 6fc4c667..f3d99639 100644 --- a/hexl/eltwise/eltwise-add-mod.cpp +++ b/hexl/eltwise/eltwise-add-mod.cpp @@ -3,13 +3,17 @@ #include "hexl/eltwise/eltwise-add-mod.hpp" +#include + +#include +#include + #include "eltwise/eltwise-add-mod-avx512.hpp" #include "eltwise/eltwise-add-mod-internal.hpp" #include "hexl/logging/logging.hpp" #include "hexl/number-theory/number-theory.hpp" #include "hexl/util/check.hpp" #include "util/cpu-features.hpp" - namespace intel { namespace hexl { @@ -27,19 +31,35 @@ void EltwiseAddModNative(uint64_t* result, const uint64_t* operand1, HEXL_CHECK_BOUNDS(operand2, n, modulus, "pre-add value in operand2 exceeds bound " << modulus); - HEXL_LOOP_UNROLL_4 - for (size_t i = 0; i < n; ++i) { - uint64_t sum = *operand1 + *operand2; - if (sum >= modulus) { - *result = sum - modulus; - } else { - *result = sum; - } + int thread_count; - ++operand1; - ++operand2; - ++result; + double start_time = omp_get_wtime(); + +#pragma omp parallel + { + thread_count = omp_get_num_threads(); + + #pragma omp for + for (size_t i = 0; i < n; ++i) { + uint64_t sum = operand1[i] + operand2[i]; + if (sum >= modulus) { + result[i] = sum - modulus; + } else { + result[i] = sum; + } + } } + // double start_time = omp_get_wtime(); + // Record the end time(timer2) + double end_time = omp_get_wtime(); + + // Calculate and print the elapsed time + double elapsed_time = end_time - start_time; + + // std::cout << "EltwiseVectorVectorAddModNative Thread count: " << thread_count + // << " time: " << elapsed_time << " seconds." << std::endl; + std::cout << thread_count << " " << std::fixed << elapsed_time + << std::setprecision(5) << std::endl; } void EltwiseAddModNative(uint64_t* result, const uint64_t* operand1, @@ -54,18 +74,30 @@ void EltwiseAddModNative(uint64_t* result, const uint64_t* operand1, HEXL_CHECK(operand2 < modulus, "Require operand2 < modulus"); uint64_t diff = modulus - operand2; - - HEXL_LOOP_UNROLL_4 - for (size_t i = 0; i < n; ++i) { - if (*operand1 >= diff) { - *result = *operand1 - diff; - } else { - *result = *operand1 + operand2; + int thread_count; + // Record the start time (timer1) + double start_time = omp_get_wtime(); +#pragma omp parallel + { + thread_count = omp_get_num_threads(); + + #pragma omp for + for (size_t i = 0; i < n; ++i) { + if (operand1[i] >= diff) { + result[i] = operand1[i] - diff; + } else { + result[i] = operand1[i] + operand2; + } + } } + // Record the end time(timer2) + double end_time = omp_get_wtime(); - ++operand1; - ++result; - } + // Calculate and print the elapsed time + double elapsed_time = end_time - start_time; + + std::cout << thread_count << " " << std::fixed << elapsed_time + << std::setprecision(5) << std::endl; } void EltwiseAddMod(uint64_t* result, const uint64_t* operand1, @@ -114,3 +146,4 @@ void EltwiseAddMod(uint64_t* result, const uint64_t* operand1, } // namespace hexl } // namespace intel + diff --git a/hexl/eltwise/eltwise-cmp-add.cpp b/hexl/eltwise/eltwise-cmp-add.cpp index 399f993c..214e537c 100644 --- a/hexl/eltwise/eltwise-cmp-add.cpp +++ b/hexl/eltwise/eltwise-cmp-add.cpp @@ -3,13 +3,17 @@ #include "hexl/eltwise/eltwise-cmp-add.hpp" +#include + +#include +#include + #include "eltwise/eltwise-cmp-add-avx512.hpp" #include "eltwise/eltwise-cmp-add-internal.hpp" #include "hexl/logging/logging.hpp" #include "hexl/number-theory/number-theory.hpp" #include "hexl/util/check.hpp" #include "util/cpu-features.hpp" - namespace intel { namespace hexl { @@ -35,9 +39,16 @@ void EltwiseCmpAddNative(uint64_t* result, const uint64_t* operand1, uint64_t n, HEXL_CHECK(operand1 != nullptr, "Require operand1 != nullptr"); HEXL_CHECK(n != 0, "Require n != 0"); HEXL_CHECK(diff != 0, "Require diff != 0"); + + int thread_count; + double start_time = omp_get_wtime(); switch (cmp) { case CMPINT::EQ: { +#pragma omp parallel + { + thread_count = omp_get_num_threads(); + #pragma omp for for (size_t i = 0; i < n; ++i) { if (operand1[i] == bound) { result[i] = operand1[i] + diff; @@ -45,9 +56,15 @@ void EltwiseCmpAddNative(uint64_t* result, const uint64_t* operand1, uint64_t n, result[i] = operand1[i]; } } + } break; } + case CMPINT::LT: +#pragma omp parallel + { + thread_count = omp_get_num_threads(); + #pragma omp for for (size_t i = 0; i < n; ++i) { if (operand1[i] < bound) { result[i] = operand1[i] + diff; @@ -55,8 +72,14 @@ void EltwiseCmpAddNative(uint64_t* result, const uint64_t* operand1, uint64_t n, result[i] = operand1[i]; } } + } break; + case CMPINT::LE: +#pragma omp parallel + { + thread_count = omp_get_num_threads(); + #pragma omp for for (size_t i = 0; i < n; ++i) { if (operand1[i] <= bound) { result[i] = operand1[i] + diff; @@ -64,13 +87,25 @@ void EltwiseCmpAddNative(uint64_t* result, const uint64_t* operand1, uint64_t n, result[i] = operand1[i]; } } + } break; + case CMPINT::FALSE: +#pragma omp parallel + { + thread_count = omp_get_num_threads(); + #pragma omp for for (size_t i = 0; i < n; ++i) { result[i] = operand1[i]; } + } break; + case CMPINT::NE: +#pragma omp parallel + { + thread_count = omp_get_num_threads(); + #pragma omp for for (size_t i = 0; i < n; ++i) { if (operand1[i] != bound) { result[i] = operand1[i] + diff; @@ -78,8 +113,14 @@ void EltwiseCmpAddNative(uint64_t* result, const uint64_t* operand1, uint64_t n, result[i] = operand1[i]; } } + } break; + case CMPINT::NLT: +#pragma omp parallel + { + thread_count = omp_get_num_threads(); + #pragma omp for for (size_t i = 0; i < n; ++i) { if (operand1[i] >= bound) { result[i] = operand1[i] + diff; @@ -87,8 +128,14 @@ void EltwiseCmpAddNative(uint64_t* result, const uint64_t* operand1, uint64_t n, result[i] = operand1[i]; } } + } break; + case CMPINT::NLE: +#pragma omp parallel + { + thread_count = omp_get_num_threads(); + #pragma omp for for (size_t i = 0; i < n; ++i) { if (operand1[i] > bound) { result[i] = operand1[i] + diff; @@ -96,13 +143,29 @@ void EltwiseCmpAddNative(uint64_t* result, const uint64_t* operand1, uint64_t n, result[i] = operand1[i]; } } + } break; + case CMPINT::TRUE: +#pragma omp parallel + { + thread_count = omp_get_num_threads(); + #pragma omp for for (size_t i = 0; i < n; ++i) { result[i] = operand1[i] + diff; } + } break; + } + // Record the end time(timer2) + double end_time = omp_get_wtime(); + + // Calculate and print the elapsed time + double elapsed_time = end_time - start_time; + + std::cout << thread_count << " " << std::fixed << elapsed_time + << std::setprecision(5) << std::endl; } } // namespace hexl diff --git a/hexl/eltwise/eltwise-cmp-sub-mod.cpp b/hexl/eltwise/eltwise-cmp-sub-mod.cpp index b3526bd1..ce635e3f 100644 --- a/hexl/eltwise/eltwise-cmp-sub-mod.cpp +++ b/hexl/eltwise/eltwise-cmp-sub-mod.cpp @@ -3,6 +3,11 @@ #include "hexl/eltwise/eltwise-cmp-sub-mod.hpp" +#include + +#include +#include + #include "eltwise/eltwise-cmp-sub-mod-avx512.hpp" #include "eltwise/eltwise-cmp-sub-mod-internal.hpp" #include "hexl/logging/logging.hpp" @@ -53,17 +58,33 @@ void EltwiseCmpSubModNative(uint64_t* result, const uint64_t* operand1, HEXL_CHECK(modulus > 1, "Require modulus > 1"); HEXL_CHECK(diff != 0, "Require diff != 0"); HEXL_CHECK(diff < modulus, "Diff " << diff << " >= modulus " << modulus); + int thread_count; + double start_time = omp_get_wtime(); - for (size_t i = 0; i < n; ++i) { - uint64_t op = operand1[i]; - bool op_cmp = Compare(cmp, op, bound); - op %= modulus; - if (op_cmp) { - op = SubUIntMod(op, diff, modulus); +#pragma omp parallel + { + thread_count = omp_get_num_threads(); + #pragma omp for + for (size_t i = 0; i < n; ++i) { + uint64_t op = operand1[i]; + bool op_cmp = Compare(cmp, op, bound); + op %= modulus; + if (op_cmp) { + op = SubUIntMod(op, diff, modulus); + } + result[i] = op; } - result[i] = op; } -} + + // Record the end time(timer2) + double end_time = omp_get_wtime(); + + // Calculate and print the elapsed time + double elapsed_time = end_time - start_time; + + std::cout << thread_count << " " << std::fixed << elapsed_time + << std::setprecision(5) << std::endl; + } } // namespace hexl } // namespace intel diff --git a/hexl/eltwise/eltwise-fma-mod-internal.hpp b/hexl/eltwise/eltwise-fma-mod-internal.hpp index 673ab61f..557e94ef 100644 --- a/hexl/eltwise/eltwise-fma-mod-internal.hpp +++ b/hexl/eltwise/eltwise-fma-mod-internal.hpp @@ -3,6 +3,11 @@ #pragma once +#include + +#include +#include + #include "hexl/number-theory/number-theory.hpp" namespace intel { @@ -17,26 +22,47 @@ void EltwiseFMAModNative(uint64_t* result, const uint64_t* arg1, uint64_t arg2, &four_times_modulus); MultiplyFactor mf(arg2, 64, modulus); + int thread_count; + double start_time = omp_get_wtime(); + if (arg3) { +#pragma omp parallel +{ + thread_count = omp_get_num_threads(); +#pragma omp for for (size_t i = 0; i < n; ++i) { uint64_t arg1_val = ReduceMod( - *arg1++, modulus, &twice_modulus, &four_times_modulus); + arg1[i], modulus, &twice_modulus, &four_times_modulus); uint64_t arg3_val = ReduceMod( - *arg3++, modulus, &twice_modulus, &four_times_modulus); + arg3[i], modulus, &twice_modulus, &four_times_modulus); uint64_t result_val = MultiplyMod(arg1_val, arg2, mf.BarrettFactor(), modulus); - *result = AddUIntMod(result_val, arg3_val, modulus); - result++; + result[i] = AddUIntMod(result_val, arg3_val, modulus); } - } else { // arg3 == nullptr - for (size_t i = 0; i < n; ++i) { - uint64_t arg1_val = ReduceMod( - *arg1++, modulus, &twice_modulus, &four_times_modulus); - *result++ = MultiplyMod(arg1_val, arg2, mf.BarrettFactor(), modulus); +} + } + else { // arg3 == nullptr +#pragma omp parallel + { + thread_count = omp_get_num_threads(); +#pragma omp for + for (size_t i = 0; i < n; ++i) { + uint64_t arg1_val = ReduceMod( + arg1[i], modulus, &twice_modulus, &four_times_modulus); + result[i] = MultiplyMod(arg1_val, arg2, mf.BarrettFactor(), modulus); + } } } -} + // Record the end time(timer2) + double end_time = omp_get_wtime(); + + // Calculate and print the elapsed time + double elapsed_time = end_time - start_time; + + std::cout << thread_count << " " << std::fixed << elapsed_time + << std::setprecision(5) << std::endl; + } } // namespace hexl } // namespace intel diff --git a/hexl/eltwise/eltwise-mult-mod-internal.hpp b/hexl/eltwise/eltwise-mult-mod-internal.hpp index cd2a12f0..4bebb668 100644 --- a/hexl/eltwise/eltwise-mult-mod-internal.hpp +++ b/hexl/eltwise/eltwise-mult-mod-internal.hpp @@ -3,9 +3,12 @@ #pragma once +#include #include #include +#include +#include #include "eltwise/eltwise-mult-mod-internal.hpp" #include "hexl/eltwise/eltwise-reduce-mod.hpp" @@ -67,38 +70,51 @@ void EltwiseMultModNative(uint64_t* result, const uint64_t* operand1, .BarrettFactor(); const uint64_t twice_modulus = 2 * modulus; + int thread_count; - HEXL_LOOP_UNROLL_4 - for (size_t i = 0; i < n; ++i) { - uint64_t prod_hi, prod_lo, c2_hi, c2_lo, Z; + double start_time = omp_get_wtime(); - uint64_t x = ReduceMod(*operand1, modulus, &twice_modulus); - uint64_t y = ReduceMod(*operand2, modulus, &twice_modulus); +#pragma omp parallel + { + thread_count = omp_get_num_threads(); +#pragma omp for + for (size_t i = 0; i < n; ++i) { + uint64_t prod_hi, prod_lo, c2_hi, c2_lo, Z; - // Multiply inputs - MultiplyUInt64(x, y, &prod_hi, &prod_lo); + uint64_t x = + ReduceMod(operand1[i], modulus, &twice_modulus); + uint64_t y = + ReduceMod(operand2[i], modulus, &twice_modulus); - // floor(U / 2^{n + beta}) - uint64_t c1 = (prod_lo >> (prod_right_shift)) + - (prod_hi << (64 - (prod_right_shift))); + // Multiply inputs + MultiplyUInt64(x, y, &prod_hi, &prod_lo); - // c2 = floor(U / 2^{n + beta}) * mu - MultiplyUInt64(c1, barr_lo, &c2_hi, &c2_lo); + // floor(U / 2^{n + beta}) + uint64_t c1 = (prod_lo >> (prod_right_shift)) + + (prod_hi << (64 - (prod_right_shift))); - // alpha - beta == 64, so we only need high 64 bits - uint64_t q_hat = c2_hi; + // c2 = floor(U / 2^{n + beta}) * mu + MultiplyUInt64(c1, barr_lo, &c2_hi, &c2_lo); - // only compute low bits, since we know high bits will be 0 - Z = prod_lo - q_hat * modulus; + // alpha - beta == 64, so we only need high 64 bits + uint64_t q_hat = c2_hi; - // Conditional subtraction - *result = (Z >= modulus) ? (Z - modulus) : Z; + // only compute low bits, since we know high bits will be 0 + Z = prod_lo - q_hat * modulus; - ++operand1; - ++operand2; - ++result; + // Conditional subtraction + result[i] = (Z >= modulus) ? (Z - modulus) : Z; + } + } + // Record the end time(timer2) + double end_time = omp_get_wtime(); + + // Calculate and print the elapsed time + double elapsed_time = end_time - start_time; + + std::cout << thread_count << " " << std::fixed << elapsed_time + << std::setprecision(5) << std::endl; } -} } // namespace hexl } // namespace intel diff --git a/hexl/eltwise/eltwise-reduce-mod.cpp b/hexl/eltwise/eltwise-reduce-mod.cpp index accfe938..14f5fcd4 100644 --- a/hexl/eltwise/eltwise-reduce-mod.cpp +++ b/hexl/eltwise/eltwise-reduce-mod.cpp @@ -3,6 +3,11 @@ #include "hexl/eltwise/eltwise-reduce-mod.hpp" +#include + +#include +#include + #include "eltwise/eltwise-reduce-mod-avx512.hpp" #include "eltwise/eltwise-reduce-mod-internal.hpp" #include "hexl/logging/logging.hpp" @@ -32,21 +37,34 @@ void EltwiseReduceModNative(uint64_t* result, const uint64_t* operand, uint64_t barrett_factor = MultiplyFactor(1, 64, modulus).BarrettFactor(); uint64_t twice_modulus = modulus << 1; + int thread_count; + double start_time = omp_get_wtime(); + if (input_mod_factor == modulus) { if (output_mod_factor == 2) { - for (size_t i = 0; i < n; ++i) { - if (operand[i] >= modulus) { - result[i] = BarrettReduce64<2>(operand[i], modulus, barrett_factor); - } else { - result[i] = operand[i]; +#pragma omp parallel + { + thread_count = omp_get_num_threads(); +#pragma omp for + for (size_t i = 0; i < n; ++i) { + if (operand[i] >= modulus) { + result[i] = BarrettReduce64<2>(operand[i], modulus, barrett_factor); + } else { + result[i] = operand[i]; + } } } } else { - for (size_t i = 0; i < n; ++i) { - if (operand[i] >= modulus) { - result[i] = BarrettReduce64<1>(operand[i], modulus, barrett_factor); - } else { - result[i] = operand[i]; +#pragma omp parallel + { + thread_count = omp_get_num_threads(); +#pragma omp for + for (size_t i = 0; i < n; ++i) { + if (operand[i] >= modulus) { + result[i] = BarrettReduce64<1>(operand[i], modulus, barrett_factor); + } else { + result[i] = operand[i]; + } } } @@ -55,27 +73,51 @@ void EltwiseReduceModNative(uint64_t* result, const uint64_t* operand, } if (input_mod_factor == 2) { - for (size_t i = 0; i < n; ++i) { - result[i] = ReduceMod<2>(operand[i], modulus); +#pragma omp parallel + { + thread_count = omp_get_num_threads(); +#pragma omp for + for (size_t i = 0; i < n; ++i) { + result[i] = ReduceMod<2>(operand[i], modulus); + } } HEXL_CHECK_BOUNDS(result, n, modulus, "result exceeds bound " << modulus); } if (input_mod_factor == 4) { if (output_mod_factor == 1) { - for (size_t i = 0; i < n; ++i) { - result[i] = ReduceMod<4>(operand[i], modulus, &twice_modulus); +#pragma omp parallel + { + thread_count = omp_get_num_threads(); +#pragma omp for + for (size_t i = 0; i < n; ++i) { + result[i] = ReduceMod<4>(operand[i], modulus, &twice_modulus); + } } HEXL_CHECK_BOUNDS(result, n, modulus, "result exceeds bound " << modulus); } if (output_mod_factor == 2) { - for (size_t i = 0; i < n; ++i) { - result[i] = ReduceMod<2>(operand[i], twice_modulus); +#pragma omp parallel + { + thread_count = omp_get_num_threads(); +#pragma omp for + for (size_t i = 0; i < n; ++i) { + result[i] = ReduceMod<2>(operand[i], twice_modulus); + } } HEXL_CHECK_BOUNDS(result, n, twice_modulus, "result exceeds bound " << twice_modulus); } } + + // Record the end time(timer2) + double end_time = omp_get_wtime(); + + // Calculate and print the elapsed time + double elapsed_time = end_time - start_time; + + std::cout << thread_count << " " << std::fixed << elapsed_time + << std::setprecision(5) << std::endl; } void EltwiseReduceMod(uint64_t* result, const uint64_t* operand, uint64_t n, @@ -90,11 +132,27 @@ void EltwiseReduceMod(uint64_t* result, const uint64_t* operand, uint64_t n, "input_mod_factor must be modulus or 2 or 4" << input_mod_factor); HEXL_CHECK(output_mod_factor == 1 || output_mod_factor == 2, "output_mod_factor must be 1 or 2 " << output_mod_factor); - + int thread_count; if (input_mod_factor == output_mod_factor && (operand != result)) { - for (size_t i = 0; i < n; ++i) { - result[i] = operand[i]; + double start_time = omp_get_wtime(); + +#pragma omp parallel + { + thread_count = omp_get_num_threads(); +#pragma omp for + for (size_t i = 0; i < n; ++i) { + result[i] = operand[i]; + } } + + // Record the end time(timer2) + double end_time = omp_get_wtime(); + + // Calculate and print the elapsed time + double elapsed_time = end_time - start_time; + + std::cout << thread_count << " " << std::fixed << elapsed_time + << std::setprecision(5) << std::endl; return; } diff --git a/hexl/eltwise/eltwise-sub-mod.cpp b/hexl/eltwise/eltwise-sub-mod.cpp index ae52d401..11c2d85e 100644 --- a/hexl/eltwise/eltwise-sub-mod.cpp +++ b/hexl/eltwise/eltwise-sub-mod.cpp @@ -1,6 +1,11 @@ // Copyright (C) 2020 Intel Corporation // SPDX-License-Identifier: Apache-2.0 +#include + +#include +#include + #include "eltwise/eltwise-sub-mod-avx512.hpp" #include "eltwise/eltwise-sub-mod-internal.hpp" #include "hexl/eltwise/eltwise-add-mod.hpp" @@ -25,19 +30,31 @@ void EltwiseSubModNative(uint64_t* result, const uint64_t* operand1, "pre-sub value in operand1 exceeds bound " << modulus); HEXL_CHECK_BOUNDS(operand2, n, modulus, "pre-sub value in operand2 exceeds bound " << modulus); + + int thread_count; - HEXL_LOOP_UNROLL_4 - for (size_t i = 0; i < n; ++i) { - if (*operand1 >= *operand2) { - *result = *operand1 - *operand2; - } else { - *result = *operand1 + modulus - *operand2; - } + double start_time = omp_get_wtime(); - ++operand1; - ++operand2; - ++result; +#pragma omp parallel + { + thread_count = omp_get_num_threads(); +#pragma omp for + for (size_t i = 0; i < n; ++i) { + if (operand1[i] >= operand2[i]) { + result[i] = operand1[i] - operand2[i]; + } else { + result[i] = operand1[i] + modulus - operand2[i]; + } + } } + // Record the end time(timer2) + double end_time = omp_get_wtime(); + + // Calculate and print the elapsed time + double elapsed_time = end_time - start_time; + + std::cout << thread_count << " " << std::fixed << elapsed_time + << std::setprecision(5) << std::endl; } void EltwiseSubModNative(uint64_t* result, const uint64_t* operand1, @@ -50,18 +67,29 @@ void EltwiseSubModNative(uint64_t* result, const uint64_t* operand1, HEXL_CHECK_BOUNDS(operand1, n, modulus, "pre-sub value in operand1 exceeds bound " << modulus); HEXL_CHECK(operand2 < modulus, "Require operand2 < modulus"); + int thread_count; + double start_time = omp_get_wtime(); - HEXL_LOOP_UNROLL_4 - for (size_t i = 0; i < n; ++i) { - if (*operand1 >= operand2) { - *result = *operand1 - operand2; +#pragma omp parallel + { + thread_count = omp_get_num_threads(); +#pragma omp for + for (size_t i = 0; i < n; ++i) { + if (operand1[i] >= operand2) { + result[i] = operand1[i] - operand2; } else { - *result = *operand1 + modulus - operand2; + result[i] = operand1[i] + modulus - operand2; } - - ++operand1; - ++result; +} } +// Record the end time(timer2) +double end_time = omp_get_wtime(); + +// Calculate and print the elapsed time +double elapsed_time = end_time - start_time; + +std::cout << thread_count << " " << std::fixed << elapsed_time + << std::setprecision(5) << std::endl; } void EltwiseSubMod(uint64_t* result, const uint64_t* operand1, diff --git a/hexl/ntt/ntt-radix-4.cpp b/hexl/ntt/ntt-radix-4.cpp index 5c6df5d4..1d21fece 100644 --- a/hexl/ntt/ntt-radix-4.cpp +++ b/hexl/ntt/ntt-radix-4.cpp @@ -60,7 +60,7 @@ void ForwardTransformToBitReverseRadix4( const uint64_t* X_op = operand; const uint64_t* Y_op = X_op + t; - HEXL_LOOP_UNROLL_8 + // HEXL_LOOP_UNROLL_8 for (size_t j = 0; j < t; j++) { FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, twice_modulus); @@ -206,7 +206,7 @@ void ForwardTransformToBitReverseRadix4( switch (t) { case 4: { - HEXL_LOOP_UNROLL_8 + // HEXL_LOOP_UNROLL_8 for (size_t i = 0; i < m; i++) { if (i != 0) { X0_offset += 4 * t; @@ -252,7 +252,7 @@ void ForwardTransformToBitReverseRadix4( break; } case 1: { - HEXL_LOOP_UNROLL_8 + // HEXL_LOOP_UNROLL_8 for (size_t i = 0; i < m; i++) { if (i != 0) { X0_offset += 4 * t; @@ -429,7 +429,7 @@ void InverseTransformFromBitReverseRadix4( const uint64_t* W = inv_root_of_unity_powers + 1; const uint64_t* W_precon = precon_inv_root_of_unity_powers + 1; - HEXL_LOOP_UNROLL_8 + // HEXL_LOOP_UNROLL_8 for (size_t j = 0; j < n / 2; j++) { InvButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, *W++, *W_precon++, modulus, twice_modulus); @@ -534,7 +534,7 @@ void InverseTransformFromBitReverseRadix4( break; } default: { - HEXL_LOOP_UNROLL_4 + // HEXL_LOOP_UNROLL_4 for (size_t i = 0; i < m; i++) { HEXL_VLOG(4, "i " << i); if (i != 0) { diff --git a/hexl_bench b/hexl_bench new file mode 100644 index 00000000..2169ad63 --- /dev/null +++ b/hexl_bench @@ -0,0 +1,102 @@ +Running benchmark/bench_hexl +Run on (8 X 1996.25 MHz CPU s) +CPU Caches: + L1 Data 32 KiB (x8) + L1 Instruction 32 KiB (x8) + L2 Unified 512 KiB (x8) + L3 Unified 32768 KiB (x8) +Load Average: 0.00, 0.00, 0.00 +----------------------------------------------------------------------------------- +Benchmark Time CPU Iterations +----------------------------------------------------------------------------------- +BM_FwdNTTNativeRadix2InPlace/1024 1130 us 1130 us 621 +BM_FwdNTTNativeRadix2InPlace/4096 5310 us 5309 us 130 +BM_FwdNTTNativeRadix2InPlace/16384 15993 us 15992 us 33 +BM_FwdNTTNativeRadix2Copy/1024 641 us 641 us 1083 +BM_FwdNTTNativeRadix2Copy/4096 2886 us 2886 us 228 +BM_FwdNTTNativeRadix2Copy/16384 12516 us 12516 us 53 +BM_FwdNTTNativeRadix4InPlace/1024 619 us 619 us 1072 +BM_FwdNTTNativeRadix4InPlace/4096 2974 us 2974 us 242 +BM_FwdNTTNativeRadix4InPlace/16384 13895 us 13895 us 50 +BM_FwdNTTNativeRadix4Copy/1024 626 us 626 us 1199 +BM_FwdNTTNativeRadix4Copy/4096 2909 us 2909 us 220 +BM_FwdNTTNativeRadix4Copy/16384 13916 us 13916 us 52 +BM_FwdNTTInPlace/1024 582 us 582 us 1245 +BM_FwdNTTInPlace/4096 2736 us 2736 us 256 +BM_FwdNTTInPlace/16384 13035 us 13034 us 52 +BM_FwdNTTCopy/1024 600 us 600 us 1283 +BM_FwdNTTCopy/4096 2822 us 2822 us 261 +BM_FwdNTTCopy/16384 12812 us 12812 us 54 +BM_InvNTTInPlace/1024 294 us 294 us 2535 +BM_InvNTTInPlace/4096 1386 us 1386 us 511 +BM_InvNTTInPlace/16384 6152 us 6152 us 117 +BM_InvNTTCopy/1024 278 us 277 us 2541 +BM_InvNTTCopy/4096 1355 us 1355 us 545 +BM_InvNTTCopy/16384 6399 us 6398 us 119 +BM_InvNTTNativeRadix2InPlace/1024 286 us 286 us 2514 +BM_InvNTTNativeRadix2InPlace/4096 1405 us 1405 us 476 +BM_InvNTTNativeRadix2InPlace/16384 6785 us 6785 us 111 +BM_InvNTTNativeRadix2Copy/1024 278 us 278 us 2435 +BM_InvNTTNativeRadix2Copy/4096 1356 us 1356 us 514 +BM_InvNTTNativeRadix2Copy/16384 6293 us 6293 us 117 +BM_InvNTTNativeRadix4InPlace/1024 375 us 375 us 1810 +BM_InvNTTNativeRadix4InPlace/4096 1792 us 1792 us 401 +BM_InvNTTNativeRadix4InPlace/16384 8502 us 8501 us 87 +BM_InvNTTNativeRadix4Copy/1024 372 us 372 us 1791 +BM_InvNTTNativeRadix4Copy/4096 1870 us 1870 us 401 +BM_InvNTTNativeRadix4Copy/16384 8573 us 8573 us 84 +BM_EltwiseVectorVectorAddModNative/1024 2.66 us 2.66 us 247844 +BM_EltwiseVectorVectorAddModNative/4096 11.4 us 11.4 us 58583 +BM_EltwiseVectorVectorAddModNative/16384 43.0 us 43.0 us 15406 +BM_EltwiseVectorScalarAddModNative/1024 1.63 us 1.63 us 404432 +BM_EltwiseVectorScalarAddModNative/4096 7.47 us 7.47 us 91480 +BM_EltwiseVectorScalarAddModNative/16384 34.0 us 34.0 us 25964 +BM_EltwiseCmpAddNative/1024 1.13 us 1.13 us 619722 +BM_EltwiseCmpAddNative/4096 4.46 us 4.46 us 161229 +BM_EltwiseCmpAddNative/16384 17.3 us 17.3 us 38988 +BM_EltwiseCmpSubModNative/1024 2.10 us 2.10 us 324030 +BM_EltwiseCmpSubModNative/4096 8.47 us 8.47 us 84155 +BM_EltwiseCmpSubModNative/16384 34.5 us 34.4 us 20466 +BM_EltwiseFMAModAddNative/1024/0 4.26 us 4.26 us 180781 +BM_EltwiseFMAModAddNative/4096/0 15.2 us 15.2 us 43828 +BM_EltwiseFMAModAddNative/16384/0 57.2 us 57.2 us 12742 +BM_EltwiseFMAModAddNative/1024/1 7.71 us 7.71 us 92500 +BM_EltwiseFMAModAddNative/4096/1 30.1 us 30.1 us 23328 +BM_EltwiseFMAModAddNative/16384/1 120 us 120 us 6047 +BM_EltwiseMultMod/1024/48/1 5.27 us 5.27 us 131503 +BM_EltwiseMultMod/4096/48/1 21.6 us 21.6 us 31372 +BM_EltwiseMultMod/16384/48/1 87.8 us 87.8 us 8129 +BM_EltwiseMultMod/1024/60/1 5.54 us 5.54 us 122171 +BM_EltwiseMultMod/4096/60/1 22.1 us 22.1 us 33929 +BM_EltwiseMultMod/16384/60/1 82.9 us 82.9 us 8203 +BM_EltwiseMultMod/1024/48/2 5.74 us 5.74 us 119526 +BM_EltwiseMultMod/4096/48/2 21.2 us 21.2 us 30096 +BM_EltwiseMultMod/16384/48/2 87.0 us 87.0 us 8031 +BM_EltwiseMultMod/1024/60/2 5.48 us 5.48 us 129185 +BM_EltwiseMultMod/4096/60/2 20.5 us 20.5 us 33219 +BM_EltwiseMultMod/16384/60/2 83.4 us 83.4 us 8801 +BM_EltwiseMultMod/1024/48/4 13.9 us 13.9 us 51036 +BM_EltwiseMultMod/4096/48/4 58.0 us 58.0 us 12081 +BM_EltwiseMultMod/16384/48/4 229 us 229 us 3265 +BM_EltwiseMultMod/1024/60/4 14.5 us 14.5 us 50790 +BM_EltwiseMultMod/4096/60/4 59.4 us 59.4 us 12679 +BM_EltwiseMultMod/16384/60/4 216 us 216 us 3070 +BM_EltwiseMultModNative/1024 4.07 us 4.07 us 171443 +BM_EltwiseMultModNative/4096 16.2 us 16.2 us 43052 +BM_EltwiseMultModNative/16384 65.4 us 65.4 us 10941 +BM_EltwiseVectorVectorSubModNative/1024 2.52 us 2.52 us 277043 +BM_EltwiseVectorVectorSubModNative/4096 10.1 us 10.1 us 70341 +BM_EltwiseVectorVectorSubModNative/16384 43.9 us 43.9 us 16064 +BM_EltwiseVectorScalarSubModNative/1024 1.57 us 1.57 us 423180 +BM_EltwiseVectorScalarSubModNative/4096 6.56 us 6.56 us 97706 +BM_EltwiseVectorScalarSubModNative/16384 31.8 us 31.8 us 21135 +BM_EltwiseReduceModInPlace/1024 1.66 us 1.66 us 444353 +BM_EltwiseReduceModInPlace/4096 6.49 us 6.49 us 116754 +BM_EltwiseReduceModInPlace/16384 24.9 us 24.9 us 27805 +BM_EltwiseReduceModCopy/1024 4.81 us 4.81 us 162944 +BM_EltwiseReduceModCopy/4096 18.7 us 18.7 us 39751 +BM_EltwiseReduceModCopy/16384 82.6 us 82.6 us 8116 +BM_EltwiseReduceModNative/1024 5.02 us 5.02 us 100000 +BM_EltwiseReduceModNative/4096 20.1 us 20.1 us 33003 +BM_EltwiseReduceModNative/16384 87.0 us 87.0 us 7462 +[100%] Built target bench \ No newline at end of file diff --git a/hexl_omp.sh b/hexl_omp.sh new file mode 100755 index 00000000..1da579d9 --- /dev/null +++ b/hexl_omp.sh @@ -0,0 +1,28 @@ +#!/bin/bash + +# Command prefix +CMD_PREFIX="./omp_example/build/omp_example" +ITERATIONS=3 +THREADS="1,2,4,6,8" + +# Array of input sizes +# INPUT_SIZES=( $((2**12)) $((2**16)) $((2**20)) $((2**24)) $((2**28)) ) +INPUT_SIZES=( $((2**12)) $((2**16)) $((2**20))) + +# INPUT_SIZES=( $((2**24)) ) + +# Name of the output file +OUTPUT_FILE="hexl_out.csv" + +# Clear any previous output file to start fresh +> $OUTPUT_FILE + +# Loop through each input size and execute the command +for size in "${INPUT_SIZES[@]}"; do + echo "Running with input size: $size" + echo "Input Size = $size" >> $OUTPUT_FILE + $CMD_PREFIX $ITERATIONS $THREADS $size >> $OUTPUT_FILE + echo -e "\n" >> $OUTPUT_FILE +done + +echo "All runs completed!" diff --git a/hexl_omp/CMakeLists.txt b/hexl_omp/CMakeLists.txt new file mode 100644 index 00000000..2c92c389 --- /dev/null +++ b/hexl_omp/CMakeLists.txt @@ -0,0 +1,222 @@ +# Copyright (C) 2020 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +set(NATIVE_SRC + eltwise/eltwise-mult-mod.cpp + eltwise/eltwise-reduce-mod.cpp + eltwise/eltwise-sub-mod.cpp + eltwise/eltwise-add-mod.cpp + eltwise/eltwise-fma-mod.cpp + eltwise/eltwise-cmp-add.cpp + eltwise/eltwise-cmp-sub-mod.cpp + ntt/ntt-internal.cpp + ntt/ntt-radix-2.cpp + ntt/ntt-radix-4.cpp + number-theory/number-theory.cpp +) + +if (HEXL_EXPERIMENTAL) + list(APPEND NATIVE_SRC + experimental/seal/dyadic-multiply.cpp + experimental/seal/key-switch.cpp + experimental/seal/dyadic-multiply-internal.cpp + experimental/seal/key-switch-internal.cpp + experimental/misc/lr-mat-vec-mult.cpp + experimental/fft-like/fft-like.cpp + experimental/fft-like/fft-like-native.cpp + experimental/fft-like/fwd-fft-like-avx512.cpp + experimental/fft-like/inv-fft-like-avx512.cpp + ) +endif() + +if (HEXL_HAS_AVX512DQ) + set(AVX512_SRC + eltwise/eltwise-mult-mod-avx512dq.cpp + eltwise/eltwise-mult-mod-avx512ifma.cpp + eltwise/eltwise-reduce-mod-avx512.cpp + eltwise/eltwise-add-mod-avx512.cpp + eltwise/eltwise-cmp-sub-mod-avx512.cpp + eltwise/eltwise-cmp-add-avx512.cpp + eltwise/eltwise-sub-mod-avx512.cpp + eltwise/eltwise-fma-mod-avx512.cpp + ntt/fwd-ntt-avx512.cpp + ntt/inv-ntt-avx512.cpp + ) +endif() + +set(HEXL_SRC "${NATIVE_SRC};${AVX512_SRC}") + +if (HEXL_DEBUG) + list(APPEND HEXL_SRC logging/logging.cpp) +endif() + +if (HEXL_SHARED_LIB) + add_library(hexl SHARED ${HEXL_SRC}) +else() + add_library(hexl STATIC ${HEXL_SRC}) +endif() +add_library(HEXL::hexl ALIAS hexl) + +hexl_add_asan_flag(hexl) + +set(HEXL_DEFINES_IN_FILENAME ${CMAKE_CURRENT_SOURCE_DIR}/include/hexl/util/defines.hpp.in) +set(HEXL_DEFINES_FILENAME ${CMAKE_CURRENT_SOURCE_DIR}/include/hexl/util/defines.hpp) +configure_file(${HEXL_DEFINES_IN_FILENAME} ${HEXL_DEFINES_FILENAME}) + +set_target_properties(hexl PROPERTIES POSITION_INDEPENDENT_CODE ON) +set_target_properties(hexl PROPERTIES VERSION ${HEXL_VERSION}) +if (HEXL_DEBUG) + set_target_properties(hexl PROPERTIES OUTPUT_NAME "hexl_debug") +else() + set_target_properties(hexl PROPERTIES OUTPUT_NAME "hexl") +endif() + +target_include_directories(hexl + PRIVATE ${HEXL_SRC_ROOT_DIR} # Private headers + PUBLIC $ # Public headers + PUBLIC $ # Public headers +) +if(CpuFeatures_FOUND) + target_include_directories(hexl PUBLIC ${CpuFeatures_INCLUDE_DIR}) # Public headers +endif() + +if (HEXL_FPGA_COMPATIBILITY STREQUAL "1") + target_compile_options(hexl PRIVATE -DHEXL_FPGA_COMPATIBLE_DYADIC_MULTIPLY) +elseif (HEXL_FPGA_COMPATIBILITY STREQUAL "2") + target_compile_options(hexl PRIVATE -DHEXL_FPGA_COMPATIBLE_KEYSWITCH) +elseif (HEXL_FPGA_COMPATIBILITY STREQUAL "3") + target_compile_options(hexl PRIVATE + -DHEXL_FPGA_COMPATIBLE_DYADIC_MULTIPLY + -DHEXL_FPGA_COMPATIBLE_KEYSWITCH + ) +endif() + +if (CMAKE_CXX_COMPILER_ID MATCHES "GNU|Clang") + target_compile_options(hexl PRIVATE -Wall -Wconversion -Wshadow -pedantic -Wextra + -Wno-unknown-pragmas -march=native -O3 -fomit-frame-pointer + -Wno-sign-conversion + -Wno-implicit-int-conversion + ) + # Avoid 3rd-party dependency warnings when including HEXL as a dependency + target_compile_options(hexl PUBLIC + -Wno-unknown-warning + -Wno-unknown-warning-option + ) + +elseif (CMAKE_CXX_COMPILER_ID STREQUAL "MSVC") + # Inlining causes some tests to fail on MSVC with AVX512 in Release mode, HEXL_DEBUG=OFF, + # so we disable it here + target_compile_options(hexl PRIVATE /Wall /W4 /Ob0 + /wd4127 # warning C4127: conditional expression is constant; C++11 doesn't support if constexpr + /wd5105 # warning C5105: macro expansion producing 'defined' has undefined behavior + ) + target_compile_definitions(hexl PRIVATE -D_CRT_SECURE_NO_WARNINGS) +endif() + +install(DIRECTORY ${HEXL_INC_ROOT_DIR}/ + DESTINATION ${CMAKE_INSTALL_PREFIX}/${CMAKE_INSTALL_INCLUDEDIR}/ + FILES_MATCHING + PATTERN "*.hpp" + PATTERN "*.h") + +#twy +find_package(OpenMP) +if (OpenMP_CXX_FOUND) + target_link_libraries(hexl PUBLIC OpenMP::OpenMP_CXX) +endif() + +if (HEXL_SHARED_LIB) + target_link_libraries(hexl PRIVATE cpu_features) + if (HEXL_DEBUG) + target_link_libraries(hexl PUBLIC easyloggingpp) + # Manually add logging include directory + target_include_directories(hexl + PUBLIC $> + ) + endif() +else () + # For static library, if the dependencies are not found on the system, + # we manually add the dependencies for Intel HEXL in the exported library. + + # Export logging only if in debug mode + if (HEXL_DEBUG) + # Manually add logging include directory + target_include_directories(hexl + PUBLIC $> + ) + if (EASYLOGGINGPP_FOUND) + target_link_libraries(hexl PRIVATE easyloggingpp) + else() + hexl_create_archive(hexl easyloggingpp) + endif() + endif() + + if (CpuFeatures_FOUND) + target_link_libraries(hexl PRIVATE cpu_features) + else() + hexl_create_archive(hexl cpu_features) + endif() + + # Manually add cpu_features include directory + target_include_directories(hexl + PRIVATE $) +endif() + +install(TARGETS hexl DESTINATION ${CMAKE_INSTALL_LIBDIR}) + +#------------------------------------------------------------------------------ +# Config export... +#------------------------------------------------------------------------------ + +# Config filenames +set(HEXL_TARGET_FILENAME ${CMAKE_CURRENT_BINARY_DIR}/cmake/hexl-${HEXL_VERSION}/HEXLTargets.cmake) +set(HEXL_CONFIG_IN_FILENAME ${HEXL_CMAKE_PATH}/HEXLConfig.cmake.in) +set(HEXL_CONFIG_FILENAME ${HEXL_ROOT_DIR}/cmake/hexl-${HEXL_VERSION}/HEXLConfig.cmake) +set(HEXL_CONFIG_VERSION_FILENAME ${CMAKE_CURRENT_BINARY_DIR}/cmake/hexl-${HEXL_VERSION}/HEXLConfigVersion.cmake) +set(HEXL_CONFIG_INSTALL_DIR ${CMAKE_INSTALL_LIBDIR}/cmake/hexl-${HEXL_VERSION}/) + +# Create and install the CMake config and target file +install( + EXPORT HEXLTargets + NAMESPACE HEXL:: + DESTINATION ${HEXL_CONFIG_INSTALL_DIR} +) + +# Export version +write_basic_package_version_file( + ${HEXL_CONFIG_VERSION_FILENAME} + VERSION ${HEXL_VERSION} + COMPATIBILITY ExactVersion) + +include(CMakePackageConfigHelpers) + configure_package_config_file( + ${HEXL_CONFIG_IN_FILENAME} ${HEXL_CONFIG_FILENAME} + INSTALL_DESTINATION ${HEXL_CONFIG_INSTALL_DIR} + ) + +install( + TARGETS hexl + EXPORT HEXLTargets + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} + ) + +install(FILES ${HEXL_CONFIG_FILENAME} + ${HEXL_CONFIG_VERSION_FILENAME} + DESTINATION ${HEXL_CONFIG_INSTALL_DIR}) + +export(EXPORT HEXLTargets + FILE ${HEXL_TARGET_FILENAME}) + +# Pkgconfig +get_target_property(HEXL_TARGET_NAME hexl OUTPUT_NAME) + +configure_file(${HEXL_ROOT_DIR}/pkgconfig/hexl.pc.in + ${HEXL_ROOT_DIR}/pkgconfig/hexl.pc @ONLY) + +if(EXISTS ${HEXL_ROOT_DIR}/pkgconfig/hexl.pc) + install( + FILES ${HEXL_ROOT_DIR}/pkgconfig/hexl.pc + DESTINATION ${CMAKE_INSTALL_LIBDIR}/pkgconfig) +endif() diff --git a/hexl_omp/eltwise/eltwise-add-mod-avx512.cpp b/hexl_omp/eltwise/eltwise-add-mod-avx512.cpp new file mode 100644 index 00000000..7c103ca3 --- /dev/null +++ b/hexl_omp/eltwise/eltwise-add-mod-avx512.cpp @@ -0,0 +1,109 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "eltwise/eltwise-add-mod-avx512.hpp" + +#include +#include + +#include "eltwise/eltwise-add-mod-internal.hpp" +#include "hexl/eltwise/eltwise-add-mod.hpp" +#include "hexl/util/check.hpp" +#include "util/avx512-util.hpp" + +#ifdef HEXL_HAS_AVX512DQ + +namespace intel { +namespace hexl { + +void EltwiseAddModAVX512(uint64_t* result, const uint64_t* operand1, + const uint64_t* operand2, uint64_t n, + uint64_t modulus) { + HEXL_CHECK(result != nullptr, "Require result != nullptr"); + HEXL_CHECK(operand1 != nullptr, "Require operand1 != nullptr"); + HEXL_CHECK(operand2 != nullptr, "Require operand2 != nullptr"); + HEXL_CHECK(n != 0, "Require n != 0"); + HEXL_CHECK(modulus > 1, "Require modulus > 1"); + HEXL_CHECK(modulus < (1ULL << 63), "Require modulus < 2**63"); + HEXL_CHECK_BOUNDS(operand1, n, modulus, + "pre-add value in operand1 exceeds bound " << modulus); + HEXL_CHECK_BOUNDS(operand2, n, modulus, + "pre-add value in operand2 exceeds bound " << modulus); + + uint64_t n_mod_8 = n % 8; + if (n_mod_8 != 0) { + EltwiseAddModNative(result, operand1, operand2, n_mod_8, modulus); + operand1 += n_mod_8; + operand2 += n_mod_8; + result += n_mod_8; + n -= n_mod_8; + } + + __m512i v_modulus = _mm512_set1_epi64(static_cast(modulus)); + __m512i* vp_result = reinterpret_cast<__m512i*>(result); + const __m512i* vp_operand1 = reinterpret_cast(operand1); + const __m512i* vp_operand2 = reinterpret_cast(operand2); + + HEXL_LOOP_UNROLL_4 + for (size_t i = n / 8; i > 0; --i) { + __m512i v_operand1 = _mm512_loadu_si512(vp_operand1); + __m512i v_operand2 = _mm512_loadu_si512(vp_operand2); + + __m512i v_result = + _mm512_hexl_small_add_mod_epi64(v_operand1, v_operand2, v_modulus); + + _mm512_storeu_si512(vp_result, v_result); + + ++vp_result; + ++vp_operand1; + ++vp_operand2; + } + + HEXL_CHECK_BOUNDS(result, n, modulus, "result exceeds bound " << modulus); +} + +void EltwiseAddModAVX512(uint64_t* result, const uint64_t* operand1, + const uint64_t operand2, uint64_t n, + uint64_t modulus) { + HEXL_CHECK(result != nullptr, "Require result != nullptr"); + HEXL_CHECK(operand1 != nullptr, "Require operand1 != nullptr"); + HEXL_CHECK(n != 0, "Require n != 0"); + HEXL_CHECK(modulus > 1, "Require modulus > 1"); + HEXL_CHECK(modulus < (1ULL << 63), "Require modulus < 2**63"); + HEXL_CHECK_BOUNDS(operand1, n, modulus, + "pre-add value in operand1 exceeds bound " << modulus); + HEXL_CHECK(operand2 < modulus, "Require operand2 < modulus"); + + uint64_t n_mod_8 = n % 8; + if (n_mod_8 != 0) { + EltwiseAddModNative(result, operand1, operand2, n_mod_8, modulus); + operand1 += n_mod_8; + result += n_mod_8; + n -= n_mod_8; + } + + __m512i v_modulus = _mm512_set1_epi64(static_cast(modulus)); + __m512i* vp_result = reinterpret_cast<__m512i*>(result); + const __m512i* vp_operand1 = reinterpret_cast(operand1); + const __m512i v_operand2 = _mm512_set1_epi64(static_cast(operand2)); + + HEXL_LOOP_UNROLL_4 + for (size_t i = n / 8; i > 0; --i) { + __m512i v_operand1 = _mm512_loadu_si512(vp_operand1); + + __m512i v_result = + _mm512_hexl_small_add_mod_epi64(v_operand1, v_operand2, v_modulus); + + _mm512_storeu_si512(vp_result, v_result); + + ++vp_result; + ++vp_operand1; + } + + HEXL_CHECK_BOUNDS(result, n, modulus, "result exceeds bound " << modulus); +} + +} // namespace hexl +} // namespace intel + +#endif diff --git a/hexl_omp/eltwise/eltwise-add-mod-avx512.hpp b/hexl_omp/eltwise/eltwise-add-mod-avx512.hpp new file mode 100644 index 00000000..befb9a0e --- /dev/null +++ b/hexl_omp/eltwise/eltwise-add-mod-avx512.hpp @@ -0,0 +1,19 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +namespace intel { +namespace hexl { + +void EltwiseAddModAVX512(uint64_t* result, const uint64_t* operand1, + const uint64_t* operand2, uint64_t n, + uint64_t modulus); + +void EltwiseAddModAVX512(uint64_t* result, const uint64_t* operand1, + const uint64_t operand2, uint64_t n, uint64_t modulus); + +} // namespace hexl +} // namespace intel diff --git a/hexl_omp/eltwise/eltwise-add-mod-internal.hpp b/hexl_omp/eltwise/eltwise-add-mod-internal.hpp new file mode 100644 index 00000000..74891811 --- /dev/null +++ b/hexl_omp/eltwise/eltwise-add-mod-internal.hpp @@ -0,0 +1,33 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +namespace intel { +namespace hexl { + +/// @brief Adds two vectors elementwise with modular reduction +/// @param[out] result Stores result +/// @param[in] operand1 Vector of elements to add +/// @param[in] operand2 Vector of elements to add +/// @param[in] n Number of elements in each vector +/// @param[in] modulus Modulus with which to perform modular reduction +/// @details Computes \f$ operand1[i] = (operand1[i] + operand2[i]) \mod modulus +/// \f$ for \f$ i=0, ..., n-1\f$. +void EltwiseAddModNative(uint64_t* result, const uint64_t* operand1, + const uint64_t* operand2, uint64_t n, + uint64_t modulus); + +/// @brief Adds a vector and scalar elementwise with modular reduction +/// @param[out] result Stores result +/// @param[in] operand1 Vector of elements to add +/// @param[in] operand2 Scalar add +/// @param[in] n Number of elements in each vector +/// @param[in] modulus Modulus with which to perform modular reduction +/// @details Computes \f$ operand1[i] = (operand1[i] + operand2) \mod modulus +/// \f$ for \f$ i=0, ..., n-1\f$. +void EltwiseAddModNative(uint64_t* result, const uint64_t* operand1, + uint64_t operand2, uint64_t n, uint64_t modulus); + +} // namespace hexl +} // namespace intel diff --git a/hexl_omp/eltwise/eltwise-add-mod.cpp b/hexl_omp/eltwise/eltwise-add-mod.cpp new file mode 100644 index 00000000..ea3093a7 --- /dev/null +++ b/hexl_omp/eltwise/eltwise-add-mod.cpp @@ -0,0 +1,113 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "hexl/eltwise/eltwise-add-mod.hpp" + +#include + +#include + +#include "eltwise/eltwise-add-mod-avx512.hpp" +#include "eltwise/eltwise-add-mod-internal.hpp" +#include "hexl/logging/logging.hpp" +#include "hexl/number-theory/number-theory.hpp" +#include "hexl/util/check.hpp" +#include "util/cpu-features.hpp" +namespace intel { +namespace hexl { + +void EltwiseAddModNative(uint64_t* result, const uint64_t* operand1, + const uint64_t* operand2, uint64_t n, + uint64_t modulus) { + HEXL_CHECK(result != nullptr, "Require result != nullptr"); + HEXL_CHECK(operand1 != nullptr, "Require operand1 != nullptr"); + HEXL_CHECK(operand2 != nullptr, "Require operand2 != nullptr"); + HEXL_CHECK(n != 0, "Require n != 0"); + HEXL_CHECK(modulus > 1, "Require modulus > 1"); + HEXL_CHECK(modulus < (1ULL << 63), "Require modulus < 2**63"); + HEXL_CHECK_BOUNDS(operand1, n, modulus, + "pre-add value in operand1 exceeds bound " << modulus); + HEXL_CHECK_BOUNDS(operand2, n, modulus, + "pre-add value in operand2 exceeds bound " << modulus); + +#pragma omp parallel for + for (size_t i = 0; i < n; ++i) { + uint64_t sum = operand1[i] + operand2[i]; + if (sum >= modulus) { + result[i] = sum - modulus; + } else { + result[i] = sum; + } + } +} + +void EltwiseAddModNative(uint64_t* result, const uint64_t* operand1, + uint64_t operand2, uint64_t n, uint64_t modulus) { + HEXL_CHECK(result != nullptr, "Require result != nullptr"); + HEXL_CHECK(operand1 != nullptr, "Require operand1 != nullptr"); + HEXL_CHECK(n != 0, "Require n != 0"); + HEXL_CHECK(modulus > 1, "Require modulus > 1"); + HEXL_CHECK(modulus < (1ULL << 63), "Require modulus < 2**63"); + HEXL_CHECK_BOUNDS(operand1, n, modulus, + "pre-add value in operand1 exceeds bound " << modulus); + HEXL_CHECK(operand2 < modulus, "Require operand2 < modulus"); + + uint64_t diff = modulus - operand2; + +#pragma omp parallel for + for (size_t i = 0; i < n; ++i) { + if (operand1[i] >= diff) { + result[i] = operand1[i] - diff; + } else { + result[i] = operand1[i] + operand2; + } + } +} + +void EltwiseAddMod(uint64_t* result, const uint64_t* operand1, + const uint64_t* operand2, uint64_t n, uint64_t modulus) { + HEXL_CHECK(operand1 != nullptr, "Require operand1 != nullptr"); + HEXL_CHECK(operand2 != nullptr, "Require operand2 != nullptr"); + HEXL_CHECK(n != 0, "Require n != 0"); + HEXL_CHECK(modulus > 1, "Require modulus > 1"); + HEXL_CHECK(modulus < (1ULL << 63), "Require modulus < 2**63"); + HEXL_CHECK_BOUNDS(operand1, n, modulus, + "pre-add value in operand1 exceeds bound " << modulus); + HEXL_CHECK_BOUNDS(operand2, n, modulus, + "pre-add value in operand2 exceeds bound " << modulus); + +#ifdef HEXL_HAS_AVX512DQ + if (has_avx512dq) { + EltwiseAddModAVX512(result, operand1, operand2, n, modulus); + return; + } +#endif + + HEXL_VLOG(3, "Calling EltwiseAddModNative"); + EltwiseAddModNative(result, operand1, operand2, n, modulus); +} + +void EltwiseAddMod(uint64_t* result, const uint64_t* operand1, + const uint64_t operand2, uint64_t n, uint64_t modulus) { + HEXL_CHECK(operand1 != nullptr, "Require operand1 != nullptr"); + HEXL_CHECK(n != 0, "Require n != 0"); + HEXL_CHECK(modulus > 1, "Require modulus > 1"); + HEXL_CHECK(modulus < (1ULL << 63), "Require modulus < 2**63"); + HEXL_CHECK_BOUNDS(operand1, n, modulus, + "pre-add value in operand1 exceeds bound " << modulus); + HEXL_CHECK(operand2 < modulus, "Require operand2 < modulus"); + +#ifdef HEXL_HAS_AVX512DQ + if (has_avx512dq) { + EltwiseAddModAVX512(result, operand1, operand2, n, modulus); + return; + } +#endif + + HEXL_VLOG(3, "Calling EltwiseAddModNative"); + EltwiseAddModNative(result, operand1, operand2, n, modulus); +} + +} // namespace hexl +} // namespace intel + diff --git a/hexl_omp/eltwise/eltwise-cmp-add-avx512.cpp b/hexl_omp/eltwise/eltwise-cmp-add-avx512.cpp new file mode 100644 index 00000000..7ead8ca2 --- /dev/null +++ b/hexl_omp/eltwise/eltwise-cmp-add-avx512.cpp @@ -0,0 +1,49 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "eltwise/eltwise-cmp-add-avx512.hpp" + +#include +#include + +#include "eltwise/eltwise-cmp-add-internal.hpp" +#include "hexl/util/check.hpp" +#include "hexl/util/util.hpp" +#include "util/avx512-util.hpp" + +namespace intel { +namespace hexl { + +#ifdef HEXL_HAS_AVX512DQ +void EltwiseCmpAddAVX512(uint64_t* result, const uint64_t* operand1, uint64_t n, + CMPINT cmp, uint64_t bound, uint64_t diff) { + HEXL_CHECK(result != nullptr, "Require result != nullptr"); + HEXL_CHECK(operand1 != nullptr, "Require operand1 != nullptr"); + HEXL_CHECK(n != 0, "Require n != 0"); + HEXL_CHECK(diff != 0, "Require diff != 0"); + + uint64_t n_mod_8 = n % 8; + if (n_mod_8 != 0) { + EltwiseCmpAddNative(result, operand1, n_mod_8, cmp, bound, diff); + operand1 += n_mod_8; + result += n_mod_8; + n -= n_mod_8; + } + + __m512i v_bound = _mm512_set1_epi64(static_cast(bound)); + const __m512i* v_op_ptr = reinterpret_cast(operand1); + __m512i* v_result_ptr = reinterpret_cast<__m512i*>(result); + for (size_t i = n / 8; i > 0; --i) { + __m512i v_op = _mm512_loadu_si512(v_op_ptr); + __m512i v_add_diff = _mm512_hexl_cmp_epi64(v_op, v_bound, cmp, diff); + v_op = _mm512_add_epi64(v_op, v_add_diff); + _mm512_storeu_si512(v_result_ptr, v_op); + + ++v_result_ptr; + ++v_op_ptr; + } +} +#endif + +} // namespace hexl +} // namespace intel diff --git a/hexl_omp/eltwise/eltwise-cmp-add-avx512.hpp b/hexl_omp/eltwise/eltwise-cmp-add-avx512.hpp new file mode 100644 index 00000000..4142325a --- /dev/null +++ b/hexl_omp/eltwise/eltwise-cmp-add-avx512.hpp @@ -0,0 +1,26 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +#include "hexl/util/util.hpp" + +namespace intel { +namespace hexl { + +/// @brief Computes element-wise conditional addition. +/// @param[out] result Stores the result +/// @param[in] operand1 Vector of elements to compare +/// @param[in] n Number of elements in \p operand1 +/// @param[in] cmp Comparison operation +/// @param[in] bound Scalar to compare against +/// @param[in] diff Scalar to conditionally add +/// @details Computes result[i] = cmp(operand1[i], bound) ? operand1[i] + +/// diff : operand1[i] for all \f$i=0, ..., n-1\f$. +void EltwiseCmpAddAVX512(uint64_t* result, const uint64_t* operand1, uint64_t n, + CMPINT cmp, uint64_t bound, uint64_t diff); + +} // namespace hexl +} // namespace intel diff --git a/hexl_omp/eltwise/eltwise-cmp-add-internal.hpp b/hexl_omp/eltwise/eltwise-cmp-add-internal.hpp new file mode 100644 index 00000000..2a95eb5c --- /dev/null +++ b/hexl_omp/eltwise/eltwise-cmp-add-internal.hpp @@ -0,0 +1,26 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +#include "hexl/util/util.hpp" + +namespace intel { +namespace hexl { + +/// @brief Computes element-wise conditional addition. +/// @param[out] result Stores the result +/// @param[in] operand1 Vector of elements to compare +/// @param[in] n Number of elements in \p operand1 +/// @param[in] cmp Comparison operation +/// @param[in] bound Scalar to compare against +/// @param[in] diff Scalar to conditionally add +/// @details Computes result[i] = cmp(operand1[i], bound) ? operand1[i] + +/// diff : operand1[i] for all \f$i=0, ..., n-1\f$. +void EltwiseCmpAddNative(uint64_t* result, const uint64_t* operand1, uint64_t n, + CMPINT cmp, uint64_t bound, uint64_t diff); + +} // namespace hexl +} // namespace intel diff --git a/hexl_omp/eltwise/eltwise-cmp-add.cpp b/hexl_omp/eltwise/eltwise-cmp-add.cpp new file mode 100644 index 00000000..c01da888 --- /dev/null +++ b/hexl_omp/eltwise/eltwise-cmp-add.cpp @@ -0,0 +1,128 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "hexl/eltwise/eltwise-cmp-add.hpp" + +#include + +#include + +#include "eltwise/eltwise-cmp-add-avx512.hpp" +#include "eltwise/eltwise-cmp-add-internal.hpp" +#include "hexl/logging/logging.hpp" +#include "hexl/number-theory/number-theory.hpp" +#include "hexl/util/check.hpp" +#include "util/cpu-features.hpp" +namespace intel { +namespace hexl { + +void EltwiseCmpAdd(uint64_t* result, const uint64_t* operand1, uint64_t n, + CMPINT cmp, uint64_t bound, uint64_t diff) { + HEXL_CHECK(result != nullptr, "Require result != nullptr"); + HEXL_CHECK(operand1 != nullptr, "Require operand1 != nullptr"); + HEXL_CHECK(n != 0, "Require n != 0"); + HEXL_CHECK(diff != 0, "Require diff != 0"); + +#ifdef HEXL_HAS_AVX512DQ + if (has_avx512dq) { + EltwiseCmpAddAVX512(result, operand1, n, cmp, bound, diff); + return; + } +#endif + EltwiseCmpAddNative(result, operand1, n, cmp, bound, diff); +} + +void EltwiseCmpAddNative(uint64_t* result, const uint64_t* operand1, uint64_t n, + CMPINT cmp, uint64_t bound, uint64_t diff) { + HEXL_CHECK(result != nullptr, "Require result != nullptr"); + HEXL_CHECK(operand1 != nullptr, "Require operand1 != nullptr"); + HEXL_CHECK(n != 0, "Require n != 0"); + HEXL_CHECK(diff != 0, "Require diff != 0"); + + switch (cmp) { + case CMPINT::EQ: { + #pragma omp parallel for + for (size_t i = 0; i < n; ++i) { + if (operand1[i] == bound) { + result[i] = operand1[i] + diff; + } else { + result[i] = operand1[i]; + } + } + break; + } + + case CMPINT::LT: + #pragma omp parallel for + for (size_t i = 0; i < n; ++i) { + if (operand1[i] < bound) { + result[i] = operand1[i] + diff; + } else { + result[i] = operand1[i]; + } + } + break; + + case CMPINT::LE: + #pragma omp parallel for + for (size_t i = 0; i < n; ++i) { + if (operand1[i] <= bound) { + result[i] = operand1[i] + diff; + } else { + result[i] = operand1[i]; + } + } + break; + + case CMPINT::FALSE: + #pragma omp parallel for + for (size_t i = 0; i < n; ++i) { + result[i] = operand1[i]; + } + break; + + case CMPINT::NE: + #pragma omp parallel for + for (size_t i = 0; i < n; ++i) { + if (operand1[i] != bound) { + result[i] = operand1[i] + diff; + } else { + result[i] = operand1[i]; + } + } + break; + + case CMPINT::NLT: + #pragma omp parallel for + for (size_t i = 0; i < n; ++i) { + if (operand1[i] >= bound) { + result[i] = operand1[i] + diff; + } else { + result[i] = operand1[i]; + } + } + break; + + case CMPINT::NLE: + #pragma omp parallel for + for (size_t i = 0; i < n; ++i) { + if (operand1[i] > bound) { + result[i] = operand1[i] + diff; + } else { + result[i] = operand1[i]; + } + } + break; + + case CMPINT::TRUE: + #pragma omp parallel for + for (size_t i = 0; i < n; ++i) { + result[i] = operand1[i] + diff; + } + break; + } + +} + +} // namespace hexl +} // namespace intel diff --git a/hexl_omp/eltwise/eltwise-cmp-sub-mod-avx512.cpp b/hexl_omp/eltwise/eltwise-cmp-sub-mod-avx512.cpp new file mode 100644 index 00000000..4cda51d3 --- /dev/null +++ b/hexl_omp/eltwise/eltwise-cmp-sub-mod-avx512.cpp @@ -0,0 +1,39 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "eltwise/eltwise-cmp-sub-mod-avx512.hpp" + +#include + +#include "hexl/util/util.hpp" + +namespace intel { +namespace hexl { + +/// @brief Computes element-wise conditional modular subtraction. +/// @param[out] result Stores the result +/// @param[in] operand1 Vector of elements to compare +/// @param[in] n Number of elements in \p operand1 +/// @param[in] modulus Modulus to reduce by +/// @param[in] cmp Comparison function +/// @param[in] bound Scalar to compare against +/// @param[in] diff Scalar to subtract by +/// @details Computes \p result[i] = (\p cmp(\p operand1, \p bound)) ? (\p +/// operand1 - \p diff) mod \p modulus : \p operand1 for all i=0, ..., n-1 + +#ifdef HEXL_HAS_AVX512DQ +template void EltwiseCmpSubModAVX512<64>(uint64_t* result, + const uint64_t* operand1, uint64_t n, + uint64_t modulus, CMPINT cmp, + uint64_t bound, uint64_t diff); +#endif + +#ifdef HEXL_HAS_AVX512IFMA +template void EltwiseCmpSubModAVX512<52>(uint64_t* result, + const uint64_t* operand1, uint64_t n, + uint64_t modulus, CMPINT cmp, + uint64_t bound, uint64_t diff); +#endif + +} // namespace hexl +} // namespace intel diff --git a/hexl_omp/eltwise/eltwise-cmp-sub-mod-avx512.hpp b/hexl_omp/eltwise/eltwise-cmp-sub-mod-avx512.hpp new file mode 100644 index 00000000..ff5a9421 --- /dev/null +++ b/hexl_omp/eltwise/eltwise-cmp-sub-mod-avx512.hpp @@ -0,0 +1,87 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include + +#include "eltwise/eltwise-cmp-sub-mod-internal.hpp" +#include "hexl/number-theory/number-theory.hpp" +#include "hexl/util/check.hpp" +#include "util/avx512-util.hpp" + +namespace intel { +namespace hexl { + +#ifdef HEXL_HAS_AVX512DQ +template +void EltwiseCmpSubModAVX512(uint64_t* result, const uint64_t* operand1, + uint64_t n, uint64_t modulus, CMPINT cmp, + uint64_t bound, uint64_t diff) { + HEXL_CHECK(result != nullptr, "Require result != nullptr"); + HEXL_CHECK(operand1 != nullptr, "Require operand1 != nullptr"); + HEXL_CHECK(n != 0, "Require n != 0") + HEXL_CHECK(modulus > 1, "Require modulus > 1"); + HEXL_CHECK(diff != 0, "Require diff != 0"); + HEXL_CHECK(diff < modulus, "Diff " << diff << " >= modulus " << modulus); + + uint64_t n_mod_8 = n % 8; + if (n_mod_8 != 0) { + EltwiseCmpSubModNative(result, operand1, n_mod_8, modulus, cmp, bound, + diff); + operand1 += n_mod_8; + result += n_mod_8; + n -= n_mod_8; + } + HEXL_CHECK(diff < modulus, "Diff " << diff << " >= modulus " << modulus); + + const __m512i* v_op_ptr = reinterpret_cast(operand1); + __m512i* v_result_ptr = reinterpret_cast<__m512i*>(result); + __m512i v_bound = _mm512_set1_epi64(static_cast(bound)); + __m512i v_diff = _mm512_set1_epi64(static_cast(diff)); + __m512i v_modulus = _mm512_set1_epi64(static_cast(modulus)); + + uint64_t mu = MultiplyFactor(1, BitShift, modulus).BarrettFactor(); + __m512i v_mu = _mm512_set1_epi64(static_cast(mu)); + + // Multi-word Barrett reduction precomputation + constexpr int64_t beta = -2; + const uint64_t ceil_log_mod = Log2(modulus) + 1; // "n" from Algorithm 2 + uint64_t prod_right_shift = ceil_log_mod + beta; + __m512i v_neg_mod = _mm512_set1_epi64(-static_cast(modulus)); + + uint64_t alpha = BitShift - 2; + uint64_t mu_64 = + MultiplyFactor(uint64_t(1) << (ceil_log_mod + alpha - BitShift), BitShift, + modulus) + .BarrettFactor(); + + if (BitShift == 64) { + // Single-worded Barrett reduction. + mu_64 = MultiplyFactor(1, 64, modulus).BarrettFactor(); + } + + __m512i v_mu_64 = _mm512_set1_epi64(static_cast(mu_64)); + + for (size_t i = n / 8; i > 0; --i) { + __m512i v_op = _mm512_loadu_si512(v_op_ptr); + __mmask8 op_le_cmp = _mm512_hexl_cmp_epu64_mask(v_op, v_bound, Not(cmp)); + + v_op = _mm512_hexl_barrett_reduce64( + v_op, v_modulus, v_mu_64, v_mu, prod_right_shift, v_neg_mod); + + __m512i v_to_add = _mm512_hexl_cmp_epi64(v_op, v_diff, CMPINT::LT, modulus); + v_to_add = _mm512_sub_epi64(v_to_add, v_diff); + v_to_add = _mm512_mask_set1_epi64(v_to_add, op_le_cmp, 0); + + v_op = _mm512_add_epi64(v_op, v_to_add); + _mm512_storeu_si512(v_result_ptr, v_op); + ++v_op_ptr; + ++v_result_ptr; + } +} +#endif + +} // namespace hexl +} // namespace intel diff --git a/hexl_omp/eltwise/eltwise-cmp-sub-mod-internal.hpp b/hexl_omp/eltwise/eltwise-cmp-sub-mod-internal.hpp new file mode 100644 index 00000000..f988058e --- /dev/null +++ b/hexl_omp/eltwise/eltwise-cmp-sub-mod-internal.hpp @@ -0,0 +1,28 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +#include "hexl/util/util.hpp" + +namespace intel { +namespace hexl { + +/// @brief Computes element-wise conditional modular subtraction. +/// @param[out] result Stores the result +/// @param[in] operand1 Vector of elements to compare +/// @param[in] n Number of elements in \p operand1 +/// @param[in] modulus Modulus to reduce by +/// @param[in] cmp Comparison function +/// @param[in] bound Scalar to compare against +/// @param[in] diff Scalar to subtract by +/// @details Computes \p result[i] = (\p cmp(\p operand1, \p bound)) ? (\p +/// operand1 - \p diff) mod \p modulus : \p operand1 for all i=0, ..., n-1 +void EltwiseCmpSubModNative(uint64_t* result, const uint64_t* operand1, + uint64_t n, uint64_t modulus, CMPINT cmp, + uint64_t bound, uint64_t diff); + +} // namespace hexl +} // namespace intel diff --git a/hexl_omp/eltwise/eltwise-cmp-sub-mod.cpp b/hexl_omp/eltwise/eltwise-cmp-sub-mod.cpp new file mode 100644 index 00000000..a33169b2 --- /dev/null +++ b/hexl_omp/eltwise/eltwise-cmp-sub-mod.cpp @@ -0,0 +1,76 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "hexl/eltwise/eltwise-cmp-sub-mod.hpp" + +#include + +#include + +#include "eltwise/eltwise-cmp-sub-mod-avx512.hpp" +#include "eltwise/eltwise-cmp-sub-mod-internal.hpp" +#include "hexl/logging/logging.hpp" +#include "hexl/number-theory/number-theory.hpp" +#include "hexl/util/check.hpp" +#include "hexl/util/util.hpp" +#include "util/cpu-features.hpp" +#include "util/util-internal.hpp" + +namespace intel { +namespace hexl { + +void EltwiseCmpSubMod(uint64_t* result, const uint64_t* operand1, uint64_t n, + uint64_t modulus, CMPINT cmp, uint64_t bound, + uint64_t diff) { + HEXL_CHECK(result != nullptr, "Require result != nullptr"); + HEXL_CHECK(operand1 != nullptr, "Require operand1 != nullptr"); + HEXL_CHECK(n != 0, "Require n != 0"); + HEXL_CHECK(modulus > 1, "Require modulus > 1"); + HEXL_CHECK(diff != 0, "Require diff != 0"); + +#ifdef HEXL_HAS_AVX512IFMA + if (has_avx512ifma) { + if (modulus < (1ULL << 52)) { + EltwiseCmpSubModAVX512<52>(result, operand1, n, modulus, cmp, bound, + diff); + return; + } + } +#endif + +#ifdef HEXL_HAS_AVX512DQ + if (has_avx512dq) { + EltwiseCmpSubModAVX512<64>(result, operand1, n, modulus, cmp, bound, diff); + return; + } +#endif + EltwiseCmpSubModNative(result, operand1, n, modulus, cmp, bound, diff); + return; +} + +void EltwiseCmpSubModNative(uint64_t* result, const uint64_t* operand1, + uint64_t n, uint64_t modulus, CMPINT cmp, + uint64_t bound, uint64_t diff) { + HEXL_CHECK(result != nullptr, "Require result != nullptr"); + HEXL_CHECK(operand1 != nullptr, "Require operand1 != nullptr"); + HEXL_CHECK(n != 0, "Require n != 0") + HEXL_CHECK(modulus > 1, "Require modulus > 1"); + HEXL_CHECK(diff != 0, "Require diff != 0"); + HEXL_CHECK(diff < modulus, "Diff " << diff << " >= modulus " << modulus); + + +#pragma omp parallel for + for (size_t i = 0; i < n; ++i) { + uint64_t op = operand1[i]; + bool op_cmp = Compare(cmp, op, bound); + op %= modulus; + if (op_cmp) { + op = SubUIntMod(op, diff, modulus); + } + result[i] = op; + } + +} + +} // namespace hexl +} // namespace intel diff --git a/hexl_omp/eltwise/eltwise-fma-mod-avx512.cpp b/hexl_omp/eltwise/eltwise-fma-mod-avx512.cpp new file mode 100644 index 00000000..fa6a5453 --- /dev/null +++ b/hexl_omp/eltwise/eltwise-fma-mod-avx512.cpp @@ -0,0 +1,156 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "eltwise/eltwise-fma-mod-avx512.hpp" + +#include + +#include "hexl/eltwise/eltwise-fma-mod.hpp" +#include "hexl/number-theory/number-theory.hpp" +#include "hexl/util/check.hpp" +#include "util/avx512-util.hpp" + +namespace intel { +namespace hexl { + +#ifdef HEXL_HAS_AVX512IFMA +template void EltwiseFMAModAVX512<52, 1>(uint64_t* result, const uint64_t* arg1, + uint64_t arg2, const uint64_t* arg3, + uint64_t n, uint64_t modulus); +template void EltwiseFMAModAVX512<52, 2>(uint64_t* result, const uint64_t* arg1, + uint64_t arg2, const uint64_t* arg3, + uint64_t n, uint64_t modulus); +template void EltwiseFMAModAVX512<52, 4>(uint64_t* result, const uint64_t* arg1, + uint64_t arg2, const uint64_t* arg3, + uint64_t n, uint64_t modulus); +template void EltwiseFMAModAVX512<52, 8>(uint64_t* result, const uint64_t* arg1, + uint64_t arg2, const uint64_t* arg3, + uint64_t n, uint64_t modulus); +#endif + +#ifdef HEXL_HAS_AVX512DQ +template void EltwiseFMAModAVX512<64, 1>(uint64_t* result, const uint64_t* arg1, + uint64_t arg2, const uint64_t* arg3, + uint64_t n, uint64_t modulus); +template void EltwiseFMAModAVX512<64, 2>(uint64_t* result, const uint64_t* arg1, + uint64_t arg2, const uint64_t* arg3, + uint64_t n, uint64_t modulus); +template void EltwiseFMAModAVX512<64, 4>(uint64_t* result, const uint64_t* arg1, + uint64_t arg2, const uint64_t* arg3, + uint64_t n, uint64_t modulus); +template void EltwiseFMAModAVX512<64, 8>(uint64_t* result, const uint64_t* arg1, + uint64_t arg2, const uint64_t* arg3, + uint64_t n, uint64_t modulus); + +#endif + +#ifdef HEXL_HAS_AVX512DQ + +/// uses Shoup's modular multiplication. See Algorithm 4 of +/// https://arxiv.org/pdf/2012.01968.pdf +template +void EltwiseFMAModAVX512(uint64_t* result, const uint64_t* arg1, uint64_t arg2, + const uint64_t* arg3, uint64_t n, uint64_t modulus) { + HEXL_CHECK(modulus < MaximumValue(BitShift), + "Modulus " << modulus << " exceeds bit shift bound " + << MaximumValue(BitShift)); + HEXL_CHECK(modulus != 0, "Require modulus != 0"); + + HEXL_CHECK(arg1, "arg1 == nullptr"); + HEXL_CHECK(result, "result == nullptr"); + + HEXL_CHECK_BOUNDS(arg1, n, InputModFactor * modulus, + "arg1 exceeds bound " << (InputModFactor * modulus)); + HEXL_CHECK_BOUNDS(&arg2, 1, InputModFactor * modulus, + "arg2 exceeds bound " << (InputModFactor * modulus)); + HEXL_CHECK(BitShift == 52 || BitShift == 64, + "Invalid bitshift " << BitShift << "; need 52 or 64"); + + uint64_t n_mod_8 = n % 8; + if (n_mod_8 != 0) { + EltwiseFMAModNative(result, arg1, arg2, arg3, n_mod_8, + modulus); + arg1 += n_mod_8; + if (arg3 != nullptr) { + arg3 += n_mod_8; + } + result += n_mod_8; + n -= n_mod_8; + } + + uint64_t twice_modulus = 2 * modulus; + uint64_t four_times_modulus = 4 * modulus; + arg2 = ReduceMod(arg2, modulus, &twice_modulus, + &four_times_modulus); + uint64_t arg2_barr = MultiplyFactor(arg2, BitShift, modulus).BarrettFactor(); + + __m512i varg2_barr = _mm512_set1_epi64(static_cast(arg2_barr)); + + __m512i vmodulus = _mm512_set1_epi64(static_cast(modulus)); + __m512i vneg_modulus = _mm512_set1_epi64(-static_cast(modulus)); + __m512i v2_modulus = _mm512_set1_epi64(static_cast(2 * modulus)); + __m512i v4_modulus = _mm512_set1_epi64(static_cast(4 * modulus)); + const __m512i* vp_arg1 = reinterpret_cast(arg1); + __m512i varg2 = _mm512_set1_epi64(static_cast(arg2)); + varg2 = _mm512_hexl_small_mod_epu64(varg2, vmodulus, + &v2_modulus, &v4_modulus); + + __m512i* vp_result = reinterpret_cast<__m512i*>(result); + + if (arg3) { + const __m512i* vp_arg3 = reinterpret_cast(arg3); + HEXL_LOOP_UNROLL_8 + for (size_t i = n / 8; i > 0; --i) { + __m512i varg1 = _mm512_loadu_si512(vp_arg1); + __m512i varg3 = _mm512_loadu_si512(vp_arg3); + + varg1 = _mm512_hexl_small_mod_epu64( + varg1, vmodulus, &v2_modulus, &v4_modulus); + varg3 = _mm512_hexl_small_mod_epu64( + varg3, vmodulus, &v2_modulus, &v4_modulus); + + __m512i va_times_b = _mm512_hexl_mullo_epi(varg1, varg2); + __m512i vq = _mm512_hexl_mulhi_epi(varg1, varg2_barr); + + // Compute vq in [0, 2 * p) where p is the modulus + // a * b - q * p + vq = _mm512_hexl_mullo_add_lo_epi(va_times_b, vq, vneg_modulus); + + // Add arg3, bringing vq to [0, 3 * p) + vq = _mm512_add_epi64(vq, varg3); + // Reduce to [0, p) + vq = _mm512_hexl_small_mod_epu64<4>(vq, vmodulus, &v2_modulus); + + _mm512_storeu_si512(vp_result, vq); + + ++vp_arg1; + ++vp_result; + ++vp_arg3; + } + } else { // arg3 == nullptr + HEXL_LOOP_UNROLL_8 + for (size_t i = n / 8; i > 0; --i) { + __m512i varg1 = _mm512_loadu_si512(vp_arg1); + varg1 = _mm512_hexl_small_mod_epu64( + varg1, vmodulus, &v2_modulus, &v4_modulus); + + __m512i va_times_b = _mm512_hexl_mullo_epi(varg1, varg2); + __m512i vq = _mm512_hexl_mulhi_epi(varg1, varg2_barr); + + // Compute vq in [0, 2 * p) where p is the modulus + // a * b - q * p + vq = _mm512_hexl_mullo_add_lo_epi(va_times_b, vq, vneg_modulus); + // Conditional Barrett subtraction + vq = _mm512_hexl_small_mod_epu64(vq, vmodulus); + _mm512_storeu_si512(vp_result, vq); + + ++vp_arg1; + ++vp_result; + } + } +} + +#endif + +} // namespace hexl +} // namespace intel diff --git a/hexl_omp/eltwise/eltwise-fma-mod-avx512.hpp b/hexl_omp/eltwise/eltwise-fma-mod-avx512.hpp new file mode 100644 index 00000000..f0750165 --- /dev/null +++ b/hexl_omp/eltwise/eltwise-fma-mod-avx512.hpp @@ -0,0 +1,22 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +#include "eltwise/eltwise-fma-mod-internal.hpp" + +namespace intel { +namespace hexl { + +#ifdef HEXL_HAS_AVX512DQ + +template +void EltwiseFMAModAVX512(uint64_t* result, const uint64_t* arg1, uint64_t arg2, + const uint64_t* arg3, uint64_t n, uint64_t modulus); + +#endif + +} // namespace hexl +} // namespace intel diff --git a/hexl_omp/eltwise/eltwise-fma-mod-internal.hpp b/hexl_omp/eltwise/eltwise-fma-mod-internal.hpp new file mode 100644 index 00000000..00d0b4b3 --- /dev/null +++ b/hexl_omp/eltwise/eltwise-fma-mod-internal.hpp @@ -0,0 +1,50 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +#include + +#include "hexl/number-theory/number-theory.hpp" + +namespace intel { +namespace hexl { + +template +void EltwiseFMAModNative(uint64_t* result, const uint64_t* arg1, uint64_t arg2, + const uint64_t* arg3, uint64_t n, uint64_t modulus) { + uint64_t twice_modulus = 2 * modulus; + uint64_t four_times_modulus = 4 * modulus; + arg2 = ReduceMod(arg2, modulus, &twice_modulus, + &four_times_modulus); + + MultiplyFactor mf(arg2, 64, modulus); + + + if (arg3) { +#pragma omp parallel for + for (size_t i = 0; i < n; ++i) { + uint64_t arg1_val = ReduceMod( + arg1[i], modulus, &twice_modulus, &four_times_modulus); + uint64_t arg3_val = ReduceMod( + arg3[i], modulus, &twice_modulus, &four_times_modulus); + + uint64_t result_val = + MultiplyMod(arg1_val, arg2, mf.BarrettFactor(), modulus); + result[i] = AddUIntMod(result_val, arg3_val, modulus); + } + } else { // arg3 == nullptr +#pragma omp parallel for + for (size_t i = 0; i < n; ++i) { + uint64_t arg1_val = ReduceMod( + arg1[i], modulus, &twice_modulus, &four_times_modulus); + result[i] = MultiplyMod(arg1_val, arg2, mf.BarrettFactor(), modulus); + } + } + +} + +} // namespace hexl +} // namespace intel diff --git a/hexl_omp/eltwise/eltwise-fma-mod.cpp b/hexl_omp/eltwise/eltwise-fma-mod.cpp new file mode 100644 index 00000000..03478fc0 --- /dev/null +++ b/hexl_omp/eltwise/eltwise-fma-mod.cpp @@ -0,0 +1,104 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "hexl/eltwise/eltwise-fma-mod.hpp" + +#include + +#include "eltwise/eltwise-fma-mod-avx512.hpp" +#include "eltwise/eltwise-fma-mod-internal.hpp" +#include "hexl/logging/logging.hpp" +#include "hexl/number-theory/number-theory.hpp" +#include "util/cpu-features.hpp" + +namespace intel { +namespace hexl { + +void EltwiseFMAMod(uint64_t* result, const uint64_t* arg1, uint64_t arg2, + const uint64_t* arg3, uint64_t n, uint64_t modulus, + uint64_t input_mod_factor) { + HEXL_CHECK(result != nullptr, "Require result != nullptr"); + HEXL_CHECK(arg1 != nullptr, "Require arg1 != nullptr"); + HEXL_CHECK(n != 0, "Require n != 0") + HEXL_CHECK(modulus > 1, "Require modulus > 1"); + HEXL_CHECK(modulus < (1ULL << 61), "Require modulus < (1ULL << 61)"); + HEXL_CHECK( + input_mod_factor == 1 || input_mod_factor == 2 || input_mod_factor == 4 || + input_mod_factor == 8, + "input_mod_factor must be 1, 2, 4, or 8. Got " << input_mod_factor); + HEXL_CHECK( + arg2 < input_mod_factor * modulus, + "arg2 " << arg2 << " exceeds bound " << (input_mod_factor * modulus)); + + HEXL_CHECK_BOUNDS(arg1, n, input_mod_factor * modulus, + "arg1 value " << (*std::max_element(arg1, arg1 + n)) + << " in EltwiseFMAMod exceeds bound " + << (input_mod_factor * modulus)); + HEXL_CHECK(arg3 == nullptr || (*std::max_element(arg3, arg3 + n) < + (input_mod_factor * modulus)), + "arg3 value in EltwiseFMAMod exceeds bound " + << (input_mod_factor * modulus)); + +#ifdef HEXL_HAS_AVX512IFMA + if (has_avx512ifma && input_mod_factor * modulus < (1ULL << 51)) { + HEXL_VLOG(3, "Calling 52-bit EltwiseFMAModAVX512"); + + switch (input_mod_factor) { + case 1: + EltwiseFMAModAVX512<52, 1>(result, arg1, arg2, arg3, n, modulus); + break; + case 2: + EltwiseFMAModAVX512<52, 2>(result, arg1, arg2, arg3, n, modulus); + break; + case 4: + EltwiseFMAModAVX512<52, 4>(result, arg1, arg2, arg3, n, modulus); + break; + case 8: + EltwiseFMAModAVX512<52, 8>(result, arg1, arg2, arg3, n, modulus); + break; + } + return; + } +#endif + +#ifdef HEXL_HAS_AVX512DQ + if (has_avx512dq) { + HEXL_VLOG(3, "Calling 64-bit EltwiseFMAModAVX512"); + + switch (input_mod_factor) { + case 1: + EltwiseFMAModAVX512<64, 1>(result, arg1, arg2, arg3, n, modulus); + break; + case 2: + EltwiseFMAModAVX512<64, 2>(result, arg1, arg2, arg3, n, modulus); + break; + case 4: + EltwiseFMAModAVX512<64, 4>(result, arg1, arg2, arg3, n, modulus); + break; + case 8: + EltwiseFMAModAVX512<64, 8>(result, arg1, arg2, arg3, n, modulus); + break; + } + return; + } +#endif + + HEXL_VLOG(3, "Calling EltwiseFMAModNative"); + switch (input_mod_factor) { + case 1: + EltwiseFMAModNative<1>(result, arg1, arg2, arg3, n, modulus); + break; + case 2: + EltwiseFMAModNative<2>(result, arg1, arg2, arg3, n, modulus); + break; + case 4: + EltwiseFMAModNative<4>(result, arg1, arg2, arg3, n, modulus); + break; + case 8: + EltwiseFMAModNative<8>(result, arg1, arg2, arg3, n, modulus); + break; + } +} + +} // namespace hexl +} // namespace intel diff --git a/hexl_omp/eltwise/eltwise-mult-mod-avx512.hpp b/hexl_omp/eltwise/eltwise-mult-mod-avx512.hpp new file mode 100644 index 00000000..e00aa702 --- /dev/null +++ b/hexl_omp/eltwise/eltwise-mult-mod-avx512.hpp @@ -0,0 +1,81 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +#include "eltwise/eltwise-mult-mod-internal.hpp" +#include "hexl/eltwise/eltwise-mult-mod.hpp" +#include "hexl/number-theory/number-theory.hpp" + +namespace intel { +namespace hexl { + +#ifdef HEXL_HAS_AVX512DQ + +/// @brief Multiplies two vectors elementwise with modular reduction +/// @param[in] result Result of element-wise multiplication +/// @param[in] operand1 Vector of elements to multiply. Each element must be +/// less than the modulus. +/// @param[in] operand2 Vector of elements to multiply. Each element must be +/// less than the modulus. +/// @param[in] n Number of elements in each vector +/// @param[in] modulus Modulus with which to perform modular reduction +/// @param[in] input_mod_factor Assumes input elements are in [0, +/// input_mod_factor * p) Must be 1, 2 or 4. +/// @details Computes \p result[i] = (\p operand1[i] * \p operand2[i]) mod \p +/// modulus for i=0, ..., \p n - 1 +/// @details Barrett's algorithm for vector-vector modular multiplication +/// (Algorithm 1 from https://hal.archives-ouvertes.fr/hal-01215845/document) +/// using AVX512IFMA +template +void EltwiseMultModAVX512IFMAInt(uint64_t* result, const uint64_t* operand1, + const uint64_t* operand2, uint64_t n, + uint64_t modulus); + +/// @brief Multiplies two vectors elementwise with modular reduction +/// @param[in] result Result of element-wise multiplication +/// @param[in] operand1 Vector of elements to multiply. Each element must be +/// less than the modulus. +/// @param[in] operand2 Vector of elements to multiply. Each element must be +/// less than the modulus. +/// @param[in] n Number of elements in each vector +/// @param[in] modulus Modulus with which to perform modular reduction +/// @param[in] input_mod_factor Assumes input elements are in [0, +/// input_mod_factor * p) Must be 1, 2 or 4. +/// @details Computes \p result[i] = (\p operand1[i] * \p operand2[i]) mod \p +/// modulus for i=0, ..., \p n - 1 +/// @details Barrett's algorithm for vector-vector modular multiplication +/// (Algorithm 1 from https://hal.archives-ouvertes.fr/hal-01215845/document) +/// using AVX512DQ +template +void EltwiseMultModAVX512DQInt(uint64_t* result, const uint64_t* operand1, + const uint64_t* operand2, uint64_t n, + uint64_t modulus); + +/// @brief Multiplies two vectors elementwise with modular reduction +/// @param[in] result Result of element-wise multiplication +/// @param[in] operand1 Vector of elements to multiply. Each element must be +/// less than the modulus. +/// @param[in] operand2 Vector of elements to multiply. Each element must be +/// less than the modulus. +/// @param[in] n Number of elements in each vector +/// @param[in] modulus Modulus with which to perform modular reduction +/// @param[in] input_mod_factor Assumes input elements are in [0, +/// input_mod_factor * p) Must be 1, 2 or 4. +/// @details Computes \p result[i] = (\p operand1[i] * \p operand2[i]) mod \p +/// modulus for i=0, ..., \p n - 1 +/// @details Function 18 on page 19 of https://arxiv.org/pdf/1407.3383.pdf +/// See also Algorithm 2/3 of +/// https://hal.archives-ouvertes.fr/hal-02552673/document +/// Uses floating-point arithmetic +template +void EltwiseMultModAVX512Float(uint64_t* result, const uint64_t* operand1, + const uint64_t* operand2, uint64_t n, + uint64_t modulus); + +#endif // HEXL_HAS_AVX512DQ + +} // namespace hexl +} // namespace intel diff --git a/hexl_omp/eltwise/eltwise-mult-mod-avx512dq.cpp b/hexl_omp/eltwise/eltwise-mult-mod-avx512dq.cpp new file mode 100644 index 00000000..cbf0ec0f --- /dev/null +++ b/hexl_omp/eltwise/eltwise-mult-mod-avx512dq.cpp @@ -0,0 +1,838 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include +#include + +#include + +#include "eltwise/eltwise-mult-mod-avx512.hpp" +#include "eltwise/eltwise-mult-mod-internal.hpp" +#include "hexl/eltwise/eltwise-mult-mod.hpp" +#include "hexl/logging/logging.hpp" +#include "hexl/number-theory/number-theory.hpp" +#include "hexl/util/check.hpp" +#include "hexl/util/compiler.hpp" +#include "hexl/util/defines.hpp" +#include "util/avx512-util.hpp" + +namespace intel { +namespace hexl { + +#ifdef HEXL_HAS_AVX512DQ + +template void EltwiseMultModAVX512Float<1>(uint64_t* result, + const uint64_t* operand1, + const uint64_t* operand2, uint64_t n, + uint64_t modulus); +template void EltwiseMultModAVX512Float<2>(uint64_t* result, + const uint64_t* operand1, + const uint64_t* operand2, uint64_t n, + uint64_t modulus); +template void EltwiseMultModAVX512Float<4>(uint64_t* result, + const uint64_t* operand1, + const uint64_t* operand2, uint64_t n, + uint64_t modulus); + +template void EltwiseMultModAVX512DQInt<1>(uint64_t* result, + const uint64_t* operand1, + const uint64_t* operand2, uint64_t n, + uint64_t modulus); +template void EltwiseMultModAVX512DQInt<2>(uint64_t* result, + const uint64_t* operand1, + const uint64_t* operand2, uint64_t n, + uint64_t modulus); +template void EltwiseMultModAVX512DQInt<4>(uint64_t* result, + const uint64_t* operand1, + const uint64_t* operand2, uint64_t n, + uint64_t modulus); + +#endif + +#ifdef HEXL_HAS_AVX512DQ + +template +void EltwiseMultModAVX512DQIntLoopUnroll(__m512i* vp_result, + const __m512i* vp_operand1, + const __m512i* vp_operand2, + __m512i v_barr_lo, __m512i v_modulus, + __m512i v_twice_mod) { + constexpr size_t manual_unroll_factor = 16; + constexpr size_t avx512_64bit_count = 8; + constexpr size_t loop_count = + CoeffCount / (manual_unroll_factor * avx512_64bit_count); + + static_assert(loop_count > 0, "loop_count too small for unrolling"); + static_assert(CoeffCount % (manual_unroll_factor * avx512_64bit_count) == 0, + "CoeffCount must be a factor of manual_unroll_factor * " + "avx512_64bit_count"); + + HEXL_UNUSED(v_twice_mod); + HEXL_LOOP_UNROLL_4 + for (size_t i = loop_count; i > 0; --i) { + __m512i x1 = _mm512_loadu_si512(vp_operand1++); + __m512i y1 = _mm512_loadu_si512(vp_operand2++); + __m512i x2 = _mm512_loadu_si512(vp_operand1++); + __m512i y2 = _mm512_loadu_si512(vp_operand2++); + __m512i x3 = _mm512_loadu_si512(vp_operand1++); + __m512i y3 = _mm512_loadu_si512(vp_operand2++); + __m512i x4 = _mm512_loadu_si512(vp_operand1++); + __m512i y4 = _mm512_loadu_si512(vp_operand2++); + __m512i x5 = _mm512_loadu_si512(vp_operand1++); + __m512i y5 = _mm512_loadu_si512(vp_operand2++); + __m512i x6 = _mm512_loadu_si512(vp_operand1++); + __m512i y6 = _mm512_loadu_si512(vp_operand2++); + __m512i x7 = _mm512_loadu_si512(vp_operand1++); + __m512i y7 = _mm512_loadu_si512(vp_operand2++); + __m512i x8 = _mm512_loadu_si512(vp_operand1++); + __m512i y8 = _mm512_loadu_si512(vp_operand2++); + __m512i x9 = _mm512_loadu_si512(vp_operand1++); + __m512i y9 = _mm512_loadu_si512(vp_operand2++); + __m512i x10 = _mm512_loadu_si512(vp_operand1++); + __m512i y10 = _mm512_loadu_si512(vp_operand2++); + __m512i x11 = _mm512_loadu_si512(vp_operand1++); + __m512i y11 = _mm512_loadu_si512(vp_operand2++); + __m512i x12 = _mm512_loadu_si512(vp_operand1++); + __m512i y12 = _mm512_loadu_si512(vp_operand2++); + __m512i x13 = _mm512_loadu_si512(vp_operand1++); + __m512i y13 = _mm512_loadu_si512(vp_operand2++); + __m512i x14 = _mm512_loadu_si512(vp_operand1++); + __m512i y14 = _mm512_loadu_si512(vp_operand2++); + __m512i x15 = _mm512_loadu_si512(vp_operand1++); + __m512i y15 = _mm512_loadu_si512(vp_operand2++); + __m512i x16 = _mm512_loadu_si512(vp_operand1++); + __m512i y16 = _mm512_loadu_si512(vp_operand2++); + + x1 = _mm512_hexl_small_mod_epu64(x1, v_modulus, + &v_twice_mod); + x2 = _mm512_hexl_small_mod_epu64(x2, v_modulus, + &v_twice_mod); + x3 = _mm512_hexl_small_mod_epu64(x3, v_modulus, + &v_twice_mod); + x4 = _mm512_hexl_small_mod_epu64(x4, v_modulus, + &v_twice_mod); + x5 = _mm512_hexl_small_mod_epu64(x5, v_modulus, + &v_twice_mod); + x6 = _mm512_hexl_small_mod_epu64(x6, v_modulus, + &v_twice_mod); + x7 = _mm512_hexl_small_mod_epu64(x7, v_modulus, + &v_twice_mod); + x8 = _mm512_hexl_small_mod_epu64(x8, v_modulus, + &v_twice_mod); + x9 = _mm512_hexl_small_mod_epu64(x9, v_modulus, + &v_twice_mod); + x10 = _mm512_hexl_small_mod_epu64(x10, v_modulus, + &v_twice_mod); + x11 = _mm512_hexl_small_mod_epu64(x11, v_modulus, + &v_twice_mod); + x12 = _mm512_hexl_small_mod_epu64(x12, v_modulus, + &v_twice_mod); + x13 = _mm512_hexl_small_mod_epu64(x13, v_modulus, + &v_twice_mod); + x14 = _mm512_hexl_small_mod_epu64(x14, v_modulus, + &v_twice_mod); + x15 = _mm512_hexl_small_mod_epu64(x15, v_modulus, + &v_twice_mod); + x16 = _mm512_hexl_small_mod_epu64(x16, v_modulus, + &v_twice_mod); + + y1 = _mm512_hexl_small_mod_epu64(y1, v_modulus, + &v_twice_mod); + y2 = _mm512_hexl_small_mod_epu64(y2, v_modulus, + &v_twice_mod); + y3 = _mm512_hexl_small_mod_epu64(y3, v_modulus, + &v_twice_mod); + y4 = _mm512_hexl_small_mod_epu64(y4, v_modulus, + &v_twice_mod); + y5 = _mm512_hexl_small_mod_epu64(y5, v_modulus, + &v_twice_mod); + y6 = _mm512_hexl_small_mod_epu64(y6, v_modulus, + &v_twice_mod); + y7 = _mm512_hexl_small_mod_epu64(y7, v_modulus, + &v_twice_mod); + y8 = _mm512_hexl_small_mod_epu64(y8, v_modulus, + &v_twice_mod); + y9 = _mm512_hexl_small_mod_epu64(y9, v_modulus, + &v_twice_mod); + y10 = _mm512_hexl_small_mod_epu64(y10, v_modulus, + &v_twice_mod); + y11 = _mm512_hexl_small_mod_epu64(y11, v_modulus, + &v_twice_mod); + y12 = _mm512_hexl_small_mod_epu64(y12, v_modulus, + &v_twice_mod); + y13 = _mm512_hexl_small_mod_epu64(y13, v_modulus, + &v_twice_mod); + y14 = _mm512_hexl_small_mod_epu64(y14, v_modulus, + &v_twice_mod); + y15 = _mm512_hexl_small_mod_epu64(y15, v_modulus, + &v_twice_mod); + y16 = _mm512_hexl_small_mod_epu64(y16, v_modulus, + &v_twice_mod); + + __m512i zhi1 = _mm512_hexl_mulhi_epi<64>(x1, y1); + __m512i zhi2 = _mm512_hexl_mulhi_epi<64>(x2, y2); + __m512i zhi3 = _mm512_hexl_mulhi_epi<64>(x3, y3); + __m512i zhi4 = _mm512_hexl_mulhi_epi<64>(x4, y4); + __m512i zhi5 = _mm512_hexl_mulhi_epi<64>(x5, y5); + __m512i zhi6 = _mm512_hexl_mulhi_epi<64>(x6, y6); + __m512i zhi7 = _mm512_hexl_mulhi_epi<64>(x7, y7); + __m512i zhi8 = _mm512_hexl_mulhi_epi<64>(x8, y8); + __m512i zhi9 = _mm512_hexl_mulhi_epi<64>(x9, y9); + __m512i zhi10 = _mm512_hexl_mulhi_epi<64>(x10, y10); + __m512i zhi11 = _mm512_hexl_mulhi_epi<64>(x11, y11); + __m512i zhi12 = _mm512_hexl_mulhi_epi<64>(x12, y12); + __m512i zhi13 = _mm512_hexl_mulhi_epi<64>(x13, y13); + __m512i zhi14 = _mm512_hexl_mulhi_epi<64>(x14, y14); + __m512i zhi15 = _mm512_hexl_mulhi_epi<64>(x15, y15); + __m512i zhi16 = _mm512_hexl_mulhi_epi<64>(x16, y16); + + __m512i zlo1 = _mm512_hexl_mullo_epi<64>(x1, y1); + __m512i zlo2 = _mm512_hexl_mullo_epi<64>(x2, y2); + __m512i zlo3 = _mm512_hexl_mullo_epi<64>(x3, y3); + __m512i zlo4 = _mm512_hexl_mullo_epi<64>(x4, y4); + __m512i zlo5 = _mm512_hexl_mullo_epi<64>(x5, y5); + __m512i zlo6 = _mm512_hexl_mullo_epi<64>(x6, y6); + __m512i zlo7 = _mm512_hexl_mullo_epi<64>(x7, y7); + __m512i zlo8 = _mm512_hexl_mullo_epi<64>(x8, y8); + __m512i zlo9 = _mm512_hexl_mullo_epi<64>(x9, y9); + __m512i zlo10 = _mm512_hexl_mullo_epi<64>(x10, y10); + __m512i zlo11 = _mm512_hexl_mullo_epi<64>(x11, y11); + __m512i zlo12 = _mm512_hexl_mullo_epi<64>(x12, y12); + __m512i zlo13 = _mm512_hexl_mullo_epi<64>(x13, y13); + __m512i zlo14 = _mm512_hexl_mullo_epi<64>(x14, y14); + __m512i zlo15 = _mm512_hexl_mullo_epi<64>(x15, y15); + __m512i zlo16 = _mm512_hexl_mullo_epi<64>(x16, y16); + + __m512i c1 = _mm512_hexl_shrdi_epi64(zlo1, zhi1); + __m512i c2 = _mm512_hexl_shrdi_epi64(zlo2, zhi2); + __m512i c3 = _mm512_hexl_shrdi_epi64(zlo3, zhi3); + __m512i c4 = _mm512_hexl_shrdi_epi64(zlo4, zhi4); + __m512i c5 = _mm512_hexl_shrdi_epi64(zlo5, zhi5); + __m512i c6 = _mm512_hexl_shrdi_epi64(zlo6, zhi6); + __m512i c7 = _mm512_hexl_shrdi_epi64(zlo7, zhi7); + __m512i c8 = _mm512_hexl_shrdi_epi64(zlo8, zhi8); + __m512i c9 = _mm512_hexl_shrdi_epi64(zlo9, zhi9); + __m512i c10 = _mm512_hexl_shrdi_epi64(zlo10, zhi10); + __m512i c11 = _mm512_hexl_shrdi_epi64(zlo11, zhi11); + __m512i c12 = _mm512_hexl_shrdi_epi64(zlo12, zhi12); + __m512i c13 = _mm512_hexl_shrdi_epi64(zlo13, zhi13); + __m512i c14 = _mm512_hexl_shrdi_epi64(zlo14, zhi14); + __m512i c15 = _mm512_hexl_shrdi_epi64(zlo15, zhi15); + __m512i c16 = _mm512_hexl_shrdi_epi64(zlo16, zhi16); + + c1 = _mm512_hexl_mulhi_approx_epi<64>(c1, v_barr_lo); + c2 = _mm512_hexl_mulhi_approx_epi<64>(c2, v_barr_lo); + c3 = _mm512_hexl_mulhi_approx_epi<64>(c3, v_barr_lo); + c4 = _mm512_hexl_mulhi_approx_epi<64>(c4, v_barr_lo); + c5 = _mm512_hexl_mulhi_approx_epi<64>(c5, v_barr_lo); + c6 = _mm512_hexl_mulhi_approx_epi<64>(c6, v_barr_lo); + c7 = _mm512_hexl_mulhi_approx_epi<64>(c7, v_barr_lo); + c8 = _mm512_hexl_mulhi_approx_epi<64>(c8, v_barr_lo); + c9 = _mm512_hexl_mulhi_approx_epi<64>(c9, v_barr_lo); + c10 = _mm512_hexl_mulhi_approx_epi<64>(c10, v_barr_lo); + c11 = _mm512_hexl_mulhi_approx_epi<64>(c11, v_barr_lo); + c12 = _mm512_hexl_mulhi_approx_epi<64>(c12, v_barr_lo); + c13 = _mm512_hexl_mulhi_approx_epi<64>(c13, v_barr_lo); + c14 = _mm512_hexl_mulhi_approx_epi<64>(c14, v_barr_lo); + c15 = _mm512_hexl_mulhi_approx_epi<64>(c15, v_barr_lo); + c16 = _mm512_hexl_mulhi_approx_epi<64>(c16, v_barr_lo); + + __m512i vr1 = _mm512_hexl_mullo_epi<64>(c1, v_modulus); + __m512i vr2 = _mm512_hexl_mullo_epi<64>(c2, v_modulus); + __m512i vr3 = _mm512_hexl_mullo_epi<64>(c3, v_modulus); + __m512i vr4 = _mm512_hexl_mullo_epi<64>(c4, v_modulus); + __m512i vr5 = _mm512_hexl_mullo_epi<64>(c5, v_modulus); + __m512i vr6 = _mm512_hexl_mullo_epi<64>(c6, v_modulus); + __m512i vr7 = _mm512_hexl_mullo_epi<64>(c7, v_modulus); + __m512i vr8 = _mm512_hexl_mullo_epi<64>(c8, v_modulus); + __m512i vr9 = _mm512_hexl_mullo_epi<64>(c9, v_modulus); + __m512i vr10 = _mm512_hexl_mullo_epi<64>(c10, v_modulus); + __m512i vr11 = _mm512_hexl_mullo_epi<64>(c11, v_modulus); + __m512i vr12 = _mm512_hexl_mullo_epi<64>(c12, v_modulus); + __m512i vr13 = _mm512_hexl_mullo_epi<64>(c13, v_modulus); + __m512i vr14 = _mm512_hexl_mullo_epi<64>(c14, v_modulus); + __m512i vr15 = _mm512_hexl_mullo_epi<64>(c15, v_modulus); + __m512i vr16 = _mm512_hexl_mullo_epi<64>(c16, v_modulus); + + vr1 = _mm512_sub_epi64(zlo1, vr1); + vr2 = _mm512_sub_epi64(zlo2, vr2); + vr3 = _mm512_sub_epi64(zlo3, vr3); + vr4 = _mm512_sub_epi64(zlo4, vr4); + vr5 = _mm512_sub_epi64(zlo5, vr5); + vr6 = _mm512_sub_epi64(zlo6, vr6); + vr7 = _mm512_sub_epi64(zlo7, vr7); + vr8 = _mm512_sub_epi64(zlo8, vr8); + vr9 = _mm512_sub_epi64(zlo9, vr9); + vr10 = _mm512_sub_epi64(zlo10, vr10); + vr11 = _mm512_sub_epi64(zlo11, vr11); + vr12 = _mm512_sub_epi64(zlo12, vr12); + vr13 = _mm512_sub_epi64(zlo13, vr13); + vr14 = _mm512_sub_epi64(zlo14, vr14); + vr15 = _mm512_sub_epi64(zlo15, vr15); + vr16 = _mm512_sub_epi64(zlo16, vr16); + + vr1 = _mm512_hexl_small_mod_epu64<4>(vr1, v_modulus, &v_twice_mod); + vr2 = _mm512_hexl_small_mod_epu64<4>(vr2, v_modulus, &v_twice_mod); + vr3 = _mm512_hexl_small_mod_epu64<4>(vr3, v_modulus, &v_twice_mod); + vr4 = _mm512_hexl_small_mod_epu64<4>(vr4, v_modulus, &v_twice_mod); + vr5 = _mm512_hexl_small_mod_epu64<4>(vr5, v_modulus, &v_twice_mod); + vr6 = _mm512_hexl_small_mod_epu64<4>(vr6, v_modulus, &v_twice_mod); + vr7 = _mm512_hexl_small_mod_epu64<4>(vr7, v_modulus, &v_twice_mod); + vr8 = _mm512_hexl_small_mod_epu64<4>(vr8, v_modulus, &v_twice_mod); + vr9 = _mm512_hexl_small_mod_epu64<4>(vr9, v_modulus, &v_twice_mod); + vr10 = _mm512_hexl_small_mod_epu64<4>(vr10, v_modulus, &v_twice_mod); + vr11 = _mm512_hexl_small_mod_epu64<4>(vr11, v_modulus, &v_twice_mod); + vr12 = _mm512_hexl_small_mod_epu64<4>(vr12, v_modulus, &v_twice_mod); + vr13 = _mm512_hexl_small_mod_epu64<4>(vr13, v_modulus, &v_twice_mod); + vr14 = _mm512_hexl_small_mod_epu64<4>(vr14, v_modulus, &v_twice_mod); + vr15 = _mm512_hexl_small_mod_epu64<4>(vr15, v_modulus, &v_twice_mod); + vr16 = _mm512_hexl_small_mod_epu64<4>(vr16, v_modulus, &v_twice_mod); + + _mm512_storeu_si512(vp_result++, vr1); + _mm512_storeu_si512(vp_result++, vr2); + _mm512_storeu_si512(vp_result++, vr3); + _mm512_storeu_si512(vp_result++, vr4); + _mm512_storeu_si512(vp_result++, vr5); + _mm512_storeu_si512(vp_result++, vr6); + _mm512_storeu_si512(vp_result++, vr7); + _mm512_storeu_si512(vp_result++, vr8); + _mm512_storeu_si512(vp_result++, vr9); + _mm512_storeu_si512(vp_result++, vr10); + _mm512_storeu_si512(vp_result++, vr11); + _mm512_storeu_si512(vp_result++, vr12); + _mm512_storeu_si512(vp_result++, vr13); + _mm512_storeu_si512(vp_result++, vr14); + _mm512_storeu_si512(vp_result++, vr15); + _mm512_storeu_si512(vp_result++, vr16); + } +} + +/// @brief Algorithm 2 from +/// https://homes.esat.kuleuven.be/~fvercaut/papers/bar_mont.pdf +template +void EltwiseMultModAVX512DQIntLoopDefault(__m512i* vp_result, + const __m512i* vp_operand1, + const __m512i* vp_operand2, + __m512i v_barr_lo, __m512i v_modulus, + __m512i v_twice_mod, uint64_t n) { + HEXL_UNUSED(v_twice_mod); + + HEXL_LOOP_UNROLL_4 + for (size_t i = n / 8; i > 0; --i) { + __m512i v_op1 = _mm512_loadu_si512(vp_operand1); + __m512i v_op2 = _mm512_loadu_si512(vp_operand2); + + v_op1 = _mm512_hexl_small_mod_epu64(v_op1, v_modulus, + &v_twice_mod); + + v_op2 = _mm512_hexl_small_mod_epu64(v_op2, v_modulus, + &v_twice_mod); + + // Compute product U + __m512i v_prod_hi = _mm512_hexl_mulhi_epi<64>(v_op1, v_op2); + __m512i v_prod_lo = _mm512_hexl_mullo_epi<64>(v_op1, v_op2); + + __m512i c1 = _mm512_hexl_shrdi_epi64(v_prod_lo, v_prod_hi); + // alpha - beta == 64, so we only need high 64 bits + // Perform approximate computation of high bits, as described on page + // 7 of https://arxiv.org/pdf/2003.04510.pdf + __m512i q_hat = _mm512_hexl_mulhi_approx_epi<64>(c1, v_barr_lo); + __m512i v_result = _mm512_hexl_mullo_epi<64>(q_hat, v_modulus); + // Computes result in [0, 4q) + v_result = _mm512_sub_epi64(v_prod_lo, v_result); + + // Reduce result to [0, q) + v_result = + _mm512_hexl_small_mod_epu64<4>(v_result, v_modulus, &v_twice_mod); + _mm512_storeu_si512(vp_result, v_result); + + ++vp_operand1; + ++vp_operand2; + ++vp_result; + } +} + +/// @brief Algorithm 2 from +/// https://homes.esat.kuleuven.be/~fvercaut/papers/bar_mont.pdf +template +void EltwiseMultModAVX512DQIntLoopDefault(__m512i* vp_result, + const __m512i* vp_operand1, + const __m512i* vp_operand2, + __m512i v_barr_lo, __m512i v_modulus, + __m512i v_twice_mod, uint64_t n, + uint64_t prod_right_shift) { + HEXL_UNUSED(v_twice_mod); + + HEXL_LOOP_UNROLL_4 + for (size_t i = n / 8; i > 0; --i) { + __m512i v_op1 = _mm512_loadu_si512(vp_operand1); + __m512i v_op2 = _mm512_loadu_si512(vp_operand2); + + v_op1 = _mm512_hexl_small_mod_epu64(v_op1, v_modulus, + &v_twice_mod); + + v_op2 = _mm512_hexl_small_mod_epu64(v_op2, v_modulus, + &v_twice_mod); + + __m512i v_prod_hi = _mm512_hexl_mulhi_epi<64>(v_op1, v_op2); + __m512i v_prod_lo = _mm512_hexl_mullo_epi<64>(v_op1, v_op2); + + // c1 = floor(U / 2^{n + beta}) + __m512i c1 = _mm512_hexl_shrdi_epi64( + v_prod_lo, v_prod_hi, static_cast(prod_right_shift)); + + // alpha - beta == 64, so we only need high 64 bits + // Perform approximate computation of high bits, as described on page + // 7 of https://arxiv.org/pdf/2003.04510.pdf + __m512i q_hat = _mm512_hexl_mulhi_approx_epi<64>(c1, v_barr_lo); + __m512i v_result = _mm512_hexl_mullo_epi<64>(q_hat, v_modulus); + // Computes result in [0, 4q) + v_result = _mm512_sub_epi64(v_prod_lo, v_result); + + // Reduce result to [0, q) + v_result = + _mm512_hexl_small_mod_epu64<4>(v_result, v_modulus, &v_twice_mod); + _mm512_storeu_si512(vp_result, v_result); + + ++vp_operand1; + ++vp_operand2; + ++vp_result; + } +} + +template +void EltwiseMultModAVX512DQIntLoop(__m512i* vp_result, + const __m512i* vp_operand1, + const __m512i* vp_operand2, + __m512i v_barr_lo, __m512i v_modulus, + __m512i v_twice_mod, uint64_t n) { + switch (n) { + case 1024: + EltwiseMultModAVX512DQIntLoopUnroll( + vp_result, vp_operand1, vp_operand2, v_barr_lo, v_modulus, + v_twice_mod); + break; + + case 2048: + EltwiseMultModAVX512DQIntLoopUnroll( + vp_result, vp_operand1, vp_operand2, v_barr_lo, v_modulus, + v_twice_mod); + break; + + case 4096: + EltwiseMultModAVX512DQIntLoopUnroll( + vp_result, vp_operand1, vp_operand2, v_barr_lo, v_modulus, + v_twice_mod); + break; + + case 8192: + EltwiseMultModAVX512DQIntLoopUnroll( + vp_result, vp_operand1, vp_operand2, v_barr_lo, v_modulus, + v_twice_mod); + break; + + case 16384: + EltwiseMultModAVX512DQIntLoopUnroll(vp_result, vp_operand1, + vp_operand2, v_barr_lo, + v_modulus, v_twice_mod); + break; + + case 32768: + EltwiseMultModAVX512DQIntLoopUnroll(vp_result, vp_operand1, + vp_operand2, v_barr_lo, + v_modulus, v_twice_mod); + break; + + default: + EltwiseMultModAVX512DQIntLoopDefault( + vp_result, vp_operand1, vp_operand2, v_barr_lo, v_modulus, + v_twice_mod, n); + } +} + +#define ELTWISE_MULT_MOD_AVX512_DQ_INT_PROD_RIGHT_SHIFT_CASE(ProdRightShift, \ + InputModFactor) \ + case (ProdRightShift): { \ + EltwiseMultModAVX512DQIntLoop<(ProdRightShift), (InputModFactor)>( \ + vp_result, vp_operand1, vp_operand2, v_barr_lo, v_modulus, \ + v_twice_mod, n); \ + break; \ + } + +// Algorithm 2 from https://homes.esat.kuleuven.be/~fvercaut/papers/bar_mont.pdf +template +void EltwiseMultModAVX512DQInt(uint64_t* result, const uint64_t* operand1, + const uint64_t* operand2, uint64_t n, + uint64_t modulus) { + HEXL_CHECK(InputModFactor == 1 || InputModFactor == 2 || InputModFactor == 4, + "Require InputModFactor = 1, 2, or 4") + HEXL_CHECK(InputModFactor * modulus > (1ULL << 50), + "Require InputModFactor * modulus > (1ULL << 50)") + HEXL_CHECK(InputModFactor * modulus < (1ULL << 63), + "Require InputModFactor * modulus < (1ULL << 63)"); + HEXL_CHECK(modulus < (1ULL << 62), "Require modulus < (1ULL << 62)"); + HEXL_CHECK_BOUNDS(operand1, n, InputModFactor * modulus, + "operand1 exceeds bound " << (InputModFactor * modulus)); + HEXL_CHECK_BOUNDS(operand2, n, InputModFactor * modulus, + "operand2 exceeds bound " << (InputModFactor * modulus)); + HEXL_CHECK(modulus > 1, "Require modulus > 1"); + uint64_t n_mod_8 = n % 8; + if (n_mod_8 != 0) { + EltwiseMultModNative(result, operand1, operand2, n_mod_8, + modulus); + operand1 += n_mod_8; + operand2 += n_mod_8; + result += n_mod_8; + n -= n_mod_8; + } + + constexpr int64_t beta = -2; + HEXL_CHECK(beta <= -2, "beta must be <= -2 for correctness"); + constexpr int64_t alpha = 62; // ensures alpha - beta = 64 + uint64_t gamma = Log2(InputModFactor); + HEXL_UNUSED(gamma); + HEXL_CHECK(alpha >= gamma + 1, "alpha must be >= gamma + 1 for correctness"); + + const uint64_t ceil_log_mod = Log2(modulus) + 1; // "n" from Algorithm 2 + uint64_t prod_right_shift = ceil_log_mod + beta; + + // Barrett factor "mu" + // TODO(fboemer): Allow MultiplyFactor to take bit shifts != 64 + HEXL_CHECK(ceil_log_mod + alpha >= 64, "ceil_log_mod + alpha < 64"); + uint64_t barr_lo = + MultiplyFactor(uint64_t(1) << (ceil_log_mod + alpha - 64), 64, modulus) + .BarrettFactor(); + + __m512i v_barr_lo = _mm512_set1_epi64(static_cast(barr_lo)); + __m512i v_modulus = _mm512_set1_epi64(static_cast(modulus)); + __m512i v_twice_mod = _mm512_set1_epi64(static_cast(2 * modulus)); + const __m512i* vp_operand1 = reinterpret_cast(operand1); + const __m512i* vp_operand2 = reinterpret_cast(operand2); + __m512i* vp_result = reinterpret_cast<__m512i*>(result); + + // Let d be the product operand1 * operand2. + // To ensure d >> prod_right_shift < (1ULL << 64), we need + // (input_mod_factor * modulus)^2 >> (prod_right_shift) < (1ULL << 64) + // This happens when 2*log_2(input_mod_factor) + prod_right_shift - beta < 63 + // If not, we need to reduce the inputs to be less than modulus for + // correctness. This is less efficient, so we avoid it when possible. + bool reduce_mod = 2 * Log2(InputModFactor) + prod_right_shift - beta >= 63; + + if (reduce_mod) { + // Here, we assume beta = -2 + HEXL_CHECK(beta == -2, "beta != -2 may skip some cases"); + // This reduce_mod case happens only when + // prod_right_shift >= 63 - 2 * log2(input_mod_factor) >= 57. + // Additionally, modulus < (1ULL << 62) implies + // prod_right_shift <= 61. So N == 57, 58, 59, 60, 61 are the + // only cases here. + switch (prod_right_shift) { + ELTWISE_MULT_MOD_AVX512_DQ_INT_PROD_RIGHT_SHIFT_CASE(57, InputModFactor) + ELTWISE_MULT_MOD_AVX512_DQ_INT_PROD_RIGHT_SHIFT_CASE(58, InputModFactor) + ELTWISE_MULT_MOD_AVX512_DQ_INT_PROD_RIGHT_SHIFT_CASE(59, InputModFactor) + ELTWISE_MULT_MOD_AVX512_DQ_INT_PROD_RIGHT_SHIFT_CASE(60, InputModFactor) + ELTWISE_MULT_MOD_AVX512_DQ_INT_PROD_RIGHT_SHIFT_CASE(61, InputModFactor) + default: { + HEXL_CHECK(false, + "Bad value for prod_right_shift: " << prod_right_shift); + } + } + } else { // Input mod reduction not required; pass InputModFactor == 1. + // The template arguments are required for use of _mm512_hexl_shrdi_epi64, + // which requires a compile-time constant for the shift. + switch (prod_right_shift) { + // For prod_right_shift < 50, we should prefer EltwiseMultModAVX512Float + // or EltwiseMultModAVX512IFMAInt, so we don't generate those special + // cases here + ELTWISE_MULT_MOD_AVX512_DQ_INT_PROD_RIGHT_SHIFT_CASE(50, 1) + ELTWISE_MULT_MOD_AVX512_DQ_INT_PROD_RIGHT_SHIFT_CASE(51, 1) + ELTWISE_MULT_MOD_AVX512_DQ_INT_PROD_RIGHT_SHIFT_CASE(52, 1) + ELTWISE_MULT_MOD_AVX512_DQ_INT_PROD_RIGHT_SHIFT_CASE(53, 1) + ELTWISE_MULT_MOD_AVX512_DQ_INT_PROD_RIGHT_SHIFT_CASE(54, 1) + ELTWISE_MULT_MOD_AVX512_DQ_INT_PROD_RIGHT_SHIFT_CASE(55, 1) + ELTWISE_MULT_MOD_AVX512_DQ_INT_PROD_RIGHT_SHIFT_CASE(56, 1) + ELTWISE_MULT_MOD_AVX512_DQ_INT_PROD_RIGHT_SHIFT_CASE(57, 1) + ELTWISE_MULT_MOD_AVX512_DQ_INT_PROD_RIGHT_SHIFT_CASE(58, 1) + ELTWISE_MULT_MOD_AVX512_DQ_INT_PROD_RIGHT_SHIFT_CASE(59, 1) + ELTWISE_MULT_MOD_AVX512_DQ_INT_PROD_RIGHT_SHIFT_CASE(60, 1) + ELTWISE_MULT_MOD_AVX512_DQ_INT_PROD_RIGHT_SHIFT_CASE(61, 1) + default: { + HEXL_VLOG(2, "calling EltwiseMultModAVX512DQIntLoopDefault"); + EltwiseMultModAVX512DQIntLoopDefault<1>( + vp_result, vp_operand1, vp_operand2, v_barr_lo, v_modulus, + v_twice_mod, n, prod_right_shift); + } + } + } + HEXL_CHECK_BOUNDS(result, n, modulus, "result exceeds bound " << modulus); +} + +// From Function 18, page 19 of https://arxiv.org/pdf/1407.3383.pdf +// See also Algorithm 2/3 of +// https://hal.archives-ouvertes.fr/hal-02552673/document +template +inline void EltwiseMultModAVX512FloatLoopDefault( + __m512i* vp_result, const __m512i* vp_operand1, const __m512i* vp_operand2, + __m512d v_u, __m512d v_p, __m512i v_modulus, __m512i v_twice_mod, + uint64_t n) { + HEXL_UNUSED(v_twice_mod); + + constexpr int round_mode = (_MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC); + + HEXL_LOOP_UNROLL_4 + for (size_t i = n / 8; i > 0; --i) { + __m512i v_op1 = _mm512_loadu_si512(vp_operand1); + v_op1 = _mm512_hexl_small_mod_epu64(v_op1, v_modulus, + &v_twice_mod); + + __m512i v_op2 = _mm512_loadu_si512(vp_operand2); + v_op2 = _mm512_hexl_small_mod_epu64(v_op2, v_modulus, + &v_twice_mod); + + __m512d v_x = _mm512_cvt_roundepu64_pd(v_op1, round_mode); + __m512d v_y = _mm512_cvt_roundepu64_pd(v_op2, round_mode); + + __m512d v_h = _mm512_mul_pd(v_x, v_y); + __m512d v_l = + _mm512_fmsub_pd(v_x, v_y, v_h); // rounding error; h + l == x * y + __m512d v_b = _mm512_mul_pd(v_h, v_u); // ~ (x * y) / p + __m512d v_c = _mm512_floor_pd(v_b); // ~ floor(x * y / p) + __m512d v_d = _mm512_fnmadd_pd(v_c, v_p, v_h); + __m512d v_g = _mm512_add_pd(v_d, v_l); + __mmask8 m = _mm512_cmp_pd_mask(v_g, _mm512_setzero_pd(), _CMP_LT_OQ); + v_g = _mm512_mask_add_pd(v_g, m, v_g, v_p); + + __m512i v_result = _mm512_cvt_roundpd_epu64(v_g, round_mode); + + _mm512_storeu_si512(vp_result, v_result); + + ++vp_operand1; + ++vp_operand2; + ++vp_result; + } +} + +template +inline void EltwiseMultModAVX512FloatLoopUnroll( + __m512i* vp_result, const __m512i* vp_operand1, const __m512i* vp_operand2, + __m512d v_u, __m512d v_p, __m512i v_modulus, __m512i v_twice_mod) { + constexpr size_t manual_unroll_factor = 4; + constexpr size_t avx512_64bit_count = 8; + constexpr size_t loop_count = + CoeffCount / (manual_unroll_factor * avx512_64bit_count); + + static_assert(loop_count > 0, "loop_count too small for unrolling"); + static_assert(CoeffCount % (manual_unroll_factor * avx512_64bit_count) == 0, + "CoeffCount must be a factor of manual_unroll_factor * " + "avx512_64bit_count"); + + constexpr int round_mode = (_MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC); + + HEXL_LOOP_UNROLL_4 + for (size_t i = loop_count; i > 0; --i) { + __m512i op1_1 = _mm512_loadu_si512(vp_operand1++); + __m512i op1_2 = _mm512_loadu_si512(vp_operand1++); + __m512i op1_3 = _mm512_loadu_si512(vp_operand1++); + __m512i op1_4 = _mm512_loadu_si512(vp_operand1++); + + __m512i op2_1 = _mm512_loadu_si512(vp_operand2++); + __m512i op2_2 = _mm512_loadu_si512(vp_operand2++); + __m512i op2_3 = _mm512_loadu_si512(vp_operand2++); + __m512i op2_4 = _mm512_loadu_si512(vp_operand2++); + + op1_1 = _mm512_hexl_small_mod_epu64(op1_1, v_modulus, + &v_twice_mod); + op1_2 = _mm512_hexl_small_mod_epu64(op1_2, v_modulus, + &v_twice_mod); + op1_3 = _mm512_hexl_small_mod_epu64(op1_3, v_modulus, + &v_twice_mod); + op1_4 = _mm512_hexl_small_mod_epu64(op1_4, v_modulus, + &v_twice_mod); + + op2_1 = _mm512_hexl_small_mod_epu64(op2_1, v_modulus, + &v_twice_mod); + op2_2 = _mm512_hexl_small_mod_epu64(op2_2, v_modulus, + &v_twice_mod); + op2_3 = _mm512_hexl_small_mod_epu64(op2_3, v_modulus, + &v_twice_mod); + op2_4 = _mm512_hexl_small_mod_epu64(op2_4, v_modulus, + &v_twice_mod); + + __m512d v_x_1 = _mm512_cvt_roundepu64_pd(op1_1, round_mode); + __m512d v_x_2 = _mm512_cvt_roundepu64_pd(op1_2, round_mode); + __m512d v_x_3 = _mm512_cvt_roundepu64_pd(op1_3, round_mode); + __m512d v_x_4 = _mm512_cvt_roundepu64_pd(op1_4, round_mode); + + __m512d v_y_1 = _mm512_cvt_roundepu64_pd(op2_1, round_mode); + __m512d v_y_2 = _mm512_cvt_roundepu64_pd(op2_2, round_mode); + __m512d v_y_3 = _mm512_cvt_roundepu64_pd(op2_3, round_mode); + __m512d v_y_4 = _mm512_cvt_roundepu64_pd(op2_4, round_mode); + + __m512d v_h_1 = _mm512_mul_pd(v_x_1, v_y_1); + __m512d v_h_2 = _mm512_mul_pd(v_x_2, v_y_2); + __m512d v_h_3 = _mm512_mul_pd(v_x_3, v_y_3); + __m512d v_h_4 = _mm512_mul_pd(v_x_4, v_y_4); + + // ~ (x * y) / p + __m512d v_b_1 = _mm512_mul_pd(v_h_1, v_u); + __m512d v_b_2 = _mm512_mul_pd(v_h_2, v_u); + __m512d v_b_3 = _mm512_mul_pd(v_h_3, v_u); + __m512d v_b_4 = _mm512_mul_pd(v_h_4, v_u); + + // rounding_ error; h + l == x * y + __m512d v_l_1 = _mm512_fmsub_pd(v_x_1, v_y_1, v_h_1); + __m512d v_l_2 = _mm512_fmsub_pd(v_x_2, v_y_2, v_h_2); + __m512d v_l_3 = _mm512_fmsub_pd(v_x_3, v_y_3, v_h_3); + __m512d v_l_4 = _mm512_fmsub_pd(v_x_4, v_y_4, v_h_4); + + // ~ floor(_x * y / p) + __m512d v_c_1 = _mm512_floor_pd(v_b_1); + __m512d v_c_2 = _mm512_floor_pd(v_b_2); + __m512d v_c_3 = _mm512_floor_pd(v_b_3); + __m512d v_c_4 = _mm512_floor_pd(v_b_4); + + __m512d v_d_1 = _mm512_fnmadd_pd(v_c_1, v_p, v_h_1); + __m512d v_d_2 = _mm512_fnmadd_pd(v_c_2, v_p, v_h_2); + __m512d v_d_3 = _mm512_fnmadd_pd(v_c_3, v_p, v_h_3); + __m512d v_d_4 = _mm512_fnmadd_pd(v_c_4, v_p, v_h_4); + + __m512d v_g_1 = _mm512_add_pd(v_d_1, v_l_1); + __m512d v_g_2 = _mm512_add_pd(v_d_2, v_l_2); + __m512d v_g_3 = _mm512_add_pd(v_d_3, v_l_3); + __m512d v_g_4 = _mm512_add_pd(v_d_4, v_l_4); + + __mmask8 m_1 = _mm512_cmp_pd_mask(v_g_1, _mm512_setzero_pd(), _CMP_LT_OQ); + __mmask8 m_2 = _mm512_cmp_pd_mask(v_g_2, _mm512_setzero_pd(), _CMP_LT_OQ); + __mmask8 m_3 = _mm512_cmp_pd_mask(v_g_3, _mm512_setzero_pd(), _CMP_LT_OQ); + __mmask8 m_4 = _mm512_cmp_pd_mask(v_g_4, _mm512_setzero_pd(), _CMP_LT_OQ); + + v_g_1 = _mm512_mask_add_pd(v_g_1, m_1, v_g_1, v_p); + v_g_2 = _mm512_mask_add_pd(v_g_2, m_2, v_g_2, v_p); + v_g_3 = _mm512_mask_add_pd(v_g_3, m_3, v_g_3, v_p); + v_g_4 = _mm512_mask_add_pd(v_g_4, m_4, v_g_4, v_p); + + __m512i v_out_1 = _mm512_cvt_roundpd_epu64(v_g_1, round_mode); + __m512i v_out_2 = _mm512_cvt_roundpd_epu64(v_g_2, round_mode); + __m512i v_out_3 = _mm512_cvt_roundpd_epu64(v_g_3, round_mode); + __m512i v_out_4 = _mm512_cvt_roundpd_epu64(v_g_4, round_mode); + + _mm512_storeu_si512(vp_result++, v_out_1); + _mm512_storeu_si512(vp_result++, v_out_2); + _mm512_storeu_si512(vp_result++, v_out_3); + _mm512_storeu_si512(vp_result++, v_out_4); + } +} + +template +inline void EltwiseMultModAVX512FloatLoop(__m512i* vp_result, + const __m512i* vp_operand1, + const __m512i* vp_operand2, + __m512d v_u, __m512d v_p, + __m512i v_modulus, + __m512i v_twice_mod, uint64_t n) { + switch (n) { + case 1024: + EltwiseMultModAVX512FloatLoopUnroll( + vp_result, vp_operand1, vp_operand2, v_u, v_p, v_modulus, + v_twice_mod); + break; + + case 2048: + EltwiseMultModAVX512FloatLoopUnroll( + vp_result, vp_operand1, vp_operand2, v_u, v_p, v_modulus, + v_twice_mod); + break; + + case 4096: + EltwiseMultModAVX512FloatLoopUnroll( + vp_result, vp_operand1, vp_operand2, v_u, v_p, v_modulus, + v_twice_mod); + break; + + case 8192: + EltwiseMultModAVX512FloatLoopUnroll( + vp_result, vp_operand1, vp_operand2, v_u, v_p, v_modulus, + v_twice_mod); + break; + + case 16384: + EltwiseMultModAVX512FloatLoopUnroll( + vp_result, vp_operand1, vp_operand2, v_u, v_p, v_modulus, + v_twice_mod); + break; + + case 32768: + EltwiseMultModAVX512FloatLoopUnroll( + vp_result, vp_operand1, vp_operand2, v_u, v_p, v_modulus, + v_twice_mod); + break; + + default: + EltwiseMultModAVX512FloatLoopDefault( + vp_result, vp_operand1, vp_operand2, v_u, v_p, v_modulus, v_twice_mod, + n); + } +} + +// From Function 18, page 19 of https://arxiv.org/pdf/1407.3383.pdf +// See also Algorithm 2/3 of +// https://hal.archives-ouvertes.fr/hal-02552673/document +template +void EltwiseMultModAVX512Float(uint64_t* result, const uint64_t* operand1, + const uint64_t* operand2, uint64_t n, + uint64_t modulus) { + HEXL_CHECK(modulus < MaximumValue(50), + " modulus " << modulus << " exceeds bound " << MaximumValue(50)); + HEXL_CHECK(modulus > 1, "Require modulus > 1"); + + HEXL_CHECK_BOUNDS(operand1, n, InputModFactor * modulus, + "operand1 exceeds bound " << (InputModFactor * modulus)); + HEXL_CHECK_BOUNDS(operand2, n, InputModFactor * modulus, + "operand2 exceeds bound " << (InputModFactor * modulus)); + uint64_t n_mod_8 = n % 8; + if (n_mod_8 != 0) { + EltwiseMultModNative(result, operand1, operand2, n_mod_8, + modulus); + operand1 += n_mod_8; + operand2 += n_mod_8; + result += n_mod_8; + n -= n_mod_8; + } + __m512d v_p = _mm512_set1_pd(static_cast(modulus)); + __m512i v_modulus = _mm512_set1_epi64(static_cast(modulus)); + __m512i v_twice_mod = _mm512_set1_epi64(static_cast(modulus * 2)); + + // Add epsilon to ensure u * p >= 1.0 + // See Proposition 13 of https://arxiv.org/pdf/1407.3383.pdf + double u_bar = (1.0 + std::numeric_limits::epsilon()) / + static_cast(modulus); + __m512d v_u = _mm512_set1_pd(u_bar); + + const __m512i* vp_operand1 = reinterpret_cast(operand1); + const __m512i* vp_operand2 = reinterpret_cast(operand2); + __m512i* vp_result = reinterpret_cast<__m512i*>(result); + + // The implementation without modular reduction of the operands is correct + // as long as (InputModFactor * modulus)^2 < 2^50 * modulus, i.e. + // InputModFactor^2 * modulus < 2^50. + // See function 16 of https://arxiv.org/pdf/1407.3383.pdf. + bool no_input_reduce_mod = + (InputModFactor * InputModFactor * modulus) < (1ULL << 50); + if (no_input_reduce_mod) { + EltwiseMultModAVX512FloatLoop<1>(vp_result, vp_operand1, vp_operand2, v_u, + v_p, v_modulus, v_twice_mod, n); + } else { + EltwiseMultModAVX512FloatLoop(vp_result, vp_operand1, + vp_operand2, v_u, v_p, + v_modulus, v_twice_mod, n); + } + + HEXL_CHECK_BOUNDS(result, n, modulus, "result exceeds bound " << modulus); +} + +#endif // HEXL_HAS_AVX512DQ + +} // namespace hexl +} // namespace intel diff --git a/hexl_omp/eltwise/eltwise-mult-mod-avx512ifma.cpp b/hexl_omp/eltwise/eltwise-mult-mod-avx512ifma.cpp new file mode 100644 index 00000000..54374cff --- /dev/null +++ b/hexl_omp/eltwise/eltwise-mult-mod-avx512ifma.cpp @@ -0,0 +1,615 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include +#include + +#include + +#include "eltwise/eltwise-mult-mod-avx512.hpp" +#include "eltwise/eltwise-mult-mod-internal.hpp" +#include "hexl/eltwise/eltwise-mult-mod.hpp" +#include "hexl/number-theory/number-theory.hpp" +#include "hexl/util/check.hpp" +#include "hexl/util/compiler.hpp" +#include "hexl/util/defines.hpp" +#include "util/avx512-util.hpp" + +namespace intel { +namespace hexl { + +#ifdef HEXL_HAS_AVX512IFMA + +template void EltwiseMultModAVX512IFMAInt<1>(uint64_t* result, + const uint64_t* operand1, + const uint64_t* operand2, + uint64_t n, uint64_t modulus); +template void EltwiseMultModAVX512IFMAInt<2>(uint64_t* result, + const uint64_t* operand1, + const uint64_t* operand2, + uint64_t n, uint64_t modulus); +template void EltwiseMultModAVX512IFMAInt<4>(uint64_t* result, + const uint64_t* operand1, + const uint64_t* operand2, + uint64_t n, uint64_t modulus); + +template +void EltwiseMultModAVX512IFMAIntLoopUnroll(__m512i* vp_result, + const __m512i* vp_operand1, + const __m512i* vp_operand2, + __m512i v_barr_lo, __m512i v_modulus, + __m512i v_neg_mod, + __m512i v_twice_mod) { + constexpr size_t manual_unroll_factor = 16; + constexpr size_t avx512_64bit_count = 8; + constexpr size_t loop_count = + CoeffCount / (manual_unroll_factor * avx512_64bit_count); + + static_assert(loop_count > 0, "loop_count too small for unrolling"); + static_assert(CoeffCount % (manual_unroll_factor * avx512_64bit_count) == 0, + "CoeffCount must be a factor of manual_unroll_factor * " + "avx512_64bit_count"); + + constexpr unsigned int HiShift = + static_cast(52 - ProdRightShift); + + HEXL_UNUSED(v_twice_mod); + HEXL_LOOP_UNROLL_4 + for (size_t i = loop_count; i > 0; --i) { + __m512i v_op1_1 = _mm512_loadu_si512(vp_operand1++); + __m512i v_op2_1 = _mm512_loadu_si512(vp_operand2++); + __m512i v_op1_2 = _mm512_loadu_si512(vp_operand1++); + __m512i v_op2_2 = _mm512_loadu_si512(vp_operand2++); + __m512i v_op1_3 = _mm512_loadu_si512(vp_operand1++); + __m512i v_op2_3 = _mm512_loadu_si512(vp_operand2++); + __m512i v_op1_4 = _mm512_loadu_si512(vp_operand1++); + __m512i v_op2_4 = _mm512_loadu_si512(vp_operand2++); + __m512i v_op1_5 = _mm512_loadu_si512(vp_operand1++); + __m512i v_op2_5 = _mm512_loadu_si512(vp_operand2++); + __m512i v_op1_6 = _mm512_loadu_si512(vp_operand1++); + __m512i v_op2_6 = _mm512_loadu_si512(vp_operand2++); + __m512i v_op1_7 = _mm512_loadu_si512(vp_operand1++); + __m512i v_op2_7 = _mm512_loadu_si512(vp_operand2++); + __m512i v_op1_8 = _mm512_loadu_si512(vp_operand1++); + __m512i v_op2_8 = _mm512_loadu_si512(vp_operand2++); + __m512i v_op1_9 = _mm512_loadu_si512(vp_operand1++); + __m512i v_op2_9 = _mm512_loadu_si512(vp_operand2++); + __m512i v_op1_10 = _mm512_loadu_si512(vp_operand1++); + __m512i v_op2_10 = _mm512_loadu_si512(vp_operand2++); + __m512i v_op1_11 = _mm512_loadu_si512(vp_operand1++); + __m512i v_op2_11 = _mm512_loadu_si512(vp_operand2++); + __m512i v_op1_12 = _mm512_loadu_si512(vp_operand1++); + __m512i v_op2_12 = _mm512_loadu_si512(vp_operand2++); + __m512i v_op1_13 = _mm512_loadu_si512(vp_operand1++); + __m512i v_op2_13 = _mm512_loadu_si512(vp_operand2++); + __m512i v_op1_14 = _mm512_loadu_si512(vp_operand1++); + __m512i v_op2_14 = _mm512_loadu_si512(vp_operand2++); + __m512i v_op1_15 = _mm512_loadu_si512(vp_operand1++); + __m512i v_op2_15 = _mm512_loadu_si512(vp_operand2++); + __m512i v_op1_16 = _mm512_loadu_si512(vp_operand1++); + __m512i v_op2_16 = _mm512_loadu_si512(vp_operand2++); + + v_op1_1 = _mm512_hexl_small_mod_epu64(v_op1_1, v_modulus, + &v_twice_mod); + v_op1_2 = _mm512_hexl_small_mod_epu64(v_op1_2, v_modulus, + &v_twice_mod); + v_op1_3 = _mm512_hexl_small_mod_epu64(v_op1_3, v_modulus, + &v_twice_mod); + v_op1_4 = _mm512_hexl_small_mod_epu64(v_op1_4, v_modulus, + &v_twice_mod); + v_op1_5 = _mm512_hexl_small_mod_epu64(v_op1_5, v_modulus, + &v_twice_mod); + v_op1_6 = _mm512_hexl_small_mod_epu64(v_op1_6, v_modulus, + &v_twice_mod); + v_op1_7 = _mm512_hexl_small_mod_epu64(v_op1_7, v_modulus, + &v_twice_mod); + v_op1_8 = _mm512_hexl_small_mod_epu64(v_op1_8, v_modulus, + &v_twice_mod); + v_op1_9 = _mm512_hexl_small_mod_epu64(v_op1_9, v_modulus, + &v_twice_mod); + v_op1_10 = _mm512_hexl_small_mod_epu64(v_op1_10, v_modulus, + &v_twice_mod); + v_op1_11 = _mm512_hexl_small_mod_epu64(v_op1_11, v_modulus, + &v_twice_mod); + v_op1_12 = _mm512_hexl_small_mod_epu64(v_op1_12, v_modulus, + &v_twice_mod); + v_op1_13 = _mm512_hexl_small_mod_epu64(v_op1_13, v_modulus, + &v_twice_mod); + v_op1_14 = _mm512_hexl_small_mod_epu64(v_op1_14, v_modulus, + &v_twice_mod); + v_op1_15 = _mm512_hexl_small_mod_epu64(v_op1_15, v_modulus, + &v_twice_mod); + v_op1_16 = _mm512_hexl_small_mod_epu64(v_op1_16, v_modulus, + &v_twice_mod); + + v_op2_1 = _mm512_hexl_small_mod_epu64(v_op2_1, v_modulus, + &v_twice_mod); + v_op2_2 = _mm512_hexl_small_mod_epu64(v_op2_2, v_modulus, + &v_twice_mod); + v_op2_3 = _mm512_hexl_small_mod_epu64(v_op2_3, v_modulus, + &v_twice_mod); + v_op2_4 = _mm512_hexl_small_mod_epu64(v_op2_4, v_modulus, + &v_twice_mod); + v_op2_5 = _mm512_hexl_small_mod_epu64(v_op2_5, v_modulus, + &v_twice_mod); + v_op2_6 = _mm512_hexl_small_mod_epu64(v_op2_6, v_modulus, + &v_twice_mod); + v_op2_7 = _mm512_hexl_small_mod_epu64(v_op2_7, v_modulus, + &v_twice_mod); + v_op2_8 = _mm512_hexl_small_mod_epu64(v_op2_8, v_modulus, + &v_twice_mod); + v_op2_9 = _mm512_hexl_small_mod_epu64(v_op2_9, v_modulus, + &v_twice_mod); + v_op2_10 = _mm512_hexl_small_mod_epu64(v_op2_10, v_modulus, + &v_twice_mod); + v_op2_11 = _mm512_hexl_small_mod_epu64(v_op2_11, v_modulus, + &v_twice_mod); + v_op2_12 = _mm512_hexl_small_mod_epu64(v_op2_12, v_modulus, + &v_twice_mod); + v_op2_13 = _mm512_hexl_small_mod_epu64(v_op2_13, v_modulus, + &v_twice_mod); + v_op2_14 = _mm512_hexl_small_mod_epu64(v_op2_14, v_modulus, + &v_twice_mod); + v_op2_15 = _mm512_hexl_small_mod_epu64(v_op2_15, v_modulus, + &v_twice_mod); + v_op2_16 = _mm512_hexl_small_mod_epu64(v_op2_16, v_modulus, + &v_twice_mod); + + __m512i v_prod_hi_1 = _mm512_hexl_mulhi_epi<52>(v_op1_1, v_op2_1); + __m512i v_prod_hi_2 = _mm512_hexl_mulhi_epi<52>(v_op1_2, v_op2_2); + __m512i v_prod_hi_3 = _mm512_hexl_mulhi_epi<52>(v_op1_3, v_op2_3); + __m512i v_prod_hi_4 = _mm512_hexl_mulhi_epi<52>(v_op1_4, v_op2_4); + __m512i v_prod_hi_5 = _mm512_hexl_mulhi_epi<52>(v_op1_5, v_op2_5); + __m512i v_prod_hi_6 = _mm512_hexl_mulhi_epi<52>(v_op1_6, v_op2_6); + __m512i v_prod_hi_7 = _mm512_hexl_mulhi_epi<52>(v_op1_7, v_op2_7); + __m512i v_prod_hi_8 = _mm512_hexl_mulhi_epi<52>(v_op1_8, v_op2_8); + __m512i v_prod_hi_9 = _mm512_hexl_mulhi_epi<52>(v_op1_9, v_op2_9); + __m512i v_prod_hi_10 = _mm512_hexl_mulhi_epi<52>(v_op1_10, v_op2_10); + __m512i v_prod_hi_11 = _mm512_hexl_mulhi_epi<52>(v_op1_11, v_op2_11); + __m512i v_prod_hi_12 = _mm512_hexl_mulhi_epi<52>(v_op1_12, v_op2_12); + __m512i v_prod_hi_13 = _mm512_hexl_mulhi_epi<52>(v_op1_13, v_op2_13); + __m512i v_prod_hi_14 = _mm512_hexl_mulhi_epi<52>(v_op1_14, v_op2_14); + __m512i v_prod_hi_15 = _mm512_hexl_mulhi_epi<52>(v_op1_15, v_op2_15); + __m512i v_prod_hi_16 = _mm512_hexl_mulhi_epi<52>(v_op1_16, v_op2_16); + + __m512i v_prod_lo_1 = _mm512_hexl_mullo_epi<52>(v_op1_1, v_op2_1); + __m512i v_prod_lo_2 = _mm512_hexl_mullo_epi<52>(v_op1_2, v_op2_2); + __m512i v_prod_lo_3 = _mm512_hexl_mullo_epi<52>(v_op1_3, v_op2_3); + __m512i v_prod_lo_4 = _mm512_hexl_mullo_epi<52>(v_op1_4, v_op2_4); + __m512i v_prod_lo_5 = _mm512_hexl_mullo_epi<52>(v_op1_5, v_op2_5); + __m512i v_prod_lo_6 = _mm512_hexl_mullo_epi<52>(v_op1_6, v_op2_6); + __m512i v_prod_lo_7 = _mm512_hexl_mullo_epi<52>(v_op1_7, v_op2_7); + __m512i v_prod_lo_8 = _mm512_hexl_mullo_epi<52>(v_op1_8, v_op2_8); + __m512i v_prod_lo_9 = _mm512_hexl_mullo_epi<52>(v_op1_9, v_op2_9); + __m512i v_prod_lo_10 = _mm512_hexl_mullo_epi<52>(v_op1_10, v_op2_10); + __m512i v_prod_lo_11 = _mm512_hexl_mullo_epi<52>(v_op1_11, v_op2_11); + __m512i v_prod_lo_12 = _mm512_hexl_mullo_epi<52>(v_op1_12, v_op2_12); + __m512i v_prod_lo_13 = _mm512_hexl_mullo_epi<52>(v_op1_13, v_op2_13); + __m512i v_prod_lo_14 = _mm512_hexl_mullo_epi<52>(v_op1_14, v_op2_14); + __m512i v_prod_lo_15 = _mm512_hexl_mullo_epi<52>(v_op1_15, v_op2_15); + __m512i v_prod_lo_16 = _mm512_hexl_mullo_epi<52>(v_op1_16, v_op2_16); + + __m512i c1_lo_1 = _mm512_srli_epi64(v_prod_lo_1, ProdRightShift); + __m512i c1_lo_2 = _mm512_srli_epi64(v_prod_lo_2, ProdRightShift); + __m512i c1_lo_3 = _mm512_srli_epi64(v_prod_lo_3, ProdRightShift); + __m512i c1_lo_4 = _mm512_srli_epi64(v_prod_lo_4, ProdRightShift); + __m512i c1_lo_5 = _mm512_srli_epi64(v_prod_lo_5, ProdRightShift); + __m512i c1_lo_6 = _mm512_srli_epi64(v_prod_lo_6, ProdRightShift); + __m512i c1_lo_7 = _mm512_srli_epi64(v_prod_lo_7, ProdRightShift); + __m512i c1_lo_8 = _mm512_srli_epi64(v_prod_lo_8, ProdRightShift); + __m512i c1_lo_9 = _mm512_srli_epi64(v_prod_lo_9, ProdRightShift); + __m512i c1_lo_10 = _mm512_srli_epi64(v_prod_lo_10, ProdRightShift); + __m512i c1_lo_11 = _mm512_srli_epi64(v_prod_lo_11, ProdRightShift); + __m512i c1_lo_12 = _mm512_srli_epi64(v_prod_lo_12, ProdRightShift); + __m512i c1_lo_13 = _mm512_srli_epi64(v_prod_lo_13, ProdRightShift); + __m512i c1_lo_14 = _mm512_srli_epi64(v_prod_lo_14, ProdRightShift); + __m512i c1_lo_15 = _mm512_srli_epi64(v_prod_lo_15, ProdRightShift); + __m512i c1_lo_16 = _mm512_srli_epi64(v_prod_lo_16, ProdRightShift); + + __m512i c1_hi_1 = _mm512_slli_epi64(v_prod_hi_1, HiShift); + __m512i c1_hi_2 = _mm512_slli_epi64(v_prod_hi_2, HiShift); + __m512i c1_hi_3 = _mm512_slli_epi64(v_prod_hi_3, HiShift); + __m512i c1_hi_4 = _mm512_slli_epi64(v_prod_hi_4, HiShift); + __m512i c1_hi_5 = _mm512_slli_epi64(v_prod_hi_5, HiShift); + __m512i c1_hi_6 = _mm512_slli_epi64(v_prod_hi_6, HiShift); + __m512i c1_hi_7 = _mm512_slli_epi64(v_prod_hi_7, HiShift); + __m512i c1_hi_8 = _mm512_slli_epi64(v_prod_hi_8, HiShift); + __m512i c1_hi_9 = _mm512_slli_epi64(v_prod_hi_9, HiShift); + __m512i c1_hi_10 = _mm512_slli_epi64(v_prod_hi_10, HiShift); + __m512i c1_hi_11 = _mm512_slli_epi64(v_prod_hi_11, HiShift); + __m512i c1_hi_12 = _mm512_slli_epi64(v_prod_hi_12, HiShift); + __m512i c1_hi_13 = _mm512_slli_epi64(v_prod_hi_13, HiShift); + __m512i c1_hi_14 = _mm512_slli_epi64(v_prod_hi_14, HiShift); + __m512i c1_hi_15 = _mm512_slli_epi64(v_prod_hi_15, HiShift); + __m512i c1_hi_16 = _mm512_slli_epi64(v_prod_hi_16, HiShift); + + __m512i c1_1 = _mm512_or_epi64(c1_lo_1, c1_hi_1); + __m512i c1_2 = _mm512_or_epi64(c1_lo_2, c1_hi_2); + __m512i c1_3 = _mm512_or_epi64(c1_lo_3, c1_hi_3); + __m512i c1_4 = _mm512_or_epi64(c1_lo_4, c1_hi_4); + __m512i c1_5 = _mm512_or_epi64(c1_lo_5, c1_hi_5); + __m512i c1_6 = _mm512_or_epi64(c1_lo_6, c1_hi_6); + __m512i c1_7 = _mm512_or_epi64(c1_lo_7, c1_hi_7); + __m512i c1_8 = _mm512_or_epi64(c1_lo_8, c1_hi_8); + __m512i c1_9 = _mm512_or_epi64(c1_lo_9, c1_hi_9); + __m512i c1_10 = _mm512_or_epi64(c1_lo_10, c1_hi_10); + __m512i c1_11 = _mm512_or_epi64(c1_lo_11, c1_hi_11); + __m512i c1_12 = _mm512_or_epi64(c1_lo_12, c1_hi_12); + __m512i c1_13 = _mm512_or_epi64(c1_lo_13, c1_hi_13); + __m512i c1_14 = _mm512_or_epi64(c1_lo_14, c1_hi_14); + __m512i c1_15 = _mm512_or_epi64(c1_lo_15, c1_hi_15); + __m512i c1_16 = _mm512_or_epi64(c1_lo_16, c1_hi_16); + + __m512i q_hat_1 = _mm512_hexl_mulhi_epi<52>(c1_1, v_barr_lo); + __m512i q_hat_2 = _mm512_hexl_mulhi_epi<52>(c1_2, v_barr_lo); + __m512i q_hat_3 = _mm512_hexl_mulhi_epi<52>(c1_3, v_barr_lo); + __m512i q_hat_4 = _mm512_hexl_mulhi_epi<52>(c1_4, v_barr_lo); + __m512i q_hat_5 = _mm512_hexl_mulhi_epi<52>(c1_5, v_barr_lo); + __m512i q_hat_6 = _mm512_hexl_mulhi_epi<52>(c1_6, v_barr_lo); + __m512i q_hat_7 = _mm512_hexl_mulhi_epi<52>(c1_7, v_barr_lo); + __m512i q_hat_8 = _mm512_hexl_mulhi_epi<52>(c1_8, v_barr_lo); + __m512i q_hat_9 = _mm512_hexl_mulhi_epi<52>(c1_9, v_barr_lo); + __m512i q_hat_10 = _mm512_hexl_mulhi_epi<52>(c1_10, v_barr_lo); + __m512i q_hat_11 = _mm512_hexl_mulhi_epi<52>(c1_11, v_barr_lo); + __m512i q_hat_12 = _mm512_hexl_mulhi_epi<52>(c1_12, v_barr_lo); + __m512i q_hat_13 = _mm512_hexl_mulhi_epi<52>(c1_13, v_barr_lo); + __m512i q_hat_14 = _mm512_hexl_mulhi_epi<52>(c1_14, v_barr_lo); + __m512i q_hat_15 = _mm512_hexl_mulhi_epi<52>(c1_15, v_barr_lo); + __m512i q_hat_16 = _mm512_hexl_mulhi_epi<52>(c1_16, v_barr_lo); + + __m512i z_1 = + _mm512_hexl_mullo_add_lo_epi<52>(v_prod_lo_1, q_hat_1, v_neg_mod); + __m512i z_2 = + _mm512_hexl_mullo_add_lo_epi<52>(v_prod_lo_2, q_hat_2, v_neg_mod); + __m512i z_3 = + _mm512_hexl_mullo_add_lo_epi<52>(v_prod_lo_3, q_hat_3, v_neg_mod); + __m512i z_4 = + _mm512_hexl_mullo_add_lo_epi<52>(v_prod_lo_4, q_hat_4, v_neg_mod); + __m512i z_5 = + _mm512_hexl_mullo_add_lo_epi<52>(v_prod_lo_5, q_hat_5, v_neg_mod); + __m512i z_6 = + _mm512_hexl_mullo_add_lo_epi<52>(v_prod_lo_6, q_hat_6, v_neg_mod); + __m512i z_7 = + _mm512_hexl_mullo_add_lo_epi<52>(v_prod_lo_7, q_hat_7, v_neg_mod); + __m512i z_8 = + _mm512_hexl_mullo_add_lo_epi<52>(v_prod_lo_8, q_hat_8, v_neg_mod); + __m512i z_9 = + _mm512_hexl_mullo_add_lo_epi<52>(v_prod_lo_9, q_hat_9, v_neg_mod); + __m512i z_10 = + _mm512_hexl_mullo_add_lo_epi<52>(v_prod_lo_10, q_hat_10, v_neg_mod); + __m512i z_11 = + _mm512_hexl_mullo_add_lo_epi<52>(v_prod_lo_11, q_hat_11, v_neg_mod); + __m512i z_12 = + _mm512_hexl_mullo_add_lo_epi<52>(v_prod_lo_12, q_hat_12, v_neg_mod); + __m512i z_13 = + _mm512_hexl_mullo_add_lo_epi<52>(v_prod_lo_13, q_hat_13, v_neg_mod); + __m512i z_14 = + _mm512_hexl_mullo_add_lo_epi<52>(v_prod_lo_14, q_hat_14, v_neg_mod); + __m512i z_15 = + _mm512_hexl_mullo_add_lo_epi<52>(v_prod_lo_15, q_hat_15, v_neg_mod); + __m512i z_16 = + _mm512_hexl_mullo_add_lo_epi<52>(v_prod_lo_16, q_hat_16, v_neg_mod); + + __m512i v_result_1 = _mm512_hexl_small_mod_epu64<2>(z_1, v_modulus); + __m512i v_result_2 = _mm512_hexl_small_mod_epu64<2>(z_2, v_modulus); + __m512i v_result_3 = _mm512_hexl_small_mod_epu64<2>(z_3, v_modulus); + __m512i v_result_4 = _mm512_hexl_small_mod_epu64<2>(z_4, v_modulus); + __m512i v_result_5 = _mm512_hexl_small_mod_epu64<2>(z_5, v_modulus); + __m512i v_result_6 = _mm512_hexl_small_mod_epu64<2>(z_6, v_modulus); + __m512i v_result_7 = _mm512_hexl_small_mod_epu64<2>(z_7, v_modulus); + __m512i v_result_8 = _mm512_hexl_small_mod_epu64<2>(z_8, v_modulus); + __m512i v_result_9 = _mm512_hexl_small_mod_epu64<2>(z_9, v_modulus); + __m512i v_result_10 = _mm512_hexl_small_mod_epu64<2>(z_10, v_modulus); + __m512i v_result_11 = _mm512_hexl_small_mod_epu64<2>(z_11, v_modulus); + __m512i v_result_12 = _mm512_hexl_small_mod_epu64<2>(z_12, v_modulus); + __m512i v_result_13 = _mm512_hexl_small_mod_epu64<2>(z_13, v_modulus); + __m512i v_result_14 = _mm512_hexl_small_mod_epu64<2>(z_14, v_modulus); + __m512i v_result_15 = _mm512_hexl_small_mod_epu64<2>(z_15, v_modulus); + __m512i v_result_16 = _mm512_hexl_small_mod_epu64<2>(z_16, v_modulus); + + _mm512_storeu_si512(vp_result++, v_result_1); + _mm512_storeu_si512(vp_result++, v_result_2); + _mm512_storeu_si512(vp_result++, v_result_3); + _mm512_storeu_si512(vp_result++, v_result_4); + _mm512_storeu_si512(vp_result++, v_result_5); + _mm512_storeu_si512(vp_result++, v_result_6); + _mm512_storeu_si512(vp_result++, v_result_7); + _mm512_storeu_si512(vp_result++, v_result_8); + _mm512_storeu_si512(vp_result++, v_result_9); + _mm512_storeu_si512(vp_result++, v_result_10); + _mm512_storeu_si512(vp_result++, v_result_11); + _mm512_storeu_si512(vp_result++, v_result_12); + _mm512_storeu_si512(vp_result++, v_result_13); + _mm512_storeu_si512(vp_result++, v_result_14); + _mm512_storeu_si512(vp_result++, v_result_15); + _mm512_storeu_si512(vp_result++, v_result_16); + } +} + +// Algorithm 2 from https://homes.esat.kuleuven.be/~fvercaut/papers/bar_mont.pdf +template +void EltwiseMultModAVX512IFMAIntLoopDefault( + __m512i* vp_result, const __m512i* vp_operand1, const __m512i* vp_operand2, + __m512i v_barr_lo, __m512i v_modulus, __m512i v_neg_mod, + __m512i v_twice_mod, uint64_t n) { + HEXL_UNUSED(v_twice_mod); + HEXL_LOOP_UNROLL_4 + for (size_t i = n / 8; i > 0; --i) { + __m512i v_op1 = _mm512_loadu_si512(vp_operand1); + v_op1 = _mm512_hexl_small_mod_epu64(v_op1, v_modulus, + &v_twice_mod); + + __m512i v_op2 = _mm512_loadu_si512(vp_operand2); + v_op2 = _mm512_hexl_small_mod_epu64(v_op2, v_modulus, + &v_twice_mod); + + // Compute product U + __m512i v_prod_hi = _mm512_hexl_mulhi_epi<52>(v_op1, v_op2); + __m512i v_prod_lo = _mm512_hexl_mullo_epi<52>(v_op1, v_op2); + + // c1 = floor(U / 2^{n + beta}) + __m512i c1_lo = + _mm512_srli_epi64(v_prod_lo, static_cast(ProdRightShift)); + __m512i c1_hi = _mm512_slli_epi64( + v_prod_hi, static_cast(52ULL - (ProdRightShift))); + __m512i c1 = _mm512_or_epi64(c1_lo, c1_hi); + + // alpha - beta == 52, so we only need high 52 bits + __m512i q_hat = _mm512_hexl_mulhi_epi<52>(c1, v_barr_lo); + + // Z = prod_lo - (p * q_hat)_lo + __m512i v_result = + _mm512_hexl_mullo_add_lo_epi<52>(v_prod_lo, q_hat, v_neg_mod); + + // Reduce result to [0, q) + v_result = _mm512_hexl_small_mod_epu64<2>(v_result, v_modulus); + _mm512_storeu_si512(vp_result, v_result); + + ++vp_operand1; + ++vp_operand2; + ++vp_result; + } +} + +// Algorithm 2 from https://homes.esat.kuleuven.be/~fvercaut/papers/bar_mont.pdf +template +void EltwiseMultModAVX512IFMAIntLoopDefault( + __m512i* vp_result, const __m512i* vp_operand1, const __m512i* vp_operand2, + __m512i v_barr_lo, __m512i v_modulus, __m512i v_neg_mod, + __m512i v_twice_mod, uint64_t n, uint64_t prod_right_shift) { + unsigned int low_shift = static_cast(prod_right_shift); + unsigned int high_shift = static_cast(52 - prod_right_shift); + + HEXL_UNUSED(v_twice_mod); + HEXL_LOOP_UNROLL_4 + for (size_t i = n / 8; i > 0; --i) { + __m512i v_op1 = _mm512_loadu_si512(vp_operand1); + v_op1 = _mm512_hexl_small_mod_epu64(v_op1, v_modulus, + &v_twice_mod); + + __m512i v_op2 = _mm512_loadu_si512(vp_operand2); + v_op2 = _mm512_hexl_small_mod_epu64(v_op2, v_modulus, + &v_twice_mod); + + // Compute product + __m512i v_prod_hi = _mm512_hexl_mulhi_epi<52>(v_op1, v_op2); + __m512i v_prod_lo = _mm512_hexl_mullo_epi<52>(v_op1, v_op2); + + __m512i c1_lo = _mm512_srli_epi64(v_prod_lo, low_shift); + __m512i c1_hi = _mm512_slli_epi64(v_prod_hi, high_shift); + __m512i c1 = _mm512_or_epi64(c1_lo, c1_hi); + + // alpha - beta == 52, so we only need high 52 bits + __m512i q_hat = _mm512_hexl_mulhi_epi<52>(c1, v_barr_lo); + + // z = prod_lo - (p * q_hat)_lo + __m512i v_result = + _mm512_hexl_mullo_add_lo_epi<52>(v_prod_lo, q_hat, v_neg_mod); + + // Reduce result to [0, q) + v_result = _mm512_hexl_small_mod_epu64<2>(v_result, v_modulus); + + _mm512_storeu_si512(vp_result, v_result); + + ++vp_operand1; + ++vp_operand2; + ++vp_result; + } +} + +template +void EltwiseMultModAVX512IFMAIntLoop(__m512i* vp_result, + const __m512i* vp_operand1, + const __m512i* vp_operand2, + __m512i v_barr_lo, __m512i v_modulus, + __m512i v_neg_mod, __m512i v_twice_mod, + uint64_t n) { + switch (n) { + case 1024: { + EltwiseMultModAVX512IFMAIntLoopUnroll( + vp_result, vp_operand1, vp_operand2, v_barr_lo, v_modulus, v_neg_mod, + v_twice_mod); + break; + } + case 2048: { + EltwiseMultModAVX512IFMAIntLoopUnroll( + vp_result, vp_operand1, vp_operand2, v_barr_lo, v_modulus, v_neg_mod, + v_twice_mod); + break; + } + case 4096: { + EltwiseMultModAVX512IFMAIntLoopUnroll( + vp_result, vp_operand1, vp_operand2, v_barr_lo, v_modulus, v_neg_mod, + v_twice_mod); + break; + } + case 8192: { + EltwiseMultModAVX512IFMAIntLoopUnroll( + vp_result, vp_operand1, vp_operand2, v_barr_lo, v_modulus, v_neg_mod, + v_twice_mod); + break; + } + case 16384: { + EltwiseMultModAVX512IFMAIntLoopUnroll( + vp_result, vp_operand1, vp_operand2, v_barr_lo, v_modulus, v_neg_mod, + v_twice_mod); + break; + } + case 32768: { + EltwiseMultModAVX512IFMAIntLoopUnroll( + vp_result, vp_operand1, vp_operand2, v_barr_lo, v_modulus, v_neg_mod, + v_twice_mod); + break; + } + default: + EltwiseMultModAVX512IFMAIntLoopDefault( + vp_result, vp_operand1, vp_operand2, v_barr_lo, v_modulus, v_neg_mod, + v_twice_mod, n); + } +} + +#define ELTWISE_MULT_MOD_AVX512_IFMA_INT_PROD_RIGHT_SHIFT_CASE(ProdRightShift, \ + InputModFactor) \ + case (ProdRightShift): { \ + EltwiseMultModAVX512IFMAIntLoop<(ProdRightShift), (InputModFactor)>( \ + vp_result, vp_operand1, vp_operand2, v_barr_lo, v_modulus, v_neg_mod, \ + v_twice_mod, n); \ + break; \ + } + +// Algorithm 2 from https://homes.esat.kuleuven.be/~fvercaut/papers/bar_mont.pdf +template +void EltwiseMultModAVX512IFMAInt(uint64_t* result, const uint64_t* operand1, + const uint64_t* operand2, uint64_t n, + uint64_t modulus) { + HEXL_CHECK(InputModFactor == 1 || InputModFactor == 2 || InputModFactor == 4, + "Require InputModFactor = 1, 2, or 4") + HEXL_CHECK(modulus < (1ULL << 50), "Require modulus < (1ULL << 50)"); + HEXL_CHECK_BOUNDS(operand1, n, InputModFactor * modulus, + "operand1 exceeds bound " << (InputModFactor * modulus)); + HEXL_CHECK_BOUNDS(operand2, n, InputModFactor * modulus, + "operand2 exceeds bound " << (InputModFactor * modulus)); + HEXL_CHECK(modulus > 1, "Require modulus > 1"); + uint64_t n_mod_8 = n % 8; + if (n_mod_8 != 0) { + EltwiseMultModNative(result, operand1, operand2, n_mod_8, + modulus); + operand1 += n_mod_8; + operand2 += n_mod_8; + result += n_mod_8; + n -= n_mod_8; + } + + constexpr int64_t beta = -2; + HEXL_CHECK(beta <= -2, "beta must be <= -2 for correctness"); + constexpr int64_t alpha = 50; // ensures alpha - beta = 52 + uint64_t gamma = Log2(InputModFactor); + HEXL_UNUSED(gamma); + HEXL_CHECK(alpha >= gamma + 1, "alpha must be >= gamma + 1 for correctness"); + + const uint64_t ceil_log_mod = Log2(modulus) + 1; // "n" from Algorithm 2 + uint64_t prod_right_shift = ceil_log_mod + beta; + + // Barrett factor "mu" + // TODO(fboemer): Allow MultiplyFactor to take bit shifts != 52 + HEXL_CHECK(ceil_log_mod + alpha >= 52, "ceil_log_mod + alpha < 52"); + uint64_t barr_lo = + MultiplyFactor((1ULL << (ceil_log_mod + alpha - 52)), 52, modulus) + .BarrettFactor(); + + __m512i v_barr_lo = _mm512_set1_epi64(static_cast(barr_lo)); + __m512i v_modulus = _mm512_set1_epi64(static_cast(modulus)); + __m512i v_twice_mod = _mm512_set1_epi64(static_cast(2 * modulus)); + __m512i v_neg_mod = _mm512_set1_epi64(-static_cast(modulus)); + const __m512i* vp_operand1 = reinterpret_cast(operand1); + const __m512i* vp_operand2 = reinterpret_cast(operand2); + __m512i* vp_result = reinterpret_cast<__m512i*>(result); + + // Let d be the product operand1 * operand2. + // To ensure d >> prod_right_shift < (1ULL << 52), we need + // (input_mod_factor * modulus)^2 >> (prod_right_shift) < (1ULL << 52) + // This happens when 2*log_2(input_mod_factor) + ceil_log_mod - beta < 51 + // If not, we need to reduce the inputs to be less than modulus for + // correctness. This is less efficient, so we avoid it when possible. + bool reduce_mod = 2 * Log2(InputModFactor) + prod_right_shift - beta >= 51; + + if (reduce_mod) { + // Here, we assume beta = -2 + HEXL_CHECK(beta == -2, "beta != -2 may skip some cases"); + // This reduce_mod case happens only when + // prod_right_shift >= 51 - 2 * log2(input_mod_factor) >= 45. + // Additionally, modulus < (1ULL << 50) implies + // prod_right_shift <= 49. So N == 45, 46, 47, 48, 49 are the + // only cases here. + switch (prod_right_shift) { + // The template arguments are required for use of _mm512_hexl_shrdi_epi64, + // which requires a compile-time constant for the shift. + ELTWISE_MULT_MOD_AVX512_IFMA_INT_PROD_RIGHT_SHIFT_CASE(45, InputModFactor) + ELTWISE_MULT_MOD_AVX512_IFMA_INT_PROD_RIGHT_SHIFT_CASE(46, InputModFactor) + ELTWISE_MULT_MOD_AVX512_IFMA_INT_PROD_RIGHT_SHIFT_CASE(47, InputModFactor) + ELTWISE_MULT_MOD_AVX512_IFMA_INT_PROD_RIGHT_SHIFT_CASE(48, InputModFactor) + ELTWISE_MULT_MOD_AVX512_IFMA_INT_PROD_RIGHT_SHIFT_CASE(49, InputModFactor) + default: { + HEXL_CHECK(false, + "Bad value for prod_right_shift: " << prod_right_shift); + } + } + } else { + switch (prod_right_shift) { + // Smaller shifts are uncommon. + // ELTWISE_MULT_MOD_AVX512_IFMA_INT_PROD_RIGHT_SHIFT_CASE(15, 1) + // ELTWISE_MULT_MOD_AVX512_IFMA_INT_PROD_RIGHT_SHIFT_CASE(16, 1) + // ELTWISE_MULT_MOD_AVX512_IFMA_INT_PROD_RIGHT_SHIFT_CASE(17, 1) + // ELTWISE_MULT_MOD_AVX512_IFMA_INT_PROD_RIGHT_SHIFT_CASE(18, 1) + // ELTWISE_MULT_MOD_AVX512_IFMA_INT_PROD_RIGHT_SHIFT_CASE(19, 1) + // ELTWISE_MULT_MOD_AVX512_IFMA_INT_PROD_RIGHT_SHIFT_CASE(20, 1) + // ELTWISE_MULT_MOD_AVX512_IFMA_INT_PROD_RIGHT_SHIFT_CASE(21, 1) + // ELTWISE_MULT_MOD_AVX512_IFMA_INT_PROD_RIGHT_SHIFT_CASE(22, 1) + // ELTWISE_MULT_MOD_AVX512_IFMA_INT_PROD_RIGHT_SHIFT_CASE(23, 1) + // ELTWISE_MULT_MOD_AVX512_IFMA_INT_PROD_RIGHT_SHIFT_CASE(24, 1) + // ELTWISE_MULT_MOD_AVX512_IFMA_INT_PROD_RIGHT_SHIFT_CASE(25, 1) + // ELTWISE_MULT_MOD_AVX512_IFMA_INT_PROD_RIGHT_SHIFT_CASE(26, 1) + // ELTWISE_MULT_MOD_AVX512_IFMA_INT_PROD_RIGHT_SHIFT_CASE(27, 1) + // ELTWISE_MULT_MOD_AVX512_IFMA_INT_PROD_RIGHT_SHIFT_CASE(28, 1) + // ELTWISE_MULT_MOD_AVX512_IFMA_INT_PROD_RIGHT_SHIFT_CASE(29, 1) + // ELTWISE_MULT_MOD_AVX512_IFMA_INT_PROD_RIGHT_SHIFT_CASE(31, 1) + // ELTWISE_MULT_MOD_AVX512_IFMA_INT_PROD_RIGHT_SHIFT_CASE(32, 1) + // ELTWISE_MULT_MOD_AVX512_IFMA_INT_PROD_RIGHT_SHIFT_CASE(33, 1) + // ELTWISE_MULT_MOD_AVX512_IFMA_INT_PROD_RIGHT_SHIFT_CASE(34, 1) + // ELTWISE_MULT_MOD_AVX512_IFMA_INT_PROD_RIGHT_SHIFT_CASE(35, 1) + // ELTWISE_MULT_MOD_AVX512_IFMA_INT_PROD_RIGHT_SHIFT_CASE(36, 1) + // ELTWISE_MULT_MOD_AVX512_IFMA_INT_PROD_RIGHT_SHIFT_CASE(37, 1) + // ELTWISE_MULT_MOD_AVX512_IFMA_INT_PROD_RIGHT_SHIFT_CASE(38, 1) + // ELTWISE_MULT_MOD_AVX512_IFMA_INT_PROD_RIGHT_SHIFT_CASE(39, 1) + // ELTWISE_MULT_MOD_AVX512_IFMA_INT_PROD_RIGHT_SHIFT_CASE(40, 1) + // ELTWISE_MULT_MOD_AVX512_IFMA_INT_PROD_RIGHT_SHIFT_CASE(41, 1) + // ELTWISE_MULT_MOD_AVX512_IFMA_INT_PROD_RIGHT_SHIFT_CASE(42, 1) + // ELTWISE_MULT_MOD_AVX512_IFMA_INT_PROD_RIGHT_SHIFT_CASE(43, 1) + // ELTWISE_MULT_MOD_AVX512_IFMA_INT_PROD_RIGHT_SHIFT_CASE(44, 1) + // ELTWISE_MULT_MOD_AVX512_IFMA_INT_PROD_RIGHT_SHIFT_CASE(45, 1) + // ELTWISE_MULT_MOD_AVX512_IFMA_INT_PROD_RIGHT_SHIFT_CASE(46, 1) + // ELTWISE_MULT_MOD_AVX512_IFMA_INT_PROD_RIGHT_SHIFT_CASE(47, 1) + // ELTWISE_MULT_MOD_AVX512_IFMA_INT_PROD_RIGHT_SHIFT_CASE(48, 1) + // ELTWISE_MULT_MOD_AVX512_IFMA_INT_PROD_RIGHT_SHIFT_CASE(49, 1) + // ELTWISE_MULT_MOD_AVX512_IFMA_INT_PROD_RIGHT_SHIFT_CASE(50, 1) + // ELTWISE_MULT_MOD_AVX512_IFMA_INT_PROD_RIGHT_SHIFT_CASE(51, 1) + default: { + EltwiseMultModAVX512IFMAIntLoopDefault<1>( + vp_result, vp_operand1, vp_operand2, v_barr_lo, v_modulus, + v_neg_mod, v_twice_mod, n, prod_right_shift); + } + } + } + HEXL_CHECK_BOUNDS(result, n, modulus, "result exceeds bound " << modulus); +} + +#endif + +} // namespace hexl +} // namespace intel diff --git a/hexl_omp/eltwise/eltwise-mult-mod-internal.hpp b/hexl_omp/eltwise/eltwise-mult-mod-internal.hpp new file mode 100644 index 00000000..5eec1e7b --- /dev/null +++ b/hexl_omp/eltwise/eltwise-mult-mod-internal.hpp @@ -0,0 +1,106 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include + +#include +#include + +#include "eltwise/eltwise-mult-mod-internal.hpp" +#include "hexl/eltwise/eltwise-reduce-mod.hpp" +#include "hexl/number-theory/number-theory.hpp" +#include "hexl/util/check.hpp" +#include "hexl/util/compiler.hpp" + +namespace intel { +namespace hexl { + +/// @brief Multiplies two vectors elementwise with modular reduction +/// @param[in] result Result of element-wise multiplication +/// @param[in] operand1 Vector of elements to multiply. Each element must be +/// less than the modulus. +/// @param[in] operand2 Vector of elements to multiply. Each element must be +/// less than the modulus. +/// @param[in] n Number of elements in each vector +/// @param[in] modulus Modulus with which to perform modular reduction +/// @param[in] input_mod_factor Assumes input elements are in [0, +/// input_mod_factor * p) Must be 1, 2 or 4. +/// @details Computes \p result[i] = (\p operand1[i] * \p operand2[i]) mod \p +/// modulus for i=0, ..., \p n - 1 +/// @details Algorithm 2 from +/// https://homes.esat.kuleuven.be/~fvercaut/papers/bar_mont.pdf +template +void EltwiseMultModNative(uint64_t* result, const uint64_t* operand1, + const uint64_t* operand2, uint64_t n, + uint64_t modulus) { + HEXL_CHECK(InputModFactor == 1 || InputModFactor == 2 || InputModFactor == 4, + "Require InputModFactor = 1, 2, or 4") + HEXL_CHECK(result != nullptr, "Require result != nullptr"); + HEXL_CHECK(operand1 != nullptr, "Require operand1 != nullptr"); + HEXL_CHECK(operand2 != nullptr, "Require operand2 != nullptr"); + HEXL_CHECK(n != 0, "Require n != 0"); + HEXL_CHECK(modulus > 1, "Require modulus > 1"); + HEXL_CHECK(modulus < (1ULL << 62), "Require modulus < (1ULL << 62)"); + HEXL_CHECK_BOUNDS(operand1, n, InputModFactor * modulus, + "operand1 exceeds bound " << (InputModFactor * modulus)); + HEXL_CHECK_BOUNDS(operand2, n, InputModFactor * modulus, + "operand2 exceeds bound " << (InputModFactor * modulus)); + + constexpr int64_t beta = -2; + HEXL_CHECK(beta <= -2, "beta must be <= -2 for correctness"); + + constexpr int64_t alpha = 62; // ensures alpha - beta = 64 + + uint64_t gamma = Log2(InputModFactor); + HEXL_UNUSED(gamma); + HEXL_CHECK(alpha >= gamma + 1, "alpha must be >= gamma + 1 for correctness"); + + const uint64_t ceil_log_mod = Log2(modulus) + 1; // "n" from Algorithm 2 + uint64_t prod_right_shift = ceil_log_mod + beta; + + // Barrett factor "mu" + // TODO(fboemer): Allow MultiplyFactor to take bit shifts != 64 + HEXL_CHECK(ceil_log_mod + alpha >= 64, "ceil_log_mod + alpha < 64"); + uint64_t barr_lo = + MultiplyFactor(uint64_t(1) << (ceil_log_mod + alpha - 64), 64, modulus) + .BarrettFactor(); + + const uint64_t twice_modulus = 2 * modulus; + + +#pragma omp parallel for + for (size_t i = 0; i < n; ++i) { + uint64_t prod_hi, prod_lo, c2_hi, c2_lo, Z; + + uint64_t x = + ReduceMod(operand1[i], modulus, &twice_modulus); + uint64_t y = + ReduceMod(operand2[i], modulus, &twice_modulus); + + // Multiply inputs + MultiplyUInt64(x, y, &prod_hi, &prod_lo); + + // floor(U / 2^{n + beta}) + uint64_t c1 = (prod_lo >> (prod_right_shift)) + + (prod_hi << (64 - (prod_right_shift))); + + // c2 = floor(U / 2^{n + beta}) * mu + MultiplyUInt64(c1, barr_lo, &c2_hi, &c2_lo); + + // alpha - beta == 64, so we only need high 64 bits + uint64_t q_hat = c2_hi; + + // only compute low bits, since we know high bits will be 0 + Z = prod_lo - q_hat * modulus; + + // Conditional subtraction + result[i] = (Z >= modulus) ? (Z - modulus) : Z; + } + +} + +} // namespace hexl +} // namespace intel diff --git a/hexl_omp/eltwise/eltwise-mult-mod.cpp b/hexl_omp/eltwise/eltwise-mult-mod.cpp new file mode 100644 index 00000000..c4a423d2 --- /dev/null +++ b/hexl_omp/eltwise/eltwise-mult-mod.cpp @@ -0,0 +1,85 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "hexl/eltwise/eltwise-mult-mod.hpp" + +#include "eltwise/eltwise-mult-mod-avx512.hpp" +#include "eltwise/eltwise-mult-mod-internal.hpp" +#include "hexl/eltwise/eltwise-reduce-mod.hpp" +#include "hexl/logging/logging.hpp" +#include "hexl/number-theory/number-theory.hpp" +#include "hexl/util/aligned-allocator.hpp" +#include "hexl/util/check.hpp" +#include "util/cpu-features.hpp" + +namespace intel { +namespace hexl { + +void EltwiseMultMod(uint64_t* result, const uint64_t* operand1, + const uint64_t* operand2, uint64_t n, uint64_t modulus, + uint64_t input_mod_factor) { + HEXL_CHECK(result != nullptr, "Require result != nullptr"); + HEXL_CHECK(operand1 != nullptr, "Require operand1 != nullptr"); + HEXL_CHECK(operand2 != nullptr, "Require operand2 != nullptr"); + HEXL_CHECK(n != 0, "Require n != 0"); + HEXL_CHECK(modulus > 1, "Require modulus > 1"); + HEXL_CHECK(input_mod_factor * modulus < (1ULL << 63), + "Require input_mod_factor * modulus < (1ULL << 63)"); + HEXL_CHECK( + input_mod_factor == 1 || input_mod_factor == 2 || input_mod_factor == 4, + "Require input_mod_factor = 1, 2, or 4") + HEXL_CHECK_BOUNDS(operand1, n, input_mod_factor * modulus, + "operand1 exceeds bound " << (input_mod_factor * modulus)) + HEXL_CHECK_BOUNDS(operand2, n, input_mod_factor * modulus, + "operand2 exceeds bound " << (input_mod_factor * modulus)) + +#ifdef HEXL_HAS_AVX512DQ + if (has_avx512dq) { + if (modulus < (1ULL << 50)) { + // EltwiseMultModAVX512IFMA has similar performance to + // EltwiseMultModAVX512Float, but requires the AVX512IFMA instruction set, + // so we prefer to use EltwiseMultModAVX512Float. + switch (input_mod_factor) { + case 1: + EltwiseMultModAVX512Float<1>(result, operand1, operand2, n, modulus); + break; + case 2: + EltwiseMultModAVX512Float<2>(result, operand1, operand2, n, modulus); + break; + case 4: + EltwiseMultModAVX512Float<4>(result, operand1, operand2, n, modulus); + break; + } + } else { + switch (input_mod_factor) { + case 1: + EltwiseMultModAVX512DQInt<1>(result, operand1, operand2, n, modulus); + break; + case 2: + EltwiseMultModAVX512DQInt<2>(result, operand1, operand2, n, modulus); + break; + case 4: + EltwiseMultModAVX512DQInt<4>(result, operand1, operand2, n, modulus); + break; + } + } + return; + } +#endif + + HEXL_VLOG(3, "Calling EltwiseMultModNative"); + switch (input_mod_factor) { + case 1: + EltwiseMultModNative<1>(result, operand1, operand2, n, modulus); + break; + case 2: + EltwiseMultModNative<2>(result, operand1, operand2, n, modulus); + break; + case 4: + EltwiseMultModNative<4>(result, operand1, operand2, n, modulus); + break; + } + return; +} +} // namespace hexl +} // namespace intel diff --git a/hexl_omp/eltwise/eltwise-reduce-mod-avx512.cpp b/hexl_omp/eltwise/eltwise-reduce-mod-avx512.cpp new file mode 100644 index 00000000..144e070b --- /dev/null +++ b/hexl_omp/eltwise/eltwise-reduce-mod-avx512.cpp @@ -0,0 +1,27 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "eltwise/eltwise-reduce-mod-avx512.hpp" + +namespace intel { +namespace hexl { + +#ifdef HEXL_HAS_AVX512DQ + +template void EltwiseReduceModAVX512<64>(uint64_t* result, + const uint64_t* operand, uint64_t n, + uint64_t modulus, + uint64_t input_mod_factor, + uint64_t output_mod_factor); +#endif + +#ifdef HEXL_HAS_AVX512IFMA +template void EltwiseReduceModAVX512<52>(uint64_t* result, + const uint64_t* operand, uint64_t n, + uint64_t modulus, + uint64_t input_mod_factor, + uint64_t output_mod_factor); +#endif + +} // namespace hexl +} // namespace intel diff --git a/hexl_omp/eltwise/eltwise-reduce-mod-avx512.hpp b/hexl_omp/eltwise/eltwise-reduce-mod-avx512.hpp new file mode 100644 index 00000000..5374c9c8 --- /dev/null +++ b/hexl_omp/eltwise/eltwise-reduce-mod-avx512.hpp @@ -0,0 +1,378 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include +#include + +#include "eltwise/eltwise-reduce-mod-avx512.hpp" +#include "eltwise/eltwise-reduce-mod-internal.hpp" +#include "hexl/eltwise/eltwise-reduce-mod.hpp" +#include "hexl/logging/logging.hpp" +#include "hexl/number-theory/number-theory.hpp" +#include "hexl/util/check.hpp" +#include "util/avx512-util.hpp" + +namespace intel { +namespace hexl { + +#ifdef HEXL_HAS_AVX512DQ +template +void EltwiseReduceModAVX512(uint64_t* result, const uint64_t* operand, + uint64_t n, uint64_t modulus, + uint64_t input_mod_factor, + uint64_t output_mod_factor) { + HEXL_CHECK(operand != nullptr, "Require operand1 != nullptr"); + HEXL_CHECK(n != 0, "Require n != 0"); + HEXL_CHECK(modulus > 1, "Require modulus > 1"); + HEXL_CHECK(input_mod_factor == modulus || input_mod_factor == 2 || + input_mod_factor == 4, + "input_mod_factor must be modulus or 2 or 4" << input_mod_factor); + HEXL_CHECK(output_mod_factor == 1 || output_mod_factor == 2, + "output_mod_factor must be 1 or 2 " << output_mod_factor); + HEXL_CHECK(input_mod_factor != output_mod_factor, + "input_mod_factor must not be equal to output_mod_factor "); + + uint64_t n_tmp = n; + + // Multi-word Barrett reduction precomputation + constexpr int64_t alpha = BitShift - 2; + constexpr int64_t beta = -2; + const uint64_t ceil_log_mod = Log2(modulus) + 1; // "n" from Algorithm 2 + uint64_t prod_right_shift = ceil_log_mod + beta; + __m512i v_neg_mod = _mm512_set1_epi64(-static_cast(modulus)); + + uint64_t barrett_factor = + MultiplyFactor(uint64_t(1) << (ceil_log_mod + alpha - BitShift), BitShift, + modulus) + .BarrettFactor(); + + uint64_t barrett_factor_52 = MultiplyFactor(1, 52, modulus).BarrettFactor(); + + if (BitShift == 64) { + // Single-worded Barrett reduction. + barrett_factor = MultiplyFactor(1, 64, modulus).BarrettFactor(); + } + + __m512i v_bf = _mm512_set1_epi64(static_cast(barrett_factor)); + __m512i v_bf_52 = _mm512_set1_epi64(static_cast(barrett_factor_52)); + + // Deals with n not divisible by 8 + uint64_t n_mod_8 = n_tmp % 8; + if (n_mod_8 != 0) { + EltwiseReduceModNative(result, operand, n_mod_8, modulus, input_mod_factor, + output_mod_factor); + operand += n_mod_8; + result += n_mod_8; + n_tmp -= n_mod_8; + } + + uint64_t twice_mod = modulus << 1; + const __m512i* v_operand = reinterpret_cast(operand); + __m512i* v_result = reinterpret_cast<__m512i*>(result); + __m512i v_modulus = _mm512_set1_epi64(static_cast(modulus)); + __m512i v_twice_mod = _mm512_set1_epi64(static_cast(twice_mod)); + + if (input_mod_factor == modulus) { + if (output_mod_factor == 2) { + for (size_t i = 0; i < n_tmp; i += 8) { + __m512i v_op = _mm512_loadu_si512(v_operand); + v_op = _mm512_hexl_barrett_reduce64( + v_op, v_modulus, v_bf, v_bf_52, prod_right_shift, v_neg_mod); + HEXL_CHECK_BOUNDS(ExtractValues(v_op).data(), 8, modulus, + "v_op exceeds bound " << modulus); + _mm512_storeu_si512(v_result, v_op); + ++v_operand; + ++v_result; + } + } else { + for (size_t i = 0; i < n_tmp; i += 8) { + __m512i v_op = _mm512_loadu_si512(v_operand); + v_op = _mm512_hexl_barrett_reduce64( + v_op, v_modulus, v_bf, v_bf_52, prod_right_shift, v_neg_mod); + HEXL_CHECK_BOUNDS(ExtractValues(v_op).data(), 8, modulus, + "v_op exceeds bound " << modulus); + _mm512_storeu_si512(v_result, v_op); + ++v_operand; + ++v_result; + } + } + } + + if (input_mod_factor == 2) { + for (size_t i = 0; i < n_tmp; i += 8) { + __m512i v_op = _mm512_loadu_si512(v_operand); + v_op = _mm512_hexl_small_mod_epu64(v_op, v_modulus); + HEXL_CHECK_BOUNDS(ExtractValues(v_op).data(), 8, modulus, + "v_op exceeds bound " << modulus); + _mm512_storeu_si512(v_result, v_op); + ++v_operand; + ++v_result; + } + } + + if (input_mod_factor == 4) { + if (output_mod_factor == 1) { + for (size_t i = 0; i < n_tmp; i += 8) { + __m512i v_op = _mm512_loadu_si512(v_operand); + v_op = _mm512_hexl_small_mod_epu64(v_op, v_twice_mod); + v_op = _mm512_hexl_small_mod_epu64(v_op, v_modulus); + HEXL_CHECK_BOUNDS(ExtractValues(v_op).data(), 8, modulus, + "v_op exceeds bound " << modulus); + _mm512_storeu_si512(v_result, v_op); + ++v_operand; + ++v_result; + } + } + if (output_mod_factor == 2) { + for (size_t i = 0; i < n_tmp; i += 8) { + __m512i v_op = _mm512_loadu_si512(v_operand); + v_op = _mm512_hexl_small_mod_epu64(v_op, v_twice_mod); + HEXL_CHECK_BOUNDS(ExtractValues(v_op).data(), 8, twice_mod, + "v_op exceeds bound " << twice_mod); + _mm512_storeu_si512(v_result, v_op); + ++v_operand; + ++v_result; + } + } + } +} + +/// @brief Returns Montgomery form of modular product ab mod q, computed via the +/// REDC algorithm, also known as Montgomery reduction. +/// @tparam BitShift denotes the operational length, in bits, of the operands +/// and result values. +/// @tparam r defines the value of R, being R = 2^r. R > modulus. +/// @param[in] a input vector. T = ab in the range [0, Rq − 1]. +/// @param[in] b input vector. +/// @param[in] modulus such that gcd(R, modulus) = 1. +/// @param[in] neg_inv_mod in [0, R − 1] such that q*neg_inv_mod ≡ −1 mod R, +/// @param[in] n number of elements in input vector. +/// @param[out] result unsigned long int vector in the range [0, q − 1] such +/// that S ≡ TR^−1 mod q +template +void EltwiseMontReduceModAVX512(uint64_t* result, const uint64_t* a, + const uint64_t* b, uint64_t n, uint64_t modulus, + uint64_t neg_inv_mod) { + HEXL_CHECK(result != nullptr, "Require result != nullptr"); + HEXL_CHECK(a != nullptr, "Require operand a != nullptr"); + HEXL_CHECK(b != nullptr, "Require operand b != nullptr"); + HEXL_CHECK(n != 0, "Require n != 0"); + HEXL_CHECK(modulus > 1, "Require modulus > 1"); + + uint64_t R = (1ULL << r); + HEXL_CHECK(std::gcd(modulus, R) == 1, "gcd(modulus, R) != 1"); + HEXL_CHECK(R > modulus, "Needs R bigger than q."); + + // mod_R_mask[63:r] all zeros & mod_R_mask[r-1:0] all ones + uint64_t mod_R_mask = R - 1; + uint64_t prod_rs; + if (BitShift == 64) { + HEXL_CHECK(r <= 62, "With r > 62 internal ops might overflow"); + prod_rs = (1ULL << 63) - 1; + } else { + prod_rs = (1ULL << (52 - r)); + } + uint64_t n_tmp = n; + + // Deals with n not divisible by 8 + uint64_t n_mod_8 = n_tmp % 8; + if (n_mod_8 != 0) { + for (size_t i = 0; i < n_mod_8; ++i) { + uint64_t T_hi; + uint64_t T_lo; + MultiplyUInt64(a[i], b[i], &T_hi, &T_lo); + result[i] = MontgomeryReduce(T_hi, T_lo, modulus, r, mod_R_mask, + neg_inv_mod); + } + a += n_mod_8; + b += n_mod_8; + result += n_mod_8; + n_tmp -= n_mod_8; + } + + const __m512i* v_a = reinterpret_cast(a); + const __m512i* v_b = reinterpret_cast(b); + __m512i* v_result = reinterpret_cast<__m512i*>(result); + __m512i v_modulus = _mm512_set1_epi64(modulus); + __m512i v_neg_inv_mod = _mm512_set1_epi64(neg_inv_mod); + __m512i v_prod_rs = _mm512_set1_epi64(prod_rs); + + for (size_t i = 0; i < n_tmp; i += 8) { + __m512i v_a_op = _mm512_loadu_si512(v_a); + __m512i v_b_op = _mm512_loadu_si512(v_b); + __m512i v_T_hi = _mm512_hexl_mulhi_epi(v_a_op, v_b_op); + __m512i v_T_lo = _mm512_hexl_mullo_epi(v_a_op, v_b_op); + + // Convert to 63 bits to save intermediate carry + if (BitShift == 64) { + v_T_hi = _mm512_slli_epi64(v_T_hi, 1); + __m512i tmp = _mm512_srli_epi64(v_T_lo, 63); + v_T_hi = _mm512_add_epi64(v_T_hi, tmp); + v_T_lo = _mm512_and_epi64(v_T_lo, v_prod_rs); + } + + __m512i v_c = _mm512_hexl_montgomery_reduce( + v_T_hi, v_T_lo, v_modulus, v_neg_inv_mod, v_prod_rs); + HEXL_CHECK_BOUNDS(ExtractValues(v_c).data(), 8, modulus, + "v_op exceeds bound " << modulus); + _mm512_storeu_si512(v_result, v_c); + ++v_a; + ++v_b; + ++v_result; + } +} + +/// @brief Returns Montgomery form of a mod q, computed via the REDC algorithm, +/// also known as Montgomery reduction. +/// @tparam BitShift denotes the operational length, in bits, of the operands +/// and result values. +/// @tparam r defines the value of R, being R = 2^r. R > modulus. +/// @param[in] a input vector. T = a(R^2 mod q) in the range [0, Rq − 1]. +/// @param[in] R2_mod_q R^2 mod q. +/// @param[in] modulus such that gcd(R, modulus) = 1. +/// @param[in] neg_inv_mod in [0, R − 1] such that q*neg_inv_mod ≡ −1 mod R, +/// @param[in] n number of elements in input vector. +/// @param[out] result unsigned long int vector in the range [0, q − 1] such +/// that S ≡ TR^−1 mod q +template +void EltwiseMontgomeryFormInAVX512(uint64_t* result, const uint64_t* a, + uint64_t R2_mod_q, uint64_t n, + uint64_t modulus, uint64_t neg_inv_mod) { + HEXL_CHECK(result != nullptr, "Require result != nullptr"); + HEXL_CHECK(a != nullptr, "Require operand a != nullptr"); + HEXL_CHECK(n != 0, "Require n != 0"); + HEXL_CHECK(modulus > 1, "Require modulus > 1"); + + uint64_t R = (1ULL << r); + HEXL_CHECK(std::gcd(modulus, R) == 1, "gcd(modulus, R) != 1"); + HEXL_CHECK(R > modulus, "Needs R bigger than q."); + + // mod_R_mask[63:r] all zeros & mod_R_mask[r-1:0] all ones + uint64_t mod_R_mask = R - 1; + uint64_t prod_rs; + if (BitShift == 64) { + HEXL_CHECK(r <= 62, "With r > 62 internal ops might overflow"); + prod_rs = (1ULL << 63) - 1; + } else { + prod_rs = (1ULL << (52 - r)); + } + uint64_t n_tmp = n; + + // Deals with n not divisible by 8 + uint64_t n_mod_8 = n_tmp % 8; + if (n_mod_8 != 0) { + for (size_t i = 0; i < n_mod_8; ++i) { + uint64_t T_hi; + uint64_t T_lo; + MultiplyUInt64(a[i], R2_mod_q, &T_hi, &T_lo); + result[i] = MontgomeryReduce(T_hi, T_lo, modulus, r, mod_R_mask, + neg_inv_mod); + } + a += n_mod_8; + result += n_mod_8; + n_tmp -= n_mod_8; + } + + const __m512i* v_a = reinterpret_cast(a); + __m512i* v_result = reinterpret_cast<__m512i*>(result); + __m512i v_b = _mm512_set1_epi64(R2_mod_q); + __m512i v_modulus = _mm512_set1_epi64(modulus); + __m512i v_neg_inv_mod = _mm512_set1_epi64(neg_inv_mod); + __m512i v_prod_rs = _mm512_set1_epi64(prod_rs); + + for (size_t i = 0; i < n_tmp; i += 8) { + __m512i v_a_op = _mm512_loadu_si512(v_a); + __m512i v_T_hi = _mm512_hexl_mulhi_epi(v_a_op, v_b); + __m512i v_T_lo = _mm512_hexl_mullo_epi(v_a_op, v_b); + + // Convert to 63 bits to save intermediate carry + if (BitShift == 64) { + v_T_hi = _mm512_slli_epi64(v_T_hi, 1); + __m512i tmp = _mm512_srli_epi64(v_T_lo, 63); + v_T_hi = _mm512_add_epi64(v_T_hi, tmp); + v_T_lo = _mm512_and_epi64(v_T_lo, v_prod_rs); + } + + __m512i v_c = _mm512_hexl_montgomery_reduce( + v_T_hi, v_T_lo, v_modulus, v_neg_inv_mod, v_prod_rs); + HEXL_CHECK_BOUNDS(ExtractValues(v_c).data(), 8, modulus, + "v_op exceeds bound " << modulus); + _mm512_storeu_si512(v_result, v_c); + ++v_a; + ++v_result; + } +} + +/// @brief Convert out of the Montgomery Form computed via the REDC algorithm, +/// also known as Montgomery reduction. +/// @tparam BitShift denotes the operational length, in bits, of the operands +/// and result values. +/// @tparam r defines the value of R, being R = 2^r. R > modulus. +/// @param[in] a input vector in Montgomery Form. +/// @param[in] modulus such that gcd(R, modulus) = 1. +/// @param[in] neg_inv_mod in [0, R − 1] such that q*neg_inv_mod ≡ −1 mod R, +/// @param[in] n number of elements in input vector. +/// @param[out] result unsigned long int vector in the range [0, q − 1] such +/// that S ≡ TR^−1 mod q +template +void EltwiseMontgomeryFormOutAVX512(uint64_t* result, const uint64_t* a, + uint64_t n, uint64_t modulus, + uint64_t neg_inv_mod) { + HEXL_CHECK(result != nullptr, "Require result != nullptr"); + HEXL_CHECK(a != nullptr, "Require operand a != nullptr"); + HEXL_CHECK(n != 0, "Require n != 0"); + HEXL_CHECK(modulus > 1, "Require modulus > 1"); + + uint64_t R = (1ULL << r); + HEXL_CHECK(std::gcd(modulus, R) == 1, "gcd(modulus, R) != 1"); + HEXL_CHECK(R > modulus, "Needs R bigger than q."); + + // mod_R_mask[63:r] all zeros & mod_R_mask[r-1:0] all ones + uint64_t mod_R_mask = R - 1; + uint64_t prod_rs; + if (BitShift == 64) { + HEXL_CHECK(r <= 62, "With r > 62 internal ops might overflow"); + prod_rs = (1ULL << 63) - 1; + } else { + prod_rs = (1ULL << (52 - r)); + } + uint64_t n_tmp = n; + + // Deals with n not divisible by 8 + uint64_t n_mod_8 = n_tmp % 8; + if (n_mod_8 != 0) { + for (size_t i = 0; i < n_mod_8; ++i) { + result[i] = MontgomeryReduce(0, a[i], modulus, r, mod_R_mask, + neg_inv_mod); + } + a += n_mod_8; + result += n_mod_8; + n_tmp -= n_mod_8; + } + + const __m512i* v_a = reinterpret_cast(a); + __m512i* v_result = reinterpret_cast<__m512i*>(result); + __m512i v_modulus = _mm512_set1_epi64(modulus); + __m512i v_neg_inv_mod = _mm512_set1_epi64(neg_inv_mod); + __m512i v_prod_rs = _mm512_set1_epi64(prod_rs); + __m512i v_T_hi = _mm512_set1_epi64(0); + + for (size_t i = 0; i < n_tmp; i += 8) { + __m512i v_T_lo = _mm512_loadu_si512(v_a); + __m512i v_c = _mm512_hexl_montgomery_reduce( + v_T_hi, v_T_lo, v_modulus, v_neg_inv_mod, v_prod_rs); + HEXL_CHECK_BOUNDS(ExtractValues(v_c).data(), 8, modulus, + "v_op exceeds bound " << modulus); + _mm512_storeu_si512(v_result, v_c); + ++v_a; + ++v_result; + } +} + +#endif + +} // namespace hexl +} // namespace intel diff --git a/hexl_omp/eltwise/eltwise-reduce-mod-internal.hpp b/hexl_omp/eltwise/eltwise-reduce-mod-internal.hpp new file mode 100644 index 00000000..ce50f5e8 --- /dev/null +++ b/hexl_omp/eltwise/eltwise-reduce-mod-internal.hpp @@ -0,0 +1,28 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +namespace intel { +namespace hexl { + +// @brief Performs elementwise modular reduction +// @param[out] result Stores result +// @param[in] operand Vector of elements +// @param[in] n Number of elements in operand +// @param[in] modulus Modulus with which to perform modular reduction +// @param[in] input_mod_factor Assumes input elements are in [0, +// input_mod_factor * p) Must be modulus, 2 or 4. input_mod_factor=modulus +// means, input range is [0, p * p]. Barrett reduction will be used in this case +// input_mod_factor > output_mod_factor +// @param[in] output_mod_factor output elements will be in [0, output_mod_factor +// * p) Must be 1 or 2. for input_mod_factor=0, output_mod_factor will be set +// to 1. +void EltwiseReduceModNative(uint64_t* result, const uint64_t* operand, + uint64_t n, uint64_t modulus, + uint64_t input_mod_factor, + uint64_t output_mod_factor); +} // namespace hexl +} // namespace intel diff --git a/hexl_omp/eltwise/eltwise-reduce-mod.cpp b/hexl_omp/eltwise/eltwise-reduce-mod.cpp new file mode 100644 index 00000000..e68a3337 --- /dev/null +++ b/hexl_omp/eltwise/eltwise-reduce-mod.cpp @@ -0,0 +1,140 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "hexl/eltwise/eltwise-reduce-mod.hpp" + +#include + +#include + +#include "eltwise/eltwise-reduce-mod-avx512.hpp" +#include "eltwise/eltwise-reduce-mod-internal.hpp" +#include "hexl/logging/logging.hpp" +#include "hexl/number-theory/number-theory.hpp" +#include "hexl/util/check.hpp" +#include "util/cpu-features.hpp" + +namespace intel { +namespace hexl { + +void EltwiseReduceModNative(uint64_t* result, const uint64_t* operand, + uint64_t n, uint64_t modulus, + uint64_t input_mod_factor, + uint64_t output_mod_factor) { + HEXL_CHECK(operand != nullptr, "Require operand1 != nullptr"); + HEXL_CHECK(result != nullptr, "Require result != nullptr"); + HEXL_CHECK(n != 0, "Require n != 0"); + HEXL_CHECK(modulus > 1, "Require modulus > 1"); + HEXL_CHECK(input_mod_factor == modulus || input_mod_factor == 2 || + input_mod_factor == 4, + "input_mod_factor must be modulus or 2 or 4" << input_mod_factor); + HEXL_CHECK(output_mod_factor == 1 || output_mod_factor == 2, + "output_mod_factor must be 1 or 2 " << output_mod_factor); + HEXL_CHECK(input_mod_factor != output_mod_factor, + "input_mod_factor must not be equal to output_mod_factor "); + + uint64_t barrett_factor = MultiplyFactor(1, 64, modulus).BarrettFactor(); + + uint64_t twice_modulus = modulus << 1; + + + if (input_mod_factor == modulus) { + if (output_mod_factor == 2) { +#pragma omp parallel for + for (size_t i = 0; i < n; ++i) { + if (operand[i] >= modulus) { + result[i] = BarrettReduce64<2>(operand[i], modulus, barrett_factor); + } else { + result[i] = operand[i]; + } + } + } else { +#pragma omp parallel for + for (size_t i = 0; i < n; ++i) { + if (operand[i] >= modulus) { + result[i] = BarrettReduce64<1>(operand[i], modulus, barrett_factor); + } else { + result[i] = operand[i]; + } + } + + HEXL_CHECK_BOUNDS(result, n, modulus, "result exceeds bound " << modulus); + } + } + + if (input_mod_factor == 2) { +#pragma omp parallel for + for (size_t i = 0; i < n; ++i) { + result[i] = ReduceMod<2>(operand[i], modulus); + } + HEXL_CHECK_BOUNDS(result, n, modulus, "result exceeds bound " << modulus); + } + + if (input_mod_factor == 4) { + if (output_mod_factor == 1) { +#pragma omp parallel for + for (size_t i = 0; i < n; ++i) { + result[i] = ReduceMod<4>(operand[i], modulus, &twice_modulus); + } + HEXL_CHECK_BOUNDS(result, n, modulus, "result exceeds bound " << modulus); + } + if (output_mod_factor == 2) { +#pragma omp parallel for + for (size_t i = 0; i < n; ++i) { + result[i] = ReduceMod<2>(operand[i], twice_modulus); + } + HEXL_CHECK_BOUNDS(result, n, twice_modulus, + "result exceeds bound " << twice_modulus); + } + } + +} + +void EltwiseReduceMod(uint64_t* result, const uint64_t* operand, uint64_t n, + uint64_t modulus, uint64_t input_mod_factor, + uint64_t output_mod_factor) { + HEXL_CHECK(operand != nullptr, "Require operand1 != nullptr"); + HEXL_CHECK(result != nullptr, "Require result != nullptr"); + HEXL_CHECK(n != 0, "Require n != 0"); + HEXL_CHECK(modulus > 1, "Require modulus > 1"); + HEXL_CHECK(input_mod_factor == modulus || input_mod_factor == 2 || + input_mod_factor == 4, + "input_mod_factor must be modulus or 2 or 4" << input_mod_factor); + HEXL_CHECK(output_mod_factor == 1 || output_mod_factor == 2, + "output_mod_factor must be 1 or 2 " << output_mod_factor); + + if (input_mod_factor == output_mod_factor && (operand != result)) { + +#pragma omp parallel for + for (size_t i = 0; i < n; ++i) { + result[i] = operand[i]; + } + return; + + } + +#ifdef HEXL_HAS_AVX512IFMA + // Modulus can be 52 bits only if input mod factors <= 4 + // otherwise modulus should be 51 bits max to give correct results + if ((has_avx512ifma && modulus < (1ULL << 51)) || + (modulus < (1ULL << 52) && input_mod_factor <= 4)) { + EltwiseReduceModAVX512<52>(result, operand, n, modulus, input_mod_factor, + output_mod_factor); + return; + } +#endif + +#ifdef HEXL_HAS_AVX512DQ + if (has_avx512dq) { + EltwiseReduceModAVX512<64>(result, operand, n, modulus, input_mod_factor, + output_mod_factor); + return; + } +#endif + + HEXL_VLOG(3, "Calling EltwiseReduceModNative"); + EltwiseReduceModNative(result, operand, n, modulus, input_mod_factor, + output_mod_factor); +} +} // namespace hexl +} // namespace intel diff --git a/hexl_omp/eltwise/eltwise-sub-mod-avx512.cpp b/hexl_omp/eltwise/eltwise-sub-mod-avx512.cpp new file mode 100644 index 00000000..2039c917 --- /dev/null +++ b/hexl_omp/eltwise/eltwise-sub-mod-avx512.cpp @@ -0,0 +1,108 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "eltwise/eltwise-sub-mod-avx512.hpp" + +#include +#include + +#include "eltwise/eltwise-sub-mod-internal.hpp" +#include "hexl/eltwise/eltwise-sub-mod.hpp" +#include "hexl/util/check.hpp" +#include "util/avx512-util.hpp" + +#ifdef HEXL_HAS_AVX512DQ + +namespace intel { +namespace hexl { + +void EltwiseSubModAVX512(uint64_t* result, const uint64_t* operand1, + const uint64_t* operand2, uint64_t n, + uint64_t modulus) { + HEXL_CHECK(result != nullptr, "Require result != nullptr"); + HEXL_CHECK(operand1 != nullptr, "Require operand1 != nullptr"); + HEXL_CHECK(operand2 != nullptr, "Require operand2 != nullptr"); + HEXL_CHECK(n != 0, "Require n != 0"); + HEXL_CHECK(modulus > 1, "Require modulus > 1"); + HEXL_CHECK(modulus < (1ULL << 63), "Require modulus < 2**63"); + HEXL_CHECK_BOUNDS(operand1, n, modulus, + "pre-sub value in operand1 exceeds bound " << modulus); + HEXL_CHECK_BOUNDS(operand2, n, modulus, + "pre-sub value in operand2 exceeds bound " << modulus); + + uint64_t n_mod_8 = n % 8; + if (n_mod_8 != 0) { + EltwiseSubModNative(result, operand1, operand2, n_mod_8, modulus); + operand1 += n_mod_8; + operand2 += n_mod_8; + result += n_mod_8; + n -= n_mod_8; + } + + __m512i v_modulus = _mm512_set1_epi64(static_cast(modulus)); + __m512i* vp_result = reinterpret_cast<__m512i*>(result); + const __m512i* vp_operand1 = reinterpret_cast(operand1); + const __m512i* vp_operand2 = reinterpret_cast(operand2); + + HEXL_LOOP_UNROLL_4 + for (size_t i = n / 8; i > 0; --i) { + __m512i v_operand1 = _mm512_loadu_si512(vp_operand1); + __m512i v_operand2 = _mm512_loadu_si512(vp_operand2); + + __m512i v_result = + _mm512_hexl_small_sub_mod_epi64(v_operand1, v_operand2, v_modulus); + + _mm512_storeu_si512(vp_result, v_result); + + ++vp_result; + ++vp_operand1; + ++vp_operand2; + } + + HEXL_CHECK_BOUNDS(result, n, modulus, "result exceeds bound " << modulus); +} + +void EltwiseSubModAVX512(uint64_t* result, const uint64_t* operand1, + uint64_t operand2, uint64_t n, uint64_t modulus) { + HEXL_CHECK(result != nullptr, "Require result != nullptr"); + HEXL_CHECK(operand1 != nullptr, "Require operand1 != nullptr"); + HEXL_CHECK(n != 0, "Require n != 0"); + HEXL_CHECK(modulus > 1, "Require modulus > 1"); + HEXL_CHECK(modulus < (1ULL << 63), "Require modulus < 2**63"); + HEXL_CHECK_BOUNDS(operand1, n, modulus, + "pre-sub value in operand1 exceeds bound " << modulus); + HEXL_CHECK(operand2 < modulus, "Require operand2 < modulus"); + + uint64_t n_mod_8 = n % 8; + if (n_mod_8 != 0) { + EltwiseSubModNative(result, operand1, operand2, n_mod_8, modulus); + operand1 += n_mod_8; + result += n_mod_8; + n -= n_mod_8; + } + + __m512i v_modulus = _mm512_set1_epi64(static_cast(modulus)); + __m512i* vp_result = reinterpret_cast<__m512i*>(result); + const __m512i* vp_operand1 = reinterpret_cast(operand1); + __m512i v_operand2 = _mm512_set1_epi64(static_cast(operand2)); + + HEXL_LOOP_UNROLL_4 + for (size_t i = n / 8; i > 0; --i) { + __m512i v_operand1 = _mm512_loadu_si512(vp_operand1); + + __m512i v_result = + _mm512_hexl_small_sub_mod_epi64(v_operand1, v_operand2, v_modulus); + + _mm512_storeu_si512(vp_result, v_result); + + ++vp_result; + ++vp_operand1; + } + + HEXL_CHECK_BOUNDS(result, n, modulus, "result exceeds bound " << modulus); +} + +} // namespace hexl +} // namespace intel + +#endif diff --git a/hexl_omp/eltwise/eltwise-sub-mod-avx512.hpp b/hexl_omp/eltwise/eltwise-sub-mod-avx512.hpp new file mode 100644 index 00000000..eab9772e --- /dev/null +++ b/hexl_omp/eltwise/eltwise-sub-mod-avx512.hpp @@ -0,0 +1,19 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +namespace intel { +namespace hexl { + +void EltwiseSubModAVX512(uint64_t* result, const uint64_t* operand1, + const uint64_t* operand2, uint64_t n, + uint64_t modulus); + +void EltwiseSubModAVX512(uint64_t* result, const uint64_t* operand1, + uint64_t operand2, uint64_t n, uint64_t modulus); + +} // namespace hexl +} // namespace intel diff --git a/hexl_omp/eltwise/eltwise-sub-mod-internal.hpp b/hexl_omp/eltwise/eltwise-sub-mod-internal.hpp new file mode 100644 index 00000000..7c05dfe9 --- /dev/null +++ b/hexl_omp/eltwise/eltwise-sub-mod-internal.hpp @@ -0,0 +1,36 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +namespace intel { +namespace hexl { + +/// @brief Subtracts two vectors elementwise with modular reduction +/// @param[out] result Stores result +/// @param[in] operand1 Vector of elements to subtract from +/// @param[in] operand2 Vector of elements to subtract +/// @param[in] n Number of elements in each vector +/// @param[in] modulus Modulus with which to perform modular reduction +/// @details Computes \f$ operand1[i] = (operand1[i] - operand2[i]) \mod modulus +/// \f$ for \f$ i=0, ..., n-1\f$. +void EltwiseSubModNative(uint64_t* result, const uint64_t* operand1, + const uint64_t* operand2, uint64_t n, + uint64_t modulus); + +/// @brief Subtracts a scalar from a vector elementwise with modular reduction +/// @param[out] result Stores result +/// @param[in] operand1 Vector of elements to subtract from. Each element must +/// be less than the modulus +/// @param[in] operand2 Elements to subtract. Each element must be +/// less than the modulus +/// @param[in] n Number of elements in each vector +/// @param[in] modulus Modulus with which to perform modular reduction. Must be +/// in the range \f$[2, 2^{63} - 1]\f$ +/// @details Computes \f$ operand1[i] = (operand1[i] - operand2) \mod modulus +/// \f$ for \f$ i=0, ..., n-1\f$. +void EltwiseSubModNative(uint64_t* result, const uint64_t* operand1, + uint64_t operand2, uint64_t n, uint64_t modulus); + +} // namespace hexl +} // namespace intel diff --git a/hexl_omp/eltwise/eltwise-sub-mod.cpp b/hexl_omp/eltwise/eltwise-sub-mod.cpp new file mode 100644 index 00000000..7abe813a --- /dev/null +++ b/hexl_omp/eltwise/eltwise-sub-mod.cpp @@ -0,0 +1,124 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include + +#include + +#include "eltwise/eltwise-sub-mod-avx512.hpp" +#include "eltwise/eltwise-sub-mod-internal.hpp" +#include "hexl/eltwise/eltwise-add-mod.hpp" +#include "hexl/logging/logging.hpp" +#include "hexl/number-theory/number-theory.hpp" +#include "hexl/util/check.hpp" +#include "util/cpu-features.hpp" + +namespace intel { +namespace hexl { + +void EltwiseSubModNative(uint64_t* result, const uint64_t* operand1, + const uint64_t* operand2, uint64_t n, + uint64_t modulus) { + HEXL_CHECK(result != nullptr, "Require result != nullptr"); + HEXL_CHECK(operand1 != nullptr, "Require operand1 != nullptr"); + HEXL_CHECK(operand2 != nullptr, "Require operand2 != nullptr"); + HEXL_CHECK(n != 0, "Require n != 0"); + HEXL_CHECK(modulus > 1, "Require modulus > 1"); + HEXL_CHECK(modulus < (1ULL << 63), "Require modulus < 2**63"); + HEXL_CHECK_BOUNDS(operand1, n, modulus, + "pre-sub value in operand1 exceeds bound " << modulus); + HEXL_CHECK_BOUNDS(operand2, n, modulus, + "pre-sub value in operand2 exceeds bound " << modulus); + + // HEXL_LOOP_UNROLL_4 + // for (size_t i = 0; i < n; ++i) { + // if (*operand1 >= *operand2) { + // *result = *operand1 - *operand2; + // } else { + // *result = *operand1 + modulus - *operand2; + // } + + // ++operand1; + // ++operand2; + // ++result; + // } + +#pragma omp parallel for + for (size_t i = 0; i < n; ++i) { + if (operand1[i] >= operand2[i]) { + result[i] = operand1[i] - operand2[i]; + } else { + result[i] = operand1[i] + modulus - operand2[i]; + } + } +} + +void EltwiseSubModNative(uint64_t* result, const uint64_t* operand1, + uint64_t operand2, uint64_t n, uint64_t modulus) { + HEXL_CHECK(result != nullptr, "Require result != nullptr"); + HEXL_CHECK(operand1 != nullptr, "Require operand1 != nullptr"); + HEXL_CHECK(n != 0, "Require n != 0"); + HEXL_CHECK(modulus > 1, "Require modulus > 1"); + HEXL_CHECK(modulus < (1ULL << 63), "Require modulus < 2**63"); + HEXL_CHECK_BOUNDS(operand1, n, modulus, + "pre-sub value in operand1 exceeds bound " << modulus); + HEXL_CHECK(operand2 < modulus, "Require operand2 < modulus"); + + +#pragma omp parallel for +for (size_t i = 0; i < n; ++i) { + if (operand1[i] >= operand2) { + result[i] = operand1[i] - operand2; + } else { + result[i] = operand1[i] + modulus - operand2; + } +} + +} + +void EltwiseSubMod(uint64_t* result, const uint64_t* operand1, + const uint64_t* operand2, uint64_t n, uint64_t modulus) { + HEXL_CHECK(operand1 != nullptr, "Require operand1 != nullptr"); + HEXL_CHECK(operand2 != nullptr, "Require operand2 != nullptr"); + HEXL_CHECK(n != 0, "Require n != 0"); + HEXL_CHECK(modulus > 1, "Require modulus > 1"); + HEXL_CHECK(modulus < (1ULL << 63), "Require modulus < 2**63"); + HEXL_CHECK_BOUNDS(operand1, n, modulus, + "pre-sub value in operand1 exceeds bound " << modulus); + HEXL_CHECK_BOUNDS(operand2, n, modulus, + "pre-sub value in operand2 exceeds bound " << modulus); + +#ifdef HEXL_HAS_AVX512DQ + if (has_avx512dq) { + EltwiseSubModAVX512(result, operand1, operand2, n, modulus); + return; + } +#endif + + HEXL_VLOG(3, "Calling EltwiseSubModNative"); + EltwiseSubModNative(result, operand1, operand2, n, modulus); +} + +void EltwiseSubMod(uint64_t* result, const uint64_t* operand1, + uint64_t operand2, uint64_t n, uint64_t modulus) { + HEXL_CHECK(operand1 != nullptr, "Require operand1 != nullptr"); + HEXL_CHECK(n != 0, "Require n != 0"); + HEXL_CHECK(modulus > 1, "Require modulus > 1"); + HEXL_CHECK(modulus < (1ULL << 63), "Require modulus < 2**63"); + HEXL_CHECK_BOUNDS(operand1, n, modulus, + "pre-sub value in operand1 exceeds bound " << modulus); + HEXL_CHECK(operand2 < modulus, "Require operand2 < modulus"); + +#ifdef HEXL_HAS_AVX512DQ + if (has_avx512dq) { + EltwiseSubModAVX512(result, operand1, operand2, n, modulus); + return; + } +#endif + + HEXL_VLOG(3, "Calling EltwiseSubModNative"); + EltwiseSubModNative(result, operand1, operand2, n, modulus); +} + +} // namespace hexl +} // namespace intel diff --git a/hexl_omp/experimental/fft-like/fft-like-native.cpp b/hexl_omp/experimental/fft-like/fft-like-native.cpp new file mode 100644 index 00000000..0c84699d --- /dev/null +++ b/hexl_omp/experimental/fft-like/fft-like-native.cpp @@ -0,0 +1,423 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "hexl/experimental/fft-like/fft-like-native.hpp" + +#include + +#include "hexl/logging/logging.hpp" +#include "util/util-internal.hpp" + +namespace intel { +namespace hexl { + +inline void ComplexFwdButterflyRadix2(std::complex* X_r, + std::complex* Y_r, + const std::complex* X_op, + const std::complex* Y_op, + const std::complex W) { + HEXL_VLOG(5, "ComplexFwdButterflyRadix2"); + HEXL_VLOG(5, "Inputs: X_op " << *X_op << ", Y_op " << *Y_op << ", W " << W); + std::complex U = *X_op; + std::complex V = *Y_op * W; + *X_r = U + V; + *Y_r = U - V; + HEXL_VLOG(5, "Output X " << *X_r << ", Y " << *Y_r); +} + +inline void ComplexInvButterflyRadix2(std::complex* X_r, + std::complex* Y_r, + const std::complex* X_op, + const std::complex* Y_op, + const std::complex W) { + HEXL_VLOG(5, "ComplexInvButterflyRadix2"); + HEXL_VLOG(5, "Inputs: X_op " << *X_op << ", Y_op " << *Y_op << ", W " << W); + std::complex U = *X_op; + *X_r = U + *Y_op; + *Y_r = (U - *Y_op) * W; + HEXL_VLOG(5, "Output X " << *X_r << ", Y " << *Y_r); +} + +inline void ScaledComplexInvButterflyRadix2(std::complex* X_r, + std::complex* Y_r, + const std::complex* X_op, + const std::complex* Y_op, + const std::complex W, + const double* scalar) { + HEXL_VLOG(5, "ScaledComplexInvButterflyRadix2"); + HEXL_VLOG(5, "Inputs: X_op " << *X_op << ", Y_op " << *Y_op << ", W " << W + << ", scalar " << *scalar); + std::complex U = *X_op; + *X_r = (U + *Y_op) * (*scalar); + *Y_r = (U - *Y_op) * W; + HEXL_VLOG(5, "Output X " << *X_r << ", Y " << *Y_r); +} + +void Forward_FFTLike_ToBitReverseRadix2( + std::complex* result, const std::complex* operand, + const std::complex* root_of_unity_powers, const uint64_t n, + const double* scalar) { + HEXL_CHECK(IsPowerOfTwo(n), "degree " << n << " is not a power of 2"); + HEXL_CHECK(root_of_unity_powers != nullptr, + "root_of_unity_powers == nullptr"); + HEXL_CHECK(operand != nullptr, "operand == nullptr"); + HEXL_CHECK(result != nullptr, "result == nullptr"); + + size_t gap = (n >> 1); + + // In case of out-of-place operation do first pass and convert to in-place + { + const std::complex W = root_of_unity_powers[1]; + std::complex* X_r = result; + std::complex* Y_r = X_r + gap; + const std::complex* X_op = operand; + const std::complex* Y_op = X_op + gap; + + // First pass for out-of-order case + switch (gap) { + case 8: { + ComplexFwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + ComplexFwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + ComplexFwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + ComplexFwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + ComplexFwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + ComplexFwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + ComplexFwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + ComplexFwdButterflyRadix2(X_r, Y_r, X_op, Y_op, W); + break; + } + case 4: { + ComplexFwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + ComplexFwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + ComplexFwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + ComplexFwdButterflyRadix2(X_r, Y_r, X_op, Y_op, W); + break; + } + case 2: { + ComplexFwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + ComplexFwdButterflyRadix2(X_r, Y_r, X_op, Y_op, W); + break; + } + case 1: { + std::complex scaled_W = W; + if (scalar != nullptr) scaled_W = W * *scalar; + ComplexFwdButterflyRadix2(X_r, Y_r, X_op, Y_op, W); + break; + } + default: { + HEXL_LOOP_UNROLL_8 + for (size_t j = 0; j < gap; j += 8) { + ComplexFwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + ComplexFwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + ComplexFwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + ComplexFwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + ComplexFwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + ComplexFwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + ComplexFwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + ComplexFwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + } + } + } + gap >>= 1; + } + + // Continue with in-place operation + for (size_t m = 2; m < n; m <<= 1) { + size_t offset = 0; + switch (gap) { + case 8: { + for (size_t i = 0; i < m; i++) { + if (i != 0) { + offset += (gap << 1); + } + const std::complex W = root_of_unity_powers[m + i]; + std::complex* X_r = result + offset; + std::complex* Y_r = X_r + gap; + const std::complex* X_op = X_r; + const std::complex* Y_op = Y_r; + ComplexFwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + ComplexFwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + ComplexFwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + ComplexFwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + ComplexFwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + ComplexFwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + ComplexFwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + ComplexFwdButterflyRadix2(X_r, Y_r, X_op, Y_op, W); + } + break; + } + case 4: { + for (size_t i = 0; i < m; i++) { + if (i != 0) { + offset += (gap << 1); + } + const std::complex W = root_of_unity_powers[m + i]; + std::complex* X_r = result + offset; + std::complex* Y_r = X_r + gap; + const std::complex* X_op = X_r; + const std::complex* Y_op = Y_r; + ComplexFwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + ComplexFwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + ComplexFwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + ComplexFwdButterflyRadix2(X_r, Y_r, X_op, Y_op, W); + } + break; + } + case 2: { + for (size_t i = 0; i < m; i++) { + if (i != 0) { + offset += (gap << 1); + } + const std::complex W = root_of_unity_powers[m + i]; + std::complex* X_r = result + offset; + std::complex* Y_r = X_r + gap; + const std::complex* X_op = X_r; + const std::complex* Y_op = Y_r; + ComplexFwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + ComplexFwdButterflyRadix2(X_r, Y_r, X_op, Y_op, W); + } + break; + } + case 1: { + if (scalar == nullptr) { + for (size_t i = 0; i < m; i++) { + if (i != 0) { + offset += (gap << 1); + } + const std::complex W = root_of_unity_powers[m + i]; + std::complex* X_r = result + offset; + std::complex* Y_r = X_r + gap; + const std::complex* X_op = X_r; + const std::complex* Y_op = Y_r; + ComplexFwdButterflyRadix2(X_r, Y_r, X_op, Y_op, W); + } + } else { + for (size_t i = 0; i < m; i++) { + if (i != 0) { + offset += (gap << 1); + } + const std::complex W = + *scalar * root_of_unity_powers[m + i]; + std::complex* X_r = result + offset; + std::complex* Y_r = X_r + gap; + *X_r = (*scalar) * (*X_r); + const std::complex* X_op = X_r; + const std::complex* Y_op = Y_r; + ComplexFwdButterflyRadix2(X_r, Y_r, X_op, Y_op, W); + } + } + break; + } + default: { + for (size_t i = 0; i < m; i++) { + if (i != 0) { + offset += (gap << 1); + } + const std::complex W = root_of_unity_powers[m + i]; + std::complex* X_r = result + offset; + std::complex* Y_r = X_r + gap; + const std::complex* X_op = X_r; + const std::complex* Y_op = Y_r; + HEXL_LOOP_UNROLL_8 + for (size_t j = 0; j < gap; j += 8) { + ComplexFwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + ComplexFwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + ComplexFwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + ComplexFwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + ComplexFwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + ComplexFwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + ComplexFwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + ComplexFwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + } + } + } + } + gap >>= 1; + } +} + +void Inverse_FFTLike_FromBitReverseRadix2( + std::complex* result, const std::complex* operand, + const std::complex* inv_root_of_unity_powers, const uint64_t n, + const double* scalar) { + HEXL_CHECK(IsPowerOfTwo(n), "degree " << n << " is not a power of 2"); + HEXL_CHECK(inv_root_of_unity_powers != nullptr, + "inv_root_of_unity_powers == nullptr"); + HEXL_CHECK(operand != nullptr, "operand == nullptr"); + HEXL_CHECK(result != nullptr, "result == nullptr"); + + uint64_t n_div_2 = (n >> 1); + size_t gap = 1; + size_t root_index = 1; + + size_t stop_loop = (scalar == nullptr) ? 0 : 1; + size_t m = n_div_2; + for (; m > stop_loop; m >>= 1) { + size_t offset = 0; + + switch (gap) { + case 1: { + for (size_t i = 0; i < m; i++, root_index++) { + if (i != 0) { + offset += (gap << 1); + } + const std::complex W = inv_root_of_unity_powers[root_index]; + + std::complex* X_r = result + offset; + std::complex* Y_r = X_r + gap; + const std::complex* X_op = operand + offset; + const std::complex* Y_op = X_op + gap; + ComplexInvButterflyRadix2(X_r, Y_r, X_op, Y_op, W); + } + break; + } + case 2: { + for (size_t i = 0; i < m; i++, root_index++) { + if (i != 0) { + offset += (gap << 1); + } + const std::complex W = inv_root_of_unity_powers[root_index]; + std::complex* X_r = result + offset; + std::complex* Y_r = X_r + gap; + const std::complex* X_op = X_r; + const std::complex* Y_op = Y_r; + ComplexInvButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + ComplexInvButterflyRadix2(X_r, Y_r, X_op, Y_op, W); + } + break; + } + case 4: { + for (size_t i = 0; i < m; i++, root_index++) { + if (i != 0) { + offset += (gap << 1); + } + const std::complex W = inv_root_of_unity_powers[root_index]; + std::complex* X_r = result + offset; + std::complex* Y_r = X_r + gap; + const std::complex* X_op = X_r; + const std::complex* Y_op = Y_r; + ComplexInvButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + ComplexInvButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + ComplexInvButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + ComplexInvButterflyRadix2(X_r, Y_r, X_op, Y_op, W); + } + break; + } + case 8: { + for (size_t i = 0; i < m; i++, root_index++) { + if (i != 0) { + offset += (gap << 1); + } + const std::complex W = inv_root_of_unity_powers[root_index]; + std::complex* X_r = result + offset; + std::complex* Y_r = X_r + gap; + const std::complex* X_op = X_r; + const std::complex* Y_op = Y_r; + ComplexInvButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + ComplexInvButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + ComplexInvButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + ComplexInvButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + ComplexInvButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + ComplexInvButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + ComplexInvButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + ComplexInvButterflyRadix2(X_r, Y_r, X_op, Y_op, W); + } + break; + } + default: { + for (size_t i = 0; i < m; i++, root_index++) { + if (i != 0) { + offset += (gap << 1); + } + const std::complex W = inv_root_of_unity_powers[root_index]; + std::complex* X_r = result + offset; + std::complex* Y_r = X_r + gap; + const std::complex* X_op = X_r; + const std::complex* Y_op = Y_r; + + HEXL_LOOP_UNROLL_8 + for (size_t j = 0; j < gap; j += 8) { + ComplexInvButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + ComplexInvButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + ComplexInvButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + ComplexInvButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + ComplexInvButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + ComplexInvButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + ComplexInvButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + ComplexInvButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + } + } + } + } + gap <<= 1; + } + + if (m > 0) { + const std::complex W = + *scalar * inv_root_of_unity_powers[root_index]; + std::complex* X_r = result; + std::complex* Y_r = X_r + gap; + const std::complex* X_o = X_r; + const std::complex* Y_o = Y_r; + + switch (gap) { + case 1: { + ScaledComplexInvButterflyRadix2(X_r, Y_r, X_o, Y_o, W, scalar); + break; + } + case 2: { + ScaledComplexInvButterflyRadix2(X_r++, Y_r++, X_o++, Y_o++, W, scalar); + ScaledComplexInvButterflyRadix2(X_r, Y_r, X_o, Y_o, W, scalar); + break; + } + case 4: { + ScaledComplexInvButterflyRadix2(X_r++, Y_r++, X_o++, Y_o++, W, scalar); + ScaledComplexInvButterflyRadix2(X_r++, Y_r++, X_o++, Y_o++, W, scalar); + ScaledComplexInvButterflyRadix2(X_r++, Y_r++, X_o++, Y_o++, W, scalar); + ScaledComplexInvButterflyRadix2(X_r, Y_r, X_o, Y_o, W, scalar); + break; + } + case 8: { + ScaledComplexInvButterflyRadix2(X_r++, Y_r++, X_o++, Y_o++, W, scalar); + ScaledComplexInvButterflyRadix2(X_r++, Y_r++, X_o++, Y_o++, W, scalar); + ScaledComplexInvButterflyRadix2(X_r++, Y_r++, X_o++, Y_o++, W, scalar); + ScaledComplexInvButterflyRadix2(X_r++, Y_r++, X_o++, Y_o++, W, scalar); + ScaledComplexInvButterflyRadix2(X_r++, Y_r++, X_o++, Y_o++, W, scalar); + ScaledComplexInvButterflyRadix2(X_r++, Y_r++, X_o++, Y_o++, W, scalar); + ScaledComplexInvButterflyRadix2(X_r++, Y_r++, X_o++, Y_o++, W, scalar); + ScaledComplexInvButterflyRadix2(X_r, Y_r, X_o, Y_o, W, scalar); + break; + } + default: { + HEXL_LOOP_UNROLL_8 + for (size_t j = 0; j < gap; j += 8) { + ScaledComplexInvButterflyRadix2(X_r++, Y_r++, X_o++, Y_o++, W, + scalar); + ScaledComplexInvButterflyRadix2(X_r++, Y_r++, X_o++, Y_o++, W, + scalar); + ScaledComplexInvButterflyRadix2(X_r++, Y_r++, X_o++, Y_o++, W, + scalar); + ScaledComplexInvButterflyRadix2(X_r++, Y_r++, X_o++, Y_o++, W, + scalar); + ScaledComplexInvButterflyRadix2(X_r++, Y_r++, X_o++, Y_o++, W, + scalar); + ScaledComplexInvButterflyRadix2(X_r++, Y_r++, X_o++, Y_o++, W, + scalar); + ScaledComplexInvButterflyRadix2(X_r++, Y_r++, X_o++, Y_o++, W, + scalar); + ScaledComplexInvButterflyRadix2(X_r++, Y_r++, X_o++, Y_o++, W, + scalar); + } + } + } + } + + // When M is too short it only needs the final stage butterfly. Copying here + // in the case of out-of-place. + if (result != operand && n == 2) { + std::memcpy(result, operand, n * sizeof(std::complex)); + } +} + +} // namespace hexl +} // namespace intel diff --git a/hexl_omp/experimental/fft-like/fft-like.cpp b/hexl_omp/experimental/fft-like/fft-like.cpp new file mode 100644 index 00000000..b6a7759d --- /dev/null +++ b/hexl_omp/experimental/fft-like/fft-like.cpp @@ -0,0 +1,161 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "hexl/experimental/fft-like/fft-like.hpp" + +#include "hexl/experimental/fft-like/fft-like-native.hpp" +#include "hexl/logging/logging.hpp" + +namespace intel { +namespace hexl { + +FFTLike::FFTLike(uint64_t degree, double* in_scalar, + std::shared_ptr alloc_ptr) + : m_degree(degree), + scalar(in_scalar), + m_alloc(alloc_ptr), + m_aligned_alloc(AlignedAllocator(m_alloc)), + m_complex_roots_of_unity(m_aligned_alloc) { + HEXL_CHECK(IsPowerOfTwo(degree), + "degree " << degree << " is not a power of 2"); + HEXL_CHECK(degree > 8, "degree should be bigger than 8"); + + m_degree_bits = Log2(m_degree); + ComputeComplexRootsOfUnity(); + + if (scalar != nullptr) { + scale = *scalar / static_cast(degree); + inv_scale = 1.0 / *scalar; + } +} + +inline std::complex swap_real_imag(std::complex c) { + return std::complex(c.imag(), c.real()); +} + +void FFTLike::ComputeComplexRootsOfUnity() { + AlignedVector64> roots_of_unity(m_degree, 0, + m_aligned_alloc); + AlignedVector64> roots_in_bit_reverse(m_degree, 0, + m_aligned_alloc); + AlignedVector64> inv_roots_in_bit_reverse( + m_degree, 0, m_aligned_alloc); + uint64_t roots_degree = static_cast(m_degree) << 1; // degree > 2 + + // PI value used to calculate the roots of unity + static constexpr double PI_ = 3.1415926535897932384626433832795028842; + + // Generate 1/8 of all roots first. + size_t i = 0; + for (; i <= roots_degree / 8; i++) { + roots_of_unity[i] = + std::polar(1.0, 2 * PI_ * static_cast(i) / + static_cast(roots_degree)); + } + // Complete first 4th + for (; i <= roots_degree / 4; i++) { + roots_of_unity[i] = swap_real_imag(roots_of_unity[roots_degree / 4 - i]); + } + // Get second 4th + for (; i < roots_degree / 2; i++) { + roots_of_unity[i] = -std::conj(roots_of_unity[roots_degree / 2 - i]); + } + // Put in bit reverse and get inv roots + for (i = 1; i < m_degree; i++) { + roots_in_bit_reverse[i] = roots_of_unity[ReverseBits(i, m_degree_bits)]; + inv_roots_in_bit_reverse[i] = + std::conj(roots_of_unity[ReverseBits(i - 1, m_degree_bits) + 1]); + } + m_complex_roots_of_unity = roots_in_bit_reverse; + m_inv_complex_roots_of_unity = inv_roots_in_bit_reverse; +} + +void FFTLike::ComputeForwardFFTLike(std::complex* result, + const std::complex* operand, + const double* in_scale) { + HEXL_CHECK(result != nullptr, "result == nullptr"); + HEXL_CHECK(operand != nullptr, "operand == nullptr"); + + const double* out_scale = nullptr; + if (scalar != nullptr) { + out_scale = &inv_scale; + } else if (in_scale != nullptr) { + out_scale = in_scale; + } + +#ifdef HEXL_HAS_AVX512DQ + HEXL_VLOG(3, "Calling 64-bit AVX512-DQ FwdFFTLike"); + + Forward_FFTLike_ToBitReverseAVX512( + &(reinterpret_cast(result[0]))[0], + &(reinterpret_cast(operand[0]))[0], + &(reinterpret_cast(m_complex_roots_of_unity[0]))[0], + m_degree, out_scale); + return; +#else + HEXL_VLOG(3, "Calling Native FwdFFTLike"); + Forward_FFTLike_ToBitReverseRadix2( + result, operand, m_complex_roots_of_unity.data(), m_degree, out_scale); + return; +#endif +} + +void FFTLike::ComputeInverseFFTLike(std::complex* result, + const std::complex* operand, + const double* in_scale) { + HEXL_CHECK(result != nullptr, "result==nullptr"); + HEXL_CHECK(operand != nullptr, "operand==nullptr"); + + const double* out_scale = nullptr; + if (scalar != nullptr) { + out_scale = &scale; + } else if (in_scale != nullptr) { + out_scale = in_scale; + } + +#ifdef HEXL_HAS_AVX512DQ + HEXL_VLOG(3, "Calling 64-bit AVX512-DQ InvFFTLike"); + + Inverse_FFTLike_FromBitReverseAVX512( + &(reinterpret_cast(result[0]))[0], + &(reinterpret_cast(operand[0]))[0], + &(reinterpret_cast( + m_inv_complex_roots_of_unity[0]))[0], + m_degree, out_scale); + + return; +#else + HEXL_VLOG(3, "Calling Native InvFFTLike"); + Inverse_FFTLike_FromBitReverseRadix2(result, operand, + m_inv_complex_roots_of_unity.data(), + m_degree, out_scale); + return; +#endif +} + +void FFTLike::BuildFloatingPoints(std::complex* res, + const uint64_t* plain, + const uint64_t* threshold, + const uint64_t* decryption_modulus, + const double in_inv_scale, size_t mod_size, + size_t coeff_count) { + HEXL_UNUSED(res); + HEXL_UNUSED(plain); + HEXL_UNUSED(threshold); + HEXL_UNUSED(decryption_modulus); + HEXL_UNUSED(in_inv_scale); + HEXL_UNUSED(mod_size); + HEXL_UNUSED(coeff_count); + +#ifdef HEXL_HAS_AVX512DQ + HEXL_VLOG(3, "Calling 64-bit AVX512-DQ BuildFloatingPoints"); + + BuildFloatingPointsAVX512(&(reinterpret_cast(res[0]))[0], plain, + threshold, decryption_modulus, in_inv_scale, + mod_size, coeff_count); + return; +#endif +} + +} // namespace hexl +} // namespace intel diff --git a/hexl_omp/experimental/fft-like/fwd-fft-like-avx512.cpp b/hexl_omp/experimental/fft-like/fwd-fft-like-avx512.cpp new file mode 100644 index 00000000..50db1147 --- /dev/null +++ b/hexl_omp/experimental/fft-like/fwd-fft-like-avx512.cpp @@ -0,0 +1,482 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "hexl/experimental/fft-like/fwd-fft-like-avx512.hpp" + +#include "hexl/experimental/fft-like/fft-like-avx512-util.hpp" +#include "hexl/logging/logging.hpp" + +namespace intel { +namespace hexl { + +#ifdef HEXL_HAS_AVX512DQ + +/// @brief Final butterfly step for the Forward FFT like. +/// @param[in,out] X_real Double precision (DP) values in SIMD form representing +/// the real part of 8 complex numbers. +/// @param[in,out] X_imag DP values in SIMD form representing the +/// imaginary part of the forementioned complex numbers. +/// @param[in,out] Y_real DP values in SIMD form representing the +/// real part of 8 complex numbers. +/// @param[in,out] Y_imag DP values in SIMD form representing the +/// imaginary part of the forementioned complex numbers. +/// @param[in] W_real DP values in SIMD form representing the real part of the +/// Complex Roots of unity. +/// @param[in] W_imag DP values in SIMD form representing the imaginary part of +/// the Complex Roots of unity. +void ComplexFwdButterfly(__m512d* X_real, __m512d* X_imag, __m512d* Y_real, + __m512d* Y_imag, __m512d W_real, __m512d W_imag) { + // U = X + __m512d U_real = *X_real; + __m512d U_imag = *X_imag; + + // V = Y*W. Complex multiplication: + // (y_r + iy_b)*(w_a + iw_b) = (y_a*w_a - y_b*w_b) + i(y_a*w_b + y_b*w_a) + __m512d V_real = _mm512_mul_pd(*Y_real, W_real); + __m512d tmp = _mm512_mul_pd(*Y_imag, W_imag); + V_real = _mm512_sub_pd(V_real, tmp); + + __m512d V_imag = _mm512_mul_pd(*Y_real, W_imag); + tmp = _mm512_mul_pd(*Y_imag, W_real); + V_imag = _mm512_add_pd(V_imag, tmp); + + // X = U + V + *X_real = _mm512_add_pd(U_real, V_real); + *X_imag = _mm512_add_pd(U_imag, V_imag); + // Y = U - V + *Y_real = _mm512_sub_pd(U_real, V_real); + *Y_imag = _mm512_sub_pd(U_imag, V_imag); +} + +// Takes operand as 8 complex interleaved: This is 8 real parts followed by +// its 8 imaginary parts. +// Returns operand as 1 complex interleaved: One real part followed by its +// imaginary part. +void ComplexFwdT1(double* operand_8C_intrlvd, const double* W_1C_intrlvd, + uint64_t m, const double* scalar = nullptr) { + size_t offset = 0; + + __m512d v_scalar; + if (scalar != nullptr) { + v_scalar = _mm512_set1_pd(*scalar); + } + + // 8 | m guaranteed by n >= 16 + HEXL_LOOP_UNROLL_8 + for (size_t i = 0; i < (m >> 1); i += 8) { + double* X_real = operand_8C_intrlvd + offset; + double* X_imag = X_real + 8; + __m512d* v_out_pt = reinterpret_cast<__m512d*>(X_real); + + __m512d v_X_real; + __m512d v_X_imag; + __m512d v_Y_real; + __m512d v_Y_imag; + + ComplexLoadFwdInterleavedT1(X_real, &v_X_real, &v_Y_real); + ComplexLoadFwdInterleavedT1(X_imag, &v_X_imag, &v_Y_imag); + + // Weights + __m512d v_W_real = _mm512_set_pd( + W_1C_intrlvd[14], W_1C_intrlvd[12], W_1C_intrlvd[10], W_1C_intrlvd[8], + W_1C_intrlvd[6], W_1C_intrlvd[4], W_1C_intrlvd[2], W_1C_intrlvd[0]); + __m512d v_W_imag = _mm512_set_pd( + W_1C_intrlvd[15], W_1C_intrlvd[13], W_1C_intrlvd[11], W_1C_intrlvd[9], + W_1C_intrlvd[7], W_1C_intrlvd[5], W_1C_intrlvd[3], W_1C_intrlvd[1]); + W_1C_intrlvd += 16; + + if (scalar != nullptr) { + v_W_real = _mm512_mul_pd(v_W_real, v_scalar); + v_W_imag = _mm512_mul_pd(v_W_imag, v_scalar); + v_X_real = _mm512_mul_pd(v_X_real, v_scalar); + v_X_imag = _mm512_mul_pd(v_X_imag, v_scalar); + } + + ComplexFwdButterfly(&v_X_real, &v_X_imag, &v_Y_real, &v_Y_imag, v_W_real, + v_W_imag); + + ComplexWriteFwdInterleavedT1(v_X_real, v_Y_real, v_X_imag, v_Y_imag, + v_out_pt); + + offset += 32; + } +} + +void ComplexFwdT2(double* operand_8C_intrlvd, const double* W_1C_intrlvd, + uint64_t m) { + size_t offset = 0; + + // 4 | m guaranteed by n >= 16 + HEXL_LOOP_UNROLL_4 + for (size_t i = 0; i < (m >> 1); i += 4) { + double* X_real = operand_8C_intrlvd + offset; + double* X_imag = X_real + 8; + + __m512d* v_X_pt_real = reinterpret_cast<__m512d*>(X_real); + __m512d* v_X_pt_imag = reinterpret_cast<__m512d*>(X_imag); + + __m512d v_X_real; + __m512d v_X_imag; + __m512d v_Y_real; + __m512d v_Y_imag; + + ComplexLoadFwdInterleavedT2(X_real, &v_X_real, &v_Y_real); + ComplexLoadFwdInterleavedT2(X_imag, &v_X_imag, &v_Y_imag); + + // Weights + __m512d v_W_real = _mm512_set_pd( + W_1C_intrlvd[6], W_1C_intrlvd[6], W_1C_intrlvd[4], W_1C_intrlvd[4], + W_1C_intrlvd[2], W_1C_intrlvd[2], W_1C_intrlvd[0], W_1C_intrlvd[0]); + __m512d v_W_imag = _mm512_set_pd( + W_1C_intrlvd[7], W_1C_intrlvd[7], W_1C_intrlvd[5], W_1C_intrlvd[5], + W_1C_intrlvd[3], W_1C_intrlvd[3], W_1C_intrlvd[1], W_1C_intrlvd[1]); + W_1C_intrlvd += 8; + + ComplexFwdButterfly(&v_X_real, &v_X_imag, &v_Y_real, &v_Y_imag, v_W_real, + v_W_imag); + + _mm512_storeu_pd(v_X_pt_real, v_X_real); + _mm512_storeu_pd(v_X_pt_imag, v_X_imag); + v_X_pt_real += 2; + v_X_pt_imag += 2; + _mm512_storeu_pd(v_X_pt_real, v_Y_real); + _mm512_storeu_pd(v_X_pt_imag, v_Y_imag); + + offset += 32; + } +} + +void ComplexFwdT4(double* operand_8C_intrlvd, const double* W_1C_intrlvd, + uint64_t m) { + size_t offset = 0; + + // 2 | m guaranteed by n >= 16 + HEXL_LOOP_UNROLL_4 + for (size_t i = 0; i < (m >> 1); i += 2) { + double* X_real = operand_8C_intrlvd + offset; + double* X_imag = X_real + 8; + + __m512d* v_X_pt_real = reinterpret_cast<__m512d*>(X_real); + __m512d* v_X_pt_imag = reinterpret_cast<__m512d*>(X_imag); + + __m512d v_X_real; + __m512d v_X_imag; + __m512d v_Y_real; + __m512d v_Y_imag; + + ComplexLoadFwdInterleavedT4(X_real, &v_X_real, &v_Y_real); + ComplexLoadFwdInterleavedT4(X_imag, &v_X_imag, &v_Y_imag); + + // Weights + // x = (11, 10, 9, 8, 3, 2, 1, 0) + // y = (15, 14, 13, 12, 7, 6, 5, 4) + __m512d v_W_real = _mm512_set_pd( + W_1C_intrlvd[2], W_1C_intrlvd[2], W_1C_intrlvd[2], W_1C_intrlvd[2], + W_1C_intrlvd[0], W_1C_intrlvd[0], W_1C_intrlvd[0], W_1C_intrlvd[0]); + __m512d v_W_imag = _mm512_set_pd( + W_1C_intrlvd[3], W_1C_intrlvd[3], W_1C_intrlvd[3], W_1C_intrlvd[3], + W_1C_intrlvd[1], W_1C_intrlvd[1], W_1C_intrlvd[1], W_1C_intrlvd[1]); + + W_1C_intrlvd += 4; + + ComplexFwdButterfly(&v_X_real, &v_X_imag, &v_Y_real, &v_Y_imag, v_W_real, + v_W_imag); + + _mm512_storeu_pd(v_X_pt_real, v_X_real); + _mm512_storeu_pd(v_X_pt_imag, v_X_imag); + v_X_pt_real += 2; + v_X_pt_imag += 2; + _mm512_storeu_pd(v_X_pt_real, v_Y_real); + _mm512_storeu_pd(v_X_pt_imag, v_Y_imag); + + offset += 32; + } +} + +void ComplexFwdT8(double* operand_8C_intrlvd, const double* W_1C_intrlvd, + uint64_t gap, uint64_t m) { + size_t offset = 0; + + HEXL_LOOP_UNROLL_4 + for (size_t i = 0; i < (m >> 1); i++) { + // Referencing operand + double* X_real = operand_8C_intrlvd + offset; + double* X_imag = X_real + 8; + + double* Y_real = X_real + gap; + double* Y_imag = X_imag + gap; + + __m512d* v_X_pt_real = reinterpret_cast<__m512d*>(X_real); + __m512d* v_X_pt_imag = reinterpret_cast<__m512d*>(X_imag); + + __m512d* v_Y_pt_real = reinterpret_cast<__m512d*>(Y_real); + __m512d* v_Y_pt_imag = reinterpret_cast<__m512d*>(Y_imag); + + // Weights + __m512d v_W_real = _mm512_set1_pd(*W_1C_intrlvd++); + __m512d v_W_imag = _mm512_set1_pd(*W_1C_intrlvd++); + + // assume 8 | t + for (size_t j = 0; j < gap; j += 16) { + __m512d v_X_real = _mm512_loadu_pd(v_X_pt_real); + __m512d v_X_imag = _mm512_loadu_pd(v_X_pt_imag); + + __m512d v_Y_real = _mm512_loadu_pd(v_Y_pt_real); + __m512d v_Y_imag = _mm512_loadu_pd(v_Y_pt_imag); + + ComplexFwdButterfly(&v_X_real, &v_X_imag, &v_Y_real, &v_Y_imag, v_W_real, + v_W_imag); + + _mm512_storeu_pd(v_X_pt_real, v_X_real); + _mm512_storeu_pd(v_X_pt_imag, v_X_imag); + + _mm512_storeu_pd(v_Y_pt_real, v_Y_real); + _mm512_storeu_pd(v_Y_pt_imag, v_Y_imag); + + // Increase pointers + v_X_pt_real += 2; + v_X_pt_imag += 2; + v_Y_pt_real += 2; + v_Y_pt_imag += 2; + } + offset += (gap << 1); + } +} + +void ComplexStartFwdT8(double* result_8C_intrlvd, + const double* operand_1C_intrlvd, + const double* W_1C_intrlvd, uint64_t gap, uint64_t m) { + size_t offset = 0; + + HEXL_LOOP_UNROLL_4 + for (size_t i = 0; i < (m >> 1); i++) { + // Referencing operand + const double* X_op = operand_1C_intrlvd + offset; + const double* Y_op = X_op + gap; + const __m512d* v_X_op_pt = reinterpret_cast(X_op); + const __m512d* v_Y_op_pt = reinterpret_cast(Y_op); + + // Referencing result + double* X_r_real = result_8C_intrlvd + offset; + double* X_r_imag = X_r_real + 8; + double* Y_r_real = X_r_real + gap; + double* Y_r_imag = X_r_imag + gap; + __m512d* v_X_r_pt_real = reinterpret_cast<__m512d*>(X_r_real); + __m512d* v_X_r_pt_imag = reinterpret_cast<__m512d*>(X_r_imag); + __m512d* v_Y_r_pt_real = reinterpret_cast<__m512d*>(Y_r_real); + __m512d* v_Y_r_pt_imag = reinterpret_cast<__m512d*>(Y_r_imag); + + // Weights + __m512d v_W_real = _mm512_set1_pd(*W_1C_intrlvd++); + __m512d v_W_imag = _mm512_set1_pd(*W_1C_intrlvd++); + + // assume 8 | t + for (size_t j = 0; j < gap; j += 16) { + __m512d v_X_real; + __m512d v_X_imag; + __m512d v_Y_real; + __m512d v_Y_imag; + + ComplexLoadFwdInterleavedT8(v_X_op_pt, v_Y_op_pt, &v_X_real, &v_X_imag, + &v_Y_real, &v_Y_imag); + + ComplexFwdButterfly(&v_X_real, &v_X_imag, &v_Y_real, &v_Y_imag, v_W_real, + v_W_imag); + + _mm512_storeu_pd(v_X_r_pt_real, v_X_real); + _mm512_storeu_pd(v_X_r_pt_imag, v_X_imag); + + _mm512_storeu_pd(v_Y_r_pt_real, v_Y_real); + _mm512_storeu_pd(v_Y_r_pt_imag, v_Y_imag); + + // Increase operand & result pointers + v_X_op_pt += 2; + v_Y_op_pt += 2; + v_X_r_pt_real += 2; + v_X_r_pt_imag += 2; + v_Y_r_pt_real += 2; + v_Y_r_pt_imag += 2; + } + offset += (gap << 1); + } +} + +void Forward_FFTLike_ToBitReverseAVX512( + double* result_cmplx_intrlvd, const double* operand_cmplx_intrlvd, + const double* root_of_unity_powers_cmplx_intrlvd, const uint64_t n, + const double* scale, uint64_t recursion_depth, uint64_t recursion_half) { + HEXL_CHECK(IsPowerOfTwo(n), "n " << n << " is not a power of 2"); + HEXL_CHECK(n >= 16, + "Don't support small transforms. Need n >= 16, got n = " << n); + HEXL_VLOG(5, "root_of_unity_powers_cmplx_intrlvd " + << std::vector>( + root_of_unity_powers_cmplx_intrlvd, + root_of_unity_powers_cmplx_intrlvd + 2 * n)); + HEXL_VLOG(5, "operand_cmplx_intrlvd " << std::vector>( + operand_cmplx_intrlvd, operand_cmplx_intrlvd + 2 * n)); + + static const size_t base_fft_like_size = 1024; + + if (n <= base_fft_like_size) { // Perform breadth-first FFT like + size_t gap = n; // (2*n >> 1) Interleaved complex numbers + size_t m = 2; // require twice the size + size_t W_idx = (m << recursion_depth) + (recursion_half * m); + + // First pass in case of out of place + if (recursion_depth == 0 && gap >= 16) { + const double* W_cmplx_intrlvd = + &root_of_unity_powers_cmplx_intrlvd[W_idx]; + ComplexStartFwdT8(result_cmplx_intrlvd, operand_cmplx_intrlvd, + W_cmplx_intrlvd, gap, m); + m <<= 1; + W_idx <<= 1; + gap >>= 1; + } + + for (; gap >= 16; gap >>= 1) { + const double* W_cmplx_intrlvd = + &root_of_unity_powers_cmplx_intrlvd[W_idx]; + ComplexFwdT8(result_cmplx_intrlvd, W_cmplx_intrlvd, gap, m); + m <<= 1; + W_idx <<= 1; + } + + { + // T4 + const double* W_cmplx_intrlvd = + &root_of_unity_powers_cmplx_intrlvd[W_idx]; + ComplexFwdT4(result_cmplx_intrlvd, W_cmplx_intrlvd, m); + m <<= 1; + W_idx <<= 1; + + // T2 + W_cmplx_intrlvd = &root_of_unity_powers_cmplx_intrlvd[W_idx]; + ComplexFwdT2(result_cmplx_intrlvd, W_cmplx_intrlvd, m); + m <<= 1; + W_idx <<= 1; + + // T1 + W_cmplx_intrlvd = &root_of_unity_powers_cmplx_intrlvd[W_idx]; + ComplexFwdT1(result_cmplx_intrlvd, W_cmplx_intrlvd, m, scale); + m <<= 1; + W_idx <<= 1; + } + } else { + // Perform depth-first FFT like via recursive call + size_t gap = n; + size_t W_idx = (2ULL << recursion_depth) + (recursion_half << 1); + const double* W_cmplx_intrlvd = &root_of_unity_powers_cmplx_intrlvd[W_idx]; + + if (recursion_depth == 0) { + ComplexStartFwdT8(result_cmplx_intrlvd, operand_cmplx_intrlvd, + W_cmplx_intrlvd, gap, 2); + } else { + ComplexFwdT8(result_cmplx_intrlvd, W_cmplx_intrlvd, gap, 2); + } + + Forward_FFTLike_ToBitReverseAVX512( + result_cmplx_intrlvd, result_cmplx_intrlvd, + root_of_unity_powers_cmplx_intrlvd, n / 2, scale, recursion_depth + 1, + recursion_half * 2); + + Forward_FFTLike_ToBitReverseAVX512( + &result_cmplx_intrlvd[n], &result_cmplx_intrlvd[n], + root_of_unity_powers_cmplx_intrlvd, n / 2, scale, recursion_depth + 1, + recursion_half * 2 + 1); + } + if (recursion_depth == 0) { + HEXL_VLOG(5, + "AVX512 returning FWD FFT like result " + << std::vector>( + result_cmplx_intrlvd, result_cmplx_intrlvd + 2 * n)); + } +} + +void BuildFloatingPointsAVX512(double* res_cmplx_intrlvd, const uint64_t* plain, + const uint64_t* threshold, + const uint64_t* decryption_modulus, + const double inv_scale, const size_t mod_size, + const size_t coeff_count) { + const __m512i v_perm = _mm512_set_epi64(7, 3, 6, 2, 5, 1, 4, 0); + __m512d v_res_imag = _mm512_setzero_pd(); + __m512d* v_res_pt = reinterpret_cast<__m512d*>(res_cmplx_intrlvd); + double two_pow_64 = std::pow(2.0, 64); + + for (size_t i = 0; i < coeff_count; i += 8) { + __mmask8 zeros = 0xff; + __mmask8 cond_lt_thr = 0; + + for (int32_t j = static_cast(mod_size) - 1; zeros && (j >= 0); + j--) { + const uint64_t* base = plain + j; + __m512i v_thrld = _mm512_set1_epi64(*(threshold + j)); + __m512i v_plain = _mm512_set_epi64( + *(base + (i + 7) * mod_size), *(base + (i + 6) * mod_size), + *(base + (i + 5) * mod_size), *(base + (i + 4) * mod_size), + *(base + (i + 3) * mod_size), *(base + (i + 2) * mod_size), + *(base + (i + 1) * mod_size), *(base + (i + 0) * mod_size)); + + cond_lt_thr = static_cast(cond_lt_thr) | + static_cast( + _mm512_mask_cmplt_epu64_mask(zeros, v_plain, v_thrld)); + zeros = _mm512_mask_cmpeq_epu64_mask(zeros, v_plain, v_thrld); + } + + __mmask8 cond_ge_thr = static_cast(~cond_lt_thr); + double scaled_two_pow_64 = inv_scale; + __m512d v_zeros = _mm512_setzero_pd(); + __m512d v_res_real = _mm512_setzero_pd(); + HEXL_LOOP_UNROLL_8 + for (size_t j = 0; j < mod_size; j++, scaled_two_pow_64 *= two_pow_64) { + const uint64_t* base = plain + j; + __m512d v_scaled_p64 = _mm512_set1_pd(scaled_two_pow_64); + __m512i v_dec_moduli = _mm512_set1_epi64(*(decryption_modulus + j)); + __m512i v_curr_coeff = _mm512_set_epi64( + *(base + (i + 7) * mod_size), *(base + (i + 6) * mod_size), + *(base + (i + 5) * mod_size), *(base + (i + 4) * mod_size), + *(base + (i + 3) * mod_size), *(base + (i + 2) * mod_size), + *(base + (i + 1) * mod_size), *(base + (i + 0) * mod_size)); + + __mmask8 cond_gt_dec_mod = + _mm512_mask_cmpgt_epu64_mask(cond_ge_thr, v_curr_coeff, v_dec_moduli); + __mmask8 cond_le_dec_mod = cond_gt_dec_mod ^ cond_ge_thr; + + __m512i v_diff = _mm512_mask_sub_epi64(v_curr_coeff, cond_gt_dec_mod, + v_curr_coeff, v_dec_moduli); + v_diff = _mm512_mask_sub_epi64(v_diff, cond_le_dec_mod, v_dec_moduli, + v_curr_coeff); + + // __m512d v_scaled_diff = _mm512_castsi512_pd(v_diff); does not work + uint64_t tmp_v_ui[8]; + __m512i* tmp_v_ui_pt = reinterpret_cast<__m512i*>(tmp_v_ui); + double tmp_v_pd[8]; + _mm512_storeu_si512(tmp_v_ui_pt, v_diff); + HEXL_LOOP_UNROLL_8 + for (size_t t = 0; t < 8; t++) { + tmp_v_pd[t] = static_cast(tmp_v_ui[t]); + } + + __m512d v_casted_diff = _mm512_loadu_pd(tmp_v_pd); + // This mask avoids multiplying by inf when diff is already zero + __mmask8 cond_no_zero = _mm512_cmpneq_pd_mask(v_casted_diff, v_zeros); + __m512d v_scaled_diff = _mm512_mask_mul_pd(v_casted_diff, cond_no_zero, + v_casted_diff, v_scaled_p64); + v_res_real = _mm512_mask_add_pd(v_res_real, cond_gt_dec_mod | cond_lt_thr, + v_res_real, v_scaled_diff); + v_res_real = _mm512_mask_sub_pd(v_res_real, cond_le_dec_mod, v_res_real, + v_scaled_diff); + } + + // Make res 1 complex interleaved + v_res_real = _mm512_permutexvar_pd(v_perm, v_res_real); + __m512d v_res1 = _mm512_shuffle_pd(v_res_real, v_res_imag, 0x00); + __m512d v_res2 = _mm512_shuffle_pd(v_res_real, v_res_imag, 0xff); + _mm512_storeu_pd(v_res_pt++, v_res1); + _mm512_storeu_pd(v_res_pt++, v_res2); + } +} + +#endif // HEXL_HAS_AVX512DQ + +} // namespace hexl +} // namespace intel diff --git a/hexl_omp/experimental/fft-like/inv-fft-like-avx512.cpp b/hexl_omp/experimental/fft-like/inv-fft-like-avx512.cpp new file mode 100644 index 00000000..feb353a2 --- /dev/null +++ b/hexl_omp/experimental/fft-like/inv-fft-like-avx512.cpp @@ -0,0 +1,411 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +#include "hexl/experimental/fft-like/inv-fft-like-avx512.hpp" + +#include "hexl/experimental/fft-like/fft-like-avx512-util.hpp" +#include "hexl/logging/logging.hpp" + +namespace intel { +namespace hexl { + +#ifdef HEXL_HAS_AVX512DQ + +/// @brief Final butterfly step for the Inverse FFT like. +/// @param[in,out] X_real Double precision (DP) values in SIMD form representing +/// the real part of 8 complex numbers. +/// @param[in,out] X_imag DP values in SIMD form representing the +/// imaginary part of the forementioned complex numbers. +/// @param[in,out] Y_real DP values in SIMD form representing the +/// real part of 8 complex numbers. +/// @param[in,out] Y_imag DP values in SIMD form representing the +/// imaginary part of the forementioned complex numbers. +/// @param[in] W_real DP values in SIMD form representing the real part of the +/// Inverse Complex Roots of unity. +/// @param[in] W_imag DP values in SIMD form representing the imaginary part of +/// the Inverse Complex Roots of unity. +void ComplexInvButterfly(__m512d* X_real, __m512d* X_imag, __m512d* Y_real, + __m512d* Y_imag, __m512d W_real, __m512d W_imag, + const double* scalar = nullptr) { + // U = X, + __m512d U_real = *X_real; + __m512d U_imag = *X_imag; + + // X = U + Y + *X_real = _mm512_add_pd(U_real, *Y_real); + *X_imag = _mm512_add_pd(U_imag, *Y_imag); + + if (scalar != nullptr) { + __m512d v_scalar = _mm512_set1_pd(*scalar); + *X_real = _mm512_mul_pd(*X_real, v_scalar); + *X_imag = _mm512_mul_pd(*X_imag, v_scalar); + } + + // V = U - Y + __m512d V_real = _mm512_sub_pd(U_real, *Y_real); + __m512d V_imag = _mm512_sub_pd(U_imag, *Y_imag); + + // Y = V*W. Complex multiplication: + // (v_r + iv_b)*(w_a + iw_b) = (v_a*w_a - v_b*w_b) + i(v_a*w_b + v_b*w_a) + *Y_real = _mm512_mul_pd(V_real, W_real); + __m512d tmp = _mm512_mul_pd(V_imag, W_imag); + *Y_real = _mm512_sub_pd(*Y_real, tmp); + + *Y_imag = _mm512_mul_pd(V_real, W_imag); + tmp = _mm512_mul_pd(V_imag, W_real); + *Y_imag = _mm512_add_pd(*Y_imag, tmp); +} + +void ComplexInvT1(double* result_8C_intrlvd, const double* operand_1C_intrlvd, + const double* W_1C_intrlvd, uint64_t m) { + size_t offset = 0; + + // 8 | m guaranteed by n >= 16 + HEXL_LOOP_UNROLL_8 + for (size_t i = 0; i < (m >> 1); i += 8) { + // Referencing operand + const double* X_op_real = operand_1C_intrlvd + offset; + + // Referencing result + double* X_r_real = result_8C_intrlvd + offset; + double* X_r_imag = X_r_real + 8; + __m512d* v_X_r_pt_real = reinterpret_cast<__m512d*>(X_r_real); + __m512d* v_X_r_pt_imag = reinterpret_cast<__m512d*>(X_r_imag); + + __m512d v_X_real; + __m512d v_X_imag; + __m512d v_Y_real; + __m512d v_Y_imag; + + ComplexLoadInvInterleavedT1(X_op_real, &v_X_real, &v_X_imag, &v_Y_real, + &v_Y_imag); + + // Weights + // x = (14r, 10r, 6r, 2r, 12r, 8r, 4r, 0r); + // y = (15r, 11r, 7r, 3r, 13r, 9r, 5r, 1r); + __m512d v_W_real = _mm512_set_pd( + W_1C_intrlvd[14], W_1C_intrlvd[10], W_1C_intrlvd[6], W_1C_intrlvd[2], + W_1C_intrlvd[12], W_1C_intrlvd[8], W_1C_intrlvd[4], W_1C_intrlvd[0]); + __m512d v_W_imag = _mm512_set_pd( + W_1C_intrlvd[15], W_1C_intrlvd[11], W_1C_intrlvd[7], W_1C_intrlvd[3], + W_1C_intrlvd[13], W_1C_intrlvd[9], W_1C_intrlvd[5], W_1C_intrlvd[1]); + W_1C_intrlvd += 16; + + ComplexInvButterfly(&v_X_real, &v_X_imag, &v_Y_real, &v_Y_imag, v_W_real, + v_W_imag); + + _mm512_storeu_pd(v_X_r_pt_real, v_X_real); + _mm512_storeu_pd(v_X_r_pt_imag, v_X_imag); + v_X_r_pt_real += 2; + v_X_r_pt_imag += 2; + _mm512_storeu_pd(v_X_r_pt_real, v_Y_real); + _mm512_storeu_pd(v_X_r_pt_imag, v_Y_imag); + + offset += 32; + } +} + +void ComplexInvT2(double* operand_8C_intrlvd, const double* W_1C_intrlvd, + uint64_t m) { + size_t offset = 0; + + // 4 | m guaranteed by n >= 16 + HEXL_LOOP_UNROLL_4 + for (size_t i = 0; i < (m >> 1); i += 4) { + double* X_real = operand_8C_intrlvd + offset; + double* X_imag = X_real + 8; + + __m512d* v_X_pt_real = reinterpret_cast<__m512d*>(X_real); + __m512d* v_X_pt_imag = reinterpret_cast<__m512d*>(X_imag); + + __m512d v_X_real; + __m512d v_X_imag; + + __m512d v_Y_real; + __m512d v_Y_imag; + + ComplexLoadInvInterleavedT2(X_real, &v_X_real, &v_Y_real); + ComplexLoadInvInterleavedT2(X_imag, &v_X_imag, &v_Y_imag); + + // Weights + // x = (13, 9, 5, 1, 12, 8, 4, 0) + // y = (15, 11, 7, 3, 14, 10, 6, 2) + __m512d v_W_real = _mm512_set_pd( + W_1C_intrlvd[6], W_1C_intrlvd[4], W_1C_intrlvd[2], W_1C_intrlvd[0], + W_1C_intrlvd[6], W_1C_intrlvd[4], W_1C_intrlvd[2], W_1C_intrlvd[0]); + __m512d v_W_imag = _mm512_set_pd( + W_1C_intrlvd[7], W_1C_intrlvd[5], W_1C_intrlvd[3], W_1C_intrlvd[1], + W_1C_intrlvd[7], W_1C_intrlvd[5], W_1C_intrlvd[3], W_1C_intrlvd[1]); + W_1C_intrlvd += 8; + + ComplexInvButterfly(&v_X_real, &v_X_imag, &v_Y_real, &v_Y_imag, v_W_real, + v_W_imag); + + _mm512_storeu_pd(v_X_pt_real, v_X_real); + _mm512_storeu_pd(v_X_pt_imag, v_X_imag); + v_X_pt_real += 2; + v_X_pt_imag += 2; + _mm512_storeu_pd(v_X_pt_real, v_Y_real); + _mm512_storeu_pd(v_X_pt_imag, v_Y_imag); + + offset += 32; + } +} + +void ComplexInvT4(double* operand_8C_intrlvd, const double* W_1C_intrlvd, + uint64_t m) { + size_t offset = 0; + + // 2 | m guaranteed by n >= 16 + HEXL_LOOP_UNROLL_4 + for (size_t i = 0; i < (m >> 1); i += 2) { + double* X_real = operand_8C_intrlvd + offset; + double* X_imag = X_real + 8; + + __m512d* v_X_pt_real = reinterpret_cast<__m512d*>(X_real); + __m512d* v_X_pt_imag = reinterpret_cast<__m512d*>(X_imag); + + __m512d v_X_real; + __m512d v_X_imag; + __m512d v_Y_real; + __m512d v_Y_imag; + + ComplexLoadInvInterleavedT4(X_real, &v_X_real, &v_Y_real); + ComplexLoadInvInterleavedT4(X_imag, &v_X_imag, &v_Y_imag); + + // Weights + // x = (11, 9, 3, 1, 10, 8, 2, 0) + // y = (15, 13, 7, 5, 14, 12, 6, 4) + __m512d v_W_real = _mm512_set_pd( + W_1C_intrlvd[2], W_1C_intrlvd[2], W_1C_intrlvd[0], W_1C_intrlvd[0], + W_1C_intrlvd[2], W_1C_intrlvd[2], W_1C_intrlvd[0], W_1C_intrlvd[0]); + __m512d v_W_imag = _mm512_set_pd( + W_1C_intrlvd[3], W_1C_intrlvd[3], W_1C_intrlvd[1], W_1C_intrlvd[1], + W_1C_intrlvd[3], W_1C_intrlvd[3], W_1C_intrlvd[1], W_1C_intrlvd[1]); + + W_1C_intrlvd += 4; + + ComplexInvButterfly(&v_X_real, &v_X_imag, &v_Y_real, &v_Y_imag, v_W_real, + v_W_imag); + + ComplexWriteInvInterleavedT4(v_X_real, v_Y_real, v_X_pt_real); + ComplexWriteInvInterleavedT4(v_X_imag, v_Y_imag, v_X_pt_imag); + + offset += 32; + } +} + +void ComplexInvT8(double* operand_8C_intrlvd, const double* W_1C_intrlvd, + uint64_t gap, uint64_t m) { + size_t offset = 0; + + HEXL_LOOP_UNROLL_4 + for (size_t i = 0; i < (m >> 1); i++) { + // Referencing operand + double* X_real = operand_8C_intrlvd + offset; + double* X_imag = X_real + 8; + + double* Y_real = X_real + gap; + double* Y_imag = X_imag + gap; + + __m512d* v_X_pt_real = reinterpret_cast<__m512d*>(X_real); + __m512d* v_X_pt_imag = reinterpret_cast<__m512d*>(X_imag); + + __m512d* v_Y_pt_real = reinterpret_cast<__m512d*>(Y_real); + __m512d* v_Y_pt_imag = reinterpret_cast<__m512d*>(Y_imag); + + // Weights + __m512d v_W_real = _mm512_set1_pd(*W_1C_intrlvd++); + __m512d v_W_imag = _mm512_set1_pd(*W_1C_intrlvd++); + + // assume 8 | t + for (size_t j = 0; j < gap; j += 16) { + __m512d v_X_real = _mm512_loadu_pd(v_X_pt_real); + __m512d v_X_imag = _mm512_loadu_pd(v_X_pt_imag); + + __m512d v_Y_real = _mm512_loadu_pd(v_Y_pt_real); + __m512d v_Y_imag = _mm512_loadu_pd(v_Y_pt_imag); + + ComplexInvButterfly(&v_X_real, &v_X_imag, &v_Y_real, &v_Y_imag, v_W_real, + v_W_imag); + + _mm512_storeu_pd(v_X_pt_real, v_X_real); + _mm512_storeu_pd(v_X_pt_imag, v_X_imag); + + _mm512_storeu_pd(v_Y_pt_real, v_Y_real); + _mm512_storeu_pd(v_Y_pt_imag, v_Y_imag); + + // Increase operand & result pointers + v_X_pt_real += 2; + v_X_pt_imag += 2; + v_Y_pt_real += 2; + v_Y_pt_imag += 2; + } + offset += (gap << 1); + } +} + +// Takes operand as 8 complex interleaved: This is 8 real parts followed by +// its 8 imaginary parts. +// Returns operand as 1 complex interleaved: One real part followed by its +// imaginary part. +void ComplexFinalInvT8(double* operand_8C_intrlvd, const double* W_1C_intrlvd, + uint64_t gap, uint64_t m, + const double* scalar = nullptr) { + size_t offset = 0; + + __m512d v_scalar; + if (scalar != nullptr) { + v_scalar = _mm512_set1_pd(*scalar); + } + + HEXL_LOOP_UNROLL_4 + for (size_t i = 0; i < (m >> 1); i++, offset += (gap << 1)) { + // Referencing operand + double* X_real = operand_8C_intrlvd + offset; + double* X_imag = X_real + 8; + + double* Y_real = X_real + gap; + double* Y_imag = X_imag + gap; + + __m512d* v_X_pt_real = reinterpret_cast<__m512d*>(X_real); + __m512d* v_X_pt_imag = reinterpret_cast<__m512d*>(X_imag); + + __m512d* v_Y_pt_real = reinterpret_cast<__m512d*>(Y_real); + __m512d* v_Y_pt_imag = reinterpret_cast<__m512d*>(Y_imag); + + // Weights + __m512d v_W_real = _mm512_set1_pd(*W_1C_intrlvd++); + __m512d v_W_imag = _mm512_set1_pd(*W_1C_intrlvd++); + + if (scalar != nullptr) { + v_W_real = _mm512_mul_pd(v_W_real, v_scalar); + v_W_imag = _mm512_mul_pd(v_W_imag, v_scalar); + } + + // assume 8 | t + for (size_t j = 0; j < gap; j += 16) { + __m512d v_X_real = _mm512_loadu_pd(v_X_pt_real); + __m512d v_X_imag = _mm512_loadu_pd(v_X_pt_imag); + __m512d v_Y_real = _mm512_loadu_pd(v_Y_pt_real); + __m512d v_Y_imag = _mm512_loadu_pd(v_Y_pt_imag); + + ComplexInvButterfly(&v_X_real, &v_X_imag, &v_Y_real, &v_Y_imag, v_W_real, + v_W_imag, scalar); + + ComplexWriteInvInterleavedT8(&v_X_real, &v_X_imag, &v_Y_real, &v_Y_imag, + v_X_pt_real, v_Y_pt_real); + + // Increase operand & result pointers + v_X_pt_real += 2; + v_X_pt_imag += 2; + v_Y_pt_real += 2; + v_Y_pt_imag += 2; + } + } +} + +void Inverse_FFTLike_FromBitReverseAVX512( + double* result_cmplx_intrlvd, const double* operand_cmplx_intrlvd, + const double* inv_root_of_unity_cmplx_intrlvd, const uint64_t n, + const double* scale, uint64_t recursion_depth, uint64_t recursion_half) { + HEXL_CHECK(IsPowerOfTwo(n), "n " << n << " is not a power of 2"); + HEXL_CHECK(n >= 16, + "Don't support small transforms. Need n >= 16, got n = " << n); + HEXL_VLOG(5, "inv_root_of_unity_cmplx_intrlvd " + << std::vector>( + inv_root_of_unity_cmplx_intrlvd, + inv_root_of_unity_cmplx_intrlvd + 2 * n)); + HEXL_VLOG(5, "operand_cmplx_intrlvd " << std::vector>( + operand_cmplx_intrlvd, operand_cmplx_intrlvd + 2 * n)); + size_t gap = 2; // Interleaved complex values requires twice the size + size_t m = n; // (2*n >> 1); + size_t W_idx = 2 + m * recursion_half; // 2*1 + + static const size_t base_fft_like_size = 1024; + + if (n <= base_fft_like_size) { // Perform breadth-first InvFFT like + // T1 + const double* W_cmplx_intrlvd = &inv_root_of_unity_cmplx_intrlvd[W_idx]; + ComplexInvT1(result_cmplx_intrlvd, operand_cmplx_intrlvd, W_cmplx_intrlvd, + m); + gap <<= 1; + m >>= 1; + uint64_t W_idx_delta = + m * ((1ULL << (recursion_depth + 1)) - recursion_half); + W_idx += W_idx_delta; + + // T2 + W_cmplx_intrlvd = &inv_root_of_unity_cmplx_intrlvd[W_idx]; + ComplexInvT2(result_cmplx_intrlvd, W_cmplx_intrlvd, m); + gap <<= 1; + m >>= 1; + W_idx_delta = m * ((1ULL << (recursion_depth + 1)) - recursion_half); + W_idx += W_idx_delta; + + // T4 + W_cmplx_intrlvd = &inv_root_of_unity_cmplx_intrlvd[W_idx]; + ComplexInvT4(result_cmplx_intrlvd, W_cmplx_intrlvd, m); + gap <<= 1; + m >>= 1; + W_idx_delta = m * ((1ULL << (recursion_depth + 1)) - recursion_half); + W_idx += W_idx_delta; + + while (m > 2) { + W_cmplx_intrlvd = &inv_root_of_unity_cmplx_intrlvd[W_idx]; + ComplexInvT8(result_cmplx_intrlvd, W_cmplx_intrlvd, gap, m); + gap <<= 1; + m >>= 1; + W_idx_delta = m * ((1ULL << (recursion_depth + 1)) - recursion_half); + W_idx += W_idx_delta; + } + + W_cmplx_intrlvd = &inv_root_of_unity_cmplx_intrlvd[W_idx]; + if (recursion_depth == 0) { + ComplexFinalInvT8(result_cmplx_intrlvd, W_cmplx_intrlvd, gap, m, scale); + HEXL_VLOG(5, + "AVX512 returning INV FFT like result " + << std::vector>( + result_cmplx_intrlvd, result_cmplx_intrlvd + 2 * n)); + } else { + ComplexInvT8(result_cmplx_intrlvd, W_cmplx_intrlvd, gap, m); + } + gap <<= 1; + m >>= 1; + W_idx_delta = m * ((1ULL << (recursion_depth + 1)) - recursion_half); + W_idx += W_idx_delta; + } else { + Inverse_FFTLike_FromBitReverseAVX512( + result_cmplx_intrlvd, operand_cmplx_intrlvd, + inv_root_of_unity_cmplx_intrlvd, n / 2, scale, recursion_depth + 1, + 2 * recursion_half); + Inverse_FFTLike_FromBitReverseAVX512( + &result_cmplx_intrlvd[n], &operand_cmplx_intrlvd[n], + inv_root_of_unity_cmplx_intrlvd, n / 2, scale, recursion_depth + 1, + 2 * recursion_half + 1); + uint64_t W_delta = m * ((1ULL << (recursion_depth + 1)) - recursion_half); + for (; m > 2; m >>= 1) { + gap <<= 1; + W_delta >>= 1; + W_idx += W_delta; + } + const double* W_cmplx_intrlvd = &inv_root_of_unity_cmplx_intrlvd[W_idx]; + if (recursion_depth == 0) { + ComplexFinalInvT8(result_cmplx_intrlvd, W_cmplx_intrlvd, gap, m, scale); + HEXL_VLOG(5, + "AVX512 returning INV FFT like result " + << std::vector>( + result_cmplx_intrlvd, result_cmplx_intrlvd + 2 * n)); + } else { + ComplexInvT8(result_cmplx_intrlvd, W_cmplx_intrlvd, gap, m); + } + gap <<= 1; + m >>= 1; + W_delta = m * ((1ULL << (recursion_depth + 1)) - recursion_half); + W_idx += W_delta; + } +} + +#endif // HEXL_HAS_AVX512DQ + +} // namespace hexl +} // namespace intel diff --git a/hexl_omp/experimental/misc/lr-mat-vec-mult.cpp b/hexl_omp/experimental/misc/lr-mat-vec-mult.cpp new file mode 100644 index 00000000..729ae160 --- /dev/null +++ b/hexl_omp/experimental/misc/lr-mat-vec-mult.cpp @@ -0,0 +1,149 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "hexl/experimental/misc/lr-mat-vec-mult.hpp" + +#include + +#include "hexl/eltwise/eltwise-add-mod.hpp" +#include "hexl/eltwise/eltwise-mult-mod.hpp" +#include "hexl/number-theory/number-theory.hpp" +#include "hexl/util/aligned-allocator.hpp" +#include "hexl/util/check.hpp" +#include "util/cpu-features.hpp" + +namespace intel { +namespace hexl { + +// operand1: num_weights x 2 x n x num_moduli +// operand2: num_weights x 2 x n x num_moduli +// +// results: num_weights x 3 x n x num_moduli +// [num_weights x {x[0].*y[0], x[0].*y[1]+x[1].*y[0], x[1].*y[1]} x num_moduli]. +// TODO(@fdiasmor): Ideally, the size of results can be optimized to [3 x n x +// num_moduli]. +void LinRegMatrixVectorMultiply(uint64_t* result, const uint64_t* operand1, + const uint64_t* operand2, uint64_t n, + const uint64_t* moduli, uint64_t num_moduli, + uint64_t num_weights) { + HEXL_CHECK(result != nullptr, "Require result != nullptr"); + HEXL_CHECK(operand1 != nullptr, "Require operand1 != nullptr"); + HEXL_CHECK(operand2 != nullptr, "Require operand2 != nullptr"); + HEXL_CHECK(moduli != nullptr, "Require moduli != nullptr"); + HEXL_CHECK(n != 0, "Require n != 0"); + HEXL_CHECK(num_weights != 0, "Require n != 0"); + + // pointer increment to switch to a next polynomial + size_t poly_size = n * num_moduli; + + // ciphertext increment to switch to the next ciphertext + size_t cipher_size = 2 * poly_size; + + // ciphertext output increment to switch to the next output + size_t output_size = 3 * poly_size; + + AlignedVector64 temp(n, 0); + + for (size_t r = 0; r < num_weights; r++) { + size_t next_output = r * output_size; + size_t next_poly_pair = r * cipher_size; + uint64_t* cipher2 = result + next_output; + const uint64_t* cipher0 = operand1 + next_poly_pair; + const uint64_t* cipher1 = operand2 + next_poly_pair; + + for (size_t i = 0; i < num_moduli; i++) { + size_t i_times_n = i * n; + size_t poly0_offset = i_times_n; + size_t poly1_offset = poly0_offset + poly_size; + size_t poly2_offset = poly0_offset + 2 * poly_size; + + // Output ciphertext has 3 polynomials, where x, y are the input + // ciphertexts: (x[0] * y[0], x[0] * y[1] + x[1] * y[0], x[1] * y[1]) + + // Compute third output polynomial + // Output written directly to result rather than temporary buffer + // result[2] = x[1] * y[1] + intel::hexl::EltwiseMultMod(cipher2 + poly2_offset, + cipher0 + poly1_offset, + cipher1 + poly1_offset, n, moduli[i], 1); + + // Compute second output polynomial + // result[1] = x[1] * y[0] + intel::hexl::EltwiseMultMod(cipher2 + poly1_offset, + cipher0 + poly1_offset, + cipher1 + poly0_offset, n, moduli[i], 1); + + // result[1] = x[0] * y[1] + intel::hexl::EltwiseMultMod(temp.data(), cipher0 + poly0_offset, + cipher1 + poly1_offset, n, moduli[i], 1); + // result[1] += temp_poly + intel::hexl::EltwiseAddMod(cipher2 + poly1_offset, cipher2 + poly1_offset, + temp.data(), n, moduli[i]); + + // Compute first output polynomial + // result[0] = x[0] * y[0] + intel::hexl::EltwiseMultMod(cipher2 + poly0_offset, + cipher0 + poly0_offset, + cipher1 + poly0_offset, n, moduli[i], 1); + } + } + + const bool USE_ADDER_TREE = true; + if (USE_ADDER_TREE) { + // Accumulate with the adder-tree algorithm in O(logn) + for (size_t dist = 1; dist < num_weights; dist += dist) { + size_t step = dist * 2; + size_t neighbor_cipher_incr = dist * output_size; + // This loop can leverage parallelism using #pragma unroll + for (size_t s = 0; s < num_weights; s += step) { + size_t next_cipher_pair_incr = s * output_size; + uint64_t* left_cipher = result + next_cipher_pair_incr; + uint64_t* right_cipher = left_cipher + neighbor_cipher_incr; + + // This loop can leverage parallelism using #pragma unroll + for (size_t i = 0; i < num_moduli; i++) { + size_t i_times_n = i * n; + size_t poly0_offset = i_times_n; + size_t poly1_offset = poly0_offset + poly_size; + size_t poly2_offset = poly0_offset + 2 * poly_size; + + // All EltwiseAddMod below can run in parallel + intel::hexl::EltwiseAddMod(left_cipher + poly0_offset, + right_cipher + poly0_offset, + left_cipher + poly0_offset, n, moduli[i]); + intel::hexl::EltwiseAddMod(left_cipher + poly1_offset, + right_cipher + poly1_offset, + left_cipher + poly1_offset, n, moduli[i]); + intel::hexl::EltwiseAddMod(left_cipher + poly2_offset, + right_cipher + poly2_offset, + left_cipher + poly2_offset, n, moduli[i]); + } + } + } + } else { + // Accumulate all rows in sequence + uint64_t* acc = result; + for (size_t r = 1; r < num_weights; r++) { + size_t next_cipher = r * output_size; + acc += next_cipher; + for (size_t i = 0; i < num_moduli; i++) { + size_t i_times_n = i * n; + size_t poly0_offset = i_times_n; + size_t poly1_offset = poly0_offset + poly_size; + size_t poly2_offset = poly0_offset + 2 * poly_size; + + // All EltwiseAddMod below can run in parallel + + intel::hexl::EltwiseAddMod(result + poly0_offset, result + poly0_offset, + acc + poly0_offset, n, moduli[i]); + intel::hexl::EltwiseAddMod(result + poly1_offset, result + poly1_offset, + acc + poly1_offset, n, moduli[i]); + intel::hexl::EltwiseAddMod(result + poly2_offset, result + poly2_offset, + acc + poly2_offset, n, moduli[i]); + } + } + } +} + +} // namespace hexl +} // namespace intel diff --git a/hexl_omp/experimental/seal/dyadic-multiply-internal.cpp b/hexl_omp/experimental/seal/dyadic-multiply-internal.cpp new file mode 100644 index 00000000..e321ec4a --- /dev/null +++ b/hexl_omp/experimental/seal/dyadic-multiply-internal.cpp @@ -0,0 +1,78 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "hexl/experimental/seal/dyadic-multiply-internal.hpp" + +#include "hexl/eltwise/eltwise-add-mod.hpp" +#include "hexl/eltwise/eltwise-mult-mod.hpp" +#include "hexl/number-theory/number-theory.hpp" +#include "hexl/util/aligned-allocator.hpp" +#include "hexl/util/check.hpp" +#include "util/cpu-features.hpp" + +namespace intel { +namespace hexl { +namespace internal { + +void DyadicMultiply(uint64_t* result, const uint64_t* operand1, + const uint64_t* operand2, uint64_t n, + const uint64_t* moduli, uint64_t num_moduli) { + HEXL_CHECK(result != nullptr, "Require result != nullptr"); + HEXL_CHECK(operand1 != nullptr, "Require operand1 != nullptr"); + HEXL_CHECK(operand2 != nullptr, "Require operand2 != nullptr"); + HEXL_CHECK(moduli != nullptr, "Require moduli != nullptr"); + HEXL_CHECK(n != 0, "Require n != 0"); + + // pointer increment to switch to a next polynomial + size_t poly_size = n * num_moduli; + + // Output ciphertext has 3 polynomials, where x, y are the input + // ciphertexts: (x[0] * y[0], x[0] * y[1] + x[1] * y[0], x[1] * y[1]) + + // TODO(fboemer): Determine based on cpu cache size + size_t tile_size = std::min(n, uint64_t(512)); + size_t num_tiles = n / tile_size; + + AlignedVector64 temp(tile_size, 0); + + // Modulus by modulus + for (size_t i = 0; i < num_moduli; i++) { + // Split by tiles for better caching + size_t i_times_n = i * n; + for (size_t tile = 0; tile < num_tiles; ++tile) { + size_t poly0_offset = i_times_n + tile_size * tile; + size_t poly1_offset = poly0_offset + poly_size; + size_t poly2_offset = poly0_offset + 2 * poly_size; + + // Compute third output polynomial + // Output written directly to result rather than temporary buffer + // result[2] = x[1] * y[1] + intel::hexl::EltwiseMultMod( + &result[poly2_offset], operand1 + poly1_offset, + operand2 + poly1_offset, tile_size, moduli[i], 1); + + // Compute second output polynomial + // result[1] = x[1] * y[0] + intel::hexl::EltwiseMultMod(temp.data(), operand1 + poly1_offset, + operand2 + poly0_offset, tile_size, moduli[i], + 1); + // result[1] = x[0] * y[1] + intel::hexl::EltwiseMultMod( + &result[poly1_offset], operand1 + poly0_offset, + operand2 + poly1_offset, tile_size, moduli[i], 1); + // result[1] += temp_poly + intel::hexl::EltwiseAddMod(&result[poly1_offset], temp.data(), + &result[poly1_offset], tile_size, moduli[i]); + + // Compute first output polynomial + // result[0] = x[0] * y[0] + intel::hexl::EltwiseMultMod( + &result[poly0_offset], operand1 + poly0_offset, + operand2 + poly0_offset, tile_size, moduli[i], 1); + } + } +} + +} // namespace internal +} // namespace hexl +} // namespace intel diff --git a/hexl_omp/experimental/seal/dyadic-multiply.cpp b/hexl_omp/experimental/seal/dyadic-multiply.cpp new file mode 100644 index 00000000..e3306530 --- /dev/null +++ b/hexl_omp/experimental/seal/dyadic-multiply.cpp @@ -0,0 +1,22 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#ifndef HEXL_FPGA_COMPATIBLE_DYADIC_MULTIPLY + +#include "hexl/experimental/seal/dyadic-multiply.hpp" + +#include "hexl/experimental/seal/dyadic-multiply-internal.hpp" + +namespace intel { +namespace hexl { + +void DyadicMultiply(uint64_t* result, const uint64_t* operand1, + const uint64_t* operand2, uint64_t n, + const uint64_t* moduli, uint64_t num_moduli) { + intel::hexl::internal::DyadicMultiply(result, operand1, operand2, n, moduli, + num_moduli); +} + +} // namespace hexl +} // namespace intel +#endif diff --git a/hexl_omp/experimental/seal/key-switch-internal.cpp b/hexl_omp/experimental/seal/key-switch-internal.cpp new file mode 100644 index 00000000..15edb9a8 --- /dev/null +++ b/hexl_omp/experimental/seal/key-switch-internal.cpp @@ -0,0 +1,205 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "hexl/experimental/seal/key-switch-internal.hpp" + +#include +#include +#include + +#include "hexl/eltwise/eltwise-add-mod.hpp" +#include "hexl/eltwise/eltwise-fma-mod.hpp" +#include "hexl/eltwise/eltwise-mult-mod.hpp" +#include "hexl/eltwise/eltwise-reduce-mod.hpp" +#include "hexl/experimental/seal/ntt-cache.hpp" +#include "hexl/logging/logging.hpp" +#include "hexl/number-theory/number-theory.hpp" +#include "hexl/util/aligned-allocator.hpp" +#include "hexl/util/check.hpp" +#include "util/cpu-features.hpp" + +namespace intel { +namespace hexl { +namespace internal { + +void KeySwitch(uint64_t* result, const uint64_t* t_target_iter_ptr, uint64_t n, + uint64_t decomp_modulus_size, uint64_t key_modulus_size, + uint64_t rns_modulus_size, uint64_t key_component_count, + const uint64_t* moduli, const uint64_t** k_switch_keys, + const uint64_t* modswitch_factors, + const uint64_t* root_of_unity_powers_ptr) { + if (root_of_unity_powers_ptr != nullptr) { + throw std::invalid_argument( + "Parameter root_of_unity_powers_ptr is not supported yet."); + } + + uint64_t coeff_count = n; + + // Create a copy of target_iter + std::vector t_target( + t_target_iter_ptr, + t_target_iter_ptr + (coeff_count * decomp_modulus_size)); + uint64_t* t_target_ptr = t_target.data(); + + // Simplified implementation, where we assume no modular reduction is required + // for intermediate additions + std::vector t_ntt(coeff_count, 0); + uint64_t* t_ntt_ptr = t_ntt.data(); + + // In CKKS t_target is in NTT form; switch + // back to normal form + for (size_t j = 0; j < decomp_modulus_size; ++j) { + GetNTT(n, moduli[j]) + .ComputeInverse(&t_target_ptr[j * coeff_count], + &t_target_ptr[j * coeff_count], 2, 1); + } + + std::vector t_poly_prod( + key_component_count * coeff_count * rns_modulus_size, 0); + + for (size_t i = 0; i < rns_modulus_size; ++i) { + size_t key_index = (i == decomp_modulus_size ? key_modulus_size - 1 : i); + + // Allocate memory for a lazy accumulator (128-bit coefficients) + std::vector t_poly_lazy(key_component_count * coeff_count * 2, 0); + uint64_t* t_poly_lazy_ptr = &t_poly_lazy[0]; + uint64_t* accumulator_ptr = &t_poly_lazy[0]; + + for (size_t j = 0; j < decomp_modulus_size; ++j) { + const uint64_t* t_operand; + // assume scheme == scheme_type::ckks + if (i == j) { + t_operand = &t_target_iter_ptr[j * coeff_count]; + } else { + // Perform RNS-NTT conversion + // No need to perform RNS conversion (modular reduction) + if (moduli[j] <= moduli[key_index]) { + for (size_t l = 0; l < coeff_count; ++l) { + t_ntt_ptr[l] = t_target_ptr[j * coeff_count + l]; + } + } else { + // Perform RNS conversion (modular reduction) + intel::hexl::EltwiseReduceMod( + t_ntt_ptr, &t_target_ptr[j * coeff_count], coeff_count, + moduli[key_index], moduli[key_index], 1); + } + + // NTT conversion lazy outputs in [0, 4q) + GetNTT(n, moduli[key_index]).ComputeForward(t_ntt_ptr, t_ntt_ptr, 4, 4); + t_operand = t_ntt_ptr; + } + + // Multiply with keys and modular accumulate products in a lazy fashion + for (size_t k = 0; k < key_component_count; ++k) { + // No reduction used; assume intermediate results don't overflow + for (size_t l = 0; l < coeff_count; ++l) { + uint64_t t_poly_idx = 2 * (k * coeff_count + l); + + uint64_t mult_op2_idx = + coeff_count * key_index + k * key_modulus_size * coeff_count + l; + + uint128_t prod = + MultiplyUInt64(t_operand[l], k_switch_keys[j][mult_op2_idx]); + + // TODO(fboemer): add uint128 + uint128_t low = t_poly_lazy_ptr[t_poly_idx]; + uint128_t hi = t_poly_lazy_ptr[t_poly_idx + 1]; + uint128_t x = (hi << 64) + low; + uint128_t sum = prod + x; + uint64_t sum_hi = static_cast(sum >> 64); + uint64_t sum_lo = static_cast(sum); + t_poly_lazy_ptr[t_poly_idx] = sum_lo; + t_poly_lazy_ptr[t_poly_idx + 1] = sum_hi; + } + } + } + + // PolyIter pointing to the destination t_poly_prod, shifted to the + // appropriate modulus + uint64_t* t_poly_prod_iter_ptr = &t_poly_prod[i * coeff_count]; + + // Final modular reduction + for (size_t k = 0; k < key_component_count; ++k) { + for (size_t l = 0; l < coeff_count; ++l) { + uint64_t accumulator_idx = 2 * coeff_count * k + 2 * l; + uint64_t poly_iter_idx = coeff_count * rns_modulus_size * k + l; + + t_poly_prod_iter_ptr[poly_iter_idx] = BarrettReduce128( + accumulator_ptr[accumulator_idx + 1], + accumulator_ptr[accumulator_idx], moduli[key_index]); + } + } + } + + uint64_t* data_array = result; + for (size_t key_component = 0; key_component < key_component_count; + ++key_component) { + uint64_t* t_poly_prod_it = + &t_poly_prod[key_component * coeff_count * rns_modulus_size]; + uint64_t* t_last = &t_poly_prod_it[decomp_modulus_size * coeff_count]; + + GetNTT(n, moduli[key_modulus_size - 1]) + .ComputeInverse(t_last, t_last, 2, 2); + + uint64_t qk = moduli[key_modulus_size - 1]; + uint64_t qk_half = qk >> 1; + + for (size_t i = 0; i < coeff_count; ++i) { + uint64_t barrett_factor = + MultiplyFactor(1, 64, moduli[key_modulus_size - 1]).BarrettFactor(); + t_last[i] = BarrettReduce64(t_last[i] + qk_half, + moduli[key_modulus_size - 1], barrett_factor); + } + + for (size_t i = 0; i < decomp_modulus_size; ++i) { + // (ct mod 4qk) mod qi + uint64_t qi = moduli[i]; + + // TODO(fboemer): Use input_mod_factor != 0 when qk / qi < 4 + // TODO(fboemer): Use output_mod_factor == 4? + uint64_t input_mod_factor = (qk > qi) ? moduli[i] : 2; + if (qk > qi) { + intel::hexl::EltwiseReduceMod(t_ntt_ptr, t_last, coeff_count, moduli[i], + input_mod_factor, 1); + } else { + for (size_t coeff_idx = 0; coeff_idx < coeff_count; ++coeff_idx) { + t_ntt_ptr[coeff_idx] = t_last[coeff_idx]; + } + } + + // Lazy subtraction, results in [0, 2*qi), since fix is in [0, qi]. + uint64_t barrett_factor = + MultiplyFactor(1, 64, moduli[i]).BarrettFactor(); + uint64_t fix = qi - BarrettReduce64(qk_half, moduli[i], barrett_factor); + for (size_t l = 0; l < coeff_count; ++l) { + t_ntt_ptr[l] += fix; + } + + uint64_t qi_lazy = qi << 1; // some multiples of qi + GetNTT(n, moduli[i]).ComputeForward(t_ntt_ptr, t_ntt_ptr, 4, 4); + // Since SEAL uses at most 60bit moduli, 8*qi < 2^63. + qi_lazy = qi << 2; + + // ((ct mod qi) - (ct mod qk)) mod qi + uint64_t* t_ith_poly = &t_poly_prod_it[i * coeff_count]; + for (size_t k = 0; k < coeff_count; ++k) { + t_ith_poly[k] = t_ith_poly[k] + qi_lazy - t_ntt[k]; + } + + // qk^(-1) * ((ct mod qi) - (ct mod qk)) mod qi + intel::hexl::EltwiseFMAMod(t_ith_poly, t_ith_poly, modswitch_factors[i], + nullptr, coeff_count, moduli[i], 8); + uint64_t data_ptr_offset = + coeff_count * (decomp_modulus_size * key_component + i); + + uint64_t* data_ptr = &data_array[data_ptr_offset]; + intel::hexl::EltwiseAddMod(data_ptr, data_ptr, t_ith_poly, coeff_count, + moduli[i]); + } + } + return; +} + +} // namespace internal +} // namespace hexl +} // namespace intel diff --git a/hexl_omp/experimental/seal/key-switch.cpp b/hexl_omp/experimental/seal/key-switch.cpp new file mode 100644 index 00000000..f006a47a --- /dev/null +++ b/hexl_omp/experimental/seal/key-switch.cpp @@ -0,0 +1,27 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#ifndef HEXL_FPGA_COMPATIBLE_KEYSWITCH + +#include "hexl/experimental/seal/key-switch.hpp" + +#include "hexl/experimental/seal/key-switch-internal.hpp" + +namespace intel { +namespace hexl { + +void KeySwitch(uint64_t* result, const uint64_t* t_target_iter_ptr, uint64_t n, + uint64_t decomp_modulus_size, uint64_t key_modulus_size, + uint64_t rns_modulus_size, uint64_t key_component_count, + const uint64_t* moduli, const uint64_t** k_switch_keys, + const uint64_t* modswitch_factors, + const uint64_t* root_of_unity_powers_ptr) { + intel::hexl::internal::KeySwitch( + result, t_target_iter_ptr, n, decomp_modulus_size, key_modulus_size, + rns_modulus_size, key_component_count, moduli, k_switch_keys, + modswitch_factors, root_of_unity_powers_ptr); +} + +} // namespace hexl +} // namespace intel +#endif diff --git a/hexl_omp/include/hexl/eltwise/eltwise-add-mod.hpp b/hexl_omp/include/hexl/eltwise/eltwise-add-mod.hpp new file mode 100644 index 00000000..cb2df110 --- /dev/null +++ b/hexl_omp/include/hexl/eltwise/eltwise-add-mod.hpp @@ -0,0 +1,40 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +namespace intel { +namespace hexl { + +/// @brief Adds two vectors elementwise with modular reduction +/// @param[out] result Stores result +/// @param[in] operand1 Vector of elements to add. Each element must be less +/// than the modulus +/// @param[in] operand2 Vector of elements to add. Each element must be less +/// than the modulus +/// @param[in] n Number of elements in each vector +/// @param[in] modulus Modulus with which to perform modular reduction. Must be +/// in the range \f$[2, 2^{63} - 1]\f$ +/// @details Computes \f$ operand1[i] = (operand1[i] + operand2[i]) \mod modulus +/// \f$ for \f$ i=0, ..., n-1\f$. +void EltwiseAddMod(uint64_t* result, const uint64_t* operand1, + const uint64_t* operand2, uint64_t n, uint64_t modulus); + +/// @brief Adds a vector and scalar elementwise with modular reduction +/// @param[out] result Stores result +/// @param[in] operand1 Vector of elements to add. Each element must be less +/// than the modulus +/// @param[in] operand2 Scalar to add. Must be less +/// than the modulus +/// @param[in] n Number of elements in each vector +/// @param[in] modulus Modulus with which to perform modular reduction. Must be +/// in the range \f$[2, 2^{63} - 1]\f$ +/// @details Computes \f$ operand1[i] = (operand1[i] + operand2) \mod modulus +/// \f$ for \f$ i=0, ..., n-1\f$. +void EltwiseAddMod(uint64_t* result, const uint64_t* operand1, + uint64_t operand2, uint64_t n, uint64_t modulus); + +} // namespace hexl +} // namespace intel diff --git a/hexl_omp/include/hexl/eltwise/eltwise-cmp-add.hpp b/hexl_omp/include/hexl/eltwise/eltwise-cmp-add.hpp new file mode 100644 index 00000000..27e514ff --- /dev/null +++ b/hexl_omp/include/hexl/eltwise/eltwise-cmp-add.hpp @@ -0,0 +1,26 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +#include "hexl/util/util.hpp" + +namespace intel { +namespace hexl { + +/// @brief Computes element-wise conditional addition. +/// @param[out] result Stores the result +/// @param[in] operand1 Vector of elements to compare; stores result +/// @param[in] n Number of elements in \p operand1 +/// @param[in] cmp Comparison operation +/// @param[in] bound Scalar to compare against +/// @param[in] diff Scalar to conditionally add +/// @details Computes result[i] = cmp(operand1[i], bound) ? operand1[i] + +/// diff : operand1[i] for all \f$i=0, ..., n-1\f$. +void EltwiseCmpAdd(uint64_t* result, const uint64_t* operand1, uint64_t n, + CMPINT cmp, uint64_t bound, uint64_t diff); + +} // namespace hexl +} // namespace intel diff --git a/hexl_omp/include/hexl/eltwise/eltwise-cmp-sub-mod.hpp b/hexl_omp/include/hexl/eltwise/eltwise-cmp-sub-mod.hpp new file mode 100644 index 00000000..07ba3d23 --- /dev/null +++ b/hexl_omp/include/hexl/eltwise/eltwise-cmp-sub-mod.hpp @@ -0,0 +1,29 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +#include "hexl/util/util.hpp" + +namespace intel { +namespace hexl { + +/// @brief Computes element-wise conditional modular subtraction. +/// @param[out] result Stores the result +/// @param[in] operand1 Vector of elements to compare +/// @param[in] n Number of elements in \p operand1 +/// @param[in] modulus Modulus to reduce by +/// @param[in] cmp Comparison function +/// @param[in] bound Scalar to compare against +/// @param[in] diff Scalar to subtract by +/// @details Computes \p operand1[i] = (\p cmp(\p operand1, \p bound)) ? (\p +/// operand1 - \p diff) mod \p modulus : \p operand1 mod \p modulus for all i=0, +/// ..., n-1 +void EltwiseCmpSubMod(uint64_t* result, const uint64_t* operand1, uint64_t n, + uint64_t modulus, CMPINT cmp, uint64_t bound, + uint64_t diff); + +} // namespace hexl +} // namespace intel diff --git a/hexl_omp/include/hexl/eltwise/eltwise-fma-mod.hpp b/hexl_omp/include/hexl/eltwise/eltwise-fma-mod.hpp new file mode 100644 index 00000000..03651a42 --- /dev/null +++ b/hexl_omp/include/hexl/eltwise/eltwise-fma-mod.hpp @@ -0,0 +1,27 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +namespace intel { +namespace hexl { + +/// @brief Computes fused multiply-add (\p arg1 * \p arg2 + \p arg3) mod \p +/// modulus element-wise, broadcasting scalars to vectors. +/// @param[out] result Stores the result +/// @param[in] arg1 Vector to multiply +/// @param[in] arg2 Scalar to multiply +/// @param[in] arg3 Vector to add. Will not add if \p arg3 == nullptr +/// @param[in] n Number of elements in each vector +/// @param[in] modulus Modulus with which to perform modular reduction. Must be +/// in the range \f$ [2, 2^{61} - 1]\f$ +/// @param[in] input_mod_factor Assumes input elements are in [0, +/// input_mod_factor * modulus). Must be 1, 2, 4, or 8. +void EltwiseFMAMod(uint64_t* result, const uint64_t* arg1, uint64_t arg2, + const uint64_t* arg3, uint64_t n, uint64_t modulus, + uint64_t input_mod_factor); + +} // namespace hexl +} // namespace intel diff --git a/hexl_omp/include/hexl/eltwise/eltwise-mult-mod.hpp b/hexl_omp/include/hexl/eltwise/eltwise-mult-mod.hpp new file mode 100644 index 00000000..e4d2dbd7 --- /dev/null +++ b/hexl_omp/include/hexl/eltwise/eltwise-mult-mod.hpp @@ -0,0 +1,28 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +namespace intel { +namespace hexl { + +/// @brief Multiplies two vectors elementwise with modular reduction +/// @param[in] result Result of element-wise multiplication +/// @param[in] operand1 Vector of elements to multiply. Each element must be +/// less than the modulus. +/// @param[in] operand2 Vector of elements to multiply. Each element must be +/// less than the modulus. +/// @param[in] n Number of elements in each vector +/// @param[in] modulus Modulus with which to perform modular reduction +/// @param[in] input_mod_factor Assumes input elements are in [0, +/// input_mod_factor * p) Must be 1, 2 or 4. +/// @details Computes \p result[i] = (\p operand1[i] * \p operand2[i]) mod \p +/// modulus for i=0, ..., \p n - 1 +void EltwiseMultMod(uint64_t* result, const uint64_t* operand1, + const uint64_t* operand2, uint64_t n, uint64_t modulus, + uint64_t input_mod_factor); + +} // namespace hexl +} // namespace intel diff --git a/hexl_omp/include/hexl/eltwise/eltwise-reduce-mod.hpp b/hexl_omp/include/hexl/eltwise/eltwise-reduce-mod.hpp new file mode 100644 index 00000000..c23abde2 --- /dev/null +++ b/hexl_omp/include/hexl/eltwise/eltwise-reduce-mod.hpp @@ -0,0 +1,29 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +namespace intel { +namespace hexl { + +/// @brief Performs elementwise modular reduction +/// @param[out] result Stores the result +/// @param[in] operand Data on which to compute the elementwise modular +/// reduction +/// @param[in] n Number of elements in operand +/// @param[in] modulus Modulus with which to perform modular reduction +/// @param[in] input_mod_factor Assumes input elements are in [0, +/// input_mod_factor * p) Must be modulus, 1, 2 or 4. input_mod_factor=modulus +/// means, input range is [0, p * p]. Barrett reduction will be used in this +/// case. input_mod_factor > output_mod_factor +/// @param[in] output_mod_factor output elements will be in [0, +/// output_mod_factor * modulus) Must be 1 or 2. For input_mod_factor=0, +/// output_mod_factor will be set to 1. +void EltwiseReduceMod(uint64_t* result, const uint64_t* operand, uint64_t n, + uint64_t modulus, uint64_t input_mod_factor, + uint64_t output_mod_factor); + +} // namespace hexl +} // namespace intel diff --git a/hexl_omp/include/hexl/eltwise/eltwise-sub-mod.hpp b/hexl_omp/include/hexl/eltwise/eltwise-sub-mod.hpp new file mode 100644 index 00000000..bd286e47 --- /dev/null +++ b/hexl_omp/include/hexl/eltwise/eltwise-sub-mod.hpp @@ -0,0 +1,40 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +namespace intel { +namespace hexl { + +/// @brief Subtracts two vectors elementwise with modular reduction +/// @param[out] result Stores result +/// @param[in] operand1 Vector of elements to subtract from. Each element must +/// be less than the modulus +/// @param[in] operand2 Vector of elements to subtract. Each element must be +/// less than the modulus +/// @param[in] n Number of elements in each vector +/// @param[in] modulus Modulus with which to perform modular reduction. Must be +/// in the range \f$[2, 2^{63} - 1]\f$ +/// @details Computes \f$ operand1[i] = (operand1[i] - operand2[i]) \mod modulus +/// \f$ for \f$ i=0, ..., n-1\f$. +void EltwiseSubMod(uint64_t* result, const uint64_t* operand1, + const uint64_t* operand2, uint64_t n, uint64_t modulus); + +/// @brief Subtracts a scalar from a vector elementwise with modular reduction +/// @param[out] result Stores result +/// @param[in] operand1 Vector of elements to subtract from. Each element must +/// be less than the modulus +/// @param[in] operand2 Elements to subtract. Each element must be +/// less than the modulus +/// @param[in] n Number of elements in each vector +/// @param[in] modulus Modulus with which to perform modular reduction. Must be +/// in the range \f$[2, 2^{63} - 1]\f$ +/// @details Computes \f$ operand1[i] = (operand1[i] - operand2) \mod modulus +/// \f$ for \f$ i=0, ..., n-1\f$. +void EltwiseSubMod(uint64_t* result, const uint64_t* operand1, + uint64_t operand2, uint64_t n, uint64_t modulus); + +} // namespace hexl +} // namespace intel diff --git a/hexl_omp/include/hexl/experimental/fft-like/fft-like-avx512-util.hpp b/hexl_omp/include/hexl/experimental/fft-like/fft-like-avx512-util.hpp new file mode 100644 index 00000000..28a2dddf --- /dev/null +++ b/hexl_omp/include/hexl/experimental/fft-like/fft-like-avx512-util.hpp @@ -0,0 +1,402 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "util/avx512-util.hpp" + +namespace intel { +namespace hexl { + +#ifdef HEXL_HAS_AVX512DQ + +// ************************************ T1 ************************************ + +// ComplexLoadFwdInterleavedT1: +// Assumes ComplexLoadFwdInterleavedT2 was used before. +// Given input: 15, 14, 11, 10, 7, 6, 3, 2, 13, 12, 9, 8, 5, 4, 1, 0 +// Returns +// *out1 = (14, 12, 10, 8, 6, 4, 2, 0); +// *out2 = (15, 13, 11, 9, 7, 5, 3, 1); +// +// Given output: 15, 13, 11, 9, 7, 5, 3, 1, 14, 12, 10, 8, 6, 4, 2, 0 +inline void ComplexLoadFwdInterleavedT1(const double* arg, __m512d* out1, + __m512d* out2) { + const __m512i vperm_idx = _mm512_set_epi64(6, 7, 4, 5, 2, 3, 0, 1); + + const __m512d* arg_512 = reinterpret_cast(arg); + + // 13 12 9 8 5 4 1 0 + __m512d v_7to0 = _mm512_loadu_pd(arg_512); + arg_512 += 2; + // 15 14 11 10 7 6 3 2 + __m512d v_15to8 = _mm512_loadu_pd(arg_512); + + // 12, 13, 8, 9, 4, 5, 0, 1 + __m512d perm_1 = _mm512_permutexvar_pd(vperm_idx, v_7to0); + // 14, 15, 10, 11, 6, 7, 2, 3 + __m512d perm_2 = _mm512_permutexvar_pd(vperm_idx, v_15to8); + + // 14, 12, 10, 8, 6, 4, 2, 0 + *out1 = _mm512_mask_blend_pd(0xaa, v_7to0, perm_2); + // 15, 13, 11, 9, 7, 5, 3, 1 + *out2 = _mm512_mask_blend_pd(0x55, v_15to8, perm_1); +} + +// ComplexWriteFwdInterleavedT1: +// Assumes ComplexLoadFwdInterleavedT1 was used before. +// Given inputs: +// 15i, 13i, 11i, 9i, 7i, 5i, 3i, 1i, 15r, 13r, 11r, 9r, 7r, 5r, 3r, 1r, +// 14i, 12i, 10i, 8i, 6i, 4i, 2i, 0i, 14r, 12r, 10r, 8r, 6r, 4r, 2r, 0r +// As seen with internal indexes: +// @param arg_yr = (15r, 14r, 13r, 12r, 11r, 10r, 9r, 8r); +// @param arg_xr = ( 7r, 6r, 5r, 4r, 3r, 2r, 1r, 0r); +// @param arg_yi = (15i, 14i, 13i, 12i, 11i, 10i, 9i, 8i); +// @param arg_xi = ( 7i, 6i, 5i, 4i, 3i, 2i, 1i, 0i); +// Writes out = +// {15i, 15r, 7i, 7r, 14i, 14r, 6i, 6r, 13i, 13r, 5i, 5r, 12i, 12r, 4i, 4r, +// 11i, 11r, 3i, 3r, 10i, 10r, 2i, 2r, 9i, 9r, 1i, 1r, 8i, 8r, 0i, 0r} +// +// Given output: +// 15i, 15r, 14i, 14r, 13i, 13r, 12i, 12r, 11i, 11r, 10i, 10r, 9i, 9r, 8i, 8r, +// 7i, 7r, 6i, 6r, 5i, 5r, 4i, 4r, 3i, 3r, 2i, 2r, 1i, 1r, 0i, 0r +inline void ComplexWriteFwdInterleavedT1(__m512d arg_xr, __m512d arg_yr, + __m512d arg_xi, __m512d arg_yi, + __m512d* out) { + const __m512i vperm_4hi_4lo_idx = _mm512_set_epi64(3, 2, 1, 0, 7, 6, 5, 4); + const __m512i v_X_out_idx = _mm512_set_epi64(3, 1, 7, 5, 2, 0, 6, 4); + const __m512i v_Y_out_idx = _mm512_set_epi64(7, 5, 3, 1, 6, 4, 2, 0); + + // Real part + // in: 14r, 12r, 10r, 8r, 6r, 4r, 2r, 0r + // -> 6r, 4r, 2r, 0r, 14r, 12r, 10r, 8r + arg_xr = _mm512_permutexvar_pd(vperm_4hi_4lo_idx, arg_xr); + + // arg_yr: 15r, 13r, 11r, 9r, 7r, 5r, 3r, 1r + // -> 6r, 4r, 2r, 0r, 7r, 5r, 3r, 1r + __m512d perm_1 = _mm512_mask_blend_pd(0x0f, arg_xr, arg_yr); + // -> 15r, 13r, 11r, 9r, 14r, 12r, 10r, 8r + __m512d perm_2 = _mm512_mask_blend_pd(0xf0, arg_xr, arg_yr); + + // 7r, 3r, 6r, 2r, 5r, 1r, 4r, 0r + arg_xr = _mm512_permutexvar_pd(v_X_out_idx, perm_1); + // 15r, 11r, 14r, 10r, 13r, 9r, 12r, 8r + arg_yr = _mm512_permutexvar_pd(v_Y_out_idx, perm_2); + + // Imaginary part + // in: 14i, 12i, 10i, 8i, 6i, 4i, 2i, 0i + // -> 6i, 4i, 2i, 0i, 14i, 12i, 10i, 8i + arg_xi = _mm512_permutexvar_pd(vperm_4hi_4lo_idx, arg_xi); + + // arg_yr: 15i, 13i, 11i, 9i, 7i, 5i, 3i, 1i + // -> 6i, 4i, 2i, 0i, 7i, 5i, 3i, 1i + perm_1 = _mm512_mask_blend_pd(0x0f, arg_xi, arg_yi); + // -> 15i, 13i, 11i, 9i, 14i, 12i, 10i, 8i + perm_2 = _mm512_mask_blend_pd(0xf0, arg_xi, arg_yi); + + // 7i, 3i, 6i, 2i, 5i, 1i, 4i, 0i + arg_xi = _mm512_permutexvar_pd(v_X_out_idx, perm_1); + // 15i, 11i, 14i, 10i, 13i, 9i, 12i, 8i + arg_yi = _mm512_permutexvar_pd(v_Y_out_idx, perm_2); + + // Merge + // 00000000 > 3i 3r 2i 2r 1i 1r 0i 0r + __m512d out1 = _mm512_shuffle_pd(arg_xr, arg_xi, 0x00); + // 11111111 > 7i 7r 6i 6r 5i 5r 4i 4r + __m512d out2 = _mm512_shuffle_pd(arg_xr, arg_xi, 0xff); + + // 00000000 > 11i 11r 10i 10r 9i 9r 8i 8r + __m512d out3 = _mm512_shuffle_pd(arg_yr, arg_yi, 0x00); + // 11111111 > 15i 15r 14i 14r 13i 13r 12i 12r + __m512d out4 = _mm512_shuffle_pd(arg_yr, arg_yi, 0xff); + + _mm512_storeu_pd(out++, out1); + _mm512_storeu_pd(out++, out2); + _mm512_storeu_pd(out++, out3); + _mm512_storeu_pd(out++, out4); +} + +// ComplexLoadInvInterleavedT1: +// Given input: 15i 15r 14i 14r 13i 13r 12i 12r 11i 11r 10i 10r 9i 9r 8i 8r +// 7i 7r 6i 6r 5i 5r 4i 4r 3i 3r 2i 2r 1i 1r 0i 0r +// Returns +// *out1_r = (14r, 10r, 6r, 2r, 12r, 8r, 4r, 0r); +// *out1_i = (14i, 10i, 6i, 2i, 12i, 8i, 4i, 0i); +// *out2_r = (15r, 11r, 7r, 3r, 13r, 9r, 5r, 1r); +// *out2_i = (15i, 11i, 7i, 3i, 13i, 9i, 5i, 1i); +// +// Given output: +// 15i, 11i, 7i, 3i, 13i, 9i, 5i, 1i, 15r, 11r, 7r, 3r, 13r, 9r, 5r, 1r, +// 14i, 10i, 6i, 2i, 12i, 8i, 4i, 0i, 14r, 10r, 6r, 2r, 12r, 8r, 4r, 0r +inline void ComplexLoadInvInterleavedT1(const double* arg, __m512d* out1_r, + __m512d* out1_i, __m512d* out2_r, + __m512d* out2_i) { + const __m512d* arg_512 = reinterpret_cast(arg); + + // 3i 3r 2i 2r 1i 1r 0i 0r + __m512d v_3to0 = _mm512_loadu_pd(arg_512++); + // 7i 7r 6i 6r 5i 5r 4i 4r + __m512d v_7to4 = _mm512_loadu_pd(arg_512++); + // 11i 11r 10i 10r 9i 9r 8i 8r + __m512d v_11to8 = _mm512_loadu_pd(arg_512++); + // 15i 15r 14i 14r 13i 13r 12i 12r + __m512d v_15to12 = _mm512_loadu_pd(arg_512++); + + // 00000000 > 7r 3r 6r 2r 5r 1r 4r 0r + __m512d v_7to0_r = _mm512_shuffle_pd(v_3to0, v_7to4, 0x00); + // 11111111 > 7i 3i 6i 2i 5i 1i 4i 0i + __m512d v_7to0_i = _mm512_shuffle_pd(v_3to0, v_7to4, 0xff); + // 00000000 > 15r 11r 14r 10r 13r 9r 12r 8r + __m512d v_15to8_r = _mm512_shuffle_pd(v_11to8, v_15to12, 0x00); + // 11111111 > 15i 11i 14i 10i 13i 9i 12i 8i + __m512d v_15to8_i = _mm512_shuffle_pd(v_11to8, v_15to12, 0xff); + + // real + const __m512i v1_perm_idx = _mm512_set_epi64(5, 4, 7, 6, 1, 0, 3, 2); + // 6 2 7 3 4 0 5 1 + __m512d v1r = _mm512_permutexvar_pd(v1_perm_idx, v_7to0_r); + // 14 10 15 11 12 8 13 9 + __m512d v2r = _mm512_permutexvar_pd(v1_perm_idx, v_15to8_r); + // 11001100 > 14 10 6 2 12 8 4 0 + *out1_r = _mm512_mask_blend_pd(0xcc, v_7to0_r, v2r); + // 11001100 > 15 11 7 3 13 9 5 1 + *out2_r = _mm512_mask_blend_pd(0xcc, v1r, v_15to8_r); + + // imag + // 6 2 7 3 4 0 5 1 + __m512d v1i = _mm512_permutexvar_pd(v1_perm_idx, v_7to0_i); + // 14 10 15 11 12 8 13 9 + __m512d v2i = _mm512_permutexvar_pd(v1_perm_idx, v_15to8_i); + // 11001100 > 14 10 6 2 12 8 4 0 + *out1_i = _mm512_mask_blend_pd(0xcc, v_7to0_i, v2i); + // 11001100 > 15 11 7 3 13 9 5 1 + *out2_i = _mm512_mask_blend_pd(0xcc, v1i, v_15to8_i); +} + +// ************************************ T2 ************************************ + +// ComplexLoadFwdInterleavedT2: +// Assumes ComplexLoadFwdInterleavedT4 was used before. +// Given input: 15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0 +// Returns +// *out1 = (13, 12, 9, 8, 5, 4, 1, 0) +// *out2 = (15, 14, 11, 10, 7, 6, 3, 2) +// +// Given output: 15, 14, 11, 10, 7, 6, 3, 2, 13, 12, 9, 8, 5, 4, 1, 0 +inline void ComplexLoadFwdInterleavedT2(const double* arg, __m512d* out1, + __m512d* out2) { + const __m512d* arg_512 = reinterpret_cast(arg); + + // Values were swapped in T4 + // 11, 10, 9, 8, 3, 2, 1, 0 + __m512d v1 = _mm512_loadu_pd(arg_512); + arg_512 += 2; + // 15, 14, 13, 12, 7, 6, 5, 4 + __m512d v2 = _mm512_loadu_pd(arg_512); + + const __m512i v1_perm_idx = _mm512_set_epi64(5, 4, 7, 6, 1, 0, 3, 2); + + __m512d v1_perm = _mm512_permutexvar_pd(v1_perm_idx, v1); + __m512d v2_perm = _mm512_permutexvar_pd(v1_perm_idx, v2); + + *out1 = _mm512_mask_blend_pd(0xcc, v1, v2_perm); + *out2 = _mm512_mask_blend_pd(0xcc, v1_perm, v2); +} + +// ComplexLoadInvInterleavedT2: +// Assumes ComplexLoadInvInterleavedT1 was used before. +// Given input: 15, 11, 7, 3, 13, 9, 5, 1, 14, 10, 6, 2, 12, 8, 4, 0 +// Returns +// *out1 = (13, 9, 5, 1, 12, 8, 4, 0) +// *out2 = (15, 11, 7, 3, 14, 10, 6, 2) +// +// Given output: 15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0 +inline void ComplexLoadInvInterleavedT2(const double* arg, __m512d* out1, + __m512d* out2) { + const __m512d* arg_512 = reinterpret_cast(arg); + + // 14 10 6 2 12 8 4 0 + __m512d v1 = _mm512_loadu_pd(arg_512); + arg_512 += 2; + // 15 11 7 3 13 9 5 1 + __m512d v2 = _mm512_loadu_pd(arg_512); + + const __m512i v1_perm_idx = _mm512_set_epi64(3, 2, 1, 0, 7, 6, 5, 4); + + // 12 8 4 0 14 10 6 2 + __m512d v1_perm = _mm512_permutexvar_pd(v1_perm_idx, v1); + // 13 9 5 1 15 11 7 3 + __m512d v2_perm = _mm512_permutexvar_pd(v1_perm_idx, v2); + + // 11110000 > 13 9 5 1 12 8 4 0 + *out1 = _mm512_mask_blend_pd(0xf0, v1, v2_perm); + // 11110000 > 15 11 7 3 14 10 6 2 + *out2 = _mm512_mask_blend_pd(0xf0, v1_perm, v2); +} + +// ************************************ T4 ************************************ + +// Complex LoadFwdInterleavedT4: +// Given input: 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0 +// Returns +// *out1 = (11, 10, 9, 8, 3, 2, 1, 0) +// *out2 = (15, 14, 13, 12, 7, 6, 5, 4) +// +// Given output: 15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0 +inline void ComplexLoadFwdInterleavedT4(const double* arg, __m512d* out1, + __m512d* out2) { + const __m512d* arg_512 = reinterpret_cast(arg); + + const __m512i vperm2_idx = _mm512_set_epi64(3, 2, 1, 0, 7, 6, 5, 4); + __m512d v_7to0 = _mm512_loadu_pd(arg_512); + arg_512 += 2; + __m512d v_15to8 = _mm512_loadu_pd(arg_512); + __m512d perm_hi = _mm512_permutexvar_pd(vperm2_idx, v_15to8); + *out1 = _mm512_mask_blend_pd(0x0f, perm_hi, v_7to0); + *out2 = _mm512_mask_blend_pd(0xf0, perm_hi, v_7to0); + *out2 = _mm512_permutexvar_pd(vperm2_idx, *out2); +} + +// ComplexLoadInvInterleavedT4: +// Assumes ComplexLoadInvInterleavedT2 was used before. +// Given input: 15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0 +// Returns +// *out1 = (11, 9, 3, 1, 10, 8, 2, 0) +// *out2 = (15, 13, 7, 5, 14, 12, 6, 4) +// +// Given output: 15, 13, 7, 5, 14, 12, 6, 4, 11, 9, 3, 1, 10, 8, 2, 0 + +inline void ComplexLoadInvInterleavedT4(const double* arg, __m512d* out1, + __m512d* out2) { + const __m512d* arg_512 = reinterpret_cast(arg); + + // 13, 9, 5, 1, 12, 8, 4, 0 + __m512d v1 = _mm512_loadu_pd(arg_512); + arg_512 += 2; + // 15, 11, 7, 3, 14, 10, 6, 2 + __m512d v2 = _mm512_loadu_pd(arg_512); + + // 00000000 > 11 9 3 1 10 8 2 0 + *out1 = _mm512_shuffle_pd(v1, v2, 0x00); + // 11111111 > 15 13 7 5 14 12 6 4 + *out2 = _mm512_shuffle_pd(v1, v2, 0xff); +} + +// ComplexWriteInvInterleavedT4: +// Assuming ComplexLoadInvInterleavedT4 was used before. +// Given inputs: 15, 13, 7, 5, 14, 12, 6, 4, 11, 9, 3, 1, 10, 8, 2, 0 +// Seen Internally: +// @param arg1 = ( 7, 6, 5, 4, 3, 2, 1, 0); +// @param arg2 = (15, 14, 13, 12, 11, 10, 9, 8); +// Writes out = {15, 11, 14, 10, 7, 3, 6, 2, +// 13, 9, 12, 8, 5, 1, 4, 0} +// +// Given output: 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0 +inline void ComplexWriteInvInterleavedT4(__m512d arg1, __m512d arg2, + __m512d* out) { + const __m512i vperm_4hi_4lo_idx = _mm512_set_epi64(3, 2, 1, 0, 7, 6, 5, 4); + const __m512i vperm1 = _mm512_set_epi64(7, 3, 6, 2, 5, 1, 4, 0); + const __m512i vperm2 = _mm512_set_epi64(5, 1, 4, 0, 7, 3, 6, 2); + + // in: 11 9 3 1 10 8 2 0 + // -> 11 10 9 8 3 2 1 0 + arg1 = _mm512_permutexvar_pd(vperm1, arg1); + // in: 15 13 7 5 14 12 6 4 + // -> 7 6 5 4 15 14 13 12 + arg2 = _mm512_permutexvar_pd(vperm2, arg2); + + // 7 6 5 4 3 2 1 0 + __m512d out1 = _mm512_mask_blend_pd(0xf0, arg1, arg2); + // 11 10 9 8 15 14 13 12 + __m512d out2 = _mm512_mask_blend_pd(0x0f, arg1, arg2); + // 15 14 13 12 11 10 9 8 + out2 = _mm512_permutexvar_pd(vperm_4hi_4lo_idx, out2); + + _mm512_storeu_pd(out, out1); + out += 2; + _mm512_storeu_pd(out, out2); +} + +// ************************************ T8 ************************************ + +// ComplexLoadFwdInterleavedT8: +// Given inputs: 7i, 7r, 6i, 6r, 5i, 5r, 4i, 4r, 3i, 3r, 2i, 2r, 1i, 1r, 0i, 0r +// Seen Internally: +// v_X1 = ( 7, 6, 5, 4, 3, 2, 1, 0); +// v_X2 = (15, 14, 13, 12, 11, 10, 9, 8); +// Writes out = {15, 13, 11, 9, 7, 5, 3, 1, +// 14, 12, 10, 8, 6, 4, 2, 0} +// +// Given output: 7i, 6i, 5i, 4i, 3i, 2i, 1i, 0i, 7r, 6r, 5r, 4r, 3r, 2r, 1r, 0r +inline void ComplexLoadFwdInterleavedT8(const __m512d* arg_x, + const __m512d* arg_y, __m512d* out1_r, + __m512d* out1_i, __m512d* out2_r, + __m512d* out2_i) { + const __m512i v_perm_idx = _mm512_set_epi64(7, 5, 3, 1, 6, 4, 2, 0); + + // 3i, 3r, 2i, 2r, 1i, 1r, 0i, 0r + __m512d v_X1 = _mm512_loadu_pd(arg_x++); + // 7i, 7r, 6i, 6r, 5i, 5r, 4i, 4r + __m512d v_X2 = _mm512_loadu_pd(arg_x); + // 7r, 3r, 6r, 2r, 5r, 1r, 4r, 0r + *out1_r = _mm512_shuffle_pd(v_X1, v_X2, 0x00); + // 7i, 3i, 6i, 2i, 5i, 1i, 4i, 0i + *out1_i = _mm512_shuffle_pd(v_X1, v_X2, 0xff); + // 7r, 6r, 5r, 4r, 3r, 2r, 1r, 0r + *out1_r = _mm512_permutexvar_pd(v_perm_idx, *out1_r); + // 7i, 6i, 5i, 4i, 3i, 2i, 1i, 0i + *out1_i = _mm512_permutexvar_pd(v_perm_idx, *out1_i); + + __m512d v_Y1 = _mm512_loadu_pd(arg_y++); + __m512d v_Y2 = _mm512_loadu_pd(arg_y); + *out2_r = _mm512_shuffle_pd(v_Y1, v_Y2, 0x00); + *out2_i = _mm512_shuffle_pd(v_Y1, v_Y2, 0xff); + *out2_r = _mm512_permutexvar_pd(v_perm_idx, *out2_r); + *out2_i = _mm512_permutexvar_pd(v_perm_idx, *out2_i); +} + +// ComplexWriteInvInterleavedT8: +// Assuming ComplexLoadInvInterleavedT4 was used before. +// Given inputs: 7i, 6i, 5i, 4i, 3i, 2i, 1i, 0i, 7r, 6r, 5r, 4r, 3r, 2r, 1r, 0r +// Seen Internally: +// @param arg1 = ( 7, 6, 5, 4, 3, 2, 1, 0); +// @param arg2 = (15, 14, 13, 12, 11, 10, 9, 8); +// Writes out = {15, 7, 14, 6, 13, 5, 12, 4, +// 11, 3, 10, 2, 9, 1, 8, 0} +// +// Given output: 7i, 7r, 6i, 6r, 5i, 5r, 4i, 4r, 3i, 3r, 2i, 2r, 1i, 1r, 0i, 0r +inline void ComplexWriteInvInterleavedT8(__m512d* v_X_real, __m512d* v_X_imag, + __m512d* v_Y_real, __m512d* v_Y_imag, + __m512d* v_X_pt, __m512d* v_Y_pt) { + const __m512i vperm = _mm512_set_epi64(7, 3, 6, 2, 5, 1, 4, 0); + // in: 7r 6r 5r 4r 3r 2r 1r 0r + // -> 7r 3r 6r 2r 5r 1r 4r 0r + *v_X_real = _mm512_permutexvar_pd(vperm, *v_X_real); + // in: 7i 6i 5i 4i 3i 2i 1i 0i + // -> 7i 3i 6i 2i 5i 1i 4i 0i + *v_X_imag = _mm512_permutexvar_pd(vperm, *v_X_imag); + // in: 15r 14r 13r 12r 11r 10r 9r 8r + // -> 15r 11r 14r 10r 13r 9r 12r 8r + *v_Y_real = _mm512_permutexvar_pd(vperm, *v_Y_real); + // in: 15i 14i 13i 12i 11i 10i 9i 8i + // -> 15i 11i 14i 10i 13i 9i 12i 8i + *v_Y_imag = _mm512_permutexvar_pd(vperm, *v_Y_imag); + + // 00000000 > 3i 3r 2i 2r 1i 1r 0i 0r + __m512d v_X1 = _mm512_shuffle_pd(*v_X_real, *v_X_imag, 0x00); + // 11111111 > 7i 7r 6i 6r 5i 5r 4i 4r + __m512d v_X2 = _mm512_shuffle_pd(*v_X_real, *v_X_imag, 0xff); + // 00000000 > 11i 11r 10i 10r 9i 9r 8i 8r + __m512d v_Y1 = _mm512_shuffle_pd(*v_Y_real, *v_Y_imag, 0x00); + // 11111111 > 15i 15r 14i 14r 13i 13r 12i 12r + __m512d v_Y2 = _mm512_shuffle_pd(*v_Y_real, *v_Y_imag, 0xff); + + _mm512_storeu_pd(v_X_pt++, v_X1); + _mm512_storeu_pd(v_X_pt, v_X2); + _mm512_storeu_pd(v_Y_pt++, v_Y1); + _mm512_storeu_pd(v_Y_pt, v_Y2); +} +#endif // HEXL_HAS_AVX512DQ + +} // namespace hexl +} // namespace intel diff --git a/hexl_omp/include/hexl/experimental/fft-like/fft-like-native.hpp b/hexl_omp/include/hexl/experimental/fft-like/fft-like-native.hpp new file mode 100644 index 00000000..7e02492d --- /dev/null +++ b/hexl_omp/include/hexl/experimental/fft-like/fft-like-native.hpp @@ -0,0 +1,38 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +namespace intel { +namespace hexl { + +/// @brief Radix-2 native C++ FFT like implementation of the forward FFT like +/// @param[out] result Output data. Overwritten with FFT like output +/// @param[in] operand Input data. +/// @param[in] n Size of the transform, i.e. the polynomial degree. Must be a +/// power of two. +/// @param[in] root_of_unity_powers Powers of 2n'th root of unity. In +/// bit-reversed order +/// @param[in] scale Scale applied to output data +void Forward_FFTLike_ToBitReverseRadix2( + std::complex* result, const std::complex* operand, + const std::complex* root_of_unity_powers, const uint64_t n, + const double* scale = nullptr); + +/// @brief Radix-2 native C++ FFT like implementation of the inverse FFT like +/// @param[out] result Output data. Overwritten with FFT like output +/// @param[in] operand Input data. +/// @param[in] n Size of the transform, i.e. the polynomial degree. Must be a +/// power of two. +/// @param[in] inv_root_of_unity_powers Powers of inverse 2n'th root of unity. +/// In bit-reversed order. +/// @param[in] scale Scale applied to output data +void Inverse_FFTLike_FromBitReverseRadix2( + std::complex* result, const std::complex* operand, + const std::complex* inv_root_of_unity_powers, const uint64_t n, + const double* scale = nullptr); + +} // namespace hexl +} // namespace intel diff --git a/hexl_omp/include/hexl/experimental/fft-like/fft-like.hpp b/hexl_omp/include/hexl/experimental/fft-like/fft-like.hpp new file mode 100644 index 00000000..334de246 --- /dev/null +++ b/hexl_omp/include/hexl/experimental/fft-like/fft-like.hpp @@ -0,0 +1,147 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +#include "hexl/experimental/fft-like/fwd-fft-like-avx512.hpp" +#include "hexl/experimental/fft-like/inv-fft-like-avx512.hpp" +#include "hexl/util/aligned-allocator.hpp" +#include "hexl/util/allocator.hpp" + +namespace intel { +namespace hexl { + +/// @brief Performs linear forward and inverse FFT like transform +/// for CKKS encoding and decoding. +class FFTLike { + public: + /// @brief Helper class for custom memory allocation + template + struct AllocatorAdapter + : public AllocatorInterface> { + explicit AllocatorAdapter(Adaptee&& _a, Args&&... args); + AllocatorAdapter(const Adaptee& _a, Args&... args); + + // interface implementation + void* allocate_impl(size_t bytes_count); + void deallocate_impl(void* p, size_t n); + + private: + Adaptee alloc; + }; + + /// @brief Initializes an empty CKKS_FTT object + FFTLike() = default; + + /// @brief Destructs the CKKS_FTT object + ~FFTLike() = default; + + /// @brief Initializes an FFTLike object with degree \p degree and scalar + /// \p in_scalar. + /// @param[in] degree also known as N. Size of the FFT like transform. Must be + /// a power of 2 + /// @param[in] in_scalar Scalar value to calculate scale and inv scale + /// @param[in] alloc_ptr Custom memory allocator used for intermediate + /// calculations + /// @details Performs pre-computation necessary for forward and inverse + /// transforms + FFTLike(uint64_t degree, double* in_scalar, + std::shared_ptr alloc_ptr = {}); + + template + FFTLike(uint64_t degree, double* in_scalar, Allocator&& a, + AllocatorArgs&&... args) + : FFTLike( + degree, in_scalar, + std::static_pointer_cast( + std::make_shared>( + std::move(a), std::forward(args)...))) {} + + /// @brief Compute forward FFT like. Results are bit-reversed. + /// @param[out] result Stores the result + /// @param[in] operand Data on which to compute the FFT like + /// @param[in] in_scale Scale applied to output values + void ComputeForwardFFTLike(std::complex* result, + const std::complex* operand, + const double* in_scale = nullptr); + + /// @brief Compute inverse FFT like. Results are bit-reversed. + /// @param[out] result Stores the result + /// @param[in] operand Data on which to compute the FFT like + /// @param[in] in_scale Scale applied to output values + void ComputeInverseFFTLike(std::complex* result, + const std::complex* operand, + const double* in_scale = nullptr); + + /// @brief Construct floating-point values from CRT-composed polynomial with + /// integer coefficients. + /// @param[out] res Stores the result + /// @param[in] plain Plaintext + /// @param[in] threshold Upper half threshold with respect to the total + /// coefficient modulus + /// @param[in] decryption_modulus Product of all primes in the coefficient + /// modulus + /// @param[in] inv_scale Scale applied to output values + /// @param[in] mod_size Size of coefficient modulus parameter + /// @param[in] coeff_count Degree of the polynomial modulus parameter + void BuildFloatingPoints(std::complex* res, const uint64_t* plain, + const uint64_t* threshold, + const uint64_t* decryption_modulus, + const double inv_scale, size_t mod_size, + size_t coeff_count); + + /// @brief Returns the root of unity power at bit-reversed index i. + /// @param[in] i Index + std::complex GetComplexRootOfUnity(size_t i) { + return GetComplexRootsOfUnity()[i]; + } + + /// @brief Returns the root of unity in bit-reversed order + const AlignedVector64>& GetComplexRootsOfUnity() const { + return m_complex_roots_of_unity; + } + + /// @brief Returns the root of unity power at bit-reversed index i. + /// @param[in] i Index + std::complex GetInvComplexRootOfUnity(size_t i) { + return GetInvComplexRootsOfUnity()[i]; + } + + /// @brief Returns the inverse root of unity in bit-reversed order + const AlignedVector64>& GetInvComplexRootsOfUnity() + const { + return m_inv_complex_roots_of_unity; + } + + /// @brief Returns the degree N + uint64_t GetDegree() const { return m_degree; } + + private: + // Computes 1~(n-1)-th powers and inv powers of the primitive 2n-th root + void ComputeComplexRootsOfUnity(); + + uint64_t m_degree; // N: size of FFT like transform, should be power of 2 + + double* scalar; // Pointer to scalar used for scale/inv_scale calculation + + double scale; // Scale value use for encoding (inv fft-like) + + double inv_scale; // Scale value use in decoding (fwd fft-like) + + std::shared_ptr m_alloc; + + AlignedAllocator m_aligned_alloc; + + uint64_t m_degree_bits; // log_2(m_degree) + + // Contains 0~(n-1)-th powers of the 2n-th primitive root. + AlignedVector64> m_complex_roots_of_unity; + + // Contains 0~(n-1)-th inv powers of the 2n-th primitive inv root. + AlignedVector64> m_inv_complex_roots_of_unity; +}; + +} // namespace hexl +} // namespace intel diff --git a/hexl_omp/include/hexl/experimental/fft-like/fwd-fft-like-avx512.hpp b/hexl_omp/include/hexl/experimental/fft-like/fwd-fft-like-avx512.hpp new file mode 100644 index 00000000..aba4ca4d --- /dev/null +++ b/hexl_omp/include/hexl/experimental/fft-like/fwd-fft-like-avx512.hpp @@ -0,0 +1,59 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "hexl/experimental/fft-like/fft-like.hpp" + +namespace intel { +namespace hexl { + +#ifdef HEXL_HAS_AVX512DQ + +/// @brief AVX512 implementation of the forward FFT like +/// @param[out] result_cmplx_intrlvd Output data. Overwritten with FFT like +/// output. Result is a vector of double with interleaved real and imaginary +/// numbers. +/// @param[in] operand_cmplx_intrlvd Input data. A vector of double with +/// interleaved real and imaginary numbers. +/// @param[in] roots_of_unity_cmplx_intrlvd Powers of 2n'th root of unity. In +/// bit-reversed order. +/// @param[in] n Size of the transform, i.e. the polynomial degree. Must be a +/// power of two. +/// @param[in] scale Scale applied to output values +/// @param[in] recursion_depth Depth of recursive call +/// @param[in] recursion_half Helper for indexing roots of unity +/// @details The implementation is recursive. The base case is a breadth-first +/// FFT like, where all the butterflies in a given stage are processed before +/// any butterflies in the next stage. The base case is small enough to fit in +/// the smallest cache. Larger FFTs are processed recursively in a depth-first +/// manner, such that an entire subtransform is completed before moving to the +/// next subtransform. This reduces the number of cache misses, improving +/// performance on larger transform sizes. +void Forward_FFTLike_ToBitReverseAVX512( + double* result_cmplx_intrlvd, const double* operand_cmplx_intrlvd, + const double* roots_of_unity_cmplx_intrlvd, const uint64_t n, + const double* scale = nullptr, uint64_t recursion_depth = 0, + uint64_t recursion_half = 0); + +/// @brief Construct floating-point values from CRT-composed polynomial with +/// integer coefficients in AVX512. +/// @param[out] res_cmplx_intrlvd Stores the result +/// @param[in] plain Plaintext +/// @param[in] threshold Upper half threshold with respect to the total +/// coefficient modulus +/// @param[in] decryption_modulus Product of all primes in the coefficient +/// modulus +/// @param[in] inv_scale Scale applied to output values +/// @param[in] mod_size Size of coefficient modulus parameter +/// @param[in] coeff_count Degree of the polynomial modulus parameter +void BuildFloatingPointsAVX512(double* res_cmplx_intrlvd, const uint64_t* plain, + const uint64_t* threshold, + const uint64_t* decryption_modulus, + const double inv_scale, const size_t mod_size, + const size_t coeff_count); + +#endif // HEXL_HAS_AVX512DQ + +} // namespace hexl +} // namespace intel diff --git a/hexl_omp/include/hexl/experimental/fft-like/inv-fft-like-avx512.hpp b/hexl_omp/include/hexl/experimental/fft-like/inv-fft-like-avx512.hpp new file mode 100644 index 00000000..487e2828 --- /dev/null +++ b/hexl_omp/include/hexl/experimental/fft-like/inv-fft-like-avx512.hpp @@ -0,0 +1,42 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "hexl/experimental/fft-like/fft-like.hpp" + +namespace intel { +namespace hexl { + +#ifdef HEXL_HAS_AVX512DQ + +/// @brief AVX512 implementation of the inverse FFT like +/// @param[out] result_cmplx_intrlvd Output data. Overwritten with FFT like +/// output. Result is a vector of double with interleaved real and imaginary +/// numbers. +/// @param[in] operand_cmplx_intrlvd Input data. A vector of double with +/// interleaved real and imaginary numbers. +/// @param[in] inv_roots_of_unity_cmplx_intrlvd Powers of 2n'th root of unity. +/// In bit-reversed order. +/// @param[in] n Size of the transform, i.e. the polynomial degree. Must be a +/// power of two. +/// @param[in] scale Scale applied to output values +/// @param[in] recursion_depth Depth of recursive call +/// @param[in] recursion_half Helper for indexing roots of unity +/// @details The implementation is recursive. The base case is a breadth-first +/// FFT like, where all the butterflies in a given stage are processed before +/// any butterflies in the next stage. The base case is small enough to fit in +/// the smallest cache. Larger FFTs are processed recursively in a depth-first +/// manner, such that an entire subtransform is completed before moving to the +/// next subtransform. This reduces the number of cache misses, improving +/// performance on larger transform sizes. +void Inverse_FFTLike_FromBitReverseAVX512( + double* result_cmplx_intrlvd, const double* operand_cmplx_intrlvd, + const double* inv_root_of_unity_cmplxintrlvd, const uint64_t n, + const double* scale = nullptr, uint64_t recursion_depth = 0, + uint64_t recursion_half = 0); + +#endif // HEXL_HAS_AVX512DQ + +} // namespace hexl +} // namespace intel diff --git a/hexl_omp/include/hexl/experimental/misc/lr-mat-vec-mult.hpp b/hexl_omp/include/hexl/experimental/misc/lr-mat-vec-mult.hpp new file mode 100644 index 00000000..df03df92 --- /dev/null +++ b/hexl_omp/include/hexl/experimental/misc/lr-mat-vec-mult.hpp @@ -0,0 +1,31 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +namespace intel { +namespace hexl { + +/// @brief Computes transposed linear regression +/// @param[in,out] result Ciphertext data. Will be over-written with result. Has +/// (3 * n * num_moduli) elements +/// @param[in] operand1 Vector of ciphertext representing a matrix that encodes +/// a transposed logistic regression model. Has (num_weights * 2 * n * +/// num_moduli) elements. +/// @param[in] operand2 Vector of ciphertext representing a matrix that encodes +/// at most n/2 input samples with feature size num_weights. Has (num_weights * +/// 2 * n * num_moduli) elements. +/// @param[in] n Number of coefficients in each polynomial +/// @param[in] moduli Pointer to contiguous array of num_moduli word-sized +/// coefficient moduli +/// @param[in] num_moduli Number of word-sized coefficient moduli +/// @param[in] num_weights Feature size of the linear/logistic regression model +void LinRegMatrixVectorMultiply(uint64_t* result, const uint64_t* operand1, + const uint64_t* operand2, uint64_t n, + const uint64_t* moduli, uint64_t num_moduli, + uint64_t num_weights); + +} // namespace hexl +} // namespace intel diff --git a/hexl_omp/include/hexl/experimental/seal/dyadic-multiply-internal.hpp b/hexl_omp/include/hexl/experimental/seal/dyadic-multiply-internal.hpp new file mode 100644 index 00000000..310a46b0 --- /dev/null +++ b/hexl_omp/include/hexl/experimental/seal/dyadic-multiply-internal.hpp @@ -0,0 +1,29 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +namespace intel { +namespace hexl { +namespace internal { + +/// @brief Computes dyadic multiplication +/// @param[in,out] result Ciphertext data. Will be over-written with result. Has +/// (2 * n * num_moduli) elements +/// @param[in] operand1 First ciphertext argument. Has (2 * n * num_moduli) +/// elements. +/// @param[in] operand2 Second ciphertext argument. Has (2 * n * num_moduli) +/// elements. +/// @param[in] n Number of coefficients in each polynomial +/// @param[in] moduli Pointer to contiguous array of num_moduli word-sized +/// coefficient moduli +/// @param[in] num_moduli Number of word-sized coefficient moduli +void DyadicMultiply(uint64_t* result, const uint64_t* operand1, + const uint64_t* operand2, uint64_t n, + const uint64_t* moduli, uint64_t num_moduli); + +} // namespace internal +} // namespace hexl +} // namespace intel diff --git a/hexl_omp/include/hexl/experimental/seal/dyadic-multiply.hpp b/hexl_omp/include/hexl/experimental/seal/dyadic-multiply.hpp new file mode 100644 index 00000000..f7eacfdf --- /dev/null +++ b/hexl_omp/include/hexl/experimental/seal/dyadic-multiply.hpp @@ -0,0 +1,27 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +namespace intel { +namespace hexl { + +/// @brief Computes dyadic multiplication +/// @param[in,out] result Ciphertext data. Will be over-written with result. Has +/// (2 * n * num_moduli) elements +/// @param[in] operand1 First ciphertext argument. Has (2 * n * num_moduli) +/// elements. +/// @param[in] operand2 Second ciphertext argument. Has (2 * n * num_moduli) +/// elements. +/// @param[in] n Number of coefficients in each polynomial +/// @param[in] moduli Pointer to contiguous array of num_moduli word-sized +/// coefficient moduli +/// @param[in] num_moduli Number of word-sized coefficient moduli +void DyadicMultiply(uint64_t* result, const uint64_t* operand1, + const uint64_t* operand2, uint64_t n, + const uint64_t* moduli, uint64_t num_moduli); + +} // namespace hexl +} // namespace intel diff --git a/hexl_omp/include/hexl/experimental/seal/key-switch-internal.hpp b/hexl_omp/include/hexl/experimental/seal/key-switch-internal.hpp new file mode 100644 index 00000000..8fc9d53e --- /dev/null +++ b/hexl_omp/include/hexl/experimental/seal/key-switch-internal.hpp @@ -0,0 +1,44 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +namespace intel { +namespace hexl { +namespace internal { + +/// @brief Computes key switching in-place +/// @param[in,out] result Ciphertext data. Will be over-written with result. Has +/// (n * decomp_modulus_size * key_component_count) elements +/// @param[in] t_target_iter_ptr Pointer to the last component of the input +/// ciphertext +/// @param[in] n Number of coefficients in each polynomial +/// @param[in] decomp_modulus_size Number of moduli in the ciphertext at its +/// current level, excluding one auxiliary prime. +/// @param[in] key_modulus_size Number of moduli in the ciphertext at its top +/// level, including one auxiliary prime. +/// @param[in] rns_modulus_size Number of moduli in the ciphertext at its +/// current level, including one auxiliary prime. rns_modulus_size == +/// decomp_modulus_size + 1 +/// @param[in] key_component_count Number of components in the resulting +/// ciphertext, e.g. key_component_count == 2. +/// @param[in] moduli Array of word-sized coefficient moduli. There must be +/// key_modulus_size moduli in the array +/// @param[in] k_switch_keys Array of evaluation key data. Has +/// decomp_modulus_size entries, each with +/// coeff_count * ((key_modulus_size - 1)+ (key_component_count - 1) * +/// (key_modulus_size) + 1) entries +/// @param[in] modswitch_factors Array of modulus switch factors +/// @param[in] root_of_unity_powers_ptr Array of root of unity powers +void KeySwitch(uint64_t* result, const uint64_t* t_target_iter_ptr, uint64_t n, + uint64_t decomp_modulus_size, uint64_t key_modulus_size, + uint64_t rns_modulus_size, uint64_t key_component_count, + const uint64_t* moduli, const uint64_t** k_switch_keys, + const uint64_t* modswitch_factors, + const uint64_t* root_of_unity_powers_ptr = nullptr); + +} // namespace internal +} // namespace hexl +} // namespace intel diff --git a/hexl_omp/include/hexl/experimental/seal/key-switch.hpp b/hexl_omp/include/hexl/experimental/seal/key-switch.hpp new file mode 100644 index 00000000..9eda159c --- /dev/null +++ b/hexl_omp/include/hexl/experimental/seal/key-switch.hpp @@ -0,0 +1,42 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +namespace intel { +namespace hexl { + +/// @brief Computes key switching in-place +/// @param[in,out] result Ciphertext data. Will be over-written with result. Has +/// (n * decomp_modulus_size * key_component_count) elements +/// @param[in] t_target_iter_ptr Pointer to the last component of the input +/// ciphertext +/// @param[in] n Number of coefficients in each polynomial +/// @param[in] decomp_modulus_size Number of moduli in the ciphertext at its +/// current level, excluding one auxiliary prime. +/// @param[in] key_modulus_size Number of moduli in the ciphertext at its top +/// level, including one auxiliary prime. +/// @param[in] rns_modulus_size Number of moduli in the ciphertext at its +/// current level, including one auxiliary prime. rns_modulus_size == +/// decomp_modulus_size + 1 +/// @param[in] key_component_count Number of components in the resulting +/// ciphertext, e.g. key_component_count == 2. +/// @param[in] moduli Array of word-sized coefficient moduli. There must be +/// key_modulus_size moduli in the array +/// @param[in] k_switch_keys Array of evaluation key data. Has +/// decomp_modulus_size entries, each with +/// coeff_count * ((key_modulus_size - 1)+ (key_component_count - 1) * +/// (key_modulus_size) + 1) entries +/// @param[in] modswitch_factors Array of modulus switch factors +/// @param[in] root_of_unity_powers_ptr Array of root of unity powers +void KeySwitch(uint64_t* result, const uint64_t* t_target_iter_ptr, uint64_t n, + uint64_t decomp_modulus_size, uint64_t key_modulus_size, + uint64_t rns_modulus_size, uint64_t key_component_count, + const uint64_t* moduli, const uint64_t** k_switch_keys, + const uint64_t* modswitch_factors, + const uint64_t* root_of_unity_powers_ptr = nullptr); + +} // namespace hexl +} // namespace intel diff --git a/hexl_omp/include/hexl/experimental/seal/locks.hpp b/hexl_omp/include/hexl/experimental/seal/locks.hpp new file mode 100644 index 00000000..4595f4e5 --- /dev/null +++ b/hexl_omp/include/hexl/experimental/seal/locks.hpp @@ -0,0 +1,35 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include + +namespace intel { +namespace hexl { + +using Lock = std::shared_mutex; +using WriteLock = std::unique_lock; +using ReadLock = std::shared_lock; + +class RWLock { + public: + RWLock() = default; + inline ReadLock AcquireRead() { return ReadLock(rw_mutex); } + inline WriteLock AcquireWrite() { return WriteLock(rw_mutex); } + inline ReadLock TryAcquireRead() noexcept { + return ReadLock(rw_mutex, std::try_to_lock); + } + inline WriteLock TryAcquireWrite() noexcept { + return WriteLock(rw_mutex, std::try_to_lock); + } + + private: + RWLock(const RWLock& copy) = delete; + RWLock& operator=(const RWLock& assign) = delete; + Lock rw_mutex{}; +}; + +} // namespace hexl +} // namespace intel diff --git a/hexl_omp/include/hexl/experimental/seal/ntt-cache.hpp b/hexl_omp/include/hexl/experimental/seal/ntt-cache.hpp new file mode 100644 index 00000000..8f6c1046 --- /dev/null +++ b/hexl_omp/include/hexl/experimental/seal/ntt-cache.hpp @@ -0,0 +1,56 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "hexl/experimental/seal/locks.hpp" +#include "ntt/ntt-internal.hpp" + +namespace intel { +namespace hexl { + +struct HashPair { + template + std::size_t operator()(const std::pair& p) const { + auto hash1 = std::hash{}(p.first); + auto hash2 = std::hash{}(p.second); + return hash_combine(hash1, hash2); + } + + // Golden Ratio Hashing with seeds + static std::size_t hash_combine(std::size_t lhs, std::size_t rhs) { + lhs ^= rhs + 0x9e3779b9 + (lhs << 6) + (lhs >> 2); + return lhs; + } +}; + +NTT& GetNTT(size_t N, uint64_t modulus) { + static std::unordered_map, NTT, HashPair> + ntt_cache; + static RWLock ntt_cache_locker; + + std::pair key{N, modulus}; + + // Enable shared access to NTT already present + { + ReadLock reader_lock(ntt_cache_locker.AcquireRead()); + auto ntt_it = ntt_cache.find(key); + if (ntt_it != ntt_cache.end()) { + return ntt_it->second; + } + } + + // Deal with NTT not yet present + WriteLock write_lock(ntt_cache_locker.AcquireWrite()); + + // Check ntt_cache for value (may be added by another thread) + auto ntt_it = ntt_cache.find(key); + if (ntt_it == ntt_cache.end()) { + NTT ntt(N, modulus); + ntt_it = ntt_cache.emplace(std::move(key), std::move(ntt)).first; + } + return ntt_it->second; +} + +} // namespace hexl +} // namespace intel diff --git a/hexl_omp/include/hexl/hexl.hpp b/hexl_omp/include/hexl/hexl.hpp new file mode 100644 index 00000000..6f07ae57 --- /dev/null +++ b/hexl_omp/include/hexl/hexl.hpp @@ -0,0 +1,26 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "hexl/eltwise/eltwise-add-mod.hpp" +#include "hexl/eltwise/eltwise-cmp-add.hpp" +#include "hexl/eltwise/eltwise-cmp-sub-mod.hpp" +#include "hexl/eltwise/eltwise-fma-mod.hpp" +#include "hexl/eltwise/eltwise-mult-mod.hpp" +#include "hexl/eltwise/eltwise-reduce-mod.hpp" +#include "hexl/eltwise/eltwise-sub-mod.hpp" +#include "hexl/experimental/fft-like/fft-like.hpp" +#include "hexl/experimental/misc/lr-mat-vec-mult.hpp" +#include "hexl/experimental/seal/dyadic-multiply-internal.hpp" +#include "hexl/experimental/seal/dyadic-multiply.hpp" +#include "hexl/experimental/seal/key-switch-internal.hpp" +#include "hexl/experimental/seal/key-switch.hpp" +#include "hexl/logging/logging.hpp" +#include "hexl/ntt/ntt.hpp" +#include "hexl/number-theory/number-theory.hpp" +#include "hexl/util/check.hpp" +#include "hexl/util/compiler.hpp" +#include "hexl/util/defines.hpp" +#include "hexl/util/types.hpp" +#include "hexl/util/util.hpp" diff --git a/hexl_omp/include/hexl/logging/logging.hpp b/hexl_omp/include/hexl/logging/logging.hpp new file mode 100644 index 00000000..af5bfcd8 --- /dev/null +++ b/hexl_omp/include/hexl/logging/logging.hpp @@ -0,0 +1,43 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include + +#include "hexl/util/defines.hpp" + +// Wrap HEXL_VLOG with HEXL_DEBUG; this ensures no logging overhead in +// release mode +#ifdef HEXL_DEBUG + +// TODO(fboemer) Enable if needed +// #define ELPP_THREAD_SAFE +#define ELPP_CUSTOM_COUT std::cerr +#define ELPP_STL_LOGGING +#define ELPP_LOG_STD_ARRAY +#define ELPP_LOG_UNORDERED_MAP +#define ELPP_LOG_UNORDERED_SET +#define ELPP_NO_LOG_TO_FILE +#define ELPP_DISABLE_DEFAULT_CRASH_HANDLING +#define ELPP_WINSOCK2 + +#include + +#define HEXL_VLOG(N, rest) \ + do { \ + if (VLOG_IS_ON(N)) { \ + VLOG(N) << rest; \ + } \ + } while (0); + +#else + +#define HEXL_VLOG(N, rest) \ + {} + +#define START_EASYLOGGINGPP(X, Y) \ + {} + +#endif diff --git a/hexl_omp/include/hexl/ntt/ntt.hpp b/hexl_omp/include/hexl/ntt/ntt.hpp new file mode 100644 index 00000000..93ccba72 --- /dev/null +++ b/hexl_omp/include/hexl/ntt/ntt.hpp @@ -0,0 +1,296 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +#include +#include + +#include "hexl/util/aligned-allocator.hpp" +#include "hexl/util/allocator.hpp" + +namespace intel { +namespace hexl { + +/// @brief Performs negacyclic forward and inverse number-theoretic transform +/// (NTT), commonly used in RLWE cryptography. +/// @details The number-theoretic transform (NTT) specializes the discrete +/// Fourier transform (DFT) to the finite field \f$ \mathbb{Z}_q[X] / (X^N + 1) +/// \f$. +class NTT { + public: + /// @brief Helper class for custom memory allocation + template + struct AllocatorAdapter + : public AllocatorInterface> { + explicit AllocatorAdapter(Adaptee&& _a, Args&&... args); + AllocatorAdapter(const Adaptee& _a, Args&... args); + + // interface implementation + void* allocate_impl(size_t bytes_count); + void deallocate_impl(void* p, size_t n); + + private: + Adaptee alloc; + }; + + /// @brief Initializes an empty NTT object + NTT() = default; + + /// @brief Destructs the NTT object + ~NTT() = default; + + /// @brief Initializes an NTT object with degree \p degree and modulus \p q. + /// @param[in] degree also known as N. Size of the NTT transform. Must be a + /// power of + /// 2 + /// @param[in] q Prime modulus. Must satisfy \f$ q == 1 \mod 2N \f$ + /// @param[in] alloc_ptr Custom memory allocator used for intermediate + /// calculations + /// @brief Performs pre-computation necessary for forward and inverse + /// transforms + NTT(uint64_t degree, uint64_t q, + std::shared_ptr alloc_ptr = {}); + + template + NTT(uint64_t degree, uint64_t q, Allocator&& a, AllocatorArgs&&... args) + : NTT(degree, q, + std::static_pointer_cast( + std::make_shared>( + std::move(a), std::forward(args)...))) {} + + /// @brief Initializes an NTT object with degree \p degree and modulus + /// \p q. + /// @param[in] degree also known as N. Size of the NTT transform. Must be a + /// power of + /// 2 + /// @param[in] q Prime modulus. Must satisfy \f$ q == 1 \mod 2N \f$ + /// @param[in] root_of_unity 2N'th root of unity in \f$ \mathbb{Z_q} \f$. + /// @param[in] alloc_ptr Custom memory allocator used for intermediate + /// calculations + /// @details Performs pre-computation necessary for forward and inverse + /// transforms + NTT(uint64_t degree, uint64_t q, uint64_t root_of_unity, + std::shared_ptr alloc_ptr = {}); + + template + NTT(uint64_t degree, uint64_t q, uint64_t root_of_unity, Allocator&& a, + AllocatorArgs&&... args) + : NTT(degree, q, root_of_unity, + std::static_pointer_cast( + std::make_shared>( + std::move(a), std::forward(args)...))) {} + + /// @brief Returns true if arguments satisfy constraints for negacyclic NTT + /// @param[in] degree N. Size of the transform, i.e. the polynomial degree. + /// Must be a power of two. + /// @param[in] modulus Prime modulus q. Must satisfy q mod 2N = 1 + static bool CheckArguments(uint64_t degree, uint64_t modulus); + + /// @brief Compute forward NTT. Results are bit-reversed. + /// @param[out] result Stores the result + /// @param[in] operand Data on which to compute the NTT + /// @param[in] input_mod_factor Assume input \p operand are in [0, + /// input_mod_factor * q). Must be 1, 2 or 4. + /// @param[in] output_mod_factor Returns output \p result in [0, + /// output_mod_factor * q). Must be 1 or 4. + void ComputeForward(uint64_t* result, const uint64_t* operand, + uint64_t input_mod_factor, uint64_t output_mod_factor); + + /// Compute inverse NTT. Results are bit-reversed. + /// @param[out] result Stores the result + /// @param[in] operand Data on which to compute the NTT + /// @param[in] input_mod_factor Assume input \p operand are in [0, + /// input_mod_factor * q). Must be 1 or 2. + /// @param[in] output_mod_factor Returns output \p result in [0, + /// output_mod_factor * q). Must be 1 or 2. + void ComputeInverse(uint64_t* result, const uint64_t* operand, + uint64_t input_mod_factor, uint64_t output_mod_factor); + + /// @brief Returns the minimal 2N'th root of unity + uint64_t GetMinimalRootOfUnity() const { return m_w; } + + /// @brief Returns the degree N + uint64_t GetDegree() const { return m_degree; } + + /// @brief Returns the word-sized prime modulus + uint64_t GetModulus() const { return m_q; } + + /// @brief Returns the root of unity powers in bit-reversed order + const AlignedVector64& GetRootOfUnityPowers() const { + return m_root_of_unity_powers; + } + + /// @brief Returns the root of unity power at bit-reversed index i. + uint64_t GetRootOfUnityPower(size_t i) { return GetRootOfUnityPowers()[i]; } + + /// @brief Returns 32-bit pre-conditioned root of unity powers in + /// bit-reversed order + const AlignedVector64& GetPrecon32RootOfUnityPowers() const { + return m_precon32_root_of_unity_powers; + } + + /// @brief Returns 64-bit pre-conditioned root of unity powers in + /// bit-reversed order + const AlignedVector64& GetPrecon64RootOfUnityPowers() const { + return m_precon64_root_of_unity_powers; + } + + /// @brief Returns the root of unity powers in bit-reversed order with + /// modifications for use by AVX512 implementation + const AlignedVector64& GetAVX512RootOfUnityPowers() const { + return m_avx512_root_of_unity_powers; + } + + /// @brief Returns 32-bit pre-conditioned AVX512 root of unity powers in + /// bit-reversed order + const AlignedVector64& GetAVX512Precon32RootOfUnityPowers() const { + return m_avx512_precon32_root_of_unity_powers; + } + + /// @brief Returns 52-bit pre-conditioned AVX512 root of unity powers in + /// bit-reversed order + const AlignedVector64& GetAVX512Precon52RootOfUnityPowers() const { + return m_avx512_precon52_root_of_unity_powers; + } + + /// @brief Returns 64-bit pre-conditioned AVX512 root of unity powers in + /// bit-reversed order + const AlignedVector64& GetAVX512Precon64RootOfUnityPowers() const { + return m_avx512_precon64_root_of_unity_powers; + } + + /// @brief Returns the inverse root of unity powers in bit-reversed order + const AlignedVector64& GetInvRootOfUnityPowers() const { + return m_inv_root_of_unity_powers; + } + + /// @brief Returns the inverse root of unity power at bit-reversed index i. + uint64_t GetInvRootOfUnityPower(size_t i) { + return GetInvRootOfUnityPowers()[i]; + } + + /// @brief Returns the vector of 32-bit pre-conditioned pre-computed root of + /// unity + // powers for the modulus and root of unity. + const AlignedVector64& GetPrecon32InvRootOfUnityPowers() const { + return m_precon32_inv_root_of_unity_powers; + } + + /// @brief Returns the vector of 52-bit pre-conditioned pre-computed root of + /// unity + // powers for the modulus and root of unity. + const AlignedVector64& GetPrecon52InvRootOfUnityPowers() const { + return m_precon52_inv_root_of_unity_powers; + } + + /// @brief Returns the vector of 64-bit pre-conditioned pre-computed root of + /// unity + // powers for the modulus and root of unity. + const AlignedVector64& GetPrecon64InvRootOfUnityPowers() const { + return m_precon64_inv_root_of_unity_powers; + } + + /// @brief Maximum power of 2 in degree + static size_t MaxDegreeBits() { return 20; } + + /// @brief Maximum number of bits in modulus; + static size_t MaxModulusBits() { return 62; } + + /// @brief Default bit shift used in Barrett precomputation + static const size_t s_default_shift_bits{64}; + + /// @brief Bit shift used in Barrett precomputation when AVX512-IFMA + /// acceleration is enabled + static const size_t s_ifma_shift_bits{52}; + + /// @brief Maximum modulus to use 32-bit AVX512-DQ acceleration for the + /// forward transform + static const size_t s_max_fwd_32_modulus{1ULL << (32 - 2)}; + + /// @brief Maximum modulus to use 32-bit AVX512-DQ acceleration for the + /// inverse transform + static const size_t s_max_inv_32_modulus{1ULL << (32 - 2)}; + + /// @brief Maximum modulus to use AVX512-IFMA acceleration for the forward + /// transform + static const size_t s_max_fwd_ifma_modulus{1ULL << (s_ifma_shift_bits - 2)}; + + /// @brief Maximum modulus to use AVX512-IFMA acceleration for the inverse + /// transform + static const size_t s_max_inv_ifma_modulus{1ULL << (s_ifma_shift_bits - 2)}; + + /// @brief Maximum modulus to use AVX512-DQ acceleration for the inverse + /// transform + static const size_t s_max_inv_dq_modulus{1ULL << (s_default_shift_bits - 2)}; + + static size_t s_max_fwd_modulus(int bit_shift) { + if (bit_shift == 32) { + return s_max_fwd_32_modulus; + } else if (bit_shift == 52) { + return s_max_fwd_ifma_modulus; + } else if (bit_shift == 64) { + return 1ULL << MaxModulusBits(); + } + HEXL_CHECK(false, "Invalid bit_shift " << bit_shift); + return 0; + } + + static size_t s_max_inv_modulus(int bit_shift) { + if (bit_shift == 32) { + return s_max_inv_32_modulus; + } else if (bit_shift == 52) { + return s_max_inv_ifma_modulus; + } else if (bit_shift == 64) { + return 1ULL << MaxModulusBits(); + } + HEXL_CHECK(false, "Invalid bit_shift " << bit_shift); + return 0; + } + + private: + void ComputeRootOfUnityPowers(); + + uint64_t m_degree; // N: size of NTT transform, should be power of 2 + uint64_t m_q; // prime modulus. Must satisfy q == 1 mod 2n + + uint64_t m_degree_bits; // log_2(m_degree) + + uint64_t m_w_inv; // Inverse of minimal root of unity + uint64_t m_w; // A 2N'th root of unity + + std::shared_ptr m_alloc; + + AlignedAllocator m_aligned_alloc; + + // powers of the minimal root of unity + AlignedVector64 m_root_of_unity_powers; + // vector of floor(W * 2**32 / m_q), with W the root of unity powers + AlignedVector64 m_precon32_root_of_unity_powers; + // vector of floor(W * 2**64 / m_q), with W the root of unity powers + AlignedVector64 m_precon64_root_of_unity_powers; + + // powers of the minimal root of unity adjusted for use in AVX512 + // implementations + AlignedVector64 m_avx512_root_of_unity_powers; + // vector of floor(W * 2**32 / m_q), with W the AVX512 root of unity powers + AlignedVector64 m_avx512_precon32_root_of_unity_powers; + // vector of floor(W * 2**52 / m_q), with W the AVX512 root of unity powers + AlignedVector64 m_avx512_precon52_root_of_unity_powers; + // vector of floor(W * 2**64 / m_q), with W the AVX512 root of unity powers + AlignedVector64 m_avx512_precon64_root_of_unity_powers; + + // vector of floor(W * 2**32 / m_q), with W the inverse root of unity powers + AlignedVector64 m_precon32_inv_root_of_unity_powers; + // vector of floor(W * 2**52 / m_q), with W the inverse root of unity powers + AlignedVector64 m_precon52_inv_root_of_unity_powers; + // vector of floor(W * 2**64 / m_q), with W the inverse root of unity powers + AlignedVector64 m_precon64_inv_root_of_unity_powers; + + AlignedVector64 m_inv_root_of_unity_powers; +}; + +} // namespace hexl +} // namespace intel diff --git a/hexl_omp/include/hexl/number-theory/number-theory.hpp b/hexl_omp/include/hexl/number-theory/number-theory.hpp new file mode 100644 index 00000000..da8d1d2a --- /dev/null +++ b/hexl_omp/include/hexl/number-theory/number-theory.hpp @@ -0,0 +1,342 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +#include +#include + +#include "hexl/util/check.hpp" +#include "hexl/util/compiler.hpp" + +namespace intel { +namespace hexl { + +/// @brief Pre-computes a Barrett factor with which modular multiplication can +/// be performed more efficiently +class MultiplyFactor { + public: + MultiplyFactor() = default; + + /// @brief Computes and stores the Barrett factor floor((operand << bit_shift) + /// / modulus). This is useful when modular multiplication of the form + /// (x * operand) mod modulus is performed with same modulus and operand + /// several times. Note, passing operand=1 can be used to pre-compute a + /// Barrett factor for multiplications of the form (x * y) mod modulus, where + /// only the modulus is re-used across calls to modular multiplication. + MultiplyFactor(uint64_t operand, uint64_t bit_shift, uint64_t modulus) + : m_operand(operand) { + HEXL_CHECK(operand <= modulus, "operand " << operand + << " must be less than modulus " + << modulus); + HEXL_CHECK(bit_shift == 32 || bit_shift == 52 || bit_shift == 64, + "Unsupported BitShift " << bit_shift); + uint64_t op_hi = operand >> (64 - bit_shift); + uint64_t op_lo = (bit_shift == 64) ? 0 : (operand << bit_shift); + + m_barrett_factor = DivideUInt128UInt64Lo(op_hi, op_lo, modulus); + } + + /// @brief Returns the pre-computed Barrett factor + inline uint64_t BarrettFactor() const { return m_barrett_factor; } + + /// @brief Returns the operand corresponding to the Barrett factor + inline uint64_t Operand() const { return m_operand; } + + private: + uint64_t m_operand; + uint64_t m_barrett_factor; +}; + +/// @brief Returns whether or not num is a power of two +inline bool IsPowerOfTwo(uint64_t num) { return num && !(num & (num - 1)); } + +/// @brief Returns floor(log2(x)) +inline uint64_t Log2(uint64_t x) { return MSB(x); } + +inline bool IsPowerOfFour(uint64_t num) { + return IsPowerOfTwo(num) && (Log2(num) % 2 == 0); +} + +/// @brief Returns the maximum value that can be represented using \p bits bits +inline uint64_t MaximumValue(uint64_t bits) { + HEXL_CHECK(bits <= 64, "MaximumValue requires bits <= 64; got " << bits); + if (bits == 64) { + return (std::numeric_limits::max)(); + } + return (1ULL << bits) - 1; +} + +/// @brief Reverses the bits +/// @param[in] x Input to reverse +/// @param[in] bit_width Number of bits in the input; must be >= MSB(x) +/// @return The bit-reversed representation of \p x using \p bit_width bits +uint64_t ReverseBits(uint64_t x, uint64_t bit_width); + +/// @brief Returns x^{-1} mod modulus +/// @details Requires x % modulus != 0 +uint64_t InverseMod(uint64_t x, uint64_t modulus); + +/// @brief Returns (x * y) mod modulus +/// @details Assumes x, y < modulus +uint64_t MultiplyMod(uint64_t x, uint64_t y, uint64_t modulus); + +/// @brief Returns (x * y) mod modulus +/// @param[in] x +/// @param[in] y +/// @param[in] y_precon 64-bit precondition factor floor(2**64 / modulus) +/// @param[in] modulus +uint64_t MultiplyMod(uint64_t x, uint64_t y, uint64_t y_precon, + uint64_t modulus); + +/// @brief Returns (x + y) mod modulus +/// @details Assumes x, y < modulus +uint64_t AddUIntMod(uint64_t x, uint64_t y, uint64_t modulus); + +/// @brief Returns (x - y) mod modulus +/// @details Assumes x, y < modulus +uint64_t SubUIntMod(uint64_t x, uint64_t y, uint64_t modulus); + +/// @brief Returns base^exp mod modulus +uint64_t PowMod(uint64_t base, uint64_t exp, uint64_t modulus); + +/// @brief Returns whether or not root is a degree-th root of unity mod modulus +/// @param[in] root Root of unity to check +/// @param[in] degree Degree of root of unity; must be a power of two +/// @param[in] modulus Modulus of finite field +bool IsPrimitiveRoot(uint64_t root, uint64_t degree, uint64_t modulus); + +/// @brief Tries to return a primitive degree-th root of unity +/// @details Returns 0 or throws an error if no root is found +uint64_t GeneratePrimitiveRoot(uint64_t degree, uint64_t modulus); + +/// @brief Returns whether or not root is a degree-th root of unity +/// @param[in] degree Must be a power of two +/// @param[in] modulus Modulus of finite field +uint64_t MinimalPrimitiveRoot(uint64_t degree, uint64_t modulus); + +/// @brief Computes (x * y) mod modulus, except that the output is in [0, 2 * +/// modulus] +/// @param[in] x +/// @param[in] y_operand also denoted y +/// @param[in] modulus +/// @param[in] y_barrett_factor Pre-computed Barrett reduction factor floor((y +/// << BitShift) / modulus) +template +inline uint64_t MultiplyModLazy(uint64_t x, uint64_t y_operand, + uint64_t y_barrett_factor, uint64_t modulus) { + HEXL_CHECK(y_operand < modulus, "y_operand " << y_operand + << " must be less than modulus " + << modulus); + HEXL_CHECK( + modulus <= MaximumValue(BitShift), + "Modulus " << modulus << " exceeds bound " << MaximumValue(BitShift)); + HEXL_CHECK(x <= MaximumValue(BitShift), + "Operand " << x << " exceeds bound " << MaximumValue(BitShift)); + + uint64_t Q = MultiplyUInt64Hi(x, y_barrett_factor); + return y_operand * x - Q * modulus; +} + +/// @brief Computes (x * y) mod modulus, except that the output is in [0, 2 * +/// modulus] +/// @param[in] x +/// @param[in] y +/// @param[in] modulus +template +inline uint64_t MultiplyModLazy(uint64_t x, uint64_t y, uint64_t modulus) { + HEXL_CHECK(BitShift == 64 || BitShift == 52, + "Unsupported BitShift " << BitShift); + HEXL_CHECK(x <= MaximumValue(BitShift), + "Operand " << x << " exceeds bound " << MaximumValue(BitShift)); + HEXL_CHECK(y < modulus, + "y " << y << " must be less than modulus " << modulus); + HEXL_CHECK( + modulus <= MaximumValue(BitShift), + "Modulus " << modulus << " exceeds bound " << MaximumValue(BitShift)); + + uint64_t y_barrett = MultiplyFactor(y, BitShift, modulus).BarrettFactor(); + return MultiplyModLazy(x, y, y_barrett, modulus); +} + +/// @brief Adds two unsigned 64-bit integers +/// @param operand1 Number to add +/// @param operand2 Number to add +/// @param result Stores the sum +/// @return The carry bit +inline unsigned char AddUInt64(uint64_t operand1, uint64_t operand2, + uint64_t* result) { + *result = operand1 + operand2; + return static_cast(*result < operand1); +} + +/// @brief Returns whether or not the input is prime +bool IsPrime(uint64_t n); + +/// @brief Generates a list of num_primes primes in the range [2^(bit_size), +// 2^(bit_size+1)]. Ensures each prime q satisfies +// q % (2*ntt_size+1)) == 1 +/// @param[in] num_primes Number of primes to generate +/// @param[in] bit_size Bit size of each prime +/// @param[in] prefer_small_primes When true, returns primes starting from +/// 2^(bit_size); when false, returns primes starting from 2^(bit_size+1) +/// @param[in] ntt_size N such that each prime q satisfies q % (2N) == 1. N must +/// be a power of two less than 2^bit_size. +std::vector GeneratePrimes(size_t num_primes, size_t bit_size, + bool prefer_small_primes, + size_t ntt_size = 1); + +/// @brief Returns input mod modulus, computed via 64-bit Barrett reduction +/// @param[in] input +/// @param[in] modulus +/// @param[in] q_barr floor(2^64 / modulus) +template +uint64_t BarrettReduce64(uint64_t input, uint64_t modulus, uint64_t q_barr) { + HEXL_CHECK(modulus != 0, "modulus == 0"); + uint64_t q = MultiplyUInt64Hi<64>(input, q_barr); + uint64_t q_times_input = input - q * modulus; + if (OutputModFactor == 2) { + return q_times_input; + } else { + return (q_times_input >= modulus) ? q_times_input - modulus : q_times_input; + } +} + +/// @brief Returns x mod modulus, assuming x < InputModFactor * modulus +/// @param[in] x +/// @param[in] modulus also denoted q +/// @param[in] twice_modulus 2 * q; must not be nullptr if InputModFactor == 4 +/// or 8 +/// @param[in] four_times_modulus 4 * q; must not be nullptr if InputModFactor +/// == 8 +template +uint64_t ReduceMod(uint64_t x, uint64_t modulus, + const uint64_t* twice_modulus = nullptr, + const uint64_t* four_times_modulus = nullptr) { + HEXL_CHECK(InputModFactor == 1 || InputModFactor == 2 || + InputModFactor == 4 || InputModFactor == 8, + "InputModFactor should be 1, 2, 4, or 8"); + if (InputModFactor == 1) { + return x; + } + if (InputModFactor == 2) { + if (x >= modulus) { + x -= modulus; + } + return x; + } + if (InputModFactor == 4) { + HEXL_CHECK(twice_modulus != nullptr, "twice_modulus should not be nullptr"); + if (x >= *twice_modulus) { + x -= *twice_modulus; + } + if (x >= modulus) { + x -= modulus; + } + return x; + } + if (InputModFactor == 8) { + HEXL_CHECK(twice_modulus != nullptr, "twice_modulus should not be nullptr"); + HEXL_CHECK(four_times_modulus != nullptr, + "four_times_modulus should not be nullptr"); + + if (x >= *four_times_modulus) { + x -= *four_times_modulus; + } + if (x >= *twice_modulus) { + x -= *twice_modulus; + } + if (x >= modulus) { + x -= modulus; + } + return x; + } + HEXL_CHECK(false, "Should be unreachable"); + return x; +} + +/// @brief Returns Montgomery form of ab mod q, computed via the REDC algorithm, +/// also known as Montgomery reduction. +/// @param[in] r +/// @param[in] q with R = 2^r such that gcd(R, q) = 1. R > q. +/// @param[in] inv_mod in [0, R − 1] such that q*v_inv_mod ≡ −1 mod R. +/// @param[in] mod_R_msk take r last bits to apply mod R. +/// @param[in] T_hi of T = ab in the range [0, Rq − 1]. +/// @param[in] T_lo of T. +/// @return Unsigned long int in the range [0, q − 1] such that S ≡ TR^−1 mod q +template +inline uint64_t MontgomeryReduce(uint64_t T_hi, uint64_t T_lo, uint64_t q, + int r, uint64_t mod_R_msk, uint64_t inv_mod) { + HEXL_CHECK(BitShift == 64 || BitShift == 52, + "Unsupported BitShift " << BitShift); + HEXL_CHECK((1ULL << r) > static_cast(q), + "R value should be greater than q = " << static_cast(q)); + + uint64_t mq_hi; + uint64_t mq_lo; + + uint64_t m = ((T_lo & mod_R_msk) * inv_mod) & mod_R_msk; + MultiplyUInt64(m, q, &mq_hi, &mq_lo); + + if (BitShift == 52) { + mq_hi = (mq_hi << 12) | (mq_lo >> 52); + mq_lo &= (1ULL << 52) - 1; + } + + uint64_t t_hi; + uint64_t t_lo; + + // first 64bit block + t_lo = T_lo + mq_lo; + unsigned int carry = static_cast(t_lo < T_lo); + t_hi = T_hi + mq_hi + carry; + + t_hi = t_hi << (BitShift - r); + t_lo = t_lo >> r; + t_lo = t_hi + t_lo; + + return (t_lo >= q) ? (t_lo - q) : t_lo; +} + +/// @brief Hensel's Lemma for 2-adic numbers +/// Find solution for qX + 1 = 0 mod 2^r +/// @param[in] r +/// @param[in] q such that gcd(2, q) = 1 +/// @return Unsigned long int in [0, 2^r − 1] such that q*x ≡ −1 mod 2^r +inline uint64_t HenselLemma2adicRoot(uint32_t r, uint64_t q) { + uint64_t a_prev = 1; + uint64_t c = 2; + uint64_t mod_mask = 3; + + // Root: + // f(x) = qX + 1 and a_(0) = 1 then f(1) ≡ 0 mod 2 + // General Case: + // - a_(n) ≡ a_(n-1) mod 2^(n) + // => a_(n) = a_(n-1) + 2^(n)*t + // - Find 't' such that f(a_(n)) = 0 mod 2^(n+1) + // First case in for: + // - a_(1) ≡ 1 mod 2 or a_(1) = 1 + 2t + // - Find 't' so f(a_(1)) ≡ 0 mod 4 => q(1 + 2t) + 1 ≡ 0 mod 4 + for (uint64_t k = 2; k <= r; k++) { + uint64_t f = 0; + uint64_t t = 0; + uint64_t a = 0; + + do { + a = a_prev + c * t++; + f = q * a + 1ULL; + } while (f & mod_mask); // f(a) ≡ 0 mod 2^(k) + + // Update vars + mod_mask = mod_mask * 2 + 1ULL; + c *= 2; + a_prev = a; + } + + return a_prev; +} + +} // namespace hexl +} // namespace intel diff --git a/hexl_omp/include/hexl/util/aligned-allocator.hpp b/hexl_omp/include/hexl/util/aligned-allocator.hpp new file mode 100644 index 00000000..d175c734 --- /dev/null +++ b/hexl_omp/include/hexl/util/aligned-allocator.hpp @@ -0,0 +1,110 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include +#include + +#include "hexl/number-theory/number-theory.hpp" +#include "hexl/util/allocator.hpp" +#include "hexl/util/defines.hpp" + +namespace intel { +namespace hexl { + +/// @brief Allocater implementation using malloc and free +struct MallocStrategy : AllocatorBase { + void* allocate(size_t bytes_count) final { return std::malloc(bytes_count); } + + void deallocate(void* p, size_t n) final { + HEXL_UNUSED(n); + std::free(p); + } +}; + +using AllocatorStrategyPtr = std::shared_ptr; +extern AllocatorStrategyPtr mallocStrategy; + +/// @brief Allocates memory aligned to Alignment-byte sized boundaries +/// @details Alignment must be a power of two +template +class AlignedAllocator { + public: + template + friend class AlignedAllocator; + + using value_type = T; + + explicit AlignedAllocator(AllocatorStrategyPtr strategy = nullptr) noexcept + : m_alloc_impl((strategy != nullptr) ? strategy : mallocStrategy) {} + + AlignedAllocator(const AlignedAllocator& src) = default; + AlignedAllocator& operator=(const AlignedAllocator& src) = default; + + template + AlignedAllocator(const AlignedAllocator& src) + : m_alloc_impl(src.m_alloc_impl) {} + + ~AlignedAllocator() {} + + template + struct rebind { + using other = AlignedAllocator; + }; + + bool operator==(const AlignedAllocator&) { return true; } + + bool operator!=(const AlignedAllocator&) { return false; } + + /// @brief Allocates \p n elements aligned to Alignment-byte boundaries + /// @return Pointer to the aligned allocated memory + T* allocate(size_t n) { + if (!IsPowerOfTwo(Alignment)) { + return nullptr; + } + // Allocate enough space to ensure the alignment can be satisfied + size_t buffer_size = sizeof(T) * n + Alignment; + // Additionally, allocate a prefix to store the memory location of the + // unaligned buffer + size_t alloc_size = buffer_size + sizeof(void*); + void* buffer = m_alloc_impl->allocate(alloc_size); + if (!buffer) { + return nullptr; + } + + // Reserve first location for pointer to originally-allocated space + void* aligned_buffer = static_cast(buffer) + sizeof(void*); + std::align(Alignment, sizeof(T) * n, aligned_buffer, buffer_size); + if (!aligned_buffer) { + return nullptr; + } + + // Store allocated buffer address at aligned_buffer - sizeof(void*). + void* store_buffer_addr = + static_cast(aligned_buffer) - sizeof(void*); + *(static_cast(store_buffer_addr)) = buffer; + + return static_cast(aligned_buffer); + } + + void deallocate(T* p, size_t n) { + if (!p) { + return; + } + void* store_buffer_addr = (reinterpret_cast(p) - sizeof(void*)); + void* free_address = *(static_cast(store_buffer_addr)); + m_alloc_impl->deallocate(free_address, n); + } + + private: + AllocatorStrategyPtr m_alloc_impl; +}; + +/// @brief 64-byte aligned memory allocator +template +using AlignedVector64 = std::vector >; + +} // namespace hexl +} // namespace intel diff --git a/hexl_omp/include/hexl/util/allocator.hpp b/hexl_omp/include/hexl/util/allocator.hpp new file mode 100644 index 00000000..5f4a7a31 --- /dev/null +++ b/hexl_omp/include/hexl/util/allocator.hpp @@ -0,0 +1,53 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +namespace intel { +namespace hexl { + +/// @brief Base class for custom memory allocator +struct AllocatorBase { + virtual ~AllocatorBase() noexcept {} + + /// @brief Allocates byte_count bytes of memory + /// @param[in] bytes_count Number of bytes to allocate + /// @return A pointer to the allocated memory + virtual void* allocate(size_t bytes_count) = 0; + + /// @brief Deallocate memory + /// @param[in] p Pointer to memory to deallocate + /// @param[in] n Number of bytes to deallocate + virtual void deallocate(void* p, size_t n) = 0; +}; + +/// @brief Helper memory allocation struct which delegates implementation to +/// AllocatorImpl +template +struct AllocatorInterface : public AllocatorBase { + /// @brief Override interface and delegate implementation to AllocatorImpl + void* allocate(size_t bytes_count) override { + return static_cast(this)->allocate_impl(bytes_count); + } + + /// @brief Override interface and delegate implementation to AllocatorImpl + void deallocate(void* p, size_t n) override { + static_cast(this)->deallocate_impl(p, n); + } + + private: + // in case AllocatorImpl doesn't provide implementations, use default null + // behavior + void* allocate_impl(size_t bytes_count) { + HEXL_UNUSED(bytes_count); + return nullptr; + } + void deallocate_impl(void* p, size_t n) { + HEXL_UNUSED(p); + HEXL_UNUSED(n); + } +}; +} // namespace hexl +} // namespace intel diff --git a/hexl_omp/include/hexl/util/check.hpp b/hexl_omp/include/hexl/util/check.hpp new file mode 100644 index 00000000..386eba89 --- /dev/null +++ b/hexl_omp/include/hexl/util/check.hpp @@ -0,0 +1,44 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +#include "hexl/util/types.hpp" + +// Create logging/debug macros with no run-time overhead unless HEXL_DEBUG is +// enabled +#ifdef HEXL_DEBUG +#include "hexl/logging/logging.hpp" + +/// @brief If input condition is not true, logs the expression and throws an +/// error +/// @param[in] cond A boolean indication the condition +/// @param[in] expr The expression to be logged +#define HEXL_CHECK(cond, expr) \ + if (!(cond)) { \ + LOG(ERROR) << expr << " in function: " << __FUNCTION__ \ + << " in file: " __FILE__ << ":" << __LINE__; \ + throw std::runtime_error("Error. Check log output"); \ + } + +/// @brief If input has an element >= bound, logs the expression and throws an +/// error +/// @param[in] arg Input container which supports the [] operator. +/// @param[in] n Size of input +/// @param[in] bound Upper bound on the input +/// @param[in] expr The expression to be logged +#define HEXL_CHECK_BOUNDS(arg, n, bound, expr) \ + for (size_t hexl_check_idx = 0; hexl_check_idx < n; ++hexl_check_idx) { \ + HEXL_CHECK((arg)[hexl_check_idx] < bound, expr); \ + } + +#else // HEXL_DEBUG=OFF + +#define HEXL_CHECK(cond, expr) \ + {} +#define HEXL_CHECK_BOUNDS(...) \ + {} + +#endif // HEXL_DEBUG diff --git a/hexl_omp/include/hexl/util/clang.hpp b/hexl_omp/include/hexl/util/clang.hpp new file mode 100644 index 00000000..958bea7b --- /dev/null +++ b/hexl_omp/include/hexl/util/clang.hpp @@ -0,0 +1,67 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once +#include + +#include "hexl/util/check.hpp" +#include "hexl/util/types.hpp" + +namespace intel { +namespace hexl { + +#ifdef HEXL_USE_CLANG +// Return x * y as 128-bit integer +// Correctness if x * y < 128 bits +inline uint128_t MultiplyUInt64(uint64_t x, uint64_t y) { + return uint128_t(x) * uint128_t(y); +} + +inline uint64_t BarrettReduce128(uint64_t input_hi, uint64_t input_lo, + uint64_t modulus) { + HEXL_CHECK(modulus != 0, "modulus == 0") + uint128_t n = (static_cast(input_hi) << 64) | + (static_cast(input_lo)); + + return n % modulus; + // TODO(fboemer): actually use barrett reduction if performance-critical +} + +// Returns low 64bit of 128b/64b where x1=high 64b, x0=low 64b +inline uint64_t DivideUInt128UInt64Lo(uint64_t x1, uint64_t x0, uint64_t y) { + uint128_t n = + (static_cast(x1) << 64) | (static_cast(x0)); + uint128_t q = n / y; + + return static_cast(q); +} + +// Multiplies x * y as 128-bit integer. +// @param prod_hi Stores high 64 bits of product +// @param prod_lo Stores low 64 bits of product +inline void MultiplyUInt64(uint64_t x, uint64_t y, uint64_t* prod_hi, + uint64_t* prod_lo) { + uint128_t prod = MultiplyUInt64(x, y); + *prod_hi = static_cast(prod >> 64); + *prod_lo = static_cast(prod); +} + +// Return the high 128 minus BitShift bits of the 128-bit product x * y +template +inline uint64_t MultiplyUInt64Hi(uint64_t x, uint64_t y) { + uint128_t product = static_cast(x) * y; + return static_cast(product >> BitShift); +} + +// Returns most-significant bit of the input +inline uint64_t MSB(uint64_t input) { + return static_cast(std::log2l(input)); +} + +#define HEXL_LOOP_UNROLL_4 _Pragma("clang loop unroll_count(4)") +#define HEXL_LOOP_UNROLL_8 _Pragma("clang loop unroll_count(8)") + +#endif + +} // namespace hexl +} // namespace intel diff --git a/hexl_omp/include/hexl/util/compiler.hpp b/hexl_omp/include/hexl/util/compiler.hpp new file mode 100644 index 00000000..7dd077df --- /dev/null +++ b/hexl_omp/include/hexl/util/compiler.hpp @@ -0,0 +1,14 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "hexl/util/defines.hpp" + +#ifdef HEXL_USE_MSVC +#include "hexl/util/msvc.hpp" +#elif defined HEXL_USE_GNU +#include "hexl/util/gcc.hpp" +#elif defined HEXL_USE_CLANG +#include "hexl/util/clang.hpp" +#endif diff --git a/hexl_omp/include/hexl/util/defines.hpp b/hexl_omp/include/hexl/util/defines.hpp new file mode 100644 index 00000000..93db376e --- /dev/null +++ b/hexl_omp/include/hexl/util/defines.hpp @@ -0,0 +1,13 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +/* #undef HEXL_USE_MSVC */ +#define HEXL_USE_GNU +/* #undef HEXL_USE_CLANG */ + +/* #undef HEXL_DEBUG */ + +// Avoid unused variable warnings +#define HEXL_UNUSED(x) (void)(x) diff --git a/hexl_omp/include/hexl/util/defines.hpp.in b/hexl_omp/include/hexl/util/defines.hpp.in new file mode 100644 index 00000000..0f146c26 --- /dev/null +++ b/hexl_omp/include/hexl/util/defines.hpp.in @@ -0,0 +1,13 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#cmakedefine HEXL_USE_MSVC +#cmakedefine HEXL_USE_GNU +#cmakedefine HEXL_USE_CLANG + +#cmakedefine HEXL_DEBUG + +// Avoid unused variable warnings +#define HEXL_UNUSED(x) (void)(x) diff --git a/hexl_omp/include/hexl/util/gcc.hpp b/hexl_omp/include/hexl/util/gcc.hpp new file mode 100644 index 00000000..828e3836 --- /dev/null +++ b/hexl_omp/include/hexl/util/gcc.hpp @@ -0,0 +1,67 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once +#include + +#include "hexl/util/check.hpp" +#include "hexl/util/types.hpp" + +namespace intel { +namespace hexl { + +#ifdef HEXL_USE_GNU +// Return x * y as 128-bit integer +// Correctness if x * y < 128 bits +inline uint128_t MultiplyUInt64(uint64_t x, uint64_t y) { + return uint128_t(x) * uint128_t(y); +} + +inline uint64_t BarrettReduce128(uint64_t input_hi, uint64_t input_lo, + uint64_t modulus) { + HEXL_CHECK(modulus != 0, "modulus == 0") + uint128_t n = (static_cast(input_hi) << 64) | + (static_cast(input_lo)); + + return static_cast(n % modulus); + // TODO(fboemer): actually use barrett reduction if performance-critical +} + +// Returns low 64bit of 128b/64b where x1=high 64b, x0=low 64b +inline uint64_t DivideUInt128UInt64Lo(uint64_t x1, uint64_t x0, uint64_t y) { + uint128_t n = + (static_cast(x1) << 64) | (static_cast(x0)); + uint128_t q = n / y; + + return static_cast(q); +} + +// Multiplies x * y as 128-bit integer. +// @param prod_hi Stores high 64 bits of product +// @param prod_lo Stores low 64 bits of product +inline void MultiplyUInt64(uint64_t x, uint64_t y, uint64_t* prod_hi, + uint64_t* prod_lo) { + uint128_t prod = MultiplyUInt64(x, y); + *prod_hi = static_cast(prod >> 64); + *prod_lo = static_cast(prod); +} + +// Return the high 128 minus BitShift bits of the 128-bit product x * y +template +inline uint64_t MultiplyUInt64Hi(uint64_t x, uint64_t y) { + uint128_t product = MultiplyUInt64(x, y); + return static_cast(product >> BitShift); +} + +// Returns most-significant bit of the input +inline uint64_t MSB(uint64_t input) { + return static_cast(std::log2l(input)); +} + +#define HEXL_LOOP_UNROLL_4 _Pragma("GCC unroll 4") +#define HEXL_LOOP_UNROLL_8 _Pragma("GCC unroll 8") + +#endif + +} // namespace hexl +} // namespace intel diff --git a/hexl_omp/include/hexl/util/msvc.hpp b/hexl_omp/include/hexl/util/msvc.hpp new file mode 100644 index 00000000..0ada2d45 --- /dev/null +++ b/hexl_omp/include/hexl/util/msvc.hpp @@ -0,0 +1,289 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#ifdef HEXL_USE_MSVC + +#define NOMINMAX // Avoid errors with std::min/std::max +#undef min +#undef max + +#include +#include +#include + +#include + +#include "hexl/util/check.hpp" + +#pragma intrinsic(_addcarry_u64, _BitScanReverse64, _subborrow_u64, _udiv128, \ + _umul128) + +#undef TRUE +#undef FALSE + +namespace intel { +namespace hexl { + +inline uint64_t BarrettReduce128(uint64_t input_hi, uint64_t input_lo, + uint64_t modulus) { + HEXL_CHECK(modulus != 0, "modulus == 0") + uint64_t remainder; + _udiv128(input_hi, input_lo, modulus, &remainder); + + return remainder; +} + +// Multiplies x * y as 128-bit integer. +// @param prod_hi Stores high 64 bits of product +// @param prod_lo Stores low 64 bits of product +inline void MultiplyUInt64(uint64_t x, uint64_t y, uint64_t* prod_hi, + uint64_t* prod_lo) { + *prod_lo = _umul128(x, y, prod_hi); +} + +// Return the high 128 minus BitShift bits of the 128-bit product x * y +template +inline uint64_t MultiplyUInt64Hi(uint64_t x, uint64_t y) { + HEXL_CHECK(BitShift == 52 || BitShift == 64, + "Invalid BitShift " << BitShift << "; expected 52 or 64"); + uint64_t prod_hi; + uint64_t prod_lo = _umul128(x, y, &prod_hi); + uint64_t result_hi; + uint64_t result_lo; + RightShift128(&result_hi, &result_lo, prod_hi, prod_lo, BitShift); + return result_lo; +} + +/// @brief Computes Left Shift op as 128-bit unsigned integer +/// @param[out] result_hi Stores high 64 bits of result +/// @param[out] result_lo Stores low 64 bits of result +/// @param[in] op_hi Stores high 64 bits of input +/// @param[in] op_lo Stores low 64 bits of input +inline void LeftShift128(uint64_t* result_hi, uint64_t* result_lo, + const uint64_t op_hi, const uint64_t op_lo, + const uint64_t shift_value) { + HEXL_CHECK(result_hi != nullptr, "Require result_hi != nullptr"); + HEXL_CHECK(result_lo != nullptr, "Require result_lo != nullptr"); + HEXL_CHECK(shift_value <= 128, + "shift_value cannot be greater than 128 " << shift_value); + + if (shift_value == 0) { + *result_hi = op_hi; + *result_lo = op_lo; + } else if (shift_value == 64) { + *result_hi = op_lo; + *result_lo = 0ULL; + } else if (shift_value == 128) { + *result_hi = 0ULL; + *result_lo = 0ULL; + } else if (shift_value >= 1 && shift_value <= 63) { + *result_hi = (op_hi << shift_value) | (op_lo >> (64 - shift_value)); + *result_lo = op_lo << shift_value; + } else if (shift_value >= 65 && shift_value < 128) { + *result_hi = op_lo << (shift_value - 64); + *result_lo = 0ULL; + } +} + +/// @brief Computes Right Shift op as 128-bit unsigned integer +/// @param[out] result_hi Stores high 64 bits of result +/// @param[out] result_lo Stores low 64 bits of result +/// @param[in] op_hi Stores high 64 bits of input +/// @param[in] op_lo Stores low 64 bits of input +inline void RightShift128(uint64_t* result_hi, uint64_t* result_lo, + const uint64_t op_hi, const uint64_t op_lo, + const uint64_t shift_value) { + HEXL_CHECK(result_hi != nullptr, "Require result_hi != nullptr"); + HEXL_CHECK(result_lo != nullptr, "Require result_lo != nullptr"); + HEXL_CHECK(shift_value <= 128, + "shift_value cannot be greater than 128 " << shift_value); + + if (shift_value == 0) { + *result_hi = op_hi; + *result_lo = op_lo; + } else if (shift_value == 64) { + *result_hi = 0ULL; + *result_lo = op_hi; + } else if (shift_value == 128) { + *result_hi = 0ULL; + *result_lo = 0ULL; + } else if (shift_value >= 1 && shift_value <= 63) { + *result_hi = op_hi >> shift_value; + *result_lo = (op_hi << (64 - shift_value)) | (op_lo >> shift_value); + } else if (shift_value >= 65 && shift_value < 128) { + *result_hi = 0ULL; + *result_lo = op_hi >> (shift_value - 64); + } +} + +/// @brief Adds op1 + op2 as 128-bit integer +/// @param[out] result_hi Stores high 64 bits of result +/// @param[out] result_lo Stores low 64 bits of result +inline void AddWithCarry128(uint64_t* result_hi, uint64_t* result_lo, + const uint64_t op1_hi, const uint64_t op1_lo, + const uint64_t op2_hi, const uint64_t op2_lo) { + HEXL_CHECK(result_hi != nullptr, "Require result_hi != nullptr"); + HEXL_CHECK(result_lo != nullptr, "Require result_lo != nullptr"); + + // first 64bit block + *result_lo = op1_lo + op2_lo; + unsigned char carry = static_cast(*result_lo < op1_lo); + + // second 64bit block + _addcarry_u64(carry, op1_hi, op2_hi, result_hi); +} + +/// @brief Subtracts op1 - op2 as 128-bit integer +/// @param[out] result_hi Stores high 64 bits of result +/// @param[out] result_lo Stores low 64 bits of result +inline void SubWithCarry128(uint64_t* result_hi, uint64_t* result_lo, + const uint64_t op1_hi, const uint64_t op1_lo, + const uint64_t op2_hi, const uint64_t op2_lo) { + HEXL_CHECK(result_hi != nullptr, "Require result_hi != nullptr"); + HEXL_CHECK(result_lo != nullptr, "Require result_lo != nullptr"); + + unsigned char borrow; + + // first 64bit block + *result_lo = op1_lo - op2_lo; + borrow = static_cast(op2_lo > op1_lo); + + // second 64bit block + _subborrow_u64(borrow, op1_hi, op2_hi, result_hi); +} + +/// @brief Computes and returns significant bit count +/// @param[in] value Input element at most 128 bits long +inline uint64_t SignificantBitLength(const uint64_t* value) { + HEXL_CHECK(value != nullptr, "Require value != nullptr"); + + unsigned long count = 0; // NOLINT(runtime/int) + + // second 64bit block + _BitScanReverse64(&count, *(value + 1)); + if (count >= 0 && *(value + 1) > 0) { + return static_cast(count) + 1 + 64; + } + + // first 64bit block + _BitScanReverse64(&count, *value); + if (count >= 0 && *(value) > 0) { + return static_cast(count) + 1; + } + return 0; +} + +/// @brief Checks if input is negative number +/// @param[in] input Input element to check for sign +inline bool CheckSign(const uint64_t* input) { + HEXL_CHECK(input != nullptr, "Require input != nullptr"); + + uint64_t input_temp[2]{0, 0}; + RightShift128(&input_temp[1], &input_temp[0], input[1], input[0], 127); + return (input_temp[0] == 1); +} + +/// @brief Divides numerator by denominator +/// @param[out] quotient Stores quotient as two 64-bit blocks after division +/// @param[in] numerator +/// @param[in] denominator +inline void DivideUInt128UInt64(uint64_t* quotient, const uint64_t* numerator, + const uint64_t denominator) { + HEXL_CHECK(quotient != nullptr, "Require quotient != nullptr"); + HEXL_CHECK(numerator != nullptr, "Require numerator != nullptr"); + HEXL_CHECK(denominator != 0, "denominator cannot be 0 " << denominator); + + // get bit count of divisor + uint64_t numerator_bits = SignificantBitLength(numerator); + const uint64_t numerator_bits_const = numerator_bits; + const uint64_t uint_128_bit = 128ULL; + + uint64_t MASK[2]{0x0000000000000001, 0x0000000000000000}; + uint64_t remainder[2]{0, 0}; + uint64_t quotient_temp[2]{0, 0}; + uint64_t denominator_temp[2]{denominator, 0}; + + quotient[0] = numerator[0]; + quotient[1] = numerator[1]; + + // align numerator + LeftShift128("ient[1], "ient[0], quotient[1], quotient[0], + (uint_128_bit - numerator_bits_const)); + + while (numerator_bits) { + // if remainder is negative + if (CheckSign(remainder)) { + LeftShift128(&remainder[1], &remainder[0], remainder[1], remainder[0], 1); + RightShift128("ient_temp[1], "ient_temp[0], quotient[1], + quotient[0], (uint_128_bit - 1)); + remainder[0] = remainder[0] | quotient_temp[0]; + LeftShift128("ient[1], "ient[0], quotient[1], quotient[0], 1); + // remainder=remainder+denominator_temp + AddWithCarry128(&remainder[1], &remainder[0], remainder[1], remainder[0], + denominator_temp[1], denominator_temp[0]); + } else { // if remainder is positive + LeftShift128(&remainder[1], &remainder[0], remainder[1], remainder[0], 1); + RightShift128("ient_temp[1], "ient_temp[0], quotient[1], + quotient[0], (uint_128_bit - 1)); + remainder[0] = remainder[0] | quotient_temp[0]; + LeftShift128("ient[1], "ient[0], quotient[1], quotient[0], 1); + // remainder=remainder-denominator_temp + SubWithCarry128(&remainder[1], &remainder[0], remainder[1], remainder[0], + denominator_temp[1], denominator_temp[0]); + } + + // if remainder is positive set MSB of quotient[0]=1 + if (!CheckSign(remainder)) { + MASK[0] = 0x0000000000000001; + MASK[1] = 0x0000000000000000; + LeftShift128(&MASK[1], &MASK[0], MASK[1], MASK[0], + (uint_128_bit - numerator_bits_const)); + quotient[0] = quotient[0] | MASK[0]; + quotient[1] = quotient[1] | MASK[1]; + } + quotient_temp[0] = 0; + quotient_temp[1] = 0; + numerator_bits--; + } + + if (CheckSign(remainder)) { + // remainder=remainder+denominator_temp + AddWithCarry128(&remainder[1], &remainder[0], remainder[1], remainder[0], + denominator_temp[1], denominator_temp[0]); + } + RightShift128("ient[1], "ient[0], quotient[1], quotient[0], + (uint_128_bit - numerator_bits_const)); +} + +/// @brief Returns low of dividing numerator by denominator +/// @param[in] numerator_hi Stores high 64 bit of numerator +/// @param[in] numerator_lo Stores low 64 bit of numerator +/// @param[in] denominator Stores denominator +inline uint64_t DivideUInt128UInt64Lo(const uint64_t numerator_hi, + const uint64_t numerator_lo, + const uint64_t denominator) { + uint64_t numerator[2]{numerator_lo, numerator_hi}; + uint64_t quotient[2]{0, 0}; + + DivideUInt128UInt64(quotient, numerator, denominator); + return quotient[0]; +} + +// Returns most-significant bit of the input +inline uint64_t MSB(uint64_t input) { + unsigned long index{0}; // NOLINT(runtime/int) + _BitScanReverse64(&index, input); + return index; +} + +#define HEXL_LOOP_UNROLL_4 \ + {} +#define HEXL_LOOP_UNROLL_8 \ + {} + +#endif + +} // namespace hexl +} // namespace intel diff --git a/hexl_omp/include/hexl/util/types.hpp b/hexl_omp/include/hexl/util/types.hpp new file mode 100644 index 00000000..2d2d8551 --- /dev/null +++ b/hexl_omp/include/hexl/util/types.hpp @@ -0,0 +1,13 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +#include "hexl/util/defines.hpp" + +#if defined(HEXL_USE_GNU) || defined(HEXL_USE_CLANG) +__extension__ typedef __int128 int128_t; +__extension__ typedef unsigned __int128 uint128_t; +#endif diff --git a/hexl_omp/include/hexl/util/util.hpp b/hexl_omp/include/hexl/util/util.hpp new file mode 100644 index 00000000..bf878a98 --- /dev/null +++ b/hexl_omp/include/hexl/util/util.hpp @@ -0,0 +1,53 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +namespace intel { +namespace hexl { + +#undef TRUE // MSVC defines TRUE +#undef FALSE // MSVC defines FALSE + +/// @enum CMPINT +/// @brief Represents binary operations between two boolean values +enum class CMPINT { + EQ = 0, ///< Equal + LT = 1, ///< Less than + LE = 2, ///< Less than or equal + FALSE = 3, ///< False + NE = 4, ///< Not equal + NLT = 5, ///< Not less than + NLE = 6, ///< Not less than or equal + TRUE = 7 ///< True +}; + +/// @brief Returns the logical negation of a binary operation +/// @param[in] cmp The binary operation to negate +inline CMPINT Not(CMPINT cmp) { + switch (cmp) { + case CMPINT::EQ: + return CMPINT::NE; + case CMPINT::LT: + return CMPINT::NLT; + case CMPINT::LE: + return CMPINT::NLE; + case CMPINT::FALSE: + return CMPINT::TRUE; + case CMPINT::NE: + return CMPINT::EQ; + case CMPINT::NLT: + return CMPINT::LT; + case CMPINT::NLE: + return CMPINT::LE; + case CMPINT::TRUE: + return CMPINT::FALSE; + default: + return CMPINT::FALSE; + } +} + +} // namespace hexl +} // namespace intel diff --git a/hexl_omp/logging/logging.cpp b/hexl_omp/logging/logging.cpp new file mode 100644 index 00000000..c491b43e --- /dev/null +++ b/hexl_omp/logging/logging.cpp @@ -0,0 +1,8 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "hexl/logging/logging.hpp" + +#ifdef HEXL_DEBUG +INITIALIZE_EASYLOGGINGPP +#endif diff --git a/hexl_omp/ntt/fwd-ntt-avx512.cpp b/hexl_omp/ntt/fwd-ntt-avx512.cpp new file mode 100644 index 00000000..8ed0ac7e --- /dev/null +++ b/hexl_omp/ntt/fwd-ntt-avx512.cpp @@ -0,0 +1,409 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "ntt/fwd-ntt-avx512.hpp" + +#include +#include +#include + +#include "hexl/logging/logging.hpp" +#include "hexl/ntt/ntt.hpp" +#include "hexl/number-theory/number-theory.hpp" +#include "ntt/ntt-avx512-util.hpp" +#include "ntt/ntt-internal.hpp" +#include "util/avx512-util.hpp" + +namespace intel { +namespace hexl { + +#ifdef HEXL_HAS_AVX512IFMA +template void ForwardTransformToBitReverseAVX512( + uint64_t* result, const uint64_t* operand, uint64_t degree, uint64_t mod, + const uint64_t* root_of_unity_powers, + const uint64_t* precon_root_of_unity_powers, uint64_t input_mod_factor, + uint64_t output_mod_factor, uint64_t recursion_depth, + uint64_t recursion_half); +#endif + +#ifdef HEXL_HAS_AVX512DQ +template void ForwardTransformToBitReverseAVX512<32>( + uint64_t* result, const uint64_t* operand, uint64_t degree, uint64_t mod, + const uint64_t* root_of_unity_powers, + const uint64_t* precon_root_of_unity_powers, uint64_t input_mod_factor, + uint64_t output_mod_factor, uint64_t recursion_depth, + uint64_t recursion_half); + +template void ForwardTransformToBitReverseAVX512( + uint64_t* result, const uint64_t* operand, uint64_t degree, uint64_t mod, + const uint64_t* root_of_unity_powers, + const uint64_t* precon_root_of_unity_powers, uint64_t input_mod_factor, + uint64_t output_mod_factor, uint64_t recursion_depth, + uint64_t recursion_half); +#endif + +#ifdef HEXL_HAS_AVX512DQ + +/// @brief The Harvey butterfly: assume \p X, \p Y in [0, 4q), and return X', Y' +/// in [0, 4q) such that X', Y' = X + WY, X - WY (mod q). +/// @param[in,out] X Input representing 8 64-bit signed integers in SIMD form +/// @param[in,out] Y Input representing 8 64-bit signed integers in SIMD form +/// @param[in] W Root of unity represented as 8 64-bit signed integers in +/// SIMD form +/// @param[in] W_precon Preconditioned \p W for BitShift-bit Barrett +/// reduction +/// @param[in] neg_modulus Negative modulus, i.e. (-q) represented as 8 64-bit +/// signed integers in SIMD form +/// @param[in] twice_modulus Twice the modulus, i.e. 2*q represented as 8 64-bit +/// signed integers in SIMD form +/// @param InputLessThanMod If true, assumes \p X, \p Y < \p q. Otherwise, +/// assumes \p X, \p Y < 4*\p q +/// @details See Algorithm 4 of https://arxiv.org/pdf/1205.2926.pdf +template +void FwdButterfly(__m512i* X, __m512i* Y, __m512i W, __m512i W_precon, + __m512i neg_modulus, __m512i twice_modulus) { + if (!InputLessThanMod) { + *X = _mm512_hexl_small_mod_epu64(*X, twice_modulus); + } + + __m512i T; + if (BitShift == 32) { + __m512i Q = _mm512_hexl_mullo_epi<64>(W_precon, *Y); + Q = _mm512_srli_epi64(Q, 32); + __m512i W_Y = _mm512_hexl_mullo_epi<64>(W, *Y); + T = _mm512_hexl_mullo_add_lo_epi<64>(W_Y, Q, neg_modulus); + } else if (BitShift == 52) { + __m512i Q = _mm512_hexl_mulhi_epi(W_precon, *Y); + __m512i W_Y = _mm512_hexl_mullo_epi(W, *Y); + T = _mm512_hexl_mullo_add_lo_epi(W_Y, Q, neg_modulus); + } else if (BitShift == 64) { + // Perform approximate computation of Q, as described in page 7 of + // https://arxiv.org/pdf/2003.04510.pdf + __m512i Q = _mm512_hexl_mulhi_approx_epi(W_precon, *Y); + __m512i W_Y = _mm512_hexl_mullo_epi(W, *Y); + // Compute T in range [0, 4q) + T = _mm512_hexl_mullo_add_lo_epi(W_Y, Q, neg_modulus); + // Reduce T to range [0, 2q) + T = _mm512_hexl_small_mod_epu64<2>(T, twice_modulus); + } else { + HEXL_CHECK(false, "Invalid BitShift " << BitShift); + } + + __m512i twice_mod_minus_T = _mm512_sub_epi64(twice_modulus, T); + *Y = _mm512_add_epi64(*X, twice_mod_minus_T); + *X = _mm512_add_epi64(*X, T); +} + +template +void FwdT1(uint64_t* operand, __m512i v_neg_modulus, __m512i v_twice_mod, + uint64_t m, const uint64_t* W, const uint64_t* W_precon) { + const __m512i* v_W_pt = reinterpret_cast(W); + const __m512i* v_W_precon_pt = reinterpret_cast(W_precon); + size_t j1 = 0; + + // 8 | m guaranteed by n >= 16 + HEXL_LOOP_UNROLL_8 + for (size_t i = m / 8; i > 0; --i) { + uint64_t* X = operand + j1; + __m512i* v_X_pt = reinterpret_cast<__m512i*>(X); + + __m512i v_X; + __m512i v_Y; + LoadFwdInterleavedT1(X, &v_X, &v_Y); + __m512i v_W = _mm512_loadu_si512(v_W_pt++); + __m512i v_W_precon = _mm512_loadu_si512(v_W_precon_pt++); + + FwdButterfly(&v_X, &v_Y, v_W, v_W_precon, v_neg_modulus, + v_twice_mod); + WriteFwdInterleavedT1(v_X, v_Y, v_X_pt); + + j1 += 16; + } +} + +template +void FwdT2(uint64_t* operand, __m512i v_neg_modulus, __m512i v_twice_mod, + uint64_t m, const uint64_t* W, const uint64_t* W_precon) { + const __m512i* v_W_pt = reinterpret_cast(W); + const __m512i* v_W_precon_pt = reinterpret_cast(W_precon); + + size_t j1 = 0; + // 4 | m guaranteed by n >= 16 + HEXL_LOOP_UNROLL_4 + for (size_t i = m / 4; i > 0; --i) { + uint64_t* X = operand + j1; + __m512i* v_X_pt = reinterpret_cast<__m512i*>(X); + + __m512i v_X; + __m512i v_Y; + LoadFwdInterleavedT2(X, &v_X, &v_Y); + + __m512i v_W = _mm512_loadu_si512(v_W_pt++); + __m512i v_W_precon = _mm512_loadu_si512(v_W_precon_pt++); + + HEXL_CHECK(ExtractValues(v_W)[0] == ExtractValues(v_W)[1], + "bad v_W " << ExtractValues(v_W)); + HEXL_CHECK(ExtractValues(v_W_precon)[0] == ExtractValues(v_W_precon)[1], + "bad v_W_precon " << ExtractValues(v_W_precon)); + FwdButterfly(&v_X, &v_Y, v_W, v_W_precon, v_neg_modulus, + v_twice_mod); + + _mm512_storeu_si512(v_X_pt++, v_X); + _mm512_storeu_si512(v_X_pt, v_Y); + + j1 += 16; + } +} + +template +void FwdT4(uint64_t* operand, __m512i v_neg_modulus, __m512i v_twice_mod, + uint64_t m, const uint64_t* W, const uint64_t* W_precon) { + size_t j1 = 0; + const __m512i* v_W_pt = reinterpret_cast(W); + const __m512i* v_W_precon_pt = reinterpret_cast(W_precon); + + // 2 | m guaranteed by n >= 16 + HEXL_LOOP_UNROLL_4 + for (size_t i = m / 2; i > 0; --i) { + uint64_t* X = operand + j1; + __m512i* v_X_pt = reinterpret_cast<__m512i*>(X); + + __m512i v_X; + __m512i v_Y; + LoadFwdInterleavedT4(X, &v_X, &v_Y); + + __m512i v_W = _mm512_loadu_si512(v_W_pt++); + __m512i v_W_precon = _mm512_loadu_si512(v_W_precon_pt++); + FwdButterfly(&v_X, &v_Y, v_W, v_W_precon, v_neg_modulus, + v_twice_mod); + + _mm512_storeu_si512(v_X_pt++, v_X); + _mm512_storeu_si512(v_X_pt, v_Y); + + j1 += 16; + } +} + +// Out-of-place implementation +template +void FwdT8(uint64_t* result, const uint64_t* operand, __m512i v_neg_modulus, + __m512i v_twice_mod, uint64_t t, uint64_t m, const uint64_t* W, + const uint64_t* W_precon) { + size_t j1 = 0; + + HEXL_LOOP_UNROLL_4 + for (size_t i = 0; i < m; i++) { + // Referencing operand + const uint64_t* X_op = operand + j1; + const uint64_t* Y_op = X_op + t; + + const __m512i* v_X_op_pt = reinterpret_cast(X_op); + const __m512i* v_Y_op_pt = reinterpret_cast(Y_op); + + // Referencing result + uint64_t* X_r = result + j1; + uint64_t* Y_r = X_r + t; + + __m512i* v_X_r_pt = reinterpret_cast<__m512i*>(X_r); + __m512i* v_Y_r_pt = reinterpret_cast<__m512i*>(Y_r); + + // Weights and weights' preconditions + __m512i v_W = _mm512_set1_epi64(static_cast(*W++)); + __m512i v_W_precon = _mm512_set1_epi64(static_cast(*W_precon++)); + + // assume 8 | t + for (size_t j = t / 8; j > 0; --j) { + __m512i v_X = _mm512_loadu_si512(v_X_op_pt); + __m512i v_Y = _mm512_loadu_si512(v_Y_op_pt); + + FwdButterfly(&v_X, &v_Y, v_W, v_W_precon, + v_neg_modulus, v_twice_mod); + + _mm512_storeu_si512(v_X_r_pt++, v_X); + _mm512_storeu_si512(v_Y_r_pt++, v_Y); + + // Increase operand pointers as well + v_X_op_pt++; + v_Y_op_pt++; + } + j1 += (t << 1); + } +} + +template +void ForwardTransformToBitReverseAVX512( + uint64_t* result, const uint64_t* operand, uint64_t n, uint64_t modulus, + const uint64_t* root_of_unity_powers, + const uint64_t* precon_root_of_unity_powers, uint64_t input_mod_factor, + uint64_t output_mod_factor, uint64_t recursion_depth, + uint64_t recursion_half) { + HEXL_CHECK(NTT::CheckArguments(n, modulus), ""); + HEXL_CHECK(modulus < NTT::s_max_fwd_modulus(BitShift), + "modulus " << modulus << " too large for BitShift " << BitShift + << " => maximum value " + << NTT::s_max_fwd_modulus(BitShift)); + HEXL_CHECK_BOUNDS(precon_root_of_unity_powers, n, MaximumValue(BitShift), + "precon_root_of_unity_powers too large"); + HEXL_CHECK_BOUNDS(operand, n, MaximumValue(BitShift), "operand too large"); + // Skip input bound checking for recursive steps + HEXL_CHECK_BOUNDS(operand, (recursion_depth == 0) ? n : 0, + input_mod_factor * modulus, + "operand larger than input_mod_factor * modulus (" + << input_mod_factor << " * " << modulus << ")"); + HEXL_CHECK(n >= 16, + "Don't support small transforms. Need n >= 16, got n = " << n); + HEXL_CHECK( + input_mod_factor == 1 || input_mod_factor == 2 || input_mod_factor == 4, + "input_mod_factor must be 1, 2, or 4; got " << input_mod_factor); + HEXL_CHECK(output_mod_factor == 1 || output_mod_factor == 4, + "output_mod_factor must be 1 or 4; got " << output_mod_factor); + + uint64_t twice_mod = modulus << 1; + + __m512i v_modulus = _mm512_set1_epi64(static_cast(modulus)); + __m512i v_neg_modulus = _mm512_set1_epi64(-static_cast(modulus)); + __m512i v_twice_mod = _mm512_set1_epi64(static_cast(twice_mod)); + + HEXL_VLOG(5, "root_of_unity_powers " << std::vector( + root_of_unity_powers, root_of_unity_powers + n)) + HEXL_VLOG(5, + "precon_root_of_unity_powers " << std::vector( + precon_root_of_unity_powers, precon_root_of_unity_powers + n)); + HEXL_VLOG(5, "operand " << std::vector(operand, operand + n)); + + static const size_t base_ntt_size = 1024; + + if (n <= base_ntt_size) { // Perform breadth-first NTT + size_t t = (n >> 1); + size_t m = 1; + size_t W_idx = (m << recursion_depth) + (recursion_half * m); + + // Copy for out-of-place in case m is <= base_ntt_size from start + if (result != operand) { + std::memcpy(result, operand, n * sizeof(uint64_t)); + } + + // First iteration assumes input in [0,p) + if (m < (n >> 3)) { + const uint64_t* W = &root_of_unity_powers[W_idx]; + const uint64_t* W_precon = &precon_root_of_unity_powers[W_idx]; + + if ((input_mod_factor <= 2) && (recursion_depth == 0)) { + FwdT8(result, result, v_neg_modulus, v_twice_mod, t, m, + W, W_precon); + } else { + FwdT8(result, result, v_neg_modulus, v_twice_mod, t, m, + W, W_precon); + } + + t >>= 1; + m <<= 1; + W_idx <<= 1; + } + for (; m < (n >> 3); m <<= 1) { + const uint64_t* W = &root_of_unity_powers[W_idx]; + const uint64_t* W_precon = &precon_root_of_unity_powers[W_idx]; + FwdT8(result, result, v_neg_modulus, v_twice_mod, t, m, + W, W_precon); + t >>= 1; + W_idx <<= 1; + } + + // Do T=4, T=2, T=1 separately + { + // Correction step needed due to extra copies of roots of unity in the + // AVX512 vectors loaded for FwdT2 and FwdT4 + auto compute_new_W_idx = [&](size_t idx) { + // Originally, from root of unity vector index to loop: + // [0, N/8) => FwdT8 + // [N/8, N/4) => FwdT4 + // [N/4, N/2) => FwdT2 + // [N/2, N) => FwdT1 + // The new mapping from AVX512 root of unity vector index to loop: + // [0, N/8) => FwdT8 + // [N/8, 5N/8) => FwdT4 + // [5N/8, 9N/8) => FwdT2 + // [9N/8, 13N/8) => FwdT1 + size_t N = n << recursion_depth; + + // FwdT8 range + if (idx <= N / 8) { + return idx; + } + // FwdT4 range + if (idx <= N / 4) { + return (idx - N / 8) * 4 + (N / 8); + } + // FwdT2 range + if (idx <= N / 2) { + return (idx - N / 4) * 2 + (5 * N / 8); + } + // FwdT1 range + return idx + (5 * N / 8); + }; + + size_t new_W_idx = compute_new_W_idx(W_idx); + const uint64_t* W = &root_of_unity_powers[new_W_idx]; + const uint64_t* W_precon = &precon_root_of_unity_powers[new_W_idx]; + FwdT4(result, v_neg_modulus, v_twice_mod, m, W, W_precon); + + m <<= 1; + W_idx <<= 1; + new_W_idx = compute_new_W_idx(W_idx); + W = &root_of_unity_powers[new_W_idx]; + W_precon = &precon_root_of_unity_powers[new_W_idx]; + FwdT2(result, v_neg_modulus, v_twice_mod, m, W, W_precon); + + m <<= 1; + W_idx <<= 1; + new_W_idx = compute_new_W_idx(W_idx); + W = &root_of_unity_powers[new_W_idx]; + W_precon = &precon_root_of_unity_powers[new_W_idx]; + FwdT1(result, v_neg_modulus, v_twice_mod, m, W, W_precon); + } + + if (output_mod_factor == 1) { + // n power of two at least 8 => n divisible by 8 + HEXL_CHECK(n % 8 == 0, "n " << n << " not a power of 2"); + __m512i* v_X_pt = reinterpret_cast<__m512i*>(result); + for (size_t i = 0; i < n; i += 8) { + __m512i v_X = _mm512_loadu_si512(v_X_pt); + + // Reduce from [0, 4q) to [0, q) + v_X = _mm512_hexl_small_mod_epu64(v_X, v_twice_mod); + v_X = _mm512_hexl_small_mod_epu64(v_X, v_modulus); + + HEXL_CHECK_BOUNDS(ExtractValues(v_X).data(), 8, modulus, + "v_X exceeds bound " << modulus); + + _mm512_storeu_si512(v_X_pt, v_X); + + ++v_X_pt; + } + } + } else { + // Perform depth-first NTT via recursive call + size_t t = (n >> 1); + size_t W_idx = (1ULL << recursion_depth) + recursion_half; + const uint64_t* W = &root_of_unity_powers[W_idx]; + const uint64_t* W_precon = &precon_root_of_unity_powers[W_idx]; + + FwdT8(result, operand, v_neg_modulus, v_twice_mod, t, 1, W, + W_precon); + + ForwardTransformToBitReverseAVX512( + result, result, n / 2, modulus, root_of_unity_powers, + precon_root_of_unity_powers, input_mod_factor, output_mod_factor, + recursion_depth + 1, recursion_half * 2); + + ForwardTransformToBitReverseAVX512( + &result[n / 2], &result[n / 2], n / 2, modulus, root_of_unity_powers, + precon_root_of_unity_powers, input_mod_factor, output_mod_factor, + recursion_depth + 1, recursion_half * 2 + 1); + } +} + +#endif // HEXL_HAS_AVX512DQ + +} // namespace hexl +} // namespace intel diff --git a/hexl_omp/ntt/fwd-ntt-avx512.hpp b/hexl_omp/ntt/fwd-ntt-avx512.hpp new file mode 100644 index 00000000..b3e4cdff --- /dev/null +++ b/hexl_omp/ntt/fwd-ntt-avx512.hpp @@ -0,0 +1,46 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "hexl/ntt/ntt.hpp" + +namespace intel { +namespace hexl { + +#ifdef HEXL_HAS_AVX512DQ + +/// @brief AVX512 implementation of the forward NTT +/// @param[in, out] operand Input data. Overwritten with NTT output +/// @param[in] n Size of the transform, i.e. the polynomial degree. Must be a +/// power of two. +/// @param[in] modulus Prime modulus q. Must satisfy q == 1 mod 2n +/// @param[in] root_of_unity_powers Powers of 2n'th root of unity in F_q. In +/// bit-reversed order. +/// @param[in] precon_root_of_unity_powers Pre-conditioned Powers of 2n'th root +/// of unity in F_q. In bit-reversed order. +/// @param[in] input_mod_factor Upper bound for inputs; inputs must be in [0, +/// input_mod_factor * q) +/// @param[in] output_mod_factor Upper bound for result; result must be in [0, +/// output_mod_factor * q) +/// @param[in] recursion_depth Depth of recursive call +/// @param[in] recursion_half Helper for indexing roots of unity +/// @details The implementation is recursive. The base case is a breadth-first +/// NTT, where all the butterflies in a given stage are processed before any +/// butterflies in the next stage. The base case is small enough to fit in the +/// smallest cache. Larger NTTs are processed recursively in a depth-first +/// manner, such that an entire subtransform is completed before moving to the +/// next subtransform. This reduces the number of cache misses, improving +/// performance on larger transform sizes. +template +void ForwardTransformToBitReverseAVX512( + uint64_t* result, const uint64_t* operand, uint64_t n, uint64_t modulus, + const uint64_t* root_of_unity_powers, + const uint64_t* precon_root_of_unity_powers, uint64_t input_mod_factor, + uint64_t output_mod_factor, uint64_t recursion_depth = 0, + uint64_t recursion_half = 0); + +#endif // HEXL_HAS_AVX512DQ + +} // namespace hexl +} // namespace intel diff --git a/hexl_omp/ntt/inv-ntt-avx512.cpp b/hexl_omp/ntt/inv-ntt-avx512.cpp new file mode 100644 index 00000000..8d340fe1 --- /dev/null +++ b/hexl_omp/ntt/inv-ntt-avx512.cpp @@ -0,0 +1,438 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "ntt/inv-ntt-avx512.hpp" + +#include + +#include +#include +#include + +#include "hexl/logging/logging.hpp" +#include "hexl/ntt/ntt.hpp" +#include "hexl/number-theory/number-theory.hpp" +#include "ntt/ntt-avx512-util.hpp" +#include "ntt/ntt-internal.hpp" +#include "util/avx512-util.hpp" + +namespace intel { +namespace hexl { + +#ifdef HEXL_HAS_AVX512IFMA +template void InverseTransformFromBitReverseAVX512( + uint64_t* result, const uint64_t* operand, uint64_t degree, + uint64_t modulus, const uint64_t* inv_root_of_unity_powers, + const uint64_t* precon_inv_root_of_unity_powers, uint64_t input_mod_factor, + uint64_t output_mod_factor, uint64_t recursion_depth, + uint64_t recursion_half); +#endif + +#ifdef HEXL_HAS_AVX512DQ +template void InverseTransformFromBitReverseAVX512<32>( + uint64_t* result, const uint64_t* operand, uint64_t degree, + uint64_t modulus, const uint64_t* inv_root_of_unity_powers, + const uint64_t* precon_inv_root_of_unity_powers, uint64_t input_mod_factor, + uint64_t output_mod_factor, uint64_t recursion_depth, + uint64_t recursion_half); + +template void InverseTransformFromBitReverseAVX512( + uint64_t* result, const uint64_t* operand, uint64_t degree, + uint64_t modulus, const uint64_t* inv_root_of_unity_powers, + const uint64_t* precon_inv_root_of_unity_powers, uint64_t input_mod_factor, + uint64_t output_mod_factor, uint64_t recursion_depth, + uint64_t recursion_half); +#endif + +#ifdef HEXL_HAS_AVX512DQ + +/// @brief The Harvey butterfly: assume X, Y in [0, 2q), and return X', Y' in +/// [0, 2q). such that X', Y' = X + Y (mod q), W(X - Y) (mod q). +/// @param[in,out] X Input representing 8 64-bit signed integers in SIMD form +/// @param[in,out] Y Input representing 8 64-bit signed integers in SIMD form +/// @param[in] W Root of unity representing 8 64-bit signed integers in SIMD +/// form +/// @param[in] W_precon Preconditioned \p W for BitShift-bit Barrett +/// reduction +/// @param[in] neg_modulus Negative modulus, i.e. (-q) represented as 8 64-bit +/// signed integers in SIMD form +/// @param[in] twice_modulus Twice the modulus, i.e. 2*q represented as 8 64-bit +/// signed integers in SIMD form +/// @param InputLessThanMod If true, assumes \p X, \p Y < \p q. Otherwise, +/// assumes \p X, \p Y < 2*\p q +/// @details See Algorithm 3 of https://arxiv.org/pdf/1205.2926.pdf +template +inline void InvButterfly(__m512i* X, __m512i* Y, __m512i W, __m512i W_precon, + __m512i neg_modulus, __m512i twice_modulus) { + // Compute T first to allow in-place update of X + __m512i Y_minus_2q = _mm512_sub_epi64(*Y, twice_modulus); + __m512i T = _mm512_sub_epi64(*X, Y_minus_2q); + + if (InputLessThanMod) { + // No need for modulus reduction, since inputs are in [0, q) + *X = _mm512_add_epi64(*X, *Y); + } else { + // Algorithm 3 computes (X >= 2q) ? (X - 2q) : X + // We instead compute (X - 2q >= 0) ? (X - 2q) : X + // This allows us to use the faster _mm512_movepi64_mask rather than + // _mm512_cmp_epu64_mask to create the mask. + *X = _mm512_add_epi64(*X, Y_minus_2q); + __mmask8 sign_bits = _mm512_movepi64_mask(*X); + *X = _mm512_mask_add_epi64(*X, sign_bits, *X, twice_modulus); + } + + if (BitShift == 32) { + __m512i Q = _mm512_hexl_mullo_epi<64>(W_precon, T); + Q = _mm512_srli_epi64(Q, 32); + __m512i Q_p = _mm512_hexl_mullo_epi<64>(Q, neg_modulus); + *Y = _mm512_hexl_mullo_add_lo_epi<64>(Q_p, W, T); + } else if (BitShift == 52) { + __m512i Q = _mm512_hexl_mulhi_epi(W_precon, T); + __m512i Q_p = _mm512_hexl_mullo_epi(Q, neg_modulus); + *Y = _mm512_hexl_mullo_add_lo_epi(Q_p, W, T); + } else if (BitShift == 64) { + // Perform approximate computation of Q, as described in page 7 of + // https://arxiv.org/pdf/2003.04510.pdf + __m512i Q = _mm512_hexl_mulhi_approx_epi(W_precon, T); + __m512i Q_p = _mm512_hexl_mullo_epi(Q, neg_modulus); + // Compute Y in range [0, 4q) + *Y = _mm512_hexl_mullo_add_lo_epi(Q_p, W, T); + // Reduce Y to range [0, 2q) + *Y = _mm512_hexl_small_mod_epu64<2>(*Y, twice_modulus); + } else { + HEXL_CHECK(false, "Invalid BitShift " << BitShift); + } +} + +template +void InvT1(uint64_t* operand, __m512i v_neg_modulus, __m512i v_twice_mod, + uint64_t m, const uint64_t* W, const uint64_t* W_precon) { + const __m512i* v_W_pt = reinterpret_cast(W); + const __m512i* v_W_precon_pt = reinterpret_cast(W_precon); + size_t j1 = 0; + + // 8 | m guaranteed by n >= 16 + HEXL_LOOP_UNROLL_8 + for (size_t i = m / 8; i > 0; --i) { + uint64_t* X = operand + j1; + __m512i* v_X_pt = reinterpret_cast<__m512i*>(X); + + __m512i v_X; + __m512i v_Y; + LoadInvInterleavedT1(X, &v_X, &v_Y); + + __m512i v_W = _mm512_loadu_si512(v_W_pt++); + __m512i v_W_precon = _mm512_loadu_si512(v_W_precon_pt++); + + InvButterfly(&v_X, &v_Y, v_W, v_W_precon, + v_neg_modulus, v_twice_mod); + + _mm512_storeu_si512(v_X_pt++, v_X); + _mm512_storeu_si512(v_X_pt, v_Y); + + j1 += 16; + } +} + +template +void InvT2(uint64_t* X, __m512i v_neg_modulus, __m512i v_twice_mod, uint64_t m, + const uint64_t* W, const uint64_t* W_precon) { + // 4 | m guaranteed by n >= 16 + HEXL_LOOP_UNROLL_4 + for (size_t i = m / 4; i > 0; --i) { + __m512i* v_X_pt = reinterpret_cast<__m512i*>(X); + + __m512i v_X; + __m512i v_Y; + LoadInvInterleavedT2(X, &v_X, &v_Y); + + __m512i v_W = LoadWOpT2(static_cast(W)); + __m512i v_W_precon = LoadWOpT2(static_cast(W_precon)); + + InvButterfly(&v_X, &v_Y, v_W, v_W_precon, v_neg_modulus, + v_twice_mod); + + _mm512_storeu_si512(v_X_pt++, v_X); + _mm512_storeu_si512(v_X_pt, v_Y); + X += 16; + + W += 4; + W_precon += 4; + } +} + +template +void InvT4(uint64_t* operand, __m512i v_neg_modulus, __m512i v_twice_mod, + uint64_t m, const uint64_t* W, const uint64_t* W_precon) { + uint64_t* X = operand; + + // 2 | m guaranteed by n >= 16 + HEXL_LOOP_UNROLL_4 + for (size_t i = m / 2; i > 0; --i) { + __m512i* v_X_pt = reinterpret_cast<__m512i*>(X); + + __m512i v_X; + __m512i v_Y; + LoadInvInterleavedT4(X, &v_X, &v_Y); + + __m512i v_W = LoadWOpT4(static_cast(W)); + __m512i v_W_precon = LoadWOpT4(static_cast(W_precon)); + + InvButterfly(&v_X, &v_Y, v_W, v_W_precon, v_neg_modulus, + v_twice_mod); + + WriteInvInterleavedT4(v_X, v_Y, v_X_pt); + X += 16; + + W += 2; + W_precon += 2; + } +} + +template +void InvT8(uint64_t* operand, __m512i v_neg_modulus, __m512i v_twice_mod, + uint64_t t, uint64_t m, const uint64_t* W, + const uint64_t* W_precon) { + size_t j1 = 0; + + HEXL_LOOP_UNROLL_4 + for (size_t i = 0; i < m; i++) { + uint64_t* X = operand + j1; + uint64_t* Y = X + t; + + __m512i v_W = _mm512_set1_epi64(static_cast(*W++)); + __m512i v_W_precon = _mm512_set1_epi64(static_cast(*W_precon++)); + + __m512i* v_X_pt = reinterpret_cast<__m512i*>(X); + __m512i* v_Y_pt = reinterpret_cast<__m512i*>(Y); + + // assume 8 | t + for (size_t j = t / 8; j > 0; --j) { + __m512i v_X = _mm512_loadu_si512(v_X_pt); + __m512i v_Y = _mm512_loadu_si512(v_Y_pt); + + InvButterfly(&v_X, &v_Y, v_W, v_W_precon, v_neg_modulus, + v_twice_mod); + + _mm512_storeu_si512(v_X_pt++, v_X); + _mm512_storeu_si512(v_Y_pt++, v_Y); + } + j1 += (t << 1); + } +} + +template +void InverseTransformFromBitReverseAVX512( + uint64_t* result, const uint64_t* operand, uint64_t n, uint64_t modulus, + const uint64_t* inv_root_of_unity_powers, + const uint64_t* precon_inv_root_of_unity_powers, uint64_t input_mod_factor, + uint64_t output_mod_factor, uint64_t recursion_depth, + uint64_t recursion_half) { + HEXL_CHECK(NTT::CheckArguments(n, modulus), ""); + HEXL_CHECK(n >= 16, + "InverseTransformFromBitReverseAVX512 doesn't support small " + "transforms. Need n >= 16, got n = " + << n); + HEXL_CHECK(modulus < NTT::s_max_inv_modulus(BitShift), + "modulus " << modulus << " too large for BitShift " << BitShift + << " => maximum value " + << NTT::s_max_inv_modulus(BitShift)); + HEXL_CHECK_BOUNDS(precon_inv_root_of_unity_powers, n, MaximumValue(BitShift), + "precon_inv_root_of_unity_powers too large"); + HEXL_CHECK_BOUNDS(operand, n, MaximumValue(BitShift), "operand too large"); + // Skip input bound checking for recursive steps + HEXL_CHECK_BOUNDS(operand, (recursion_depth == 0) ? n : 0, + input_mod_factor * modulus, + "operand larger than input_mod_factor * modulus (" + << input_mod_factor << " * " << modulus << ")"); + HEXL_CHECK(input_mod_factor == 1 || input_mod_factor == 2, + "input_mod_factor must be 1 or 2; got " << input_mod_factor); + HEXL_CHECK(output_mod_factor == 1 || output_mod_factor == 2, + "output_mod_factor must be 1 or 2; got " << output_mod_factor); + + uint64_t twice_mod = modulus << 1; + __m512i v_modulus = _mm512_set1_epi64(static_cast(modulus)); + __m512i v_neg_modulus = _mm512_set1_epi64(-static_cast(modulus)); + __m512i v_twice_mod = _mm512_set1_epi64(static_cast(twice_mod)); + + size_t t = 1; + size_t m = (n >> 1); + size_t W_idx = 1 + m * recursion_half; + + static const size_t base_ntt_size = 1024; + + if (n <= base_ntt_size) { // Perform breadth-first InvNTT + if (operand != result) { + std::memcpy(result, operand, n * sizeof(uint64_t)); + } + + // Extract t=1, t=2, t=4 loops separately + { + // t = 1 + const uint64_t* W = &inv_root_of_unity_powers[W_idx]; + const uint64_t* W_precon = &precon_inv_root_of_unity_powers[W_idx]; + if ((input_mod_factor == 1) && (recursion_depth == 0)) { + InvT1(result, v_neg_modulus, v_twice_mod, m, W, + W_precon); + } else { + InvT1(result, v_neg_modulus, v_twice_mod, m, W, + W_precon); + } + + t <<= 1; + m >>= 1; + uint64_t W_idx_delta = + m * ((1ULL << (recursion_depth + 1)) - recursion_half); + W_idx += W_idx_delta; + + // t = 2 + W = &inv_root_of_unity_powers[W_idx]; + W_precon = &precon_inv_root_of_unity_powers[W_idx]; + InvT2(result, v_neg_modulus, v_twice_mod, m, W, W_precon); + + t <<= 1; + m >>= 1; + W_idx_delta >>= 1; + W_idx += W_idx_delta; + + // t = 4 + W = &inv_root_of_unity_powers[W_idx]; + W_precon = &precon_inv_root_of_unity_powers[W_idx]; + InvT4(result, v_neg_modulus, v_twice_mod, m, W, W_precon); + t <<= 1; + m >>= 1; + W_idx_delta >>= 1; + W_idx += W_idx_delta; + + // t >= 8 + for (; m > 1;) { + W = &inv_root_of_unity_powers[W_idx]; + W_precon = &precon_inv_root_of_unity_powers[W_idx]; + InvT8(result, v_neg_modulus, v_twice_mod, t, m, W, W_precon); + t <<= 1; + m >>= 1; + W_idx_delta >>= 1; + W_idx += W_idx_delta; + } + } + } else { + InverseTransformFromBitReverseAVX512( + result, operand, n / 2, modulus, inv_root_of_unity_powers, + precon_inv_root_of_unity_powers, input_mod_factor, output_mod_factor, + recursion_depth + 1, 2 * recursion_half); + InverseTransformFromBitReverseAVX512( + &result[n / 2], &operand[n / 2], n / 2, modulus, + inv_root_of_unity_powers, precon_inv_root_of_unity_powers, + input_mod_factor, output_mod_factor, recursion_depth + 1, + 2 * recursion_half + 1); + + uint64_t W_idx_delta = + m * ((1ULL << (recursion_depth + 1)) - recursion_half); + for (; m > 2; m >>= 1) { + t <<= 1; + W_idx_delta >>= 1; + W_idx += W_idx_delta; + } + if (m == 2) { + const uint64_t* W = &inv_root_of_unity_powers[W_idx]; + const uint64_t* W_precon = &precon_inv_root_of_unity_powers[W_idx]; + InvT8(result, v_neg_modulus, v_twice_mod, t, m, W, W_precon); + t <<= 1; + m >>= 1; + W_idx_delta >>= 1; + W_idx += W_idx_delta; + } + } + + // Final loop through data + if (recursion_depth == 0) { + HEXL_VLOG(4, "AVX512 intermediate result " + << std::vector(result, result + n)); + + const uint64_t W = inv_root_of_unity_powers[W_idx]; + MultiplyFactor mf_inv_n(InverseMod(n, modulus), BitShift, modulus); + const uint64_t inv_n = mf_inv_n.Operand(); + const uint64_t inv_n_prime = mf_inv_n.BarrettFactor(); + + MultiplyFactor mf_inv_n_w(MultiplyMod(inv_n, W, modulus), BitShift, + modulus); + const uint64_t inv_n_w = mf_inv_n_w.Operand(); + const uint64_t inv_n_w_prime = mf_inv_n_w.BarrettFactor(); + + HEXL_VLOG(4, "inv_n_w " << inv_n_w); + + uint64_t* X = result; + uint64_t* Y = X + (n >> 1); + + __m512i v_inv_n = _mm512_set1_epi64(static_cast(inv_n)); + __m512i v_inv_n_prime = + _mm512_set1_epi64(static_cast(inv_n_prime)); + __m512i v_inv_n_w = _mm512_set1_epi64(static_cast(inv_n_w)); + __m512i v_inv_n_w_prime = + _mm512_set1_epi64(static_cast(inv_n_w_prime)); + + __m512i* v_X_pt = reinterpret_cast<__m512i*>(X); + __m512i* v_Y_pt = reinterpret_cast<__m512i*>(Y); + + // Merge final InvNTT loop with modulus reduction baked-in + HEXL_LOOP_UNROLL_4 + for (size_t j = n / 16; j > 0; --j) { + __m512i v_X = _mm512_loadu_si512(v_X_pt); + __m512i v_Y = _mm512_loadu_si512(v_Y_pt); + + // Slightly different from regular InvButterfly because different W is + // used for X and Y + __m512i Y_minus_2q = _mm512_sub_epi64(v_Y, v_twice_mod); + __m512i X_plus_Y_mod2q = + _mm512_hexl_small_add_mod_epi64(v_X, v_Y, v_twice_mod); + // T = *X + twice_mod - *Y + __m512i T = _mm512_sub_epi64(v_X, Y_minus_2q); + + if (BitShift == 32) { + __m512i Q1 = _mm512_hexl_mullo_epi<64>(v_inv_n_prime, X_plus_Y_mod2q); + Q1 = _mm512_srli_epi64(Q1, 32); + // X = inv_N * X_plus_Y_mod2q - Q1 * modulus; + __m512i inv_N_tx = _mm512_hexl_mullo_epi<64>(v_inv_n, X_plus_Y_mod2q); + v_X = _mm512_hexl_mullo_add_lo_epi<64>(inv_N_tx, Q1, v_neg_modulus); + + __m512i Q2 = _mm512_hexl_mullo_epi<64>(v_inv_n_w_prime, T); + Q2 = _mm512_srli_epi64(Q2, 32); + + // Y = inv_N_W * T - Q2 * modulus; + __m512i inv_N_W_T = _mm512_hexl_mullo_epi<64>(v_inv_n_w, T); + v_Y = _mm512_hexl_mullo_add_lo_epi<64>(inv_N_W_T, Q2, v_neg_modulus); + } else { + __m512i Q1 = + _mm512_hexl_mulhi_epi(v_inv_n_prime, X_plus_Y_mod2q); + // X = inv_N * X_plus_Y_mod2q - Q1 * modulus; + __m512i inv_N_tx = + _mm512_hexl_mullo_epi(v_inv_n, X_plus_Y_mod2q); + v_X = + _mm512_hexl_mullo_add_lo_epi(inv_N_tx, Q1, v_neg_modulus); + + __m512i Q2 = _mm512_hexl_mulhi_epi(v_inv_n_w_prime, T); + // Y = inv_N_W * T - Q2 * modulus; + __m512i inv_N_W_T = _mm512_hexl_mullo_epi(v_inv_n_w, T); + v_Y = _mm512_hexl_mullo_add_lo_epi(inv_N_W_T, Q2, + v_neg_modulus); + } + + if (output_mod_factor == 1) { + // Modulus reduction from [0, 2q), to [0, q) + v_X = _mm512_hexl_small_mod_epu64(v_X, v_modulus); + v_Y = _mm512_hexl_small_mod_epu64(v_Y, v_modulus); + } + + _mm512_storeu_si512(v_X_pt++, v_X); + _mm512_storeu_si512(v_Y_pt++, v_Y); + } + + HEXL_VLOG(5, "AVX512 returning result " + << std::vector(result, result + n)); + } +} + +#endif // HEXL_HAS_AVX512DQ + +} // namespace hexl +} // namespace intel diff --git a/hexl_omp/ntt/inv-ntt-avx512.hpp b/hexl_omp/ntt/inv-ntt-avx512.hpp new file mode 100644 index 00000000..143b9476 --- /dev/null +++ b/hexl_omp/ntt/inv-ntt-avx512.hpp @@ -0,0 +1,46 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "ntt/ntt-internal.hpp" + +namespace intel { +namespace hexl { + +#ifdef HEXL_HAS_AVX512DQ + +/// @brief AVX512 implementation of the inverse NTT +/// @param[in, out] operand Input data. Overwritten with NTT output +/// @param[in] n Size of the transform, i.e. the polynomial degree. Must be a +/// power of two. +/// @param[in] modulus Prime modulus q. Must satisfy q == 1 mod 2n +/// @param[in] inv_root_of_unity_powers Powers of inverse 2n'th root of unity in +/// F_q. In bit-reversed order. +/// @param[in] precon_root_of_unity_powers Pre-conditioned powers of inverse +/// 2n'th root of unity in F_q. In bit-reversed order. +/// @param[in] input_mod_factor Upper bound for inputs; inputs must be in [0, +/// input_mod_factor * q) +/// @param[in] output_mod_factor Upper bound for result; result must be in [0, +/// output_mod_factor * q) +/// @param[in] recursion_depth Depth of recursive call +/// @param[in] recursion_half Helper for indexing roots of unity +/// @details The implementation is recursive. The base case is a breadth-first +/// NTT, where all the butterflies in a given stage are processed before any +/// butterflies in the next stage. The base case is small enough to fit in the +/// smallest cache. Larger NTTs are processed recursively in a depth-first +/// manner, such that an entire subtransform is completed before moving to the +/// next subtransform. This reduces the number of cache misses, improving +/// performance on larger transform sizes. +template +void InverseTransformFromBitReverseAVX512( + uint64_t* result, const uint64_t* operand, uint64_t n, uint64_t modulus, + const uint64_t* inv_root_of_unity_powers, + const uint64_t* precon_inv_root_of_unity_powers, uint64_t input_mod_factor, + uint64_t output_mod_factor, uint64_t recursion_depth = 0, + uint64_t recursion_half = 0); + +#endif // HEXL_HAS_AVX512DQ + +} // namespace hexl +} // namespace intel diff --git a/hexl_omp/ntt/ntt-avx512-util.hpp b/hexl_omp/ntt/ntt-avx512-util.hpp new file mode 100644 index 00000000..c2342c01 --- /dev/null +++ b/hexl_omp/ntt/ntt-avx512-util.hpp @@ -0,0 +1,218 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +#include +#include + +#include "hexl/ntt/ntt.hpp" +#include "hexl/number-theory/number-theory.hpp" +#include "util/avx512-util.hpp" + +namespace intel { +namespace hexl { + +#ifdef HEXL_HAS_AVX512DQ + +// Given input: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 +// Returns +// *out1 = _mm512_set_epi64(14, 6, 12, 4, 10, 2, 8, 0); +// *out2 = _mm512_set_epi64(15, 7, 13, 5, 11, 3, 9, 1); +inline void LoadFwdInterleavedT1(const uint64_t* arg, __m512i* out1, + __m512i* out2) { + const __m512i* arg_512 = reinterpret_cast(arg); + + // 0, 1, 2, 3, 4, 5, 6, 7 + __m512i v1 = _mm512_loadu_si512(arg_512++); + // 8, 9, 10, 11, 12, 13, 14, 15 + __m512i v2 = _mm512_loadu_si512(arg_512); + + const __m512i perm_idx = _mm512_set_epi64(6, 7, 4, 5, 2, 3, 0, 1); + + // 1, 0, 3, 2, 5, 4, 7, 6 + __m512i v1_perm = _mm512_permutexvar_epi64(perm_idx, v1); + // 9, 8, 11, 10, 13, 12, 15, 14 + __m512i v2_perm = _mm512_permutexvar_epi64(perm_idx, v2); + + *out1 = _mm512_mask_blend_epi64(0xaa, v1, v2_perm); + *out2 = _mm512_mask_blend_epi64(0xaa, v1_perm, v2); +} + +// Given input: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 +// Returns +// *out1 = _mm512_set_epi64(14, 12, 10, 8, 6, 4, 2, 0); +// *out2 = _mm512_set_epi64(15, 13, 11, 9, 7, 5, 3, 1); +inline void LoadInvInterleavedT1(const uint64_t* arg, __m512i* out1, + __m512i* out2) { + const __m512i vperm_hi_idx = _mm512_set_epi64(6, 4, 2, 0, 7, 5, 3, 1); + const __m512i vperm_lo_idx = _mm512_set_epi64(7, 5, 3, 1, 6, 4, 2, 0); + const __m512i vperm2_idx = _mm512_set_epi64(3, 2, 1, 0, 7, 6, 5, 4); + + const __m512i* arg_512 = reinterpret_cast(arg); + + // 7, 6, 5, 4, 3, 2, 1, 0 + __m512i v_7to0 = _mm512_loadu_si512(arg_512++); + // 15, 14, 13, 12, 11, 10, 9, 8 + __m512i v_15to8 = _mm512_loadu_si512(arg_512); + // 7, 5, 3, 1, 6, 4, 2, 0 + __m512i perm_lo = _mm512_permutexvar_epi64(vperm_lo_idx, v_7to0); + // 14, 12, 10, 8, 15, 13, 11, 9 + __m512i perm_hi = _mm512_permutexvar_epi64(vperm_hi_idx, v_15to8); + + *out1 = _mm512_mask_blend_epi64(0x0f, perm_hi, perm_lo); + *out2 = _mm512_mask_blend_epi64(0xf0, perm_hi, perm_lo); + *out2 = _mm512_permutexvar_epi64(vperm2_idx, *out2); +} + +// Given input: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 +// Returns +// *out1 = _mm512_set_epi64(13, 12, 9, 8, 5, 4, 1, 0); +// *out2 = _mm512_set_epi64(15, 14, 11, 10, 7, 6, 3, 2) +inline void LoadFwdInterleavedT2(const uint64_t* arg, __m512i* out1, + __m512i* out2) { + const __m512i* arg_512 = reinterpret_cast(arg); + + // 11, 10, 9, 8, 3, 2, 1, 0 + __m512i v1 = _mm512_loadu_si512(arg_512++); + // 15, 14, 13, 12, 7, 6, 5, 4 + __m512i v2 = _mm512_loadu_si512(arg_512); + + const __m512i v1_perm_idx = _mm512_set_epi64(5, 4, 7, 6, 1, 0, 3, 2); + + __m512i v1_perm = _mm512_permutexvar_epi64(v1_perm_idx, v1); + __m512i v2_perm = _mm512_permutexvar_epi64(v1_perm_idx, v2); + + *out1 = _mm512_mask_blend_epi64(0xcc, v1, v2_perm); + *out2 = _mm512_mask_blend_epi64(0xcc, v1_perm, v2); +} + +// Given input: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 +// Returns +// *out1 = _mm512_set_epi64(14, 12, 10, 8, 6, 4, 2, 0); +// *out2 = _mm512_set_epi64(15, 13, 11, 9, 7, 5, 3, 1); +inline void LoadInvInterleavedT2(const uint64_t* arg, __m512i* out1, + __m512i* out2) { + const __m512i* arg_512 = reinterpret_cast(arg); + + __m512i v1 = _mm512_loadu_si512(arg_512++); + __m512i v2 = _mm512_loadu_si512(arg_512); + + const __m512i v1_perm_idx = _mm512_set_epi64(6, 7, 4, 5, 2, 3, 0, 1); + + __m512i v1_perm = _mm512_permutexvar_epi64(v1_perm_idx, v1); + __m512i v2_perm = _mm512_permutexvar_epi64(v1_perm_idx, v2); + + *out1 = _mm512_mask_blend_epi64(0xaa, v1, v2_perm); + *out2 = _mm512_mask_blend_epi64(0xaa, v1_perm, v2); +} + +// Returns +// *out1 = _mm512_set_epi64(arg[11], arg[10], arg[9], arg[8], +// arg[3], arg[2], arg[1], arg[0]); +// *out2 = _mm512_set_epi64(arg[15], arg[14], arg[13], arg[12], +// arg[7], arg[6], arg[5], arg[4]); +inline void LoadFwdInterleavedT4(const uint64_t* arg, __m512i* out1, + __m512i* out2) { + const __m512i* arg_512 = reinterpret_cast(arg); + + const __m512i vperm2_idx = _mm512_set_epi64(3, 2, 1, 0, 7, 6, 5, 4); + __m512i v_7to0 = _mm512_loadu_si512(arg_512++); + __m512i v_15to8 = _mm512_loadu_si512(arg_512); + __m512i perm_hi = _mm512_permutexvar_epi64(vperm2_idx, v_15to8); + *out1 = _mm512_mask_blend_epi64(0x0f, perm_hi, v_7to0); + *out2 = _mm512_mask_blend_epi64(0xf0, perm_hi, v_7to0); + *out2 = _mm512_permutexvar_epi64(vperm2_idx, *out2); +} + +inline void LoadInvInterleavedT4(const uint64_t* arg, __m512i* out1, + __m512i* out2) { + const __m512i* arg_512 = reinterpret_cast(arg); + + // 0, 1, 2, 3, 4, 5, 6, 7 + __m512i v1 = _mm512_loadu_si512(arg_512++); + // 8, 9, 10, 11, 12, 13, 14, 15 + __m512i v2 = _mm512_loadu_si512(arg_512); + const __m512i perm_idx = _mm512_set_epi64(5, 4, 7, 6, 1, 0, 3, 2); + + // 1, 0, 3, 2, 5, 4, 7, 6 + __m512i v1_perm = _mm512_permutexvar_epi64(perm_idx, v1); + // 9, 8, 11, 10, 13, 12, 15, 14 + __m512i v2_perm = _mm512_permutexvar_epi64(perm_idx, v2); + + *out1 = _mm512_mask_blend_epi64(0xcc, v1, v2_perm); + *out2 = _mm512_mask_blend_epi64(0xcc, v1_perm, v2); +} + +// Given inputs +// @param arg1 = _mm512_set_epi64(15, 14, 13, 12, 11, 10, 9, 8); +// @param arg2 = _mm512_set_epi64(7, 6, 5, 4, 3, 2, 1, 0); +// Writes out = {8, 0, 9, 1, 10, 2, 11, 3, +// 12, 4, 13, 5, 14, 6, 15, 7} +inline void WriteFwdInterleavedT1(__m512i arg1, __m512i arg2, __m512i* out) { + const __m512i vperm2_idx = _mm512_set_epi64(3, 2, 1, 0, 7, 6, 5, 4); + const __m512i v_X_out_idx = _mm512_set_epi64(7, 3, 6, 2, 5, 1, 4, 0); + const __m512i v_Y_out_idx = _mm512_set_epi64(3, 7, 2, 6, 1, 5, 0, 4); + + // v_Y => (4, 5, 6, 7, 0, 1, 2, 3) + arg2 = _mm512_permutexvar_epi64(vperm2_idx, arg2); + // 4, 5, 6, 7, 12, 13, 14, 15 + __m512i perm_lo = _mm512_mask_blend_epi64(0x0f, arg1, arg2); + + // 8, 9, 10, 11, 0, 1, 2, 3 + __m512i perm_hi = _mm512_mask_blend_epi64(0xf0, arg1, arg2); + + arg1 = _mm512_permutexvar_epi64(v_X_out_idx, perm_hi); + arg2 = _mm512_permutexvar_epi64(v_Y_out_idx, perm_lo); + + _mm512_storeu_si512(out++, arg1); + _mm512_storeu_si512(out, arg2); +} + +// Given inputs +// @param arg1 = _mm512_set_epi64(15, 14, 13, 12, 11, 10, 9, 8); +// @param arg2 = _mm512_set_epi64(7, 6, 5, 4, 3, 2, 1, 0); +// Writes out = {8, 9, 10, 11, 0, 1, 2, 3, +// 12, 13, 14, 15, 4, 5, 6, 7} +inline void WriteInvInterleavedT4(__m512i arg1, __m512i arg2, __m512i* out) { + __m256i x0 = _mm512_extracti64x4_epi64(arg1, 0); + __m256i x1 = _mm512_extracti64x4_epi64(arg1, 1); + __m256i y0 = _mm512_extracti64x4_epi64(arg2, 0); + __m256i y1 = _mm512_extracti64x4_epi64(arg2, 1); + __m256i* out_256 = reinterpret_cast<__m256i*>(out); + _mm256_storeu_si256(out_256++, x0); + _mm256_storeu_si256(out_256++, y0); + _mm256_storeu_si256(out_256++, x1); + _mm256_storeu_si256(out_256++, y1); +} + +// Returns _mm512_set_epi64(arg[3], arg[3], arg[2], arg[2], +// arg[1], arg[1], arg[0], arg[0]); +inline __m512i LoadWOpT2(const void* arg) { + const __m512i vperm_w_idx = _mm512_set_epi64(3, 3, 2, 2, 1, 1, 0, 0); + + __m256i v_W_256 = _mm256_loadu_si256(reinterpret_cast(arg)); + __m512i v_W = _mm512_broadcast_i64x4(v_W_256); + v_W = _mm512_permutexvar_epi64(vperm_w_idx, v_W); + + return v_W; +} + +// Returns _mm512_set_epi64(arg[1], arg[1], arg[1], arg[1], +// arg[0], arg[0], arg[0], arg[0]); +inline __m512i LoadWOpT4(const void* arg) { + const __m512i vperm_w_idx = _mm512_set_epi64(1, 1, 1, 1, 0, 0, 0, 0); + + __m128i v_W_128 = _mm_loadu_si128(reinterpret_cast(arg)); + __m512i v_W = _mm512_broadcast_i64x2(v_W_128); + v_W = _mm512_permutexvar_epi64(vperm_w_idx, v_W); + + return v_W; +} + +#endif // HEXL_HAS_AVX512DQ + +} // namespace hexl +} // namespace intel diff --git a/hexl_omp/ntt/ntt-default.hpp b/hexl_omp/ntt/ntt-default.hpp new file mode 100644 index 00000000..7fe79db4 --- /dev/null +++ b/hexl_omp/ntt/ntt-default.hpp @@ -0,0 +1,159 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "hexl/logging/logging.hpp" +#include "hexl/number-theory/number-theory.hpp" +#include "hexl/util/check.hpp" + +namespace intel { +namespace hexl { + +/// @brief Out of place Harvey butterfly: assume \p X_op, \p Y_op in [0, 4q), +/// and return X_r, Y_r in [0, 4q) such that X_r = X_op + WY_op, Y_r = X_op - +/// WY_op (mod q). +/// @param[out] X_r Butterfly data +/// @param[out] Y_r Butterfly data +/// @param[in] X_op Butterfly data +/// @param[in] Y_op Butterfly data +/// @param[in] W Root of unity +/// @param[in] W_precon Preconditioned \p W for BitShift-bit Barrett +/// reduction +/// @param[in] modulus Negative modulus, i.e. (-q) represented as 8 64-bit +/// signed integers in SIMD form +/// @param[in] twice_modulus Twice the modulus, i.e. 2*q represented as 8 64-bit +/// signed integers in SIMD form +/// @details See Algorithm 4 of https://arxiv.org/pdf/1205.2926.pdf +inline void FwdButterflyRadix2(uint64_t* X_r, uint64_t* Y_r, + const uint64_t* X_op, const uint64_t* Y_op, + uint64_t W, uint64_t W_precon, uint64_t modulus, + uint64_t twice_modulus) { + HEXL_VLOG(5, "FwdButterflyRadix2"); + HEXL_VLOG(5, "Inputs: X_op " << *X_op << ", Y_op " << *Y_op << ", W " << W + << ", modulus " << modulus); + uint64_t tx = ReduceMod<2>(*X_op, twice_modulus); + uint64_t T = MultiplyModLazy<64>(*Y_op, W, W_precon, modulus); + HEXL_VLOG(5, "T " << T); + *X_r = tx + T; + *Y_r = tx + twice_modulus - T; + + HEXL_VLOG(5, "Output X " << *X_r << ", Y " << *Y_r); +} + +// Assume X, Y in [0, n*q) and return X_r, Y_r in [0, (n+2)*q) +// such that X_r = X_op + WY_op mod q and Y_r = X_op - WY_op mod q +inline void FwdButterflyRadix4Lazy(uint64_t* X_r, uint64_t* Y_r, + const uint64_t X_op, const uint64_t Y_op, + uint64_t W, uint64_t W_precon, + uint64_t modulus, uint64_t twice_modulus) { + HEXL_VLOG(3, "FwdButterflyRadix4Lazy"); + HEXL_VLOG(3, "Inputs: X_op " << X_op << ", Y_op " << Y_op << ", W " << W + << ", modulus " << modulus); + + uint64_t T = MultiplyModLazy<64>(Y_op, W, W_precon, modulus); + HEXL_VLOG(3, "T " << T); + *X_r = X_op + T; + *Y_r = X_op + twice_modulus - T; + + HEXL_VLOG(3, "Outputs: X_r " << *X_r << ", Y_r " << *Y_r); +} + +// Assume X0, X1, X2, X3 in [0, 4q) and return X0, X1, X2, X3 in [0, 4q) +inline void FwdButterflyRadix4( + uint64_t* X_r0, uint64_t* X_r1, uint64_t* X_r2, uint64_t* X_r3, + const uint64_t* X_op0, const uint64_t* X_op1, const uint64_t* X_op2, + const uint64_t* X_op3, uint64_t W1, uint64_t W1_precon, uint64_t W2, + uint64_t W2_precon, uint64_t W3, uint64_t W3_precon, uint64_t modulus, + uint64_t twice_modulus, uint64_t four_times_modulus) { + HEXL_VLOG(3, "FwdButterflyRadix4"); + HEXL_UNUSED(four_times_modulus); + + FwdButterflyRadix2(X_r0, X_r2, X_op0, X_op2, W1, W1_precon, modulus, + twice_modulus); + FwdButterflyRadix2(X_r1, X_r3, X_op1, X_op3, W1, W1_precon, modulus, + twice_modulus); + FwdButterflyRadix2(X_r0, X_r1, X_r0, X_r1, W2, W2_precon, modulus, + twice_modulus); + FwdButterflyRadix2(X_r2, X_r3, X_r2, X_r3, W3, W3_precon, modulus, + twice_modulus); + + // Alternate implementation + // // Returns Xs in [0, 6q) + // FwdButterflyRadix4Lazy(X0, X2, W1, W1_precon, modulus, twice_modulus); + // FwdButterflyRadix4Lazy(X1, X3, W1, W1_precon, modulus, twice_modulus); + + // // Returns Xs in [0, 8q) + // FwdButterflyRadix4Lazy(X0, X1, W2, W2_precon, modulus, twice_modulus); + // FwdButterflyRadix4Lazy(X2, X3, W3, W3_precon, modulus, twice_modulus); + + // // Reduce Xs to [0, 4q) + // *X0 = ReduceMod<2>(*X0, four_times_modulus); + // *X1 = ReduceMod<2>(*X1, four_times_modulus); + // *X2 = ReduceMod<2>(*X2, four_times_modulus); + // *X3 = ReduceMod<2>(*X3, four_times_modulus); +} + +/// @brief Out-of-place Harvey butterfly: assume X_op, Y_op in [0, 2q), and +/// return X_r, Y_r in [0, 2q) such that X_r = X_op + Y_op (mod q), +/// Y_r = W(X_op - Y_op) (mod q). +/// @param[out] X_r Butterfly data +/// @param[out] Y_r Butterfly data +/// @param[in] X_op Butterfly data +/// @param[in] Y_op Butterfly data +/// @param[in] W Root of unity +/// @param[in] W_precon Preconditioned root of unity for 64-bit Barrett +/// reduction +/// @param[in] modulus Modulus, i.e. (q) represented as 8 64-bit +/// signed integers in SIMD form +/// @param[in] twice_modulus Twice the modulus, i.e. 2*q represented as 8 64-bit +/// signed integers in SIMD form +/// @details See Algorithm 3 of https://arxiv.org/pdf/1205.2926.pdf +inline void InvButterflyRadix2(uint64_t* X_r, uint64_t* Y_r, + const uint64_t* X_op, const uint64_t* Y_op, + uint64_t W, uint64_t W_precon, uint64_t modulus, + uint64_t twice_modulus) { + HEXL_VLOG(4, "InvButterflyRadix2 X_op " + << *X_op << ", Y_op " << *Y_op << " W " << W << " W_precon " + << W_precon << " modulus " << modulus); + uint64_t tx = *X_op + *Y_op; + *Y_r = *X_op + twice_modulus - *Y_op; + *X_r = ReduceMod<2>(tx, twice_modulus); + *Y_r = MultiplyModLazy<64>(*Y_r, W, W_precon, modulus); + + HEXL_VLOG(4, "InvButterflyRadix2 returning X_r " << *X_r << ", Y_r " << *Y_r); +} + +// Assume X0, X1, X2, X3 in [0, 2q) and return X0, X1, X2, X3 in [0, 2q) +inline void InvButterflyRadix4(uint64_t* X_r0, uint64_t* X_r1, uint64_t* X_r2, + uint64_t* X_r3, const uint64_t* X_op0, + const uint64_t* X_op1, const uint64_t* X_op2, + const uint64_t* X_op3, uint64_t W1, + uint64_t W1_precon, uint64_t W2, + uint64_t W2_precon, uint64_t W3, + uint64_t W3_precon, uint64_t modulus, + uint64_t twice_modulus) { + HEXL_VLOG(4, "InvButterflyRadix4 " // + << "X_op0 " << *X_op0 << ", X_op1 " << *X_op1 // + << ", X_op2 " << *X_op2 << " X_op3 " << *X_op3 // + << " W1 " << W1 << " W1_precon " << W1_precon // + << " W2 " << W2 << " W2_precon " << W2_precon // + << " W3 " << W3 << " W3_precon " << W3_precon // + << " modulus " << modulus); + + InvButterflyRadix2(X_r0, X_r1, X_op0, X_op1, W1, W1_precon, modulus, + twice_modulus); + InvButterflyRadix2(X_r2, X_r3, X_op2, X_op3, W2, W2_precon, modulus, + twice_modulus); + InvButterflyRadix2(X_r0, X_r2, X_r0, X_r2, W3, W3_precon, modulus, + twice_modulus); + InvButterflyRadix2(X_r1, X_r3, X_r1, X_r3, W3, W3_precon, modulus, + twice_modulus); + + HEXL_VLOG(4, "InvButterflyRadix4 returning X0 " << *X_r0 << ", X_r1 " << *X_r1 + << ", X_r2 " << *X_r2 // + << " X_r3 " << *X_r3); +} + +} // namespace hexl +} // namespace intel diff --git a/hexl_omp/ntt/ntt-internal.cpp b/hexl_omp/ntt/ntt-internal.cpp new file mode 100644 index 00000000..9ace8741 --- /dev/null +++ b/hexl_omp/ntt/ntt-internal.cpp @@ -0,0 +1,313 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "ntt/ntt-internal.hpp" + +#include +#include + +#include "hexl/logging/logging.hpp" +#include "hexl/ntt/ntt.hpp" +#include "hexl/number-theory/number-theory.hpp" +#include "hexl/util/aligned-allocator.hpp" +#include "hexl/util/check.hpp" +#include "hexl/util/defines.hpp" +#include "ntt/fwd-ntt-avx512.hpp" +#include "ntt/inv-ntt-avx512.hpp" +#include "util/cpu-features.hpp" + +namespace intel { +namespace hexl { + +AllocatorStrategyPtr mallocStrategy = AllocatorStrategyPtr(new MallocStrategy); + +NTT::NTT(uint64_t degree, uint64_t q, uint64_t root_of_unity, + std::shared_ptr alloc_ptr) + : m_degree(degree), + m_q(q), + m_w(root_of_unity), + m_alloc(alloc_ptr), + m_aligned_alloc(AlignedAllocator(m_alloc)), + m_root_of_unity_powers(m_aligned_alloc), + m_precon32_root_of_unity_powers(m_aligned_alloc), + m_precon64_root_of_unity_powers(m_aligned_alloc), + m_avx512_root_of_unity_powers(m_aligned_alloc), + m_avx512_precon32_root_of_unity_powers(m_aligned_alloc), + m_avx512_precon52_root_of_unity_powers(m_aligned_alloc), + m_avx512_precon64_root_of_unity_powers(m_aligned_alloc), + m_precon32_inv_root_of_unity_powers(m_aligned_alloc), + m_precon52_inv_root_of_unity_powers(m_aligned_alloc), + m_precon64_inv_root_of_unity_powers(m_aligned_alloc), + m_inv_root_of_unity_powers(m_aligned_alloc) { + HEXL_CHECK(CheckArguments(degree, q), ""); + HEXL_CHECK(IsPrimitiveRoot(m_w, 2 * degree, q), + m_w << " is not a primitive 2*" << degree << "'th root of unity"); + + m_degree_bits = Log2(m_degree); + m_w_inv = InverseMod(m_w, m_q); + ComputeRootOfUnityPowers(); +} + +NTT::NTT(uint64_t degree, uint64_t q, std::shared_ptr alloc_ptr) + : NTT(degree, q, MinimalPrimitiveRoot(2 * degree, q), alloc_ptr) {} + +void NTT::ComputeRootOfUnityPowers() { + AlignedVector64 root_of_unity_powers(m_degree, 0, m_aligned_alloc); + AlignedVector64 inv_root_of_unity_powers(m_degree, 0, + m_aligned_alloc); + + // 64-bit preconditioned inverse and root of unity powers + root_of_unity_powers[0] = 1; + inv_root_of_unity_powers[0] = InverseMod(1, m_q); + uint64_t idx = 0; + uint64_t prev_idx = idx; + + for (size_t i = 1; i < m_degree; i++) { + idx = ReverseBits(i, m_degree_bits); + root_of_unity_powers[idx] = + MultiplyMod(root_of_unity_powers[prev_idx], m_w, m_q); + inv_root_of_unity_powers[idx] = InverseMod(root_of_unity_powers[idx], m_q); + + prev_idx = idx; + } + + m_root_of_unity_powers = root_of_unity_powers; + m_avx512_root_of_unity_powers = m_root_of_unity_powers; + + // Duplicate each root of unity at indices [N/4, N/2]. + // These are the roots of unity used in the FwdNTT FwdT2 function + // By creating these duplicates, we avoid extra permutations while loading the + // roots of unity + AlignedVector64 W2_roots; + W2_roots.reserve(m_degree / 2); + for (size_t i = m_degree / 4; i < m_degree / 2; ++i) { + W2_roots.push_back(m_root_of_unity_powers[i]); + W2_roots.push_back(m_root_of_unity_powers[i]); + } + m_avx512_root_of_unity_powers.erase( + m_avx512_root_of_unity_powers.begin() + m_degree / 4, + m_avx512_root_of_unity_powers.begin() + m_degree / 2); + m_avx512_root_of_unity_powers.insert( + m_avx512_root_of_unity_powers.begin() + m_degree / 4, W2_roots.begin(), + W2_roots.end()); + + // Duplicate each root of unity at indices [N/8, N/4]. + // These are the roots of unity used in the FwdNTT FwdT4 function + // By creating these duplicates, we avoid extra permutations while loading the + // roots of unity + AlignedVector64 W4_roots; + W4_roots.reserve(m_degree / 2); + for (size_t i = m_degree / 8; i < m_degree / 4; ++i) { + W4_roots.push_back(m_root_of_unity_powers[i]); + W4_roots.push_back(m_root_of_unity_powers[i]); + W4_roots.push_back(m_root_of_unity_powers[i]); + W4_roots.push_back(m_root_of_unity_powers[i]); + } + m_avx512_root_of_unity_powers.erase( + m_avx512_root_of_unity_powers.begin() + m_degree / 8, + m_avx512_root_of_unity_powers.begin() + m_degree / 4); + m_avx512_root_of_unity_powers.insert( + m_avx512_root_of_unity_powers.begin() + m_degree / 8, W4_roots.begin(), + W4_roots.end()); + + auto compute_barrett_vector = [&](const AlignedVector64& values, + uint64_t bit_shift) { + AlignedVector64 barrett_vector(m_aligned_alloc); + for (uint64_t value : values) { + MultiplyFactor mf(value, bit_shift, m_q); + barrett_vector.push_back(mf.BarrettFactor()); + } + return barrett_vector; + }; + + m_precon32_root_of_unity_powers = + compute_barrett_vector(root_of_unity_powers, 32); + m_precon64_root_of_unity_powers = + compute_barrett_vector(root_of_unity_powers, 64); + + // 52-bit preconditioned root of unity powers + if (has_avx512ifma) { + m_avx512_precon52_root_of_unity_powers = + compute_barrett_vector(m_avx512_root_of_unity_powers, 52); + } + + if (has_avx512dq) { + m_avx512_precon32_root_of_unity_powers = + compute_barrett_vector(m_avx512_root_of_unity_powers, 32); + m_avx512_precon64_root_of_unity_powers = + compute_barrett_vector(m_avx512_root_of_unity_powers, 64); + } + + // Inverse root of unity powers + + // Reordering inv_root_of_powers + AlignedVector64 temp(m_degree, 0, m_aligned_alloc); + temp[0] = inv_root_of_unity_powers[0]; + idx = 1; + + for (size_t m = (m_degree >> 1); m > 0; m >>= 1) { + for (size_t i = 0; i < m; i++) { + temp[idx] = inv_root_of_unity_powers[m + i]; + idx++; + } + } + m_inv_root_of_unity_powers = std::move(temp); + + // 32-bit preconditioned inverse root of unity powers + m_precon32_inv_root_of_unity_powers = + compute_barrett_vector(m_inv_root_of_unity_powers, 32); + + // 52-bit preconditioned inverse root of unity powers + if (has_avx512ifma) { + m_precon52_inv_root_of_unity_powers = + compute_barrett_vector(m_inv_root_of_unity_powers, 52); + } + + // 64-bit preconditioned inverse root of unity powers + m_precon64_inv_root_of_unity_powers = + compute_barrett_vector(m_inv_root_of_unity_powers, 64); +} + +bool NTT::CheckArguments(uint64_t degree, uint64_t modulus) { + HEXL_UNUSED(degree); + HEXL_UNUSED(modulus); + HEXL_CHECK(IsPowerOfTwo(degree), + "degree " << degree << " is not a power of 2"); + HEXL_CHECK(degree <= (1ULL << NTT::MaxDegreeBits()), + "degree should be less than 2^" << NTT::MaxDegreeBits() << " got " + << degree); + HEXL_CHECK(modulus <= (1ULL << NTT::MaxModulusBits()), + "modulus should be less than 2^" << NTT::MaxModulusBits() + << " got " << modulus); + HEXL_CHECK(modulus % (2 * degree) == 1, "modulus mod 2n != 1"); + HEXL_CHECK(IsPrime(modulus), "modulus is not prime"); + + return true; +} + +void NTT::ComputeForward(uint64_t* result, const uint64_t* operand, + uint64_t input_mod_factor, + uint64_t output_mod_factor) { + HEXL_CHECK(result != nullptr, "result == nullptr"); + HEXL_CHECK(operand != nullptr, "operand == nullptr"); + HEXL_CHECK( + input_mod_factor == 1 || input_mod_factor == 2 || input_mod_factor == 4, + "input_mod_factor must be 1, 2 or 4; got " << input_mod_factor); + HEXL_CHECK(output_mod_factor == 1 || output_mod_factor == 4, + "output_mod_factor must be 1 or 4; got " << output_mod_factor); + HEXL_CHECK_BOUNDS( + operand, m_degree, m_q * input_mod_factor, + "value in operand exceeds bound " << m_q * input_mod_factor); + +#ifdef HEXL_HAS_AVX512IFMA + if (has_avx512ifma && (m_q < s_max_fwd_ifma_modulus && (m_degree >= 16))) { + const uint64_t* root_of_unity_powers = GetAVX512RootOfUnityPowers().data(); + const uint64_t* precon_root_of_unity_powers = + GetAVX512Precon52RootOfUnityPowers().data(); + + HEXL_VLOG(3, "Calling 52-bit AVX512-IFMA FwdNTT"); + ForwardTransformToBitReverseAVX512( + result, operand, m_degree, m_q, root_of_unity_powers, + precon_root_of_unity_powers, input_mod_factor, output_mod_factor); + return; + } +#endif + +#ifdef HEXL_HAS_AVX512DQ + if (has_avx512dq && m_degree >= 16) { + if (m_q < s_max_fwd_32_modulus) { + HEXL_VLOG(3, "Calling 32-bit AVX512-DQ FwdNTT"); + const uint64_t* root_of_unity_powers = + GetAVX512RootOfUnityPowers().data(); + const uint64_t* precon_root_of_unity_powers = + GetAVX512Precon32RootOfUnityPowers().data(); + ForwardTransformToBitReverseAVX512<32>( + result, operand, m_degree, m_q, root_of_unity_powers, + precon_root_of_unity_powers, input_mod_factor, output_mod_factor); + } else { + HEXL_VLOG(3, "Calling 64-bit AVX512-DQ FwdNTT"); + const uint64_t* root_of_unity_powers = + GetAVX512RootOfUnityPowers().data(); + const uint64_t* precon_root_of_unity_powers = + GetAVX512Precon64RootOfUnityPowers().data(); + + ForwardTransformToBitReverseAVX512( + result, operand, m_degree, m_q, root_of_unity_powers, + precon_root_of_unity_powers, input_mod_factor, output_mod_factor); + } + return; + } +#endif + + HEXL_VLOG(3, "Calling ForwardTransformToBitReverseRadix2"); + const uint64_t* root_of_unity_powers = GetRootOfUnityPowers().data(); + const uint64_t* precon_root_of_unity_powers = + GetPrecon64RootOfUnityPowers().data(); + + ForwardTransformToBitReverseRadix2( + result, operand, m_degree, m_q, root_of_unity_powers, + precon_root_of_unity_powers, input_mod_factor, output_mod_factor); +} + +void NTT::ComputeInverse(uint64_t* result, const uint64_t* operand, + uint64_t input_mod_factor, + uint64_t output_mod_factor) { + HEXL_CHECK(result != nullptr, "result == nullptr"); + HEXL_CHECK(operand != nullptr, "operand == nullptr"); + HEXL_CHECK(input_mod_factor == 1 || input_mod_factor == 2, + "input_mod_factor must be 1 or 2; got " << input_mod_factor); + HEXL_CHECK(output_mod_factor == 1 || output_mod_factor == 2, + "output_mod_factor must be 1 or 2; got " << output_mod_factor); + HEXL_CHECK_BOUNDS(operand, m_degree, m_q * input_mod_factor, + "operand exceeds bound " << m_q * input_mod_factor); + +#ifdef HEXL_HAS_AVX512IFMA + if (has_avx512ifma && (m_q < s_max_inv_ifma_modulus) && (m_degree >= 16)) { + HEXL_VLOG(3, "Calling 52-bit AVX512-IFMA InvNTT"); + const uint64_t* inv_root_of_unity_powers = GetInvRootOfUnityPowers().data(); + const uint64_t* precon_inv_root_of_unity_powers = + GetPrecon52InvRootOfUnityPowers().data(); + InverseTransformFromBitReverseAVX512( + result, operand, m_degree, m_q, inv_root_of_unity_powers, + precon_inv_root_of_unity_powers, input_mod_factor, output_mod_factor); + return; + } +#endif + +#ifdef HEXL_HAS_AVX512DQ + if (has_avx512dq && m_degree >= 16) { + if (m_q < s_max_inv_32_modulus) { + HEXL_VLOG(3, "Calling 32-bit AVX512-DQ InvNTT"); + const uint64_t* inv_root_of_unity_powers = + GetInvRootOfUnityPowers().data(); + const uint64_t* precon_inv_root_of_unity_powers = + GetPrecon32InvRootOfUnityPowers().data(); + InverseTransformFromBitReverseAVX512<32>( + result, operand, m_degree, m_q, inv_root_of_unity_powers, + precon_inv_root_of_unity_powers, input_mod_factor, output_mod_factor); + } else { + HEXL_VLOG(3, "Calling 64-bit AVX512 InvNTT"); + const uint64_t* inv_root_of_unity_powers = + GetInvRootOfUnityPowers().data(); + const uint64_t* precon_inv_root_of_unity_powers = + GetPrecon64InvRootOfUnityPowers().data(); + + InverseTransformFromBitReverseAVX512( + result, operand, m_degree, m_q, inv_root_of_unity_powers, + precon_inv_root_of_unity_powers, input_mod_factor, output_mod_factor); + } + return; + } +#endif + + HEXL_VLOG(3, "Calling 64-bit default InvNTT"); + const uint64_t* inv_root_of_unity_powers = GetInvRootOfUnityPowers().data(); + const uint64_t* precon_inv_root_of_unity_powers = + GetPrecon64InvRootOfUnityPowers().data(); + InverseTransformFromBitReverseRadix2( + result, operand, m_degree, m_q, inv_root_of_unity_powers, + precon_inv_root_of_unity_powers, input_mod_factor, output_mod_factor); +} + +} // namespace hexl +} // namespace intel diff --git a/hexl_omp/ntt/ntt-internal.hpp b/hexl_omp/ntt/ntt-internal.hpp new file mode 100644 index 00000000..420dca10 --- /dev/null +++ b/hexl_omp/ntt/ntt-internal.hpp @@ -0,0 +1,125 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +#include + +#include "hexl/ntt/ntt.hpp" +#include "hexl/number-theory/number-theory.hpp" +#include "hexl/util/aligned-allocator.hpp" +#include "hexl/util/check.hpp" +#include "hexl/util/util.hpp" +#include "util/util-internal.hpp" + +namespace intel { +namespace hexl { + +/// @brief Radix-2 native C++ NTT implementation of the forward NTT +/// @param[out] result Output data. Overwritten with NTT output +/// @param[in] operand Input data. +/// @param[in] n Size of the transform, i.e. the polynomial degree. Must be a +/// power of two. +/// @param[in] modulus Prime modulus q. Must satisfy q == 1 mod 2n +/// @param[in] root_of_unity_powers Powers of 2n'th root of unity in F_q. In +/// bit-reversed order +/// @param[in] precon_root_of_unity_powers Pre-conditioned Powers of 2n'th root +/// of unity in F_q. In bit-reversed order. +/// @param[in] input_mod_factor Upper bound for inputs; inputs must be in [0, +/// input_mod_factor * q) +/// @param[in] output_mod_factor Upper bound for result; result must be in [0, +/// output_mod_factor * q) +void ForwardTransformToBitReverseRadix2( + uint64_t* result, const uint64_t* operand, uint64_t n, uint64_t modulus, + const uint64_t* root_of_unity_powers, + const uint64_t* precon_root_of_unity_powers, uint64_t input_mod_factor = 1, + uint64_t output_mod_factor = 1); + +/// @brief Radix-4 native C++ NTT implementation of the forward NTT +/// @param[out] result Output data. Overwritten with NTT output +/// @param[in] operand Input data. +/// @param[in] n Size of the transform, i.e. the polynomial degree. Must be a +/// power of two. +/// @param[in] modulus Prime modulus q. Must satisfy q == 1 mod 2n +/// @param[in] root_of_unity_powers Powers of 2n'th root of unity in F_q. In +/// bit-reversed order +/// @param[in] precon_root_of_unity_powers Pre-conditioned Powers of 2n'th root +/// of unity in F_q. In bit-reversed order. +/// @param[in] input_mod_factor Upper bound for inputs; inputs must be in [0, +/// input_mod_factor * q) +/// @param[in] output_mod_factor Upper bound for result; result must be in [0, +/// output_mod_factor * q) +void ForwardTransformToBitReverseRadix4( + uint64_t* result, const uint64_t* operand, uint64_t n, uint64_t modulus, + const uint64_t* root_of_unity_powers, + const uint64_t* precon_root_of_unity_powers, uint64_t input_mod_factor = 1, + uint64_t output_mod_factor = 1); + +/// @brief Reference forward NTT which is written for clarity rather than +/// performance +/// @param[in, out] operand Input data. Overwritten with NTT output +/// @param[in] n Size of the transform, i.e. the polynomial degree. Must be a +/// power of two. +/// @param[in] modulus Prime modulus. Must satisfy q == 1 mod 2n +/// @param[in] root_of_unity_powers Powers of 2n'th root of unity in F_q. In +/// bit-reversed order +void ReferenceForwardTransformToBitReverse( + uint64_t* operand, uint64_t n, uint64_t modulus, + const uint64_t* root_of_unity_powers); + +/// @brief Reference inverse NTT which is written for clarity rather than +/// performance +/// @param[in, out] operand Input data. Overwritten with NTT output +/// @param[in] n Size of the transform, i.e. the polynomial degree. Must be a +/// power of two. +/// @param[in] modulus Prime modulus. Must satisfy q == 1 mod 2n +/// @param[in] inv_root_of_unity_powers Powers of inverse 2n'th root of unity in +/// F_q. In bit-reversed order. +void ReferenceInverseTransformFromBitReverse( + uint64_t* operand, uint64_t n, uint64_t modulus, + const uint64_t* inv_root_of_unity_powers); + +/// @brief Radix-2 native C++ NTT implementation of the inverse NTT +/// @param[out] result Output data. Overwritten with NTT output +/// @param[in] operand Input data. +/// @param[in] n Size of the transform, i.e. the polynomial degree. Must be a +/// power of two. +/// @param[in] modulus Prime modulus q. Must satisfy q == 1 mod 2n +/// @param[in] inv_root_of_unity_powers Powers of inverse 2n'th root of unity in +/// F_q. In bit-reversed order. +/// @param[in] precon_root_of_unity_powers Pre-conditioned powers of inverse +/// 2n'th root of unity in F_q. In bit-reversed order. +/// @param[in] input_mod_factor Upper bound for inputs; inputs must be in [0, +/// input_mod_factor * q) +/// @param[in] output_mod_factor Upper bound for result; result must be in [0, +/// output_mod_factor * q) +void InverseTransformFromBitReverseRadix2( + uint64_t* result, const uint64_t* operand, uint64_t n, uint64_t modulus, + const uint64_t* inv_root_of_unity_powers, + const uint64_t* precon_inv_root_of_unity_powers, + uint64_t input_mod_factor = 1, uint64_t output_mod_factor = 1); + +/// @brief Radix-4 native C++ NTT implementation of the inverse NTT +/// @param[out] result Output data. Overwritten with NTT output +/// @param[in] operand Input data. +/// @param[in] n Size of the transform, i.e. the polynomial degree. Must be a +/// power of two. +/// @param[in] modulus Prime modulus q. Must satisfy q == 1 mod 2n +/// @param[in] root_of_unity_powers Powers of 2n'th root of unity in F_q. In +/// bit-reversed order +/// @param[in] precon_root_of_unity_powers Pre-conditioned Powers of 2n'th root +/// of unity in F_q. In bit-reversed order. +/// @param[in] input_mod_factor Upper bound for inputs; inputs must be in [0, +/// input_mod_factor * q) +/// @param[in] output_mod_factor Upper bound for result; result must be in [0, +/// output_mod_factor * q) +void InverseTransformFromBitReverseRadix4( + uint64_t* result, const uint64_t* operand, uint64_t n, uint64_t modulus, + const uint64_t* root_of_unity_powers, + const uint64_t* precon_root_of_unity_powers, uint64_t input_mod_factor = 1, + uint64_t output_mod_factor = 1); + +} // namespace hexl +} // namespace intel diff --git a/hexl_omp/ntt/ntt-radix-2.cpp b/hexl_omp/ntt/ntt-radix-2.cpp new file mode 100644 index 00000000..d0555d47 --- /dev/null +++ b/hexl_omp/ntt/ntt-radix-2.cpp @@ -0,0 +1,522 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include + +#include "hexl/logging/logging.hpp" +#include "hexl/ntt/ntt.hpp" +#include "hexl/number-theory/number-theory.hpp" +#include "hexl/util/aligned-allocator.hpp" +#include "hexl/util/check.hpp" +#include "ntt/ntt-default.hpp" +#include "util/cpu-features.hpp" + +namespace intel { +namespace hexl { + +void ForwardTransformToBitReverseRadix2( + uint64_t* result, const uint64_t* operand, uint64_t n, uint64_t modulus, + const uint64_t* root_of_unity_powers, + const uint64_t* precon_root_of_unity_powers, uint64_t input_mod_factor, + uint64_t output_mod_factor) { + HEXL_CHECK(NTT::CheckArguments(n, modulus), ""); + HEXL_CHECK_BOUNDS(operand, n, modulus * input_mod_factor, + "operand exceeds bound " << modulus * input_mod_factor); + HEXL_CHECK(root_of_unity_powers != nullptr, + "root_of_unity_powers == nullptr"); + HEXL_CHECK(precon_root_of_unity_powers != nullptr, + "precon_root_of_unity_powers == nullptr"); + HEXL_CHECK( + input_mod_factor == 1 || input_mod_factor == 2 || input_mod_factor == 4, + "input_mod_factor must be 1, 2, or 4; got " << input_mod_factor); + HEXL_UNUSED(input_mod_factor); + HEXL_CHECK(output_mod_factor == 1 || output_mod_factor == 4, + "output_mod_factor must be 1 or 4; got " << output_mod_factor); + + uint64_t twice_modulus = modulus << 1; + size_t t = (n >> 1); + + // In case of out-of-place operation do first pass and convert to in-place + { + const uint64_t W = root_of_unity_powers[1]; + const uint64_t W_precon = precon_root_of_unity_powers[1]; + + uint64_t* X_r = result; + uint64_t* Y_r = X_r + t; + + const uint64_t* X_op = operand; + const uint64_t* Y_op = X_op + t; + + // First pass for out-of-order case + switch (t) { + case 8: { + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + break; + } + case 4: { + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + break; + } + case 2: { + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + break; + } + case 1: { + FwdButterflyRadix2(X_r, Y_r, X_op, Y_op, W, W_precon, modulus, + twice_modulus); + break; + } + default: { + HEXL_LOOP_UNROLL_8 + for (size_t j = 0; j < t; j += 8) { + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + } + } + } + t >>= 1; + } + + // Continue with in-place operation + for (size_t m = 2; m < n; m <<= 1) { + size_t offset = 0; + switch (t) { + case 8: { + for (size_t i = 0; i < m; i++) { + if (i != 0) { + offset += (t << 1); + } + const uint64_t W = root_of_unity_powers[m + i]; + const uint64_t W_precon = precon_root_of_unity_powers[m + i]; + + uint64_t* X_r = result + offset; + uint64_t* Y_r = X_r + t; + const uint64_t* X_op = X_r; + const uint64_t* Y_op = Y_r; + + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + } + break; + } + case 4: { + for (size_t i = 0; i < m; i++) { + if (i != 0) { + offset += (t << 1); + } + const uint64_t W = root_of_unity_powers[m + i]; + const uint64_t W_precon = precon_root_of_unity_powers[m + i]; + + uint64_t* X_r = result + offset; + uint64_t* Y_r = X_r + t; + const uint64_t* X_op = X_r; + const uint64_t* Y_op = Y_r; + + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + } + break; + } + case 2: { + for (size_t i = 0; i < m; i++) { + if (i != 0) { + offset += (t << 1); + } + const uint64_t W = root_of_unity_powers[m + i]; + const uint64_t W_precon = precon_root_of_unity_powers[m + i]; + + uint64_t* X_r = result + offset; + uint64_t* Y_r = X_r + t; + const uint64_t* X_op = X_r; + const uint64_t* Y_op = Y_r; + + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + } + break; + } + case 1: { + for (size_t i = 0; i < m; i++) { + if (i != 0) { + offset += (t << 1); + } + const uint64_t W = root_of_unity_powers[m + i]; + const uint64_t W_precon = precon_root_of_unity_powers[m + i]; + + uint64_t* X_r = result + offset; + uint64_t* Y_r = X_r + t; + const uint64_t* X_op = X_r; + const uint64_t* Y_op = Y_r; + + FwdButterflyRadix2(X_r, Y_r, X_op, Y_op, W, W_precon, modulus, + twice_modulus); + } + break; + } + default: { + for (size_t i = 0; i < m; i++) { + if (i != 0) { + offset += (t << 1); + } + const uint64_t W = root_of_unity_powers[m + i]; + const uint64_t W_precon = precon_root_of_unity_powers[m + i]; + + uint64_t* X_r = result + offset; + uint64_t* Y_r = X_r + t; + const uint64_t* X_op = X_r; + const uint64_t* Y_op = Y_r; + + HEXL_LOOP_UNROLL_8 + for (size_t j = 0; j < t; j += 8) { + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, + modulus, twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, + modulus, twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, + modulus, twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, + modulus, twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, + modulus, twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, + modulus, twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, + modulus, twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, + modulus, twice_modulus); + } + } + } + } + t >>= 1; + } + if (output_mod_factor == 1) { + for (size_t i = 0; i < n; ++i) { + result[i] = ReduceMod<4>(result[i], modulus, &twice_modulus); + HEXL_CHECK(result[i] < modulus, "Incorrect modulus reduction in NTT " + << result[i] << " >= " << modulus); + } + } +} + +void ReferenceForwardTransformToBitReverse( + uint64_t* operand, uint64_t n, uint64_t modulus, + const uint64_t* root_of_unity_powers) { + HEXL_CHECK(NTT::CheckArguments(n, modulus), ""); + HEXL_CHECK(root_of_unity_powers != nullptr, + "root_of_unity_powers == nullptr"); + HEXL_CHECK(operand != nullptr, "operand == nullptr"); + + size_t t = (n >> 1); + for (size_t m = 1; m < n; m <<= 1) { + size_t offset = 0; + for (size_t i = 0; i < m; i++) { + size_t offset2 = offset + t; + const uint64_t W = root_of_unity_powers[m + i]; + + uint64_t* X = operand + offset; + uint64_t* Y = X + t; + for (size_t j = offset; j < offset2; j++) { + // X', Y' = X + WY, X - WY (mod q). + uint64_t tx = *X; + uint64_t W_x_Y = MultiplyMod(*Y, W, modulus); + *X++ = AddUIntMod(tx, W_x_Y, modulus); + *Y++ = SubUIntMod(tx, W_x_Y, modulus); + } + offset += (t << 1); + } + t >>= 1; + } +} + +void ReferenceInverseTransformFromBitReverse( + uint64_t* operand, uint64_t n, uint64_t modulus, + const uint64_t* inv_root_of_unity_powers) { + HEXL_CHECK(NTT::CheckArguments(n, modulus), ""); + HEXL_CHECK(inv_root_of_unity_powers != nullptr, + "inv_root_of_unity_powers == nullptr"); + HEXL_CHECK(operand != nullptr, "operand == nullptr"); + + size_t t = 1; + size_t root_index = 1; + for (size_t m = (n >> 1); m >= 1; m >>= 1) { + size_t offset = 0; + for (size_t i = 0; i < m; i++, root_index++) { + const uint64_t W = inv_root_of_unity_powers[root_index]; + uint64_t* X_r = operand + offset; + uint64_t* Y_r = X_r + t; + for (size_t j = 0; j < t; j++) { + uint64_t X_op = *X_r; + uint64_t Y_op = *Y_r; + // Butterfly X' = (X + Y) mod q, Y' = W(X-Y) mod q + *X_r = AddUIntMod(X_op, Y_op, modulus); + *Y_r = MultiplyMod(W, SubUIntMod(X_op, Y_op, modulus), modulus); + X_r++; + Y_r++; + } + offset += (t << 1); + } + t <<= 1; + } + + // Final multiplication by N^{-1} + const uint64_t inv_n = InverseMod(n, modulus); + for (size_t i = 0; i < n; ++i) { + operand[i] = MultiplyMod(operand[i], inv_n, modulus); + } +} + +void InverseTransformFromBitReverseRadix2( + uint64_t* result, const uint64_t* operand, uint64_t n, uint64_t modulus, + const uint64_t* inv_root_of_unity_powers, + const uint64_t* precon_inv_root_of_unity_powers, uint64_t input_mod_factor, + uint64_t output_mod_factor) { + HEXL_CHECK(NTT::CheckArguments(n, modulus), ""); + HEXL_CHECK(inv_root_of_unity_powers != nullptr, + "inv_root_of_unity_powers == nullptr"); + HEXL_CHECK(precon_inv_root_of_unity_powers != nullptr, + "precon_inv_root_of_unity_powers == nullptr"); + HEXL_CHECK(operand != nullptr, "operand == nullptr"); + HEXL_CHECK(input_mod_factor == 1 || input_mod_factor == 2, + "input_mod_factor must be 1 or 2; got " << input_mod_factor); + HEXL_UNUSED(input_mod_factor); + HEXL_CHECK(output_mod_factor == 1 || output_mod_factor == 2, + "output_mod_factor must be 1 or 2; got " << output_mod_factor); + + uint64_t twice_modulus = modulus << 1; + uint64_t n_div_2 = (n >> 1); + size_t t = 1; + size_t root_index = 1; + + for (size_t m = n_div_2; m > 1; m >>= 1) { + size_t offset = 0; + + switch (t) { + case 1: { + for (size_t i = 0; i < m; i++, root_index++) { + if (i != 0) { + offset += (t << 1); + } + const uint64_t W = inv_root_of_unity_powers[root_index]; + const uint64_t W_precon = precon_inv_root_of_unity_powers[root_index]; + + uint64_t* X_r = result + offset; + uint64_t* Y_r = X_r + t; + const uint64_t* X_op = operand + offset; + const uint64_t* Y_op = X_op + t; + InvButterflyRadix2(X_r, Y_r, X_op, Y_op, W, W_precon, modulus, + twice_modulus); + } + break; + } + case 2: { + for (size_t i = 0; i < m; i++, root_index++) { + if (i != 0) { + offset += (t << 1); + } + const uint64_t W = inv_root_of_unity_powers[root_index]; + const uint64_t W_precon = precon_inv_root_of_unity_powers[root_index]; + + uint64_t* X_r = result + offset; + uint64_t* Y_r = X_r + t; + const uint64_t* X_op = X_r; + const uint64_t* Y_op = Y_r; + InvButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + InvButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + } + break; + } + case 4: { + for (size_t i = 0; i < m; i++, root_index++) { + if (i != 0) { + offset += (t << 1); + } + const uint64_t W = inv_root_of_unity_powers[root_index]; + const uint64_t W_precon = precon_inv_root_of_unity_powers[root_index]; + + uint64_t* X_r = result + offset; + uint64_t* Y_r = X_r + t; + const uint64_t* X_op = X_r; + const uint64_t* Y_op = Y_r; + InvButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + InvButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + InvButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + InvButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + } + break; + } + case 8: { + for (size_t i = 0; i < m; i++, root_index++) { + if (i != 0) { + offset += (t << 1); + } + const uint64_t W = inv_root_of_unity_powers[root_index]; + const uint64_t W_precon = precon_inv_root_of_unity_powers[root_index]; + + uint64_t* X_r = result + offset; + uint64_t* Y_r = X_r + t; + const uint64_t* X_op = X_r; + const uint64_t* Y_op = Y_r; + InvButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + InvButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + InvButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + InvButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + InvButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + InvButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + InvButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + InvButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + } + break; + } + default: { + for (size_t i = 0; i < m; i++, root_index++) { + if (i != 0) { + offset += (t << 1); + } + const uint64_t W = inv_root_of_unity_powers[root_index]; + const uint64_t W_precon = precon_inv_root_of_unity_powers[root_index]; + + uint64_t* X_r = result + offset; + uint64_t* Y_r = X_r + t; + const uint64_t* X_op = X_r; + const uint64_t* Y_op = Y_r; + + HEXL_LOOP_UNROLL_8 + for (size_t j = 0; j < t; j += 8) { + InvButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, + modulus, twice_modulus); + InvButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, + modulus, twice_modulus); + InvButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, + modulus, twice_modulus); + InvButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, + modulus, twice_modulus); + InvButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, + modulus, twice_modulus); + InvButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, + modulus, twice_modulus); + InvButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, + modulus, twice_modulus); + InvButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, + modulus, twice_modulus); + } + } + } + } + t <<= 1; + } + + // When M is too short it only needs the final stage butterfly. Copying here + // in the case of out-of-place. + if (result != operand && n == 2) { + std::memcpy(result, operand, n * sizeof(uint64_t)); + } + + // Fold multiplication by N^{-1} to final stage butterfly + const uint64_t W = inv_root_of_unity_powers[n - 1]; + + const uint64_t inv_n = InverseMod(n, modulus); + uint64_t inv_n_precon = MultiplyFactor(inv_n, 64, modulus).BarrettFactor(); + const uint64_t inv_n_w = MultiplyMod(inv_n, W, modulus); + uint64_t inv_n_w_precon = + MultiplyFactor(inv_n_w, 64, modulus).BarrettFactor(); + + uint64_t* X = result; + uint64_t* Y = X + n_div_2; + for (size_t j = 0; j < n_div_2; ++j) { + // Assume X, Y in [0, 2q) and compute + // X' = N^{-1} (X + Y) (mod q) + // Y' = N^{-1} * W * (X - Y) (mod q) + uint64_t tx = AddUIntMod(X[j], Y[j], twice_modulus); + uint64_t ty = X[j] + twice_modulus - Y[j]; + X[j] = MultiplyModLazy<64>(tx, inv_n, inv_n_precon, modulus); + Y[j] = MultiplyModLazy<64>(ty, inv_n_w, inv_n_w_precon, modulus); + } + + if (output_mod_factor == 1) { + // Reduce from [0, 2q) to [0,q) + for (size_t i = 0; i < n; ++i) { + result[i] = ReduceMod<2>(result[i], modulus); + HEXL_CHECK(result[i] < modulus, "Incorrect modulus reduction in InvNTT" + << result[i] << " >= " << modulus); + } + } +} + +} // namespace hexl +} // namespace intel diff --git a/hexl_omp/ntt/ntt-radix-4.cpp b/hexl_omp/ntt/ntt-radix-4.cpp new file mode 100644 index 00000000..1d21fece --- /dev/null +++ b/hexl_omp/ntt/ntt-radix-4.cpp @@ -0,0 +1,622 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include + +#include "hexl/logging/logging.hpp" +#include "hexl/ntt/ntt.hpp" +#include "hexl/number-theory/number-theory.hpp" +#include "hexl/util/aligned-allocator.hpp" +#include "hexl/util/check.hpp" +#include "ntt/ntt-default.hpp" +#include "util/cpu-features.hpp" + +namespace intel { +namespace hexl { + +void ForwardTransformToBitReverseRadix4( + uint64_t* result, const uint64_t* operand, uint64_t n, uint64_t modulus, + const uint64_t* root_of_unity_powers, + const uint64_t* precon_root_of_unity_powers, uint64_t input_mod_factor, + uint64_t output_mod_factor) { + HEXL_CHECK(NTT::CheckArguments(n, modulus), ""); + HEXL_CHECK_BOUNDS(operand, n, modulus * input_mod_factor, + "operand exceeds bound " << modulus * input_mod_factor); + HEXL_CHECK(root_of_unity_powers != nullptr, + "root_of_unity_powers == nullptr"); + HEXL_CHECK(precon_root_of_unity_powers != nullptr, + "precon_root_of_unity_powers == nullptr"); + HEXL_CHECK( + input_mod_factor == 1 || input_mod_factor == 2 || input_mod_factor == 4, + "input_mod_factor must be 1, 2, or 4; got " << input_mod_factor); + HEXL_UNUSED(input_mod_factor); + HEXL_CHECK(output_mod_factor == 1 || output_mod_factor == 4, + "output_mod_factor must be 1 or 4; got " << output_mod_factor); + + HEXL_VLOG(3, "modulus " << modulus); + HEXL_VLOG(3, "n " << n); + + HEXL_VLOG(3, "operand " << std::vector(operand, operand + n)); + + HEXL_VLOG(3, "root_of_unity_powers " << std::vector( + root_of_unity_powers, root_of_unity_powers + n)); + + bool is_power_of_4 = IsPowerOfFour(n); + + uint64_t twice_modulus = modulus << 1; + uint64_t four_times_modulus = modulus << 2; + + // Radix-2 step for non-powers of 4 + if (!is_power_of_4) { + HEXL_VLOG(3, "Radix 2 step"); + + size_t t = (n >> 1); + + const uint64_t W = root_of_unity_powers[1]; + const uint64_t W_precon = precon_root_of_unity_powers[1]; + + uint64_t* X_r = result; + uint64_t* Y_r = X_r + t; + const uint64_t* X_op = operand; + const uint64_t* Y_op = X_op + t; + + // HEXL_LOOP_UNROLL_8 + for (size_t j = 0; j < t; j++) { + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + } + // Data in [0, 4q) + HEXL_VLOG(3, "after radix 2 outputs " + << std::vector(result, result + n)); + } + + uint64_t m_start = 2; + size_t t = n >> 3; + if (is_power_of_4) { + t = n >> 2; + + uint64_t* X_r0 = result; + uint64_t* X_r1 = X_r0 + t; + uint64_t* X_r2 = X_r0 + 2 * t; + uint64_t* X_r3 = X_r0 + 3 * t; + const uint64_t* X_op0 = operand; + const uint64_t* X_op1 = operand + t; + const uint64_t* X_op2 = operand + 2 * t; + const uint64_t* X_op3 = operand + 3 * t; + + uint64_t W1_ind = 1; + uint64_t W2_ind = 2 * W1_ind; + uint64_t W3_ind = 2 * W1_ind + 1; + + const uint64_t W1 = root_of_unity_powers[W1_ind]; + const uint64_t W2 = root_of_unity_powers[W2_ind]; + const uint64_t W3 = root_of_unity_powers[W3_ind]; + + const uint64_t W1_precon = precon_root_of_unity_powers[W1_ind]; + const uint64_t W2_precon = precon_root_of_unity_powers[W2_ind]; + const uint64_t W3_precon = precon_root_of_unity_powers[W3_ind]; + + switch (t) { + case 4: { + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, W3, + W3_precon, modulus, twice_modulus, + four_times_modulus); + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, W3, + W3_precon, modulus, twice_modulus, + four_times_modulus); + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, W3, + W3_precon, modulus, twice_modulus, + four_times_modulus); + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, W3, + W3_precon, modulus, twice_modulus, + four_times_modulus); + break; + } + case 1: { + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, W3, + W3_precon, modulus, twice_modulus, + four_times_modulus); + break; + } + default: { + for (size_t j = 0; j < t; j += 16) { + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, W3, + W3_precon, modulus, twice_modulus, + four_times_modulus); + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, W3, + W3_precon, modulus, twice_modulus, + four_times_modulus); + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, W3, + W3_precon, modulus, twice_modulus, + four_times_modulus); + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, W3, + W3_precon, modulus, twice_modulus, + four_times_modulus); + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, W3, + W3_precon, modulus, twice_modulus, + four_times_modulus); + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, W3, + W3_precon, modulus, twice_modulus, + four_times_modulus); + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, W3, + W3_precon, modulus, twice_modulus, + four_times_modulus); + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, W3, + W3_precon, modulus, twice_modulus, + four_times_modulus); + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, W3, + W3_precon, modulus, twice_modulus, + four_times_modulus); + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, W3, + W3_precon, modulus, twice_modulus, + four_times_modulus); + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, W3, + W3_precon, modulus, twice_modulus, + four_times_modulus); + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, W3, + W3_precon, modulus, twice_modulus, + four_times_modulus); + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, W3, + W3_precon, modulus, twice_modulus, + four_times_modulus); + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, W3, + W3_precon, modulus, twice_modulus, + four_times_modulus); + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, W3, + W3_precon, modulus, twice_modulus, + four_times_modulus); + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, W3, + W3_precon, modulus, twice_modulus, + four_times_modulus); + } + } + } + t >>= 2; + m_start = 4; + } + + // uint64_t m_start = is_power_of_4 ? 1 : 2; + // size_t t = (n >> m_start) >> 1; + + for (size_t m = m_start; m < n; m <<= 2) { + HEXL_VLOG(3, "m " << m); + + size_t X0_offset = 0; + + switch (t) { + case 4: { + // HEXL_LOOP_UNROLL_8 + for (size_t i = 0; i < m; i++) { + if (i != 0) { + X0_offset += 4 * t; + } + uint64_t* X_r0 = result + X0_offset; + uint64_t* X_r1 = X_r0 + t; + uint64_t* X_r2 = X_r0 + 2 * t; + uint64_t* X_r3 = X_r0 + 3 * t; + const uint64_t* X_op0 = X_r0; + const uint64_t* X_op1 = X_r1; + const uint64_t* X_op2 = X_r2; + const uint64_t* X_op3 = X_r3; + + uint64_t W1_ind = m + i; + uint64_t W2_ind = 2 * W1_ind; + uint64_t W3_ind = 2 * W1_ind + 1; + + const uint64_t W1 = root_of_unity_powers[W1_ind]; + const uint64_t W2 = root_of_unity_powers[W2_ind]; + const uint64_t W3 = root_of_unity_powers[W3_ind]; + + const uint64_t W1_precon = precon_root_of_unity_powers[W1_ind]; + const uint64_t W2_precon = precon_root_of_unity_powers[W2_ind]; + const uint64_t W3_precon = precon_root_of_unity_powers[W3_ind]; + + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, W3, + W3_precon, modulus, twice_modulus, + four_times_modulus); + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, W3, + W3_precon, modulus, twice_modulus, + four_times_modulus); + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, W3, + W3_precon, modulus, twice_modulus, + four_times_modulus); + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, W3, + W3_precon, modulus, twice_modulus, + four_times_modulus); + } + break; + } + case 1: { + // HEXL_LOOP_UNROLL_8 + for (size_t i = 0; i < m; i++) { + if (i != 0) { + X0_offset += 4 * t; + } + uint64_t* X_r0 = result + X0_offset; + uint64_t* X_r1 = X_r0 + t; + uint64_t* X_r2 = X_r0 + 2 * t; + uint64_t* X_r3 = X_r0 + 3 * t; + const uint64_t* X_op0 = X_r0; + const uint64_t* X_op1 = X_r1; + const uint64_t* X_op2 = X_r2; + const uint64_t* X_op3 = X_r3; + + uint64_t W1_ind = m + i; + uint64_t W2_ind = 2 * W1_ind; + uint64_t W3_ind = 2 * W1_ind + 1; + + const uint64_t W1 = root_of_unity_powers[W1_ind]; + const uint64_t W2 = root_of_unity_powers[W2_ind]; + const uint64_t W3 = root_of_unity_powers[W3_ind]; + + const uint64_t W1_precon = precon_root_of_unity_powers[W1_ind]; + const uint64_t W2_precon = precon_root_of_unity_powers[W2_ind]; + const uint64_t W3_precon = precon_root_of_unity_powers[W3_ind]; + + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, W3, + W3_precon, modulus, twice_modulus, + four_times_modulus); + } + break; + } + default: { + for (size_t i = 0; i < m; i++) { + if (i != 0) { + X0_offset += 4 * t; + } + uint64_t* X_r0 = result + X0_offset; + uint64_t* X_r1 = X_r0 + t; + uint64_t* X_r2 = X_r0 + 2 * t; + uint64_t* X_r3 = X_r0 + 3 * t; + const uint64_t* X_op0 = X_r0; + const uint64_t* X_op1 = X_r1; + const uint64_t* X_op2 = X_r2; + const uint64_t* X_op3 = X_r3; + + uint64_t W1_ind = m + i; + uint64_t W2_ind = 2 * W1_ind; + uint64_t W3_ind = 2 * W1_ind + 1; + + const uint64_t W1 = root_of_unity_powers[W1_ind]; + const uint64_t W2 = root_of_unity_powers[W2_ind]; + const uint64_t W3 = root_of_unity_powers[W3_ind]; + + const uint64_t W1_precon = precon_root_of_unity_powers[W1_ind]; + const uint64_t W2_precon = precon_root_of_unity_powers[W2_ind]; + const uint64_t W3_precon = precon_root_of_unity_powers[W3_ind]; + + for (size_t j = 0; j < t; j += 16) { + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, + W3, W3_precon, modulus, twice_modulus, + four_times_modulus); + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, + W3, W3_precon, modulus, twice_modulus, + four_times_modulus); + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, + W3, W3_precon, modulus, twice_modulus, + four_times_modulus); + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, + W3, W3_precon, modulus, twice_modulus, + four_times_modulus); + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, + W3, W3_precon, modulus, twice_modulus, + four_times_modulus); + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, + W3, W3_precon, modulus, twice_modulus, + four_times_modulus); + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, + W3, W3_precon, modulus, twice_modulus, + four_times_modulus); + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, + W3, W3_precon, modulus, twice_modulus, + four_times_modulus); + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, + W3, W3_precon, modulus, twice_modulus, + four_times_modulus); + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, + W3, W3_precon, modulus, twice_modulus, + four_times_modulus); + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, + W3, W3_precon, modulus, twice_modulus, + four_times_modulus); + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, + W3, W3_precon, modulus, twice_modulus, + four_times_modulus); + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, + W3, W3_precon, modulus, twice_modulus, + four_times_modulus); + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, + W3, W3_precon, modulus, twice_modulus, + four_times_modulus); + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, + W3, W3_precon, modulus, twice_modulus, + four_times_modulus); + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, + W3, W3_precon, modulus, twice_modulus, + four_times_modulus); + } + } + } + } + t >>= 2; + } + + if (output_mod_factor == 1) { + for (size_t i = 0; i < n; ++i) { + if (result[i] >= twice_modulus) { + result[i] -= twice_modulus; + } + if (result[i] >= modulus) { + result[i] -= modulus; + } + HEXL_CHECK(result[i] < modulus, "Incorrect modulus reduction in NTT " + << result[i] << " >= " << modulus); + } + } + + HEXL_VLOG(3, "outputs " << std::vector(result, result + n)); +} + +void InverseTransformFromBitReverseRadix4( + uint64_t* result, const uint64_t* operand, uint64_t n, uint64_t modulus, + const uint64_t* inv_root_of_unity_powers, + const uint64_t* precon_inv_root_of_unity_powers, uint64_t input_mod_factor, + uint64_t output_mod_factor) { + HEXL_CHECK(NTT::CheckArguments(n, modulus), ""); + HEXL_CHECK(inv_root_of_unity_powers != nullptr, + "inv_root_of_unity_powers == nullptr"); + HEXL_CHECK(precon_inv_root_of_unity_powers != nullptr, + "precon_inv_root_of_unity_powers == nullptr"); + HEXL_CHECK(operand != nullptr, "operand == nullptr"); + HEXL_CHECK(input_mod_factor == 1 || input_mod_factor == 2, + "input_mod_factor must be 1 or 2; got " << input_mod_factor); + HEXL_UNUSED(input_mod_factor); + HEXL_CHECK(output_mod_factor == 1 || output_mod_factor == 2, + "output_mod_factor must be 1 or 2; got " << output_mod_factor); + + uint64_t twice_modulus = modulus << 1; + uint64_t n_div_2 = (n >> 1); + + bool is_power_of_4 = IsPowerOfFour(n); + // Radix-2 step for powers of 4 + if (is_power_of_4) { + uint64_t* X_r = result; + uint64_t* Y_r = X_r + 1; + const uint64_t* X_op = operand; + const uint64_t* Y_op = X_op + 1; + const uint64_t* W = inv_root_of_unity_powers + 1; + const uint64_t* W_precon = precon_inv_root_of_unity_powers + 1; + + // HEXL_LOOP_UNROLL_8 + for (size_t j = 0; j < n / 2; j++) { + InvButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, *W++, *W_precon++, + modulus, twice_modulus); + X_r++; + Y_r++; + X_op++; + Y_op++; + } + // Data in [0, 2q) + } + + uint64_t m_start = n >> (is_power_of_4 ? 3 : 2); + size_t t = is_power_of_4 ? 2 : 1; + + size_t w1_root_index = 1 + (is_power_of_4 ? n_div_2 : 0); + size_t w3_root_index = n_div_2 + 1 + (is_power_of_4 ? (n / 4) : 0); + + HEXL_VLOG(4, "m_start " << m_start); + + for (size_t m = m_start; m > 0; m >>= 2) { + HEXL_VLOG(4, "m " << m); + HEXL_VLOG(4, "t " << t); + + size_t X0_offset = 0; + + switch (t) { + case 1: { + for (size_t i = 0; i < m; i++) { + HEXL_VLOG(4, "i " << i); + if (i != 0) { + X0_offset += 4 * t; + } + + uint64_t* X_r0 = result + X0_offset; + uint64_t* X_r1 = X_r0 + t; + uint64_t* X_r2 = X_r0 + 2 * t; + uint64_t* X_r3 = X_r0 + 3 * t; + const uint64_t* X_op0 = operand + X0_offset; + const uint64_t* X_op1 = X_op0 + t; + const uint64_t* X_op2 = X_op0 + 2 * t; + const uint64_t* X_op3 = X_op0 + 3 * t; + + uint64_t W1_ind = w1_root_index++; + uint64_t W2_ind = w1_root_index++; + uint64_t W3_ind = w3_root_index++; + + const uint64_t W1 = inv_root_of_unity_powers[W1_ind]; + const uint64_t W2 = inv_root_of_unity_powers[W2_ind]; + const uint64_t W3 = inv_root_of_unity_powers[W3_ind]; + + const uint64_t W1_precon = precon_inv_root_of_unity_powers[W1_ind]; + const uint64_t W2_precon = precon_inv_root_of_unity_powers[W2_ind]; + const uint64_t W3_precon = precon_inv_root_of_unity_powers[W3_ind]; + + InvButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, W3, + W3_precon, modulus, twice_modulus); + } + break; + } + case 4: { + for (size_t i = 0; i < m; i++) { + HEXL_VLOG(4, "i " << i); + if (i != 0) { + X0_offset += 4 * t; + } + + uint64_t* X_r0 = result + X0_offset; + uint64_t* X_r1 = X_r0 + t; + uint64_t* X_r2 = X_r0 + 2 * t; + uint64_t* X_r3 = X_r0 + 3 * t; + const uint64_t* X_op0 = X_r0; + const uint64_t* X_op1 = X_r1; + const uint64_t* X_op2 = X_r2; + const uint64_t* X_op3 = X_r3; + + uint64_t W1_ind = w1_root_index++; + uint64_t W2_ind = w1_root_index++; + uint64_t W3_ind = w3_root_index++; + + const uint64_t W1 = inv_root_of_unity_powers[W1_ind]; + const uint64_t W2 = inv_root_of_unity_powers[W2_ind]; + const uint64_t W3 = inv_root_of_unity_powers[W3_ind]; + + const uint64_t W1_precon = precon_inv_root_of_unity_powers[W1_ind]; + const uint64_t W2_precon = precon_inv_root_of_unity_powers[W2_ind]; + const uint64_t W3_precon = precon_inv_root_of_unity_powers[W3_ind]; + + InvButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, W3, + W3_precon, modulus, twice_modulus); + InvButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, W3, + W3_precon, modulus, twice_modulus); + InvButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, W3, + W3_precon, modulus, twice_modulus); + InvButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, W3, + W3_precon, modulus, twice_modulus); + } + break; + } + default: { + // HEXL_LOOP_UNROLL_4 + for (size_t i = 0; i < m; i++) { + HEXL_VLOG(4, "i " << i); + if (i != 0) { + X0_offset += 4 * t; + } + + uint64_t* X_r0 = result + X0_offset; + uint64_t* X_r1 = X_r0 + t; + uint64_t* X_r2 = X_r0 + 2 * t; + uint64_t* X_r3 = X_r0 + 3 * t; + const uint64_t* X_op0 = X_r0; + const uint64_t* X_op1 = X_r1; + const uint64_t* X_op2 = X_r2; + const uint64_t* X_op3 = X_r3; + + uint64_t W1_ind = w1_root_index++; + uint64_t W2_ind = w1_root_index++; + uint64_t W3_ind = w3_root_index++; + + const uint64_t W1 = inv_root_of_unity_powers[W1_ind]; + const uint64_t W2 = inv_root_of_unity_powers[W2_ind]; + const uint64_t W3 = inv_root_of_unity_powers[W3_ind]; + + const uint64_t W1_precon = precon_inv_root_of_unity_powers[W1_ind]; + const uint64_t W2_precon = precon_inv_root_of_unity_powers[W2_ind]; + const uint64_t W3_precon = precon_inv_root_of_unity_powers[W3_ind]; + + for (size_t j = 0; j < t; j++) { + HEXL_VLOG(4, "j " << j); + InvButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, + W3, W3_precon, modulus, twice_modulus); + } + } + } + } + t <<= 2; + w1_root_index += m; + w3_root_index += m / 2; + } + + // When M is too short it only needs the final stage butterfly. Copying here + // in the case of out-of-place. + if (result != operand && n == 2) { + std::memcpy(result, operand, n * sizeof(uint64_t)); + } + + HEXL_VLOG(4, "Starting final invNTT stage"); + HEXL_VLOG(4, "operand " << std::vector(result, result + n)); + + // Fold multiplication by N^{-1} to final stage butterfly + const uint64_t W = inv_root_of_unity_powers[n - 1]; + HEXL_VLOG(4, "final W " << W); + + const uint64_t inv_n = InverseMod(n, modulus); + uint64_t inv_n_precon = MultiplyFactor(inv_n, 64, modulus).BarrettFactor(); + const uint64_t inv_n_w = MultiplyMod(inv_n, W, modulus); + uint64_t inv_n_w_precon = + MultiplyFactor(inv_n_w, 64, modulus).BarrettFactor(); + + uint64_t* X = result; + uint64_t* Y = X + n_div_2; + for (size_t j = 0; j < n_div_2; ++j) { + // Assume X, Y in [0, 2q) and compute + // X' = N^{-1} (X + Y) (mod q) + // Y' = N^{-1} * W * (X - Y) (mod q) + // with X', Y' in [0, 2q) + uint64_t tx = AddUIntMod(X[j], Y[j], twice_modulus); + uint64_t ty = X[j] + twice_modulus - Y[j]; + X[j] = MultiplyModLazy<64>(tx, inv_n, inv_n_precon, modulus); + Y[j] = MultiplyModLazy<64>(ty, inv_n_w, inv_n_w_precon, modulus); + } + + if (output_mod_factor == 1) { + // Reduce from [0, 2q) to [0,q) + for (size_t i = 0; i < n; ++i) { + result[i] = ReduceMod<2>(result[i], modulus); + HEXL_CHECK(result[i] < modulus, "Incorrect modulus reduction in InvNTT" + << result[i] << " >= " << modulus); + } + } +} + +} // namespace hexl +} // namespace intel diff --git a/hexl_omp/number-theory/number-theory.cpp b/hexl_omp/number-theory/number-theory.cpp new file mode 100644 index 00000000..d8d6f079 --- /dev/null +++ b/hexl_omp/number-theory/number-theory.cpp @@ -0,0 +1,264 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "hexl/number-theory/number-theory.hpp" + +#include "hexl/logging/logging.hpp" +#include "hexl/util/check.hpp" +#include "util/util-internal.hpp" + +namespace intel { +namespace hexl { + +uint64_t InverseMod(uint64_t input, uint64_t modulus) { + uint64_t a = input % modulus; + HEXL_CHECK(a != 0, input << " does not have a InverseMod"); + + if (modulus == 1) { + return 0; + } + + int64_t m0 = static_cast(modulus); + int64_t y = 0; + int64_t x = 1; + while (a > 1) { + // q is quotient + int64_t q = static_cast(a / modulus); + + int64_t t = static_cast(modulus); + modulus = a % modulus; + a = static_cast(t); + + // Update y and x + t = y; + y = x - q * y; + x = t; + } + + // Make x positive + if (x < 0) x += m0; + + return uint64_t(x); +} + +uint64_t MultiplyMod(uint64_t x, uint64_t y, uint64_t modulus) { + HEXL_CHECK(modulus != 0, "modulus == 0"); + HEXL_CHECK(x < modulus, "x " << x << " >= modulus " << modulus); + HEXL_CHECK(y < modulus, "y " << y << " >= modulus " << modulus); + uint64_t prod_hi, prod_lo; + MultiplyUInt64(x, y, &prod_hi, &prod_lo); + + return BarrettReduce128(prod_hi, prod_lo, modulus); +} + +uint64_t MultiplyMod(uint64_t x, uint64_t y, uint64_t y_precon, + uint64_t modulus) { + uint64_t q = MultiplyUInt64Hi<64>(x, y_precon); + q = x * y - q * modulus; + return q >= modulus ? q - modulus : q; +} + +uint64_t AddUIntMod(uint64_t x, uint64_t y, uint64_t modulus) { + HEXL_CHECK(x < modulus, "x " << x << " >= modulus " << modulus); + HEXL_CHECK(y < modulus, "y " << y << " >= modulus " << modulus); + uint64_t sum = x + y; + return (sum >= modulus) ? (sum - modulus) : sum; +} + +uint64_t SubUIntMod(uint64_t x, uint64_t y, uint64_t modulus) { + HEXL_CHECK(x < modulus, "x " << x << " >= modulus " << modulus); + HEXL_CHECK(y < modulus, "y " << y << " >= modulus " << modulus); + uint64_t diff = (x + modulus) - y; + return (diff >= modulus) ? (diff - modulus) : diff; +} + +// Returns base^exp mod modulus +uint64_t PowMod(uint64_t base, uint64_t exp, uint64_t modulus) { + base %= modulus; + uint64_t result = 1; + while (exp > 0) { + if (exp & 1) { + result = MultiplyMod(result, base, modulus); + } + base = MultiplyMod(base, base, modulus); + exp >>= 1; + } + return result; +} + +// Returns true whether root is a degree-th root of unity +// degree must be a power of two. +bool IsPrimitiveRoot(uint64_t root, uint64_t degree, uint64_t modulus) { + if (root == 0) { + return false; + } + HEXL_CHECK(IsPowerOfTwo(degree), degree << " not a power of 2"); + + HEXL_VLOG(4, "IsPrimitiveRoot root " << root << ", degree " << degree + << ", modulus " << modulus); + + // Check if root^(degree/2) == -1 mod modulus + return PowMod(root, degree / 2, modulus) == (modulus - 1); +} + +// Tries to return a primitive degree-th root of unity +// throw error if no root is found +uint64_t GeneratePrimitiveRoot(uint64_t degree, uint64_t modulus) { + // We need to divide modulus-1 by degree to get the size of the quotient group + uint64_t size_entire_group = modulus - 1; + + // Compute size of quotient group + uint64_t size_quotient_group = size_entire_group / degree; + + for (int trial = 0; trial < 200; ++trial) { + uint64_t root = GenerateInsecureUniformIntRandomValue(0, modulus); + root = PowMod(root, size_quotient_group, modulus); + + if (IsPrimitiveRoot(root, degree, modulus)) { + return root; + } + } + HEXL_CHECK(false, "no primitive root found for degree " + << degree << " modulus " << modulus); + return 0; +} + +// Returns true whether root is a degree-th root of unity +// degree must be a power of two. +uint64_t MinimalPrimitiveRoot(uint64_t degree, uint64_t modulus) { + HEXL_CHECK(IsPowerOfTwo(degree), + "Degere " << degree << " is not a power of 2"); + + uint64_t root = GeneratePrimitiveRoot(degree, modulus); + + uint64_t generator_sq = MultiplyMod(root, root, modulus); + uint64_t current_generator = root; + + uint64_t min_root = root; + + // Check if root^(degree/2) == -1 mod modulus + for (size_t i = 0; i < degree; ++i) { + if (current_generator < min_root) { + min_root = current_generator; + } + current_generator = MultiplyMod(current_generator, generator_sq, modulus); + } + + return min_root; +} + +uint64_t ReverseBits(uint64_t x, uint64_t bit_width) { + HEXL_CHECK(x == 0 || MSB(x) <= bit_width, "MSB(" << x << ") = " << MSB(x) + << " must be >= bit_width " + << bit_width) + if (bit_width == 0) { + return 0; + } + uint64_t rev = 0; + for (uint64_t i = bit_width; i > 0; i--) { + rev |= ((x & 1) << (i - 1)); + x >>= 1; + } + return rev; +} + +// Miller-Rabin primality test +bool IsPrime(uint64_t n) { + // n < 2^64, so it is enough to test a=2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, + // and 37. See + // https://en.wikipedia.org/wiki/Miller%E2%80%93Rabin_primality_test#Testing_against_small_sets_of_bases + static const std::vector as{2, 3, 5, 7, 11, 13, + 17, 19, 23, 29, 31, 37}; + + for (const uint64_t a : as) { + if (n == a) return true; + if (n % a == 0) return false; + } + + // Write n == 2**r * d + 1 with d odd. + uint64_t r = 63; + while (r > 0) { + uint64_t two_pow_r = (1ULL << r); + if ((n - 1) % two_pow_r == 0) { + break; + } + --r; + } + HEXL_CHECK(r != 0, "Error factoring n " << n); + uint64_t d = (n - 1) / (1ULL << r); + + HEXL_CHECK(n == (1ULL << r) * d + 1, "Error factoring n " << n); + HEXL_CHECK(d % 2 == 1, "d is even"); + + for (const uint64_t a : as) { + uint64_t x = PowMod(a, d, n); + if ((x == 1) || (x == n - 1)) { + continue; + } + + bool prime = false; + for (uint64_t i = 1; i < r; ++i) { + x = PowMod(x, 2, n); + if (x == n - 1) { + prime = true; + break; + } + } + if (!prime) { + return false; + } + } + return true; +} + +std::vector GeneratePrimes(size_t num_primes, size_t bit_size, + bool prefer_small_primes, + size_t ntt_size) { + HEXL_CHECK(num_primes > 0, "num_primes == 0"); + HEXL_CHECK(IsPowerOfTwo(ntt_size), + "ntt_size " << ntt_size << " is not a power of two"); + HEXL_CHECK(Log2(ntt_size) < bit_size, + "log2(ntt_size) " << Log2(ntt_size) + << " should be less than bit_size " << bit_size); + + int64_t prime_lower_bound = (1LL << bit_size) + 1LL; + int64_t prime_upper_bound = (1LL << (bit_size + 1LL)) - 1LL; + + // Keep signed to enable negative step + int64_t prime_candidate = + prefer_small_primes + ? prime_lower_bound + : prime_upper_bound - (prime_upper_bound % (2 * ntt_size)) + 1; + HEXL_CHECK(prime_candidate % (2 * ntt_size) == 1, "bad prime candidate"); + + // Ensure prime % 2 * ntt_size == 1 + int64_t prime_candidate_step = + (prefer_small_primes ? 1 : -1) * 2 * static_cast(ntt_size); + + auto continue_condition = [&](int64_t local_candidate_prime) { + if (prefer_small_primes) { + return local_candidate_prime < prime_upper_bound; + } else { + return local_candidate_prime > prime_lower_bound; + } + }; + + std::vector ret; + + while (continue_condition(prime_candidate)) { + if (IsPrime(prime_candidate)) { + HEXL_CHECK(prime_candidate % (2 * ntt_size) == 1, "bad prime candidate"); + ret.emplace_back(static_cast(prime_candidate)); + if (ret.size() == num_primes) { + return ret; + } + } + prime_candidate += prime_candidate_step; + } + + HEXL_CHECK(false, "Failed to find enough primes"); + return ret; +} + +} // namespace hexl +} // namespace intel diff --git a/hexl_omp/util/avx512-util.hpp b/hexl_omp/util/avx512-util.hpp new file mode 100644 index 00000000..c5e206a9 --- /dev/null +++ b/hexl_omp/util/avx512-util.hpp @@ -0,0 +1,523 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +#include + +#include "hexl/logging/logging.hpp" +#include "hexl/number-theory/number-theory.hpp" +#include "hexl/util/check.hpp" +#include "hexl/util/defines.hpp" +#include "hexl/util/util.hpp" + +namespace intel { +namespace hexl { + +#ifdef HEXL_HAS_AVX512DQ + +/// @brief Returns the unsigned 64-bit integer values in x as a vector +inline std::vector ExtractValues(__m512i x) { + __m256i x0 = _mm512_extracti64x4_epi64(x, 0); + __m256i x1 = _mm512_extracti64x4_epi64(x, 1); + + std::vector xs{static_cast(_mm256_extract_epi64(x0, 0)), + static_cast(_mm256_extract_epi64(x0, 1)), + static_cast(_mm256_extract_epi64(x0, 2)), + static_cast(_mm256_extract_epi64(x0, 3)), + static_cast(_mm256_extract_epi64(x1, 0)), + static_cast(_mm256_extract_epi64(x1, 1)), + static_cast(_mm256_extract_epi64(x1, 2)), + static_cast(_mm256_extract_epi64(x1, 3))}; + + return xs; +} + +/// @brief Returns the signed 64-bit integer values in x as a vector +inline std::vector ExtractIntValues(__m512i x) { + __m256i x0 = _mm512_extracti64x4_epi64(x, 0); + __m256i x1 = _mm512_extracti64x4_epi64(x, 1); + + std::vector xs{static_cast(_mm256_extract_epi64(x0, 0)), + static_cast(_mm256_extract_epi64(x0, 1)), + static_cast(_mm256_extract_epi64(x0, 2)), + static_cast(_mm256_extract_epi64(x0, 3)), + static_cast(_mm256_extract_epi64(x1, 0)), + static_cast(_mm256_extract_epi64(x1, 1)), + static_cast(_mm256_extract_epi64(x1, 2)), + static_cast(_mm256_extract_epi64(x1, 3))}; + + return xs; +} + +// Returns the 64-bit floating-point values in x as a vector +inline std::vector ExtractValues(__m512d x) { + const double* x_ptr = reinterpret_cast(&x); + return std::vector{x_ptr, x_ptr + 8}; +} + +// Returns lower NumBits bits from a 64-bit value +template +inline __m512i ClearTopBits64(__m512i x) { + const __m512i low52b_mask = _mm512_set1_epi64((1ULL << NumBits) - 1); + return _mm512_and_epi64(x, low52b_mask); +} + +// Multiply packed unsigned BitShift-bit integers in each 64-bit element of x +// and y to form a 2*BitShift-bit intermediate result. +// Returns the high BitShift-bit unsigned integer from the intermediate result +template +inline __m512i _mm512_hexl_mulhi_epi(__m512i x, __m512i y); + +// Dummy implementation to avoid template substitution errors +template <> +inline __m512i _mm512_hexl_mulhi_epi<32>(__m512i x, __m512i y) { + HEXL_CHECK(false, "Unimplemented"); + HEXL_UNUSED(x); + HEXL_UNUSED(y); + return x; +} + +template <> +inline __m512i _mm512_hexl_mulhi_epi<64>(__m512i x, __m512i y) { + // https://stackoverflow.com/questions/28807341/simd-signed-with-unsigned-multiplication-for-64-bit-64-bit-to-128-bit + __m512i lo_mask = _mm512_set1_epi64(0x00000000ffffffff); + // Shuffle high bits with low bits in each 64-bit integer => + // x0_lo, x0_hi, x1_lo, x1_hi, x2_lo, x2_hi, ... + __m512i x_hi = _mm512_shuffle_epi32(x, (_MM_PERM_ENUM)0xB1); + // y0_lo, y0_hi, y1_lo, y1_hi, y2_lo, y2_hi, ... + __m512i y_hi = _mm512_shuffle_epi32(y, (_MM_PERM_ENUM)0xB1); + __m512i z_lo_lo = _mm512_mul_epu32(x, y); // x_lo * y_lo + __m512i z_lo_hi = _mm512_mul_epu32(x, y_hi); // x_lo * y_hi + __m512i z_hi_lo = _mm512_mul_epu32(x_hi, y); // x_hi * y_lo + __m512i z_hi_hi = _mm512_mul_epu32(x_hi, y_hi); // x_hi * y_hi + + // x_hi | x_lo + // x y_hi | y_lo + // ------------------------------ + // [x_lo * y_lo] // z_lo_lo + // + [z_lo * y_hi] // z_lo_hi + // + [x_hi * y_lo] // z_hi_lo + // + [x_hi * y_hi] // z_hi_hi + // ^-----------^ <-- only bits needed + // sum_| hi | mid | lo | + + // Low bits of z_lo_lo are not needed + __m512i z_lo_lo_shift = _mm512_srli_epi64(z_lo_lo, 32); + + // [x_lo * y_lo] // z_lo_lo + // + [z_lo * y_hi] // z_lo_hi + // ------------------------ + // | sum_tmp | + // |sum_mid|sum_lo| + __m512i sum_tmp = _mm512_add_epi64(z_lo_hi, z_lo_lo_shift); + __m512i sum_lo = _mm512_and_si512(sum_tmp, lo_mask); + __m512i sum_mid = _mm512_srli_epi64(sum_tmp, 32); + // | |sum_lo| + // + [x_hi * y_lo] // z_hi_lo + // ------------------ + // [ sum_mid2 ] + __m512i sum_mid2 = _mm512_add_epi64(z_hi_lo, sum_lo); + __m512i sum_mid2_hi = _mm512_srli_epi64(sum_mid2, 32); + __m512i sum_hi = _mm512_add_epi64(z_hi_hi, sum_mid); + return _mm512_add_epi64(sum_hi, sum_mid2_hi); +} + +#ifdef HEXL_HAS_AVX512IFMA +template <> +inline __m512i _mm512_hexl_mulhi_epi<52>(__m512i x, __m512i y) { + __m512i zero = _mm512_set1_epi64(0); + return _mm512_madd52hi_epu64(zero, x, y); +} +#endif + +// Multiply packed unsigned BitShift-bit integers in each 64-bit element of x +// and y to form a 2*BitShift-bit intermediate result. +// Returns the high BitShift-bit unsigned integer from the intermediate result, +// with approximation error at most 1 +template +inline __m512i _mm512_hexl_mulhi_approx_epi(__m512i x, __m512i y); + +// Dummy implementation to avoid template substitution errors +template <> +inline __m512i _mm512_hexl_mulhi_approx_epi<32>(__m512i x, __m512i y) { + HEXL_CHECK(false, "Unimplemented"); + HEXL_UNUSED(x); + HEXL_UNUSED(y); + return x; +} + +template <> +inline __m512i _mm512_hexl_mulhi_approx_epi<64>(__m512i x, __m512i y) { + // https://stackoverflow.com/questions/28807341/simd-signed-with-unsigned-multiplication-for-64-bit-64-bit-to-128-bit + __m512i lo_mask = _mm512_set1_epi64(0x00000000ffffffff); + // Shuffle high bits with low bits in each 64-bit integer => + // x0_lo, x0_hi, x1_lo, x1_hi, x2_lo, x2_hi, ... + __m512i x_hi = _mm512_shuffle_epi32(x, (_MM_PERM_ENUM)0xB1); + // y0_lo, y0_hi, y1_lo, y1_hi, y2_lo, y2_hi, ... + __m512i y_hi = _mm512_shuffle_epi32(y, (_MM_PERM_ENUM)0xB1); + __m512i z_lo_hi = _mm512_mul_epu32(x, y_hi); // x_lo * y_hi + __m512i z_hi_lo = _mm512_mul_epu32(x_hi, y); // x_hi * y_lo + __m512i z_hi_hi = _mm512_mul_epu32(x_hi, y_hi); // x_hi * y_hi + + // x_hi | x_lo + // x y_hi | y_lo + // ------------------------------ + // [x_lo * y_lo] // unused, resulting in approximation + // + [z_lo * y_hi] // z_lo_hi + // + [x_hi * y_lo] // z_hi_lo + // + [x_hi * y_hi] // z_hi_hi + // ^-----------^ <-- only bits needed + // sum_| hi | mid | lo | + + __m512i sum_lo = _mm512_and_si512(z_lo_hi, lo_mask); + __m512i sum_mid = _mm512_srli_epi64(z_lo_hi, 32); + // | |sum_lo| + // + [x_hi * y_lo] // z_hi_lo + // ------------------ + // [ sum_mid2 ] + __m512i sum_mid2 = _mm512_add_epi64(z_hi_lo, sum_lo); + __m512i sum_mid2_hi = _mm512_srli_epi64(sum_mid2, 32); + __m512i sum_hi = _mm512_add_epi64(z_hi_hi, sum_mid); + return _mm512_add_epi64(sum_hi, sum_mid2_hi); +} + +#ifdef HEXL_HAS_AVX512IFMA +template <> +inline __m512i _mm512_hexl_mulhi_approx_epi<52>(__m512i x, __m512i y) { + __m512i zero = _mm512_set1_epi64(0); + return _mm512_madd52hi_epu64(zero, x, y); +} +#endif + +// Multiply packed unsigned BitShift-bit integers in each 64-bit element of x +// and y to form a 2*BitShift-bit intermediate result. +// Returns the low BitShift-bit unsigned integer from the intermediate result +template +inline __m512i _mm512_hexl_mullo_epi(__m512i x, __m512i y); + +// Dummy implementation to avoid template substitution errors +template <> +inline __m512i _mm512_hexl_mullo_epi<32>(__m512i x, __m512i y) { + HEXL_CHECK(false, "Unimplemented"); + HEXL_UNUSED(x); + HEXL_UNUSED(y); + return x; +} + +template <> +inline __m512i _mm512_hexl_mullo_epi<64>(__m512i x, __m512i y) { + return _mm512_mullo_epi64(x, y); +} + +#ifdef HEXL_HAS_AVX512IFMA +template <> +inline __m512i _mm512_hexl_mullo_epi<52>(__m512i x, __m512i y) { + __m512i zero = _mm512_set1_epi64(0); + return _mm512_madd52lo_epu64(zero, x, y); +} +#endif + +// Multiply packed unsigned BitShift-bit integers in each 64-bit element of y +// and z to form a 2*BitShift-bit intermediate result. The low BitShift bits of +// the result are added to x, then the low BitShift bits of the result are +// returned. +template +inline __m512i _mm512_hexl_mullo_add_lo_epi(__m512i x, __m512i y, __m512i z); + +#ifdef HEXL_HAS_AVX512IFMA +template <> +inline __m512i _mm512_hexl_mullo_add_lo_epi<52>(__m512i x, __m512i y, + __m512i z) { + __m512i result = _mm512_madd52lo_epu64(x, y, z); + + // Clear high 12 bits from result + result = ClearTopBits64<52>(result); + return result; +} +#endif + +// Dummy implementation to avoid template substitution errors +template <> +inline __m512i _mm512_hexl_mullo_add_lo_epi<32>(__m512i x, __m512i y, + __m512i z) { + HEXL_CHECK(false, "Unimplemented"); + HEXL_UNUSED(x); + HEXL_UNUSED(y); + HEXL_UNUSED(z); + return x; +} + +template <> +inline __m512i _mm512_hexl_mullo_add_lo_epi<64>(__m512i x, __m512i y, + __m512i z) { + __m512i prod = _mm512_mullo_epi64(y, z); + return _mm512_add_epi64(x, prod); +} + +// Returns x mod q across each 64-bit integer SIMD lanes +// Assumes x < InputModFactor * q in all lanes +template +inline __m512i _mm512_hexl_small_mod_epu64(__m512i x, __m512i q, + __m512i* q_times_2 = nullptr, + __m512i* q_times_4 = nullptr) { + HEXL_CHECK(InputModFactor == 1 || InputModFactor == 2 || + InputModFactor == 4 || InputModFactor == 8, + "InputModFactor must be 1, 2, 4, or 8"); + if (InputModFactor == 1) { + return x; + } + if (InputModFactor == 2) { + return _mm512_min_epu64(x, _mm512_sub_epi64(x, q)); + } + if (InputModFactor == 4) { + HEXL_CHECK(q_times_2 != nullptr, "q_times_2 must not be nullptr"); + x = _mm512_min_epu64(x, _mm512_sub_epi64(x, *q_times_2)); + return _mm512_min_epu64(x, _mm512_sub_epi64(x, q)); + } + if (InputModFactor == 8) { + HEXL_CHECK(q_times_2 != nullptr, "q_times_2 must not be nullptr"); + HEXL_CHECK(q_times_4 != nullptr, "q_times_4 must not be nullptr"); + x = _mm512_min_epu64(x, _mm512_sub_epi64(x, *q_times_4)); + x = _mm512_min_epu64(x, _mm512_sub_epi64(x, *q_times_2)); + return _mm512_min_epu64(x, _mm512_sub_epi64(x, q)); + } + HEXL_CHECK(false, "Invalid InputModFactor"); + return x; // Return dummy value +} + +// Returns (x + y) mod q; assumes 0 < x, y < q +inline __m512i _mm512_hexl_small_add_mod_epi64(__m512i x, __m512i y, + __m512i q) { + HEXL_CHECK_BOUNDS(ExtractValues(x).data(), 8, ExtractValues(q)[0], + "x exceeds bound " << ExtractValues(q)[0]); + HEXL_CHECK_BOUNDS(ExtractValues(y).data(), 8, ExtractValues(q)[0], + "y exceeds bound " << ExtractValues(q)[0]); + return _mm512_hexl_small_mod_epu64(_mm512_add_epi64(x, y), q); + + // Alternate implementation: + // x += y - q; + // if (x < 0) x+= q + // return x + // __m512i v_diff = _mm512_sub_epi64(y, q); + // x = _mm512_add_epi64(x, v_diff); + // __mmask8 sign_bits = _mm512_movepi64_mask(x); + // return _mm512_mask_add_epi64(x, sign_bits, x, q); +} + +// Returns (x - y) mod q; assumes 0 < x, y < q + +inline __m512i _mm512_hexl_small_sub_mod_epi64(__m512i x, __m512i y, + __m512i q) { + HEXL_CHECK_BOUNDS(ExtractValues(x).data(), 8, ExtractValues(q)[0], + "x exceeds bound " << ExtractValues(q)[0]); + HEXL_CHECK_BOUNDS(ExtractValues(y).data(), 8, ExtractValues(q)[0], + "y exceeds bound " << ExtractValues(q)[0]); + + // diff = x - y; + // return (diff < 0) ? (diff + q) : diff + __m512i v_diff = _mm512_sub_epi64(x, y); + __mmask8 sign_bits = _mm512_movepi64_mask(v_diff); + return _mm512_mask_add_epi64(v_diff, sign_bits, v_diff, q); +} + +inline __mmask8 _mm512_hexl_cmp_epu64_mask(__m512i a, __m512i b, CMPINT cmp) { + switch (cmp) { + case CMPINT::EQ: + return _mm512_cmp_epu64_mask(a, b, static_cast(CMPINT::EQ)); + case CMPINT::LT: + return _mm512_cmp_epu64_mask(a, b, static_cast(CMPINT::LT)); + case CMPINT::LE: + return _mm512_cmp_epu64_mask(a, b, static_cast(CMPINT::LE)); + case CMPINT::FALSE: + return _mm512_cmp_epu64_mask(a, b, static_cast(CMPINT::FALSE)); + case CMPINT::NE: + return _mm512_cmp_epu64_mask(a, b, static_cast(CMPINT::NE)); + case CMPINT::NLT: + return _mm512_cmp_epu64_mask(a, b, static_cast(CMPINT::NLT)); + case CMPINT::NLE: + return _mm512_cmp_epu64_mask(a, b, static_cast(CMPINT::NLE)); + case CMPINT::TRUE: + return _mm512_cmp_epu64_mask(a, b, static_cast(CMPINT::TRUE)); + } + __mmask8 dummy = 0; // Avoid end of non-void function warning + return dummy; +} + +// Returns c[i] = a[i] CMP b[i] ? match_value : 0 +inline __m512i _mm512_hexl_cmp_epi64(__m512i a, __m512i b, CMPINT cmp, + uint64_t match_value) { + __mmask8 mask = _mm512_hexl_cmp_epu64_mask(a, b, cmp); + return _mm512_maskz_broadcastq_epi64( + mask, _mm_set1_epi64x(static_cast(match_value))); +} + +// Returns c[i] = a[i] >= b[i] ? match_value : 0 +inline __m512i _mm512_hexl_cmpge_epu64(__m512i a, __m512i b, + uint64_t match_value) { + return _mm512_hexl_cmp_epi64(a, b, CMPINT::NLT, match_value); +} + +// Returns c[i] = a[i] < b[i] ? match_value : 0 +inline __m512i _mm512_hexl_cmplt_epu64(__m512i a, __m512i b, + uint64_t match_value) { + return _mm512_hexl_cmp_epi64(a, b, CMPINT::LT, match_value); +} + +// Returns c[i] = a[i] <= b[i] ? match_value : 0 +inline __m512i _mm512_hexl_cmple_epu64(__m512i a, __m512i b, + uint64_t match_value) { + return _mm512_hexl_cmp_epi64(a, b, CMPINT::LE, match_value); +} + +// Returns Montgomery form of ab mod q, computed via the REDC algorithm, +// also known as Montgomery reduction. +// Template: r with R = 2^r +// Inputs: q such that gcd(R, q) = 1. R > q. +// v_inv_mod in [0, R − 1] such that q*v_inv_mod ≡ −1 mod R, +// T = ab in the range [0, Rq − 1]. +// T_hi and T_lo for BitShift = 64 should be given in 63 bits. +// Output: Integer S in the range [0, q − 1] such that S ≡ TR^−1 mod q +template +inline __m512i _mm512_hexl_montgomery_reduce(__m512i T_hi, __m512i T_lo, + __m512i q, __m512i v_inv_mod, + __m512i v_rs_or_msk) { + HEXL_CHECK(BitShift == 52 || BitShift == 64, + "Invalid bitshift " << BitShift << "; need 52 or 64"); + +#ifdef HEXL_HAS_AVX512IFMA + if (BitShift == 52) { + // Operation: + // m ← ((T mod R)N′) mod R | m ← ((T & mod_R_mask)*v_inv_mod) & mod_R_mask + __m512i m = _mm512_hexl_mullo_epi(T_lo, v_inv_mod); + m = ClearTopBits64(m); + + // Operation: t ← (T + mN) / R = (T + m*q) >> r + // Hi part + __m512i t_hi = _mm512_madd52hi_epu64(T_hi, m, q); + // Low part + __m512i t = _mm512_madd52lo_epu64(T_lo, m, q); + t = _mm512_srli_epi64(t, r); + // Join parts + t = _mm512_madd52lo_epu64(t, t_hi, v_rs_or_msk); + + // If this function exists for 52 bits we could save 1 cycle + // t = _mm512_shrdi_epi64 (t_hi, t, r) + + // Operation: t ≥ q? return (t - q) : return t + return _mm512_hexl_small_mod_epu64<2>(t, q); + } +#endif + + HEXL_CHECK(BitShift == 64, "Invalid bitshift " << BitShift << "; need 64"); + + // Operation: + // m ← ((T mod R)N′) mod R | m ← ((T & mod_R_mask)*v_inv_mod) & mod_R_mask + __m512i m = ClearTopBits64(T_lo); + m = _mm512_hexl_mullo_epi(m, v_inv_mod); + m = ClearTopBits64(m); + + __m512i mq_hi = _mm512_hexl_mulhi_epi(m, q); + __m512i mq_lo = _mm512_hexl_mullo_epi(m, q); + + // to 63 bits + mq_hi = _mm512_slli_epi64(mq_hi, 1); + __m512i tmp = _mm512_srli_epi64(mq_lo, 63); + mq_hi = _mm512_add_epi64(mq_hi, tmp); + mq_lo = _mm512_and_epi64(mq_lo, v_rs_or_msk); + + __m512i t_hi = _mm512_add_epi64(T_hi, mq_hi); + t_hi = _mm512_slli_epi64(t_hi, 63 - r); + __m512i t = _mm512_add_epi64(T_lo, mq_lo); + t = _mm512_srli_epi64(t, r); + + // Join parts + t = _mm512_add_epi64(t_hi, t); + + return _mm512_hexl_small_mod_epu64<2>(t, q); +} + +// Returns x mod q, computed via Barrett reduction +// @param q_barr floor(2^BitShift / q) +template +inline __m512i _mm512_hexl_barrett_reduce64(__m512i x, __m512i q, + __m512i q_barr_64, + __m512i q_barr_52, + uint64_t prod_right_shift, + __m512i v_neg_mod) { + HEXL_UNUSED(q_barr_52); + HEXL_UNUSED(prod_right_shift); + HEXL_UNUSED(v_neg_mod); + HEXL_CHECK(BitShift == 52 || BitShift == 64, + "Invalid bitshift " << BitShift << "; need 52 or 64"); + +#ifdef HEXL_HAS_AVX512IFMA + if (BitShift == 52) { + __m512i two_pow_fiftytwo = _mm512_set1_epi64(2251799813685248); + __mmask8 mask = + _mm512_hexl_cmp_epu64_mask(x, two_pow_fiftytwo, CMPINT::NLT); + if (mask != 0) { + // values above 2^52 + __m512i x_hi = _mm512_srli_epi64(x, static_cast(52ULL)); + __m512i x_lo = ClearTopBits64<52>(x); + + // c1 = floor(U / 2^{n + beta}) + __m512i c1_lo = + _mm512_srli_epi64(x_lo, static_cast(prod_right_shift)); + __m512i c1_hi = _mm512_slli_epi64( + x_hi, static_cast(52ULL - (prod_right_shift))); + __m512i c1 = _mm512_or_epi64(c1_lo, c1_hi); + + // alpha - beta == 52, so we only need high 52 bits + __m512i q_hat = _mm512_hexl_mulhi_epi<52>(c1, q_barr_64); + // Z = prod_lo - (p * q_hat)_lo + x = _mm512_hexl_mullo_add_lo_epi<52>(x_lo, q_hat, v_neg_mod); + } else { + __m512i rnd1_hi = _mm512_hexl_mulhi_epi<52>(x, q_barr_52); + __m512i tmp1_times_mod = _mm512_hexl_mullo_epi<52>(rnd1_hi, q); + x = _mm512_sub_epi64(x, tmp1_times_mod); + } + } +#endif + if (BitShift == 64) { + __m512i rnd1_hi = _mm512_hexl_mulhi_epi<64>(x, q_barr_64); + __m512i tmp1_times_mod = _mm512_hexl_mullo_epi<64>(rnd1_hi, q); + x = _mm512_sub_epi64(x, tmp1_times_mod); + } + + // Correction + if (OutputModFactor == 1) { + x = _mm512_hexl_small_mod_epu64<2>(x, q); + } + return x; +} + +// Concatenate packed 64-bit integers in x and y, producing an intermediate +// 128-bit result. Shift the result right by bit_shift bits, and return the +// lower 64 bits. The bit_shift is a run-time argument, rather than a +// compile-time template parameter, so we can't use _mm512_shrdi_epi64 +inline __m512i _mm512_hexl_shrdi_epi64(__m512i x, __m512i y, + unsigned int bit_shift) { + __m512i c_lo = _mm512_srli_epi64(x, bit_shift); + __m512i c_hi = _mm512_slli_epi64(y, 64 - bit_shift); + return _mm512_add_epi64(c_lo, c_hi); +} + +// Concatenate packed 64-bit integers in x and y, producing an intermediate +// 128-bit result. Shift the result right by BitShift bits, and return the lower +// 64 bits. +template +inline __m512i _mm512_hexl_shrdi_epi64(__m512i x, __m512i y) { +#ifdef HEXL_HAS_AVX512VBMI2 + return _mm512_shrdi_epi64(x, y, BitShift); +#else + return _mm512_hexl_shrdi_epi64(x, y, BitShift); +#endif +} + +#endif // HEXL_HAS_AVX512DQ + +} // namespace hexl +} // namespace intel diff --git a/hexl_omp/util/cpu-features.hpp b/hexl_omp/util/cpu-features.hpp new file mode 100644 index 00000000..ba408394 --- /dev/null +++ b/hexl_omp/util/cpu-features.hpp @@ -0,0 +1,35 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +#include + +#include "cpuinfo_x86.h" // NOLINT(build/include_subdir) + +namespace intel { +namespace hexl { + +// Use to disable avx512 dispatching at runtime +static const bool disable_avx512dq = + (std::getenv("HEXL_DISABLE_AVX512DQ") != nullptr); +static const bool disable_avx512ifma = + disable_avx512dq || (std::getenv("HEXL_DISABLE_AVX512IFMA") != nullptr); +static const bool disable_avx512vbmi2 = + disable_avx512dq || (std::getenv("HEXL_DISABLE_AVX512VBMI2") != nullptr); + +static const cpu_features::X86Features features = + cpu_features::GetX86Info().features; + +static const bool has_avx512dq = features.avx512f && features.avx512dq && + features.avx512vl && !disable_avx512dq; + +static const bool has_avx512ifma = features.avx512ifma && !disable_avx512ifma; + +static const bool has_avx512vbmi2 = + features.avx512vbmi2 && !disable_avx512vbmi2; + +} // namespace hexl +} // namespace intel diff --git a/hexl_omp/util/util-internal.hpp b/hexl_omp/util/util-internal.hpp new file mode 100644 index 00000000..20e94da9 --- /dev/null +++ b/hexl_omp/util/util-internal.hpp @@ -0,0 +1,102 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include +#include + +#include "hexl/util/aligned-allocator.hpp" +#include "hexl/util/util.hpp" + +namespace intel { +namespace hexl { + +inline bool Compare(CMPINT cmp, uint64_t lhs, uint64_t rhs) { + switch (cmp) { + case CMPINT::EQ: + return lhs == rhs; + case CMPINT::LT: + return lhs < rhs; + break; + case CMPINT::LE: + return lhs <= rhs; + break; + case CMPINT::FALSE: + return false; + break; + case CMPINT::NE: + return lhs != rhs; + break; + case CMPINT::NLT: + return lhs >= rhs; + break; + case CMPINT::NLE: + return lhs > rhs; + case CMPINT::TRUE: + return true; + default: + return true; + } +} + +/// Generates a vector of size random values drawn uniformly from [min_value, +/// max_value) +/// NOTE: this function is not a cryptographically secure random number +/// generator and should be used for testing/benchmarking only +inline double GenerateInsecureUniformRealRandomValue(double min_value, + double max_value) { + HEXL_CHECK(min_value < max_value, "min_value must be > max_value"); + + static std::random_device rd; + static std::mt19937 mersenne_engine(rd()); + std::uniform_real_distribution distrib(min_value, max_value); + double res = distrib(mersenne_engine); + return (res == max_value) ? min_value : res; +} + +/// Generates a vector of size random values drawn uniformly from [min_value, +/// max_value) +/// NOTE: this function is not a cryptographically secure random number +/// generator and should be used for testing/benchmarking only +inline uint64_t GenerateInsecureUniformIntRandomValue(uint64_t min_value, + uint64_t max_value) { + HEXL_CHECK(min_value < max_value, "min_value must be > max_value"); + + static std::random_device rd; + static std::mt19937 mersenne_engine(rd()); + std::uniform_int_distribution distrib(min_value, max_value - 1); + return distrib(mersenne_engine); +} + +/// Generates a vector of size random values drawn uniformly from [min_value, +/// max_value) +/// NOTE: this function is not a cryptographically secure random +/// number generator and should be used for testing/benchmarking only +inline AlignedVector64 GenerateInsecureUniformRealRandomValues( + uint64_t size, double min_value, double max_value) { + AlignedVector64 values(size); + auto generator = [&]() { + return GenerateInsecureUniformRealRandomValue(min_value, max_value); + }; + std::generate(values.begin(), values.end(), generator); + return values; +} + +/// Generates a vector of size random values drawn uniformly from [min_value, +/// max_value) +/// NOTE: this function is not a cryptographically secure random +/// number generator and should be used for testing/benchmarking only +inline AlignedVector64 GenerateInsecureUniformIntRandomValues( + uint64_t size, uint64_t min_value, uint64_t max_value) { + AlignedVector64 values(size); + auto generator = [&]() { + return GenerateInsecureUniformIntRandomValue(min_value, max_value); + }; + std::generate(values.begin(), values.end(), generator); + return values; +} + +} // namespace hexl +} // namespace intel diff --git a/hexl_omp_out_0824_23_59.csv b/hexl_omp_out_0824_23_59.csv new file mode 100644 index 00000000..c7ede15d --- /dev/null +++ b/hexl_omp_out_0824_23_59.csv @@ -0,0 +1,64 @@ +Input Size = 4096 +Method Threads=1 Threads=2 Threads=4 Threads=6 Threads=8 +BM_EltwiseCmpAdd 0.00512657 0.00455576 0.00479712 0.00490144 0.00493344 +BM_EltwiseCmpSubMod 0.0334064 0.0227785 0.0187128 0.0185598 0.0163335 +BM_EltwiseFMAModAdd 0.0216051 0.0176496 0.0142812 0.0119878 0.0147407 +BM_EltwiseMultMod 0.0182516 0.0121492 0.0126691 0.0121201 0.178388 +BM_EltwiseReduceModInPlace 0.0443034 0.0559384 0.0470443 0.0366166 0.032561 +BM_EltwiseVectorScalarAddMod 0.00154469 0.00599568 0.00760386 0.00885226 0.0149704 +BM_EltwiseVectorVectorAddMod 0.00181662 0.00735656 0.00953703 0.01136 0.0259735 +BM_EltwiseVectorVectorSubMod 0.00220881 0.00739459 0.00978286 0.0113137 0.0199804 +BM_NTTInPlace 0.013323 0.0119631 0.0126235 0.0130003 0.0129615 + + +Input Size = 65536 +Method Threads=1 Threads=2 Threads=4 Threads=6 Threads=8 +BM_EltwiseCmpAdd 0.0695428 0.0705741 0.0740151 0.0758498 0.0763839 +BM_EltwiseCmpSubMod 0.48305 0.249345 0.126763 0.0918585 0.0757632 +BM_EltwiseFMAModAdd 0.395409 0.240749 0.159874 0.124983 0.134324 +BM_EltwiseMultMod 0.201719 0.199572 0.140779 0.151836 0.190951 +BM_EltwiseReduceModInPlace 0.501312 0.451995 0.294307 0.244569 0.231872 +BM_EltwiseVectorScalarAddMod 0.0193946 0.0514457 0.078022 0.0908455 0.120163 +BM_EltwiseVectorVectorAddMod 0.0293836 0.109481 0.131305 0.141342 0.183336 +BM_EltwiseVectorVectorSubMod 0.0293993 0.106441 0.130041 0.143904 0.216673 +BM_NTTInPlace 0.0205064 0.0207469 0.0215263 0.0219023 0.0220524 + + +Input Size = 1048576 +Method Threads=1 Threads=2 Threads=4 Threads=6 Threads=8 +BM_EltwiseCmpAdd 0.327546 0.618757 0.573513 0.713836 0.784873 +BM_EltwiseCmpSubMod 4.2913 3.82233 2.12955 1.57278 1.45903 +BM_EltwiseFMAModAdd 3.55693 4.14268 2.16293 1.56032 1.84955 +BM_EltwiseMultMod 2.37014 2.65853 1.51409 2.05329 2.14015 +BM_EltwiseReduceModInPlace 3.02424 2.79114 1.60214 1.34462 2.15723 +BM_EltwiseVectorScalarAddMod 0.849692 0.755491 0.874564 0.916296 1.07004 +BM_EltwiseVectorVectorAddMod 1.4735 1.16771 5.17203 1.33028 5.18916 +BM_EltwiseVectorVectorSubMod 0.795287 1.04016 1.42112 1.58007 1.65453 +BM_NTTInPlace 0.196683 0.199438 0.207158 0.213862 0.215379 + + +Input Size = 16777216 +Method Threads=1 Threads=2 Threads=4 Threads=6 Threads=8 +BM_EltwiseCmpAdd 10.7087 8.32306 10.1689 9.18992 11.4777 +BM_EltwiseCmpSubMod 45.8251 54.1855 46.1839 37.9283 35.9951 +BM_EltwiseFMAModAdd 57.891 51.3104 29.8278 20.0244 16.8797 +BM_EltwiseMultMod 61.3749 54.5152 30.3268 21.6294 18.414 +BM_EltwiseReduceModInPlace 51.7744 50.7112 27.7881 20.9143 16.6882 +BM_EltwiseVectorScalarAddMod 17.3349 9.77281 6.80029 5.1626 5.9123 +BM_EltwiseVectorVectorAddMod 23.29 13.8666 9.78924 13.8197 12.3682 +BM_EltwiseVectorVectorSubMod 20.9593 13.0098 9.52602 9.52581 9.31579 +BM_NTTInPlace 4.06966 4.12731 4.31532 4.22544 4.10822 + + +Input Size = 268435456 +Method Threads=1 Threads=2 Threads=4 Threads=6 Threads=8 +BM_EltwiseCmpAdd 106.736 70.8016 50.8658 46.203 44.8108 +BM_EltwiseCmpSubMod 1092.54 1370.19 622.666 300.148 201.382 +BM_EltwiseFMAModAdd 900.52 791.323 479.051 260.392 172.116 +BM_EltwiseMultMod 1012.74 683.705 307.87 241.261 201.739 +BM_EltwiseReduceModInPlace 669.785 621.418 284.352 209.343 159.47 +BM_EltwiseVectorScalarAddMod 181.477 108.196 92.5808 69.1951 63.4043 +BM_EltwiseVectorVectorAddMod 217.719 140.504 117.334 85.744 79.2764 +BM_EltwiseVectorVectorSubMod 226.875 134.509 123.356 78.2418 74.9126 +BM_NTTInPlace 86.8708 90.5149 87.8874 91.6698 88.499 + diff --git a/hexl_ser/CMakeLists.txt b/hexl_ser/CMakeLists.txt new file mode 100644 index 00000000..2c92c389 --- /dev/null +++ b/hexl_ser/CMakeLists.txt @@ -0,0 +1,222 @@ +# Copyright (C) 2020 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +set(NATIVE_SRC + eltwise/eltwise-mult-mod.cpp + eltwise/eltwise-reduce-mod.cpp + eltwise/eltwise-sub-mod.cpp + eltwise/eltwise-add-mod.cpp + eltwise/eltwise-fma-mod.cpp + eltwise/eltwise-cmp-add.cpp + eltwise/eltwise-cmp-sub-mod.cpp + ntt/ntt-internal.cpp + ntt/ntt-radix-2.cpp + ntt/ntt-radix-4.cpp + number-theory/number-theory.cpp +) + +if (HEXL_EXPERIMENTAL) + list(APPEND NATIVE_SRC + experimental/seal/dyadic-multiply.cpp + experimental/seal/key-switch.cpp + experimental/seal/dyadic-multiply-internal.cpp + experimental/seal/key-switch-internal.cpp + experimental/misc/lr-mat-vec-mult.cpp + experimental/fft-like/fft-like.cpp + experimental/fft-like/fft-like-native.cpp + experimental/fft-like/fwd-fft-like-avx512.cpp + experimental/fft-like/inv-fft-like-avx512.cpp + ) +endif() + +if (HEXL_HAS_AVX512DQ) + set(AVX512_SRC + eltwise/eltwise-mult-mod-avx512dq.cpp + eltwise/eltwise-mult-mod-avx512ifma.cpp + eltwise/eltwise-reduce-mod-avx512.cpp + eltwise/eltwise-add-mod-avx512.cpp + eltwise/eltwise-cmp-sub-mod-avx512.cpp + eltwise/eltwise-cmp-add-avx512.cpp + eltwise/eltwise-sub-mod-avx512.cpp + eltwise/eltwise-fma-mod-avx512.cpp + ntt/fwd-ntt-avx512.cpp + ntt/inv-ntt-avx512.cpp + ) +endif() + +set(HEXL_SRC "${NATIVE_SRC};${AVX512_SRC}") + +if (HEXL_DEBUG) + list(APPEND HEXL_SRC logging/logging.cpp) +endif() + +if (HEXL_SHARED_LIB) + add_library(hexl SHARED ${HEXL_SRC}) +else() + add_library(hexl STATIC ${HEXL_SRC}) +endif() +add_library(HEXL::hexl ALIAS hexl) + +hexl_add_asan_flag(hexl) + +set(HEXL_DEFINES_IN_FILENAME ${CMAKE_CURRENT_SOURCE_DIR}/include/hexl/util/defines.hpp.in) +set(HEXL_DEFINES_FILENAME ${CMAKE_CURRENT_SOURCE_DIR}/include/hexl/util/defines.hpp) +configure_file(${HEXL_DEFINES_IN_FILENAME} ${HEXL_DEFINES_FILENAME}) + +set_target_properties(hexl PROPERTIES POSITION_INDEPENDENT_CODE ON) +set_target_properties(hexl PROPERTIES VERSION ${HEXL_VERSION}) +if (HEXL_DEBUG) + set_target_properties(hexl PROPERTIES OUTPUT_NAME "hexl_debug") +else() + set_target_properties(hexl PROPERTIES OUTPUT_NAME "hexl") +endif() + +target_include_directories(hexl + PRIVATE ${HEXL_SRC_ROOT_DIR} # Private headers + PUBLIC $ # Public headers + PUBLIC $ # Public headers +) +if(CpuFeatures_FOUND) + target_include_directories(hexl PUBLIC ${CpuFeatures_INCLUDE_DIR}) # Public headers +endif() + +if (HEXL_FPGA_COMPATIBILITY STREQUAL "1") + target_compile_options(hexl PRIVATE -DHEXL_FPGA_COMPATIBLE_DYADIC_MULTIPLY) +elseif (HEXL_FPGA_COMPATIBILITY STREQUAL "2") + target_compile_options(hexl PRIVATE -DHEXL_FPGA_COMPATIBLE_KEYSWITCH) +elseif (HEXL_FPGA_COMPATIBILITY STREQUAL "3") + target_compile_options(hexl PRIVATE + -DHEXL_FPGA_COMPATIBLE_DYADIC_MULTIPLY + -DHEXL_FPGA_COMPATIBLE_KEYSWITCH + ) +endif() + +if (CMAKE_CXX_COMPILER_ID MATCHES "GNU|Clang") + target_compile_options(hexl PRIVATE -Wall -Wconversion -Wshadow -pedantic -Wextra + -Wno-unknown-pragmas -march=native -O3 -fomit-frame-pointer + -Wno-sign-conversion + -Wno-implicit-int-conversion + ) + # Avoid 3rd-party dependency warnings when including HEXL as a dependency + target_compile_options(hexl PUBLIC + -Wno-unknown-warning + -Wno-unknown-warning-option + ) + +elseif (CMAKE_CXX_COMPILER_ID STREQUAL "MSVC") + # Inlining causes some tests to fail on MSVC with AVX512 in Release mode, HEXL_DEBUG=OFF, + # so we disable it here + target_compile_options(hexl PRIVATE /Wall /W4 /Ob0 + /wd4127 # warning C4127: conditional expression is constant; C++11 doesn't support if constexpr + /wd5105 # warning C5105: macro expansion producing 'defined' has undefined behavior + ) + target_compile_definitions(hexl PRIVATE -D_CRT_SECURE_NO_WARNINGS) +endif() + +install(DIRECTORY ${HEXL_INC_ROOT_DIR}/ + DESTINATION ${CMAKE_INSTALL_PREFIX}/${CMAKE_INSTALL_INCLUDEDIR}/ + FILES_MATCHING + PATTERN "*.hpp" + PATTERN "*.h") + +#twy +find_package(OpenMP) +if (OpenMP_CXX_FOUND) + target_link_libraries(hexl PUBLIC OpenMP::OpenMP_CXX) +endif() + +if (HEXL_SHARED_LIB) + target_link_libraries(hexl PRIVATE cpu_features) + if (HEXL_DEBUG) + target_link_libraries(hexl PUBLIC easyloggingpp) + # Manually add logging include directory + target_include_directories(hexl + PUBLIC $> + ) + endif() +else () + # For static library, if the dependencies are not found on the system, + # we manually add the dependencies for Intel HEXL in the exported library. + + # Export logging only if in debug mode + if (HEXL_DEBUG) + # Manually add logging include directory + target_include_directories(hexl + PUBLIC $> + ) + if (EASYLOGGINGPP_FOUND) + target_link_libraries(hexl PRIVATE easyloggingpp) + else() + hexl_create_archive(hexl easyloggingpp) + endif() + endif() + + if (CpuFeatures_FOUND) + target_link_libraries(hexl PRIVATE cpu_features) + else() + hexl_create_archive(hexl cpu_features) + endif() + + # Manually add cpu_features include directory + target_include_directories(hexl + PRIVATE $) +endif() + +install(TARGETS hexl DESTINATION ${CMAKE_INSTALL_LIBDIR}) + +#------------------------------------------------------------------------------ +# Config export... +#------------------------------------------------------------------------------ + +# Config filenames +set(HEXL_TARGET_FILENAME ${CMAKE_CURRENT_BINARY_DIR}/cmake/hexl-${HEXL_VERSION}/HEXLTargets.cmake) +set(HEXL_CONFIG_IN_FILENAME ${HEXL_CMAKE_PATH}/HEXLConfig.cmake.in) +set(HEXL_CONFIG_FILENAME ${HEXL_ROOT_DIR}/cmake/hexl-${HEXL_VERSION}/HEXLConfig.cmake) +set(HEXL_CONFIG_VERSION_FILENAME ${CMAKE_CURRENT_BINARY_DIR}/cmake/hexl-${HEXL_VERSION}/HEXLConfigVersion.cmake) +set(HEXL_CONFIG_INSTALL_DIR ${CMAKE_INSTALL_LIBDIR}/cmake/hexl-${HEXL_VERSION}/) + +# Create and install the CMake config and target file +install( + EXPORT HEXLTargets + NAMESPACE HEXL:: + DESTINATION ${HEXL_CONFIG_INSTALL_DIR} +) + +# Export version +write_basic_package_version_file( + ${HEXL_CONFIG_VERSION_FILENAME} + VERSION ${HEXL_VERSION} + COMPATIBILITY ExactVersion) + +include(CMakePackageConfigHelpers) + configure_package_config_file( + ${HEXL_CONFIG_IN_FILENAME} ${HEXL_CONFIG_FILENAME} + INSTALL_DESTINATION ${HEXL_CONFIG_INSTALL_DIR} + ) + +install( + TARGETS hexl + EXPORT HEXLTargets + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} + ) + +install(FILES ${HEXL_CONFIG_FILENAME} + ${HEXL_CONFIG_VERSION_FILENAME} + DESTINATION ${HEXL_CONFIG_INSTALL_DIR}) + +export(EXPORT HEXLTargets + FILE ${HEXL_TARGET_FILENAME}) + +# Pkgconfig +get_target_property(HEXL_TARGET_NAME hexl OUTPUT_NAME) + +configure_file(${HEXL_ROOT_DIR}/pkgconfig/hexl.pc.in + ${HEXL_ROOT_DIR}/pkgconfig/hexl.pc @ONLY) + +if(EXISTS ${HEXL_ROOT_DIR}/pkgconfig/hexl.pc) + install( + FILES ${HEXL_ROOT_DIR}/pkgconfig/hexl.pc + DESTINATION ${CMAKE_INSTALL_LIBDIR}/pkgconfig) +endif() diff --git a/hexl_ser/eltwise/eltwise-add-mod-avx512.cpp b/hexl_ser/eltwise/eltwise-add-mod-avx512.cpp new file mode 100644 index 00000000..7c103ca3 --- /dev/null +++ b/hexl_ser/eltwise/eltwise-add-mod-avx512.cpp @@ -0,0 +1,109 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "eltwise/eltwise-add-mod-avx512.hpp" + +#include +#include + +#include "eltwise/eltwise-add-mod-internal.hpp" +#include "hexl/eltwise/eltwise-add-mod.hpp" +#include "hexl/util/check.hpp" +#include "util/avx512-util.hpp" + +#ifdef HEXL_HAS_AVX512DQ + +namespace intel { +namespace hexl { + +void EltwiseAddModAVX512(uint64_t* result, const uint64_t* operand1, + const uint64_t* operand2, uint64_t n, + uint64_t modulus) { + HEXL_CHECK(result != nullptr, "Require result != nullptr"); + HEXL_CHECK(operand1 != nullptr, "Require operand1 != nullptr"); + HEXL_CHECK(operand2 != nullptr, "Require operand2 != nullptr"); + HEXL_CHECK(n != 0, "Require n != 0"); + HEXL_CHECK(modulus > 1, "Require modulus > 1"); + HEXL_CHECK(modulus < (1ULL << 63), "Require modulus < 2**63"); + HEXL_CHECK_BOUNDS(operand1, n, modulus, + "pre-add value in operand1 exceeds bound " << modulus); + HEXL_CHECK_BOUNDS(operand2, n, modulus, + "pre-add value in operand2 exceeds bound " << modulus); + + uint64_t n_mod_8 = n % 8; + if (n_mod_8 != 0) { + EltwiseAddModNative(result, operand1, operand2, n_mod_8, modulus); + operand1 += n_mod_8; + operand2 += n_mod_8; + result += n_mod_8; + n -= n_mod_8; + } + + __m512i v_modulus = _mm512_set1_epi64(static_cast(modulus)); + __m512i* vp_result = reinterpret_cast<__m512i*>(result); + const __m512i* vp_operand1 = reinterpret_cast(operand1); + const __m512i* vp_operand2 = reinterpret_cast(operand2); + + HEXL_LOOP_UNROLL_4 + for (size_t i = n / 8; i > 0; --i) { + __m512i v_operand1 = _mm512_loadu_si512(vp_operand1); + __m512i v_operand2 = _mm512_loadu_si512(vp_operand2); + + __m512i v_result = + _mm512_hexl_small_add_mod_epi64(v_operand1, v_operand2, v_modulus); + + _mm512_storeu_si512(vp_result, v_result); + + ++vp_result; + ++vp_operand1; + ++vp_operand2; + } + + HEXL_CHECK_BOUNDS(result, n, modulus, "result exceeds bound " << modulus); +} + +void EltwiseAddModAVX512(uint64_t* result, const uint64_t* operand1, + const uint64_t operand2, uint64_t n, + uint64_t modulus) { + HEXL_CHECK(result != nullptr, "Require result != nullptr"); + HEXL_CHECK(operand1 != nullptr, "Require operand1 != nullptr"); + HEXL_CHECK(n != 0, "Require n != 0"); + HEXL_CHECK(modulus > 1, "Require modulus > 1"); + HEXL_CHECK(modulus < (1ULL << 63), "Require modulus < 2**63"); + HEXL_CHECK_BOUNDS(operand1, n, modulus, + "pre-add value in operand1 exceeds bound " << modulus); + HEXL_CHECK(operand2 < modulus, "Require operand2 < modulus"); + + uint64_t n_mod_8 = n % 8; + if (n_mod_8 != 0) { + EltwiseAddModNative(result, operand1, operand2, n_mod_8, modulus); + operand1 += n_mod_8; + result += n_mod_8; + n -= n_mod_8; + } + + __m512i v_modulus = _mm512_set1_epi64(static_cast(modulus)); + __m512i* vp_result = reinterpret_cast<__m512i*>(result); + const __m512i* vp_operand1 = reinterpret_cast(operand1); + const __m512i v_operand2 = _mm512_set1_epi64(static_cast(operand2)); + + HEXL_LOOP_UNROLL_4 + for (size_t i = n / 8; i > 0; --i) { + __m512i v_operand1 = _mm512_loadu_si512(vp_operand1); + + __m512i v_result = + _mm512_hexl_small_add_mod_epi64(v_operand1, v_operand2, v_modulus); + + _mm512_storeu_si512(vp_result, v_result); + + ++vp_result; + ++vp_operand1; + } + + HEXL_CHECK_BOUNDS(result, n, modulus, "result exceeds bound " << modulus); +} + +} // namespace hexl +} // namespace intel + +#endif diff --git a/hexl_ser/eltwise/eltwise-add-mod-avx512.hpp b/hexl_ser/eltwise/eltwise-add-mod-avx512.hpp new file mode 100644 index 00000000..befb9a0e --- /dev/null +++ b/hexl_ser/eltwise/eltwise-add-mod-avx512.hpp @@ -0,0 +1,19 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +namespace intel { +namespace hexl { + +void EltwiseAddModAVX512(uint64_t* result, const uint64_t* operand1, + const uint64_t* operand2, uint64_t n, + uint64_t modulus); + +void EltwiseAddModAVX512(uint64_t* result, const uint64_t* operand1, + const uint64_t operand2, uint64_t n, uint64_t modulus); + +} // namespace hexl +} // namespace intel diff --git a/hexl_ser/eltwise/eltwise-add-mod-internal.hpp b/hexl_ser/eltwise/eltwise-add-mod-internal.hpp new file mode 100644 index 00000000..74891811 --- /dev/null +++ b/hexl_ser/eltwise/eltwise-add-mod-internal.hpp @@ -0,0 +1,33 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +namespace intel { +namespace hexl { + +/// @brief Adds two vectors elementwise with modular reduction +/// @param[out] result Stores result +/// @param[in] operand1 Vector of elements to add +/// @param[in] operand2 Vector of elements to add +/// @param[in] n Number of elements in each vector +/// @param[in] modulus Modulus with which to perform modular reduction +/// @details Computes \f$ operand1[i] = (operand1[i] + operand2[i]) \mod modulus +/// \f$ for \f$ i=0, ..., n-1\f$. +void EltwiseAddModNative(uint64_t* result, const uint64_t* operand1, + const uint64_t* operand2, uint64_t n, + uint64_t modulus); + +/// @brief Adds a vector and scalar elementwise with modular reduction +/// @param[out] result Stores result +/// @param[in] operand1 Vector of elements to add +/// @param[in] operand2 Scalar add +/// @param[in] n Number of elements in each vector +/// @param[in] modulus Modulus with which to perform modular reduction +/// @details Computes \f$ operand1[i] = (operand1[i] + operand2) \mod modulus +/// \f$ for \f$ i=0, ..., n-1\f$. +void EltwiseAddModNative(uint64_t* result, const uint64_t* operand1, + uint64_t operand2, uint64_t n, uint64_t modulus); + +} // namespace hexl +} // namespace intel diff --git a/hexl_ser/eltwise/eltwise-add-mod.cpp b/hexl_ser/eltwise/eltwise-add-mod.cpp new file mode 100644 index 00000000..38cab458 --- /dev/null +++ b/hexl_ser/eltwise/eltwise-add-mod.cpp @@ -0,0 +1,116 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "hexl/eltwise/eltwise-add-mod.hpp" + +#include "eltwise/eltwise-add-mod-avx512.hpp" +#include "eltwise/eltwise-add-mod-internal.hpp" +#include "hexl/logging/logging.hpp" +#include "hexl/number-theory/number-theory.hpp" +#include "hexl/util/check.hpp" +#include "util/cpu-features.hpp" + +namespace intel { +namespace hexl { + +void EltwiseAddModNative(uint64_t* result, const uint64_t* operand1, + const uint64_t* operand2, uint64_t n, + uint64_t modulus) { + HEXL_CHECK(result != nullptr, "Require result != nullptr"); + HEXL_CHECK(operand1 != nullptr, "Require operand1 != nullptr"); + HEXL_CHECK(operand2 != nullptr, "Require operand2 != nullptr"); + HEXL_CHECK(n != 0, "Require n != 0"); + HEXL_CHECK(modulus > 1, "Require modulus > 1"); + HEXL_CHECK(modulus < (1ULL << 63), "Require modulus < 2**63"); + HEXL_CHECK_BOUNDS(operand1, n, modulus, + "pre-add value in operand1 exceeds bound " << modulus); + HEXL_CHECK_BOUNDS(operand2, n, modulus, + "pre-add value in operand2 exceeds bound " << modulus); + + // HEXL_LOOP_UNROLL_4 + for (size_t i = 0; i < n; ++i) { + uint64_t sum = *operand1 + *operand2; + if (sum >= modulus) { + *result = sum - modulus; + } else { + *result = sum; + } + + ++operand1; + ++operand2; + ++result; + } +} + +void EltwiseAddModNative(uint64_t* result, const uint64_t* operand1, + uint64_t operand2, uint64_t n, uint64_t modulus) { + HEXL_CHECK(result != nullptr, "Require result != nullptr"); + HEXL_CHECK(operand1 != nullptr, "Require operand1 != nullptr"); + HEXL_CHECK(n != 0, "Require n != 0"); + HEXL_CHECK(modulus > 1, "Require modulus > 1"); + HEXL_CHECK(modulus < (1ULL << 63), "Require modulus < 2**63"); + HEXL_CHECK_BOUNDS(operand1, n, modulus, + "pre-add value in operand1 exceeds bound " << modulus); + HEXL_CHECK(operand2 < modulus, "Require operand2 < modulus"); + + uint64_t diff = modulus - operand2; + + // HEXL_LOOP_UNROLL_4 + for (size_t i = 0; i < n; ++i) { + if (*operand1 >= diff) { + *result = *operand1 - diff; + } else { + *result = *operand1 + operand2; + } + + ++operand1; + ++result; + } +} + +void EltwiseAddMod(uint64_t* result, const uint64_t* operand1, + const uint64_t* operand2, uint64_t n, uint64_t modulus) { + HEXL_CHECK(operand1 != nullptr, "Require operand1 != nullptr"); + HEXL_CHECK(operand2 != nullptr, "Require operand2 != nullptr"); + HEXL_CHECK(n != 0, "Require n != 0"); + HEXL_CHECK(modulus > 1, "Require modulus > 1"); + HEXL_CHECK(modulus < (1ULL << 63), "Require modulus < 2**63"); + HEXL_CHECK_BOUNDS(operand1, n, modulus, + "pre-add value in operand1 exceeds bound " << modulus); + HEXL_CHECK_BOUNDS(operand2, n, modulus, + "pre-add value in operand2 exceeds bound " << modulus); + +#ifdef HEXL_HAS_AVX512DQ + if (has_avx512dq) { + EltwiseAddModAVX512(result, operand1, operand2, n, modulus); + return; + } +#endif + + HEXL_VLOG(3, "Calling EltwiseAddModNative"); + EltwiseAddModNative(result, operand1, operand2, n, modulus); +} + +void EltwiseAddMod(uint64_t* result, const uint64_t* operand1, + const uint64_t operand2, uint64_t n, uint64_t modulus) { + HEXL_CHECK(operand1 != nullptr, "Require operand1 != nullptr"); + HEXL_CHECK(n != 0, "Require n != 0"); + HEXL_CHECK(modulus > 1, "Require modulus > 1"); + HEXL_CHECK(modulus < (1ULL << 63), "Require modulus < 2**63"); + HEXL_CHECK_BOUNDS(operand1, n, modulus, + "pre-add value in operand1 exceeds bound " << modulus); + HEXL_CHECK(operand2 < modulus, "Require operand2 < modulus"); + +#ifdef HEXL_HAS_AVX512DQ + if (has_avx512dq) { + EltwiseAddModAVX512(result, operand1, operand2, n, modulus); + return; + } +#endif + + HEXL_VLOG(3, "Calling EltwiseAddModNative"); + EltwiseAddModNative(result, operand1, operand2, n, modulus); +} + +} // namespace hexl +} // namespace intel diff --git a/hexl_ser/eltwise/eltwise-cmp-add-avx512.cpp b/hexl_ser/eltwise/eltwise-cmp-add-avx512.cpp new file mode 100644 index 00000000..7ead8ca2 --- /dev/null +++ b/hexl_ser/eltwise/eltwise-cmp-add-avx512.cpp @@ -0,0 +1,49 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "eltwise/eltwise-cmp-add-avx512.hpp" + +#include +#include + +#include "eltwise/eltwise-cmp-add-internal.hpp" +#include "hexl/util/check.hpp" +#include "hexl/util/util.hpp" +#include "util/avx512-util.hpp" + +namespace intel { +namespace hexl { + +#ifdef HEXL_HAS_AVX512DQ +void EltwiseCmpAddAVX512(uint64_t* result, const uint64_t* operand1, uint64_t n, + CMPINT cmp, uint64_t bound, uint64_t diff) { + HEXL_CHECK(result != nullptr, "Require result != nullptr"); + HEXL_CHECK(operand1 != nullptr, "Require operand1 != nullptr"); + HEXL_CHECK(n != 0, "Require n != 0"); + HEXL_CHECK(diff != 0, "Require diff != 0"); + + uint64_t n_mod_8 = n % 8; + if (n_mod_8 != 0) { + EltwiseCmpAddNative(result, operand1, n_mod_8, cmp, bound, diff); + operand1 += n_mod_8; + result += n_mod_8; + n -= n_mod_8; + } + + __m512i v_bound = _mm512_set1_epi64(static_cast(bound)); + const __m512i* v_op_ptr = reinterpret_cast(operand1); + __m512i* v_result_ptr = reinterpret_cast<__m512i*>(result); + for (size_t i = n / 8; i > 0; --i) { + __m512i v_op = _mm512_loadu_si512(v_op_ptr); + __m512i v_add_diff = _mm512_hexl_cmp_epi64(v_op, v_bound, cmp, diff); + v_op = _mm512_add_epi64(v_op, v_add_diff); + _mm512_storeu_si512(v_result_ptr, v_op); + + ++v_result_ptr; + ++v_op_ptr; + } +} +#endif + +} // namespace hexl +} // namespace intel diff --git a/hexl_ser/eltwise/eltwise-cmp-add-avx512.hpp b/hexl_ser/eltwise/eltwise-cmp-add-avx512.hpp new file mode 100644 index 00000000..4142325a --- /dev/null +++ b/hexl_ser/eltwise/eltwise-cmp-add-avx512.hpp @@ -0,0 +1,26 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +#include "hexl/util/util.hpp" + +namespace intel { +namespace hexl { + +/// @brief Computes element-wise conditional addition. +/// @param[out] result Stores the result +/// @param[in] operand1 Vector of elements to compare +/// @param[in] n Number of elements in \p operand1 +/// @param[in] cmp Comparison operation +/// @param[in] bound Scalar to compare against +/// @param[in] diff Scalar to conditionally add +/// @details Computes result[i] = cmp(operand1[i], bound) ? operand1[i] + +/// diff : operand1[i] for all \f$i=0, ..., n-1\f$. +void EltwiseCmpAddAVX512(uint64_t* result, const uint64_t* operand1, uint64_t n, + CMPINT cmp, uint64_t bound, uint64_t diff); + +} // namespace hexl +} // namespace intel diff --git a/hexl_ser/eltwise/eltwise-cmp-add-internal.hpp b/hexl_ser/eltwise/eltwise-cmp-add-internal.hpp new file mode 100644 index 00000000..2a95eb5c --- /dev/null +++ b/hexl_ser/eltwise/eltwise-cmp-add-internal.hpp @@ -0,0 +1,26 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +#include "hexl/util/util.hpp" + +namespace intel { +namespace hexl { + +/// @brief Computes element-wise conditional addition. +/// @param[out] result Stores the result +/// @param[in] operand1 Vector of elements to compare +/// @param[in] n Number of elements in \p operand1 +/// @param[in] cmp Comparison operation +/// @param[in] bound Scalar to compare against +/// @param[in] diff Scalar to conditionally add +/// @details Computes result[i] = cmp(operand1[i], bound) ? operand1[i] + +/// diff : operand1[i] for all \f$i=0, ..., n-1\f$. +void EltwiseCmpAddNative(uint64_t* result, const uint64_t* operand1, uint64_t n, + CMPINT cmp, uint64_t bound, uint64_t diff); + +} // namespace hexl +} // namespace intel diff --git a/hexl_ser/eltwise/eltwise-cmp-add.cpp b/hexl_ser/eltwise/eltwise-cmp-add.cpp new file mode 100644 index 00000000..399f993c --- /dev/null +++ b/hexl_ser/eltwise/eltwise-cmp-add.cpp @@ -0,0 +1,109 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "hexl/eltwise/eltwise-cmp-add.hpp" + +#include "eltwise/eltwise-cmp-add-avx512.hpp" +#include "eltwise/eltwise-cmp-add-internal.hpp" +#include "hexl/logging/logging.hpp" +#include "hexl/number-theory/number-theory.hpp" +#include "hexl/util/check.hpp" +#include "util/cpu-features.hpp" + +namespace intel { +namespace hexl { + +void EltwiseCmpAdd(uint64_t* result, const uint64_t* operand1, uint64_t n, + CMPINT cmp, uint64_t bound, uint64_t diff) { + HEXL_CHECK(result != nullptr, "Require result != nullptr"); + HEXL_CHECK(operand1 != nullptr, "Require operand1 != nullptr"); + HEXL_CHECK(n != 0, "Require n != 0"); + HEXL_CHECK(diff != 0, "Require diff != 0"); + +#ifdef HEXL_HAS_AVX512DQ + if (has_avx512dq) { + EltwiseCmpAddAVX512(result, operand1, n, cmp, bound, diff); + return; + } +#endif + EltwiseCmpAddNative(result, operand1, n, cmp, bound, diff); +} + +void EltwiseCmpAddNative(uint64_t* result, const uint64_t* operand1, uint64_t n, + CMPINT cmp, uint64_t bound, uint64_t diff) { + HEXL_CHECK(result != nullptr, "Require result != nullptr"); + HEXL_CHECK(operand1 != nullptr, "Require operand1 != nullptr"); + HEXL_CHECK(n != 0, "Require n != 0"); + HEXL_CHECK(diff != 0, "Require diff != 0"); + + switch (cmp) { + case CMPINT::EQ: { + for (size_t i = 0; i < n; ++i) { + if (operand1[i] == bound) { + result[i] = operand1[i] + diff; + } else { + result[i] = operand1[i]; + } + } + break; + } + case CMPINT::LT: + for (size_t i = 0; i < n; ++i) { + if (operand1[i] < bound) { + result[i] = operand1[i] + diff; + } else { + result[i] = operand1[i]; + } + } + break; + case CMPINT::LE: + for (size_t i = 0; i < n; ++i) { + if (operand1[i] <= bound) { + result[i] = operand1[i] + diff; + } else { + result[i] = operand1[i]; + } + } + break; + case CMPINT::FALSE: + for (size_t i = 0; i < n; ++i) { + result[i] = operand1[i]; + } + break; + case CMPINT::NE: + for (size_t i = 0; i < n; ++i) { + if (operand1[i] != bound) { + result[i] = operand1[i] + diff; + } else { + result[i] = operand1[i]; + } + } + break; + case CMPINT::NLT: + for (size_t i = 0; i < n; ++i) { + if (operand1[i] >= bound) { + result[i] = operand1[i] + diff; + } else { + result[i] = operand1[i]; + } + } + break; + case CMPINT::NLE: + for (size_t i = 0; i < n; ++i) { + if (operand1[i] > bound) { + result[i] = operand1[i] + diff; + } else { + result[i] = operand1[i]; + } + } + break; + case CMPINT::TRUE: + for (size_t i = 0; i < n; ++i) { + result[i] = operand1[i] + diff; + } + break; + } +} + +} // namespace hexl +} // namespace intel diff --git a/hexl_ser/eltwise/eltwise-cmp-sub-mod-avx512.cpp b/hexl_ser/eltwise/eltwise-cmp-sub-mod-avx512.cpp new file mode 100644 index 00000000..4cda51d3 --- /dev/null +++ b/hexl_ser/eltwise/eltwise-cmp-sub-mod-avx512.cpp @@ -0,0 +1,39 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "eltwise/eltwise-cmp-sub-mod-avx512.hpp" + +#include + +#include "hexl/util/util.hpp" + +namespace intel { +namespace hexl { + +/// @brief Computes element-wise conditional modular subtraction. +/// @param[out] result Stores the result +/// @param[in] operand1 Vector of elements to compare +/// @param[in] n Number of elements in \p operand1 +/// @param[in] modulus Modulus to reduce by +/// @param[in] cmp Comparison function +/// @param[in] bound Scalar to compare against +/// @param[in] diff Scalar to subtract by +/// @details Computes \p result[i] = (\p cmp(\p operand1, \p bound)) ? (\p +/// operand1 - \p diff) mod \p modulus : \p operand1 for all i=0, ..., n-1 + +#ifdef HEXL_HAS_AVX512DQ +template void EltwiseCmpSubModAVX512<64>(uint64_t* result, + const uint64_t* operand1, uint64_t n, + uint64_t modulus, CMPINT cmp, + uint64_t bound, uint64_t diff); +#endif + +#ifdef HEXL_HAS_AVX512IFMA +template void EltwiseCmpSubModAVX512<52>(uint64_t* result, + const uint64_t* operand1, uint64_t n, + uint64_t modulus, CMPINT cmp, + uint64_t bound, uint64_t diff); +#endif + +} // namespace hexl +} // namespace intel diff --git a/hexl_ser/eltwise/eltwise-cmp-sub-mod-avx512.hpp b/hexl_ser/eltwise/eltwise-cmp-sub-mod-avx512.hpp new file mode 100644 index 00000000..ff5a9421 --- /dev/null +++ b/hexl_ser/eltwise/eltwise-cmp-sub-mod-avx512.hpp @@ -0,0 +1,87 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include + +#include "eltwise/eltwise-cmp-sub-mod-internal.hpp" +#include "hexl/number-theory/number-theory.hpp" +#include "hexl/util/check.hpp" +#include "util/avx512-util.hpp" + +namespace intel { +namespace hexl { + +#ifdef HEXL_HAS_AVX512DQ +template +void EltwiseCmpSubModAVX512(uint64_t* result, const uint64_t* operand1, + uint64_t n, uint64_t modulus, CMPINT cmp, + uint64_t bound, uint64_t diff) { + HEXL_CHECK(result != nullptr, "Require result != nullptr"); + HEXL_CHECK(operand1 != nullptr, "Require operand1 != nullptr"); + HEXL_CHECK(n != 0, "Require n != 0") + HEXL_CHECK(modulus > 1, "Require modulus > 1"); + HEXL_CHECK(diff != 0, "Require diff != 0"); + HEXL_CHECK(diff < modulus, "Diff " << diff << " >= modulus " << modulus); + + uint64_t n_mod_8 = n % 8; + if (n_mod_8 != 0) { + EltwiseCmpSubModNative(result, operand1, n_mod_8, modulus, cmp, bound, + diff); + operand1 += n_mod_8; + result += n_mod_8; + n -= n_mod_8; + } + HEXL_CHECK(diff < modulus, "Diff " << diff << " >= modulus " << modulus); + + const __m512i* v_op_ptr = reinterpret_cast(operand1); + __m512i* v_result_ptr = reinterpret_cast<__m512i*>(result); + __m512i v_bound = _mm512_set1_epi64(static_cast(bound)); + __m512i v_diff = _mm512_set1_epi64(static_cast(diff)); + __m512i v_modulus = _mm512_set1_epi64(static_cast(modulus)); + + uint64_t mu = MultiplyFactor(1, BitShift, modulus).BarrettFactor(); + __m512i v_mu = _mm512_set1_epi64(static_cast(mu)); + + // Multi-word Barrett reduction precomputation + constexpr int64_t beta = -2; + const uint64_t ceil_log_mod = Log2(modulus) + 1; // "n" from Algorithm 2 + uint64_t prod_right_shift = ceil_log_mod + beta; + __m512i v_neg_mod = _mm512_set1_epi64(-static_cast(modulus)); + + uint64_t alpha = BitShift - 2; + uint64_t mu_64 = + MultiplyFactor(uint64_t(1) << (ceil_log_mod + alpha - BitShift), BitShift, + modulus) + .BarrettFactor(); + + if (BitShift == 64) { + // Single-worded Barrett reduction. + mu_64 = MultiplyFactor(1, 64, modulus).BarrettFactor(); + } + + __m512i v_mu_64 = _mm512_set1_epi64(static_cast(mu_64)); + + for (size_t i = n / 8; i > 0; --i) { + __m512i v_op = _mm512_loadu_si512(v_op_ptr); + __mmask8 op_le_cmp = _mm512_hexl_cmp_epu64_mask(v_op, v_bound, Not(cmp)); + + v_op = _mm512_hexl_barrett_reduce64( + v_op, v_modulus, v_mu_64, v_mu, prod_right_shift, v_neg_mod); + + __m512i v_to_add = _mm512_hexl_cmp_epi64(v_op, v_diff, CMPINT::LT, modulus); + v_to_add = _mm512_sub_epi64(v_to_add, v_diff); + v_to_add = _mm512_mask_set1_epi64(v_to_add, op_le_cmp, 0); + + v_op = _mm512_add_epi64(v_op, v_to_add); + _mm512_storeu_si512(v_result_ptr, v_op); + ++v_op_ptr; + ++v_result_ptr; + } +} +#endif + +} // namespace hexl +} // namespace intel diff --git a/hexl_ser/eltwise/eltwise-cmp-sub-mod-internal.hpp b/hexl_ser/eltwise/eltwise-cmp-sub-mod-internal.hpp new file mode 100644 index 00000000..f988058e --- /dev/null +++ b/hexl_ser/eltwise/eltwise-cmp-sub-mod-internal.hpp @@ -0,0 +1,28 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +#include "hexl/util/util.hpp" + +namespace intel { +namespace hexl { + +/// @brief Computes element-wise conditional modular subtraction. +/// @param[out] result Stores the result +/// @param[in] operand1 Vector of elements to compare +/// @param[in] n Number of elements in \p operand1 +/// @param[in] modulus Modulus to reduce by +/// @param[in] cmp Comparison function +/// @param[in] bound Scalar to compare against +/// @param[in] diff Scalar to subtract by +/// @details Computes \p result[i] = (\p cmp(\p operand1, \p bound)) ? (\p +/// operand1 - \p diff) mod \p modulus : \p operand1 for all i=0, ..., n-1 +void EltwiseCmpSubModNative(uint64_t* result, const uint64_t* operand1, + uint64_t n, uint64_t modulus, CMPINT cmp, + uint64_t bound, uint64_t diff); + +} // namespace hexl +} // namespace intel diff --git a/hexl_ser/eltwise/eltwise-cmp-sub-mod.cpp b/hexl_ser/eltwise/eltwise-cmp-sub-mod.cpp new file mode 100644 index 00000000..b3526bd1 --- /dev/null +++ b/hexl_ser/eltwise/eltwise-cmp-sub-mod.cpp @@ -0,0 +1,69 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "hexl/eltwise/eltwise-cmp-sub-mod.hpp" + +#include "eltwise/eltwise-cmp-sub-mod-avx512.hpp" +#include "eltwise/eltwise-cmp-sub-mod-internal.hpp" +#include "hexl/logging/logging.hpp" +#include "hexl/number-theory/number-theory.hpp" +#include "hexl/util/check.hpp" +#include "hexl/util/util.hpp" +#include "util/cpu-features.hpp" +#include "util/util-internal.hpp" + +namespace intel { +namespace hexl { + +void EltwiseCmpSubMod(uint64_t* result, const uint64_t* operand1, uint64_t n, + uint64_t modulus, CMPINT cmp, uint64_t bound, + uint64_t diff) { + HEXL_CHECK(result != nullptr, "Require result != nullptr"); + HEXL_CHECK(operand1 != nullptr, "Require operand1 != nullptr"); + HEXL_CHECK(n != 0, "Require n != 0"); + HEXL_CHECK(modulus > 1, "Require modulus > 1"); + HEXL_CHECK(diff != 0, "Require diff != 0"); + +#ifdef HEXL_HAS_AVX512IFMA + if (has_avx512ifma) { + if (modulus < (1ULL << 52)) { + EltwiseCmpSubModAVX512<52>(result, operand1, n, modulus, cmp, bound, + diff); + return; + } + } +#endif + +#ifdef HEXL_HAS_AVX512DQ + if (has_avx512dq) { + EltwiseCmpSubModAVX512<64>(result, operand1, n, modulus, cmp, bound, diff); + return; + } +#endif + EltwiseCmpSubModNative(result, operand1, n, modulus, cmp, bound, diff); + return; +} + +void EltwiseCmpSubModNative(uint64_t* result, const uint64_t* operand1, + uint64_t n, uint64_t modulus, CMPINT cmp, + uint64_t bound, uint64_t diff) { + HEXL_CHECK(result != nullptr, "Require result != nullptr"); + HEXL_CHECK(operand1 != nullptr, "Require operand1 != nullptr"); + HEXL_CHECK(n != 0, "Require n != 0") + HEXL_CHECK(modulus > 1, "Require modulus > 1"); + HEXL_CHECK(diff != 0, "Require diff != 0"); + HEXL_CHECK(diff < modulus, "Diff " << diff << " >= modulus " << modulus); + + for (size_t i = 0; i < n; ++i) { + uint64_t op = operand1[i]; + bool op_cmp = Compare(cmp, op, bound); + op %= modulus; + if (op_cmp) { + op = SubUIntMod(op, diff, modulus); + } + result[i] = op; + } +} + +} // namespace hexl +} // namespace intel diff --git a/hexl_ser/eltwise/eltwise-fma-mod-avx512.cpp b/hexl_ser/eltwise/eltwise-fma-mod-avx512.cpp new file mode 100644 index 00000000..fa6a5453 --- /dev/null +++ b/hexl_ser/eltwise/eltwise-fma-mod-avx512.cpp @@ -0,0 +1,156 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "eltwise/eltwise-fma-mod-avx512.hpp" + +#include + +#include "hexl/eltwise/eltwise-fma-mod.hpp" +#include "hexl/number-theory/number-theory.hpp" +#include "hexl/util/check.hpp" +#include "util/avx512-util.hpp" + +namespace intel { +namespace hexl { + +#ifdef HEXL_HAS_AVX512IFMA +template void EltwiseFMAModAVX512<52, 1>(uint64_t* result, const uint64_t* arg1, + uint64_t arg2, const uint64_t* arg3, + uint64_t n, uint64_t modulus); +template void EltwiseFMAModAVX512<52, 2>(uint64_t* result, const uint64_t* arg1, + uint64_t arg2, const uint64_t* arg3, + uint64_t n, uint64_t modulus); +template void EltwiseFMAModAVX512<52, 4>(uint64_t* result, const uint64_t* arg1, + uint64_t arg2, const uint64_t* arg3, + uint64_t n, uint64_t modulus); +template void EltwiseFMAModAVX512<52, 8>(uint64_t* result, const uint64_t* arg1, + uint64_t arg2, const uint64_t* arg3, + uint64_t n, uint64_t modulus); +#endif + +#ifdef HEXL_HAS_AVX512DQ +template void EltwiseFMAModAVX512<64, 1>(uint64_t* result, const uint64_t* arg1, + uint64_t arg2, const uint64_t* arg3, + uint64_t n, uint64_t modulus); +template void EltwiseFMAModAVX512<64, 2>(uint64_t* result, const uint64_t* arg1, + uint64_t arg2, const uint64_t* arg3, + uint64_t n, uint64_t modulus); +template void EltwiseFMAModAVX512<64, 4>(uint64_t* result, const uint64_t* arg1, + uint64_t arg2, const uint64_t* arg3, + uint64_t n, uint64_t modulus); +template void EltwiseFMAModAVX512<64, 8>(uint64_t* result, const uint64_t* arg1, + uint64_t arg2, const uint64_t* arg3, + uint64_t n, uint64_t modulus); + +#endif + +#ifdef HEXL_HAS_AVX512DQ + +/// uses Shoup's modular multiplication. See Algorithm 4 of +/// https://arxiv.org/pdf/2012.01968.pdf +template +void EltwiseFMAModAVX512(uint64_t* result, const uint64_t* arg1, uint64_t arg2, + const uint64_t* arg3, uint64_t n, uint64_t modulus) { + HEXL_CHECK(modulus < MaximumValue(BitShift), + "Modulus " << modulus << " exceeds bit shift bound " + << MaximumValue(BitShift)); + HEXL_CHECK(modulus != 0, "Require modulus != 0"); + + HEXL_CHECK(arg1, "arg1 == nullptr"); + HEXL_CHECK(result, "result == nullptr"); + + HEXL_CHECK_BOUNDS(arg1, n, InputModFactor * modulus, + "arg1 exceeds bound " << (InputModFactor * modulus)); + HEXL_CHECK_BOUNDS(&arg2, 1, InputModFactor * modulus, + "arg2 exceeds bound " << (InputModFactor * modulus)); + HEXL_CHECK(BitShift == 52 || BitShift == 64, + "Invalid bitshift " << BitShift << "; need 52 or 64"); + + uint64_t n_mod_8 = n % 8; + if (n_mod_8 != 0) { + EltwiseFMAModNative(result, arg1, arg2, arg3, n_mod_8, + modulus); + arg1 += n_mod_8; + if (arg3 != nullptr) { + arg3 += n_mod_8; + } + result += n_mod_8; + n -= n_mod_8; + } + + uint64_t twice_modulus = 2 * modulus; + uint64_t four_times_modulus = 4 * modulus; + arg2 = ReduceMod(arg2, modulus, &twice_modulus, + &four_times_modulus); + uint64_t arg2_barr = MultiplyFactor(arg2, BitShift, modulus).BarrettFactor(); + + __m512i varg2_barr = _mm512_set1_epi64(static_cast(arg2_barr)); + + __m512i vmodulus = _mm512_set1_epi64(static_cast(modulus)); + __m512i vneg_modulus = _mm512_set1_epi64(-static_cast(modulus)); + __m512i v2_modulus = _mm512_set1_epi64(static_cast(2 * modulus)); + __m512i v4_modulus = _mm512_set1_epi64(static_cast(4 * modulus)); + const __m512i* vp_arg1 = reinterpret_cast(arg1); + __m512i varg2 = _mm512_set1_epi64(static_cast(arg2)); + varg2 = _mm512_hexl_small_mod_epu64(varg2, vmodulus, + &v2_modulus, &v4_modulus); + + __m512i* vp_result = reinterpret_cast<__m512i*>(result); + + if (arg3) { + const __m512i* vp_arg3 = reinterpret_cast(arg3); + HEXL_LOOP_UNROLL_8 + for (size_t i = n / 8; i > 0; --i) { + __m512i varg1 = _mm512_loadu_si512(vp_arg1); + __m512i varg3 = _mm512_loadu_si512(vp_arg3); + + varg1 = _mm512_hexl_small_mod_epu64( + varg1, vmodulus, &v2_modulus, &v4_modulus); + varg3 = _mm512_hexl_small_mod_epu64( + varg3, vmodulus, &v2_modulus, &v4_modulus); + + __m512i va_times_b = _mm512_hexl_mullo_epi(varg1, varg2); + __m512i vq = _mm512_hexl_mulhi_epi(varg1, varg2_barr); + + // Compute vq in [0, 2 * p) where p is the modulus + // a * b - q * p + vq = _mm512_hexl_mullo_add_lo_epi(va_times_b, vq, vneg_modulus); + + // Add arg3, bringing vq to [0, 3 * p) + vq = _mm512_add_epi64(vq, varg3); + // Reduce to [0, p) + vq = _mm512_hexl_small_mod_epu64<4>(vq, vmodulus, &v2_modulus); + + _mm512_storeu_si512(vp_result, vq); + + ++vp_arg1; + ++vp_result; + ++vp_arg3; + } + } else { // arg3 == nullptr + HEXL_LOOP_UNROLL_8 + for (size_t i = n / 8; i > 0; --i) { + __m512i varg1 = _mm512_loadu_si512(vp_arg1); + varg1 = _mm512_hexl_small_mod_epu64( + varg1, vmodulus, &v2_modulus, &v4_modulus); + + __m512i va_times_b = _mm512_hexl_mullo_epi(varg1, varg2); + __m512i vq = _mm512_hexl_mulhi_epi(varg1, varg2_barr); + + // Compute vq in [0, 2 * p) where p is the modulus + // a * b - q * p + vq = _mm512_hexl_mullo_add_lo_epi(va_times_b, vq, vneg_modulus); + // Conditional Barrett subtraction + vq = _mm512_hexl_small_mod_epu64(vq, vmodulus); + _mm512_storeu_si512(vp_result, vq); + + ++vp_arg1; + ++vp_result; + } + } +} + +#endif + +} // namespace hexl +} // namespace intel diff --git a/hexl_ser/eltwise/eltwise-fma-mod-avx512.hpp b/hexl_ser/eltwise/eltwise-fma-mod-avx512.hpp new file mode 100644 index 00000000..f0750165 --- /dev/null +++ b/hexl_ser/eltwise/eltwise-fma-mod-avx512.hpp @@ -0,0 +1,22 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +#include "eltwise/eltwise-fma-mod-internal.hpp" + +namespace intel { +namespace hexl { + +#ifdef HEXL_HAS_AVX512DQ + +template +void EltwiseFMAModAVX512(uint64_t* result, const uint64_t* arg1, uint64_t arg2, + const uint64_t* arg3, uint64_t n, uint64_t modulus); + +#endif + +} // namespace hexl +} // namespace intel diff --git a/hexl_ser/eltwise/eltwise-fma-mod-internal.hpp b/hexl_ser/eltwise/eltwise-fma-mod-internal.hpp new file mode 100644 index 00000000..673ab61f --- /dev/null +++ b/hexl_ser/eltwise/eltwise-fma-mod-internal.hpp @@ -0,0 +1,42 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "hexl/number-theory/number-theory.hpp" + +namespace intel { +namespace hexl { + +template +void EltwiseFMAModNative(uint64_t* result, const uint64_t* arg1, uint64_t arg2, + const uint64_t* arg3, uint64_t n, uint64_t modulus) { + uint64_t twice_modulus = 2 * modulus; + uint64_t four_times_modulus = 4 * modulus; + arg2 = ReduceMod(arg2, modulus, &twice_modulus, + &four_times_modulus); + + MultiplyFactor mf(arg2, 64, modulus); + if (arg3) { + for (size_t i = 0; i < n; ++i) { + uint64_t arg1_val = ReduceMod( + *arg1++, modulus, &twice_modulus, &four_times_modulus); + uint64_t arg3_val = ReduceMod( + *arg3++, modulus, &twice_modulus, &four_times_modulus); + + uint64_t result_val = + MultiplyMod(arg1_val, arg2, mf.BarrettFactor(), modulus); + *result = AddUIntMod(result_val, arg3_val, modulus); + result++; + } + } else { // arg3 == nullptr + for (size_t i = 0; i < n; ++i) { + uint64_t arg1_val = ReduceMod( + *arg1++, modulus, &twice_modulus, &four_times_modulus); + *result++ = MultiplyMod(arg1_val, arg2, mf.BarrettFactor(), modulus); + } + } +} + +} // namespace hexl +} // namespace intel diff --git a/hexl_ser/eltwise/eltwise-fma-mod.cpp b/hexl_ser/eltwise/eltwise-fma-mod.cpp new file mode 100644 index 00000000..03478fc0 --- /dev/null +++ b/hexl_ser/eltwise/eltwise-fma-mod.cpp @@ -0,0 +1,104 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "hexl/eltwise/eltwise-fma-mod.hpp" + +#include + +#include "eltwise/eltwise-fma-mod-avx512.hpp" +#include "eltwise/eltwise-fma-mod-internal.hpp" +#include "hexl/logging/logging.hpp" +#include "hexl/number-theory/number-theory.hpp" +#include "util/cpu-features.hpp" + +namespace intel { +namespace hexl { + +void EltwiseFMAMod(uint64_t* result, const uint64_t* arg1, uint64_t arg2, + const uint64_t* arg3, uint64_t n, uint64_t modulus, + uint64_t input_mod_factor) { + HEXL_CHECK(result != nullptr, "Require result != nullptr"); + HEXL_CHECK(arg1 != nullptr, "Require arg1 != nullptr"); + HEXL_CHECK(n != 0, "Require n != 0") + HEXL_CHECK(modulus > 1, "Require modulus > 1"); + HEXL_CHECK(modulus < (1ULL << 61), "Require modulus < (1ULL << 61)"); + HEXL_CHECK( + input_mod_factor == 1 || input_mod_factor == 2 || input_mod_factor == 4 || + input_mod_factor == 8, + "input_mod_factor must be 1, 2, 4, or 8. Got " << input_mod_factor); + HEXL_CHECK( + arg2 < input_mod_factor * modulus, + "arg2 " << arg2 << " exceeds bound " << (input_mod_factor * modulus)); + + HEXL_CHECK_BOUNDS(arg1, n, input_mod_factor * modulus, + "arg1 value " << (*std::max_element(arg1, arg1 + n)) + << " in EltwiseFMAMod exceeds bound " + << (input_mod_factor * modulus)); + HEXL_CHECK(arg3 == nullptr || (*std::max_element(arg3, arg3 + n) < + (input_mod_factor * modulus)), + "arg3 value in EltwiseFMAMod exceeds bound " + << (input_mod_factor * modulus)); + +#ifdef HEXL_HAS_AVX512IFMA + if (has_avx512ifma && input_mod_factor * modulus < (1ULL << 51)) { + HEXL_VLOG(3, "Calling 52-bit EltwiseFMAModAVX512"); + + switch (input_mod_factor) { + case 1: + EltwiseFMAModAVX512<52, 1>(result, arg1, arg2, arg3, n, modulus); + break; + case 2: + EltwiseFMAModAVX512<52, 2>(result, arg1, arg2, arg3, n, modulus); + break; + case 4: + EltwiseFMAModAVX512<52, 4>(result, arg1, arg2, arg3, n, modulus); + break; + case 8: + EltwiseFMAModAVX512<52, 8>(result, arg1, arg2, arg3, n, modulus); + break; + } + return; + } +#endif + +#ifdef HEXL_HAS_AVX512DQ + if (has_avx512dq) { + HEXL_VLOG(3, "Calling 64-bit EltwiseFMAModAVX512"); + + switch (input_mod_factor) { + case 1: + EltwiseFMAModAVX512<64, 1>(result, arg1, arg2, arg3, n, modulus); + break; + case 2: + EltwiseFMAModAVX512<64, 2>(result, arg1, arg2, arg3, n, modulus); + break; + case 4: + EltwiseFMAModAVX512<64, 4>(result, arg1, arg2, arg3, n, modulus); + break; + case 8: + EltwiseFMAModAVX512<64, 8>(result, arg1, arg2, arg3, n, modulus); + break; + } + return; + } +#endif + + HEXL_VLOG(3, "Calling EltwiseFMAModNative"); + switch (input_mod_factor) { + case 1: + EltwiseFMAModNative<1>(result, arg1, arg2, arg3, n, modulus); + break; + case 2: + EltwiseFMAModNative<2>(result, arg1, arg2, arg3, n, modulus); + break; + case 4: + EltwiseFMAModNative<4>(result, arg1, arg2, arg3, n, modulus); + break; + case 8: + EltwiseFMAModNative<8>(result, arg1, arg2, arg3, n, modulus); + break; + } +} + +} // namespace hexl +} // namespace intel diff --git a/hexl_ser/eltwise/eltwise-mult-mod-avx512.hpp b/hexl_ser/eltwise/eltwise-mult-mod-avx512.hpp new file mode 100644 index 00000000..e00aa702 --- /dev/null +++ b/hexl_ser/eltwise/eltwise-mult-mod-avx512.hpp @@ -0,0 +1,81 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +#include "eltwise/eltwise-mult-mod-internal.hpp" +#include "hexl/eltwise/eltwise-mult-mod.hpp" +#include "hexl/number-theory/number-theory.hpp" + +namespace intel { +namespace hexl { + +#ifdef HEXL_HAS_AVX512DQ + +/// @brief Multiplies two vectors elementwise with modular reduction +/// @param[in] result Result of element-wise multiplication +/// @param[in] operand1 Vector of elements to multiply. Each element must be +/// less than the modulus. +/// @param[in] operand2 Vector of elements to multiply. Each element must be +/// less than the modulus. +/// @param[in] n Number of elements in each vector +/// @param[in] modulus Modulus with which to perform modular reduction +/// @param[in] input_mod_factor Assumes input elements are in [0, +/// input_mod_factor * p) Must be 1, 2 or 4. +/// @details Computes \p result[i] = (\p operand1[i] * \p operand2[i]) mod \p +/// modulus for i=0, ..., \p n - 1 +/// @details Barrett's algorithm for vector-vector modular multiplication +/// (Algorithm 1 from https://hal.archives-ouvertes.fr/hal-01215845/document) +/// using AVX512IFMA +template +void EltwiseMultModAVX512IFMAInt(uint64_t* result, const uint64_t* operand1, + const uint64_t* operand2, uint64_t n, + uint64_t modulus); + +/// @brief Multiplies two vectors elementwise with modular reduction +/// @param[in] result Result of element-wise multiplication +/// @param[in] operand1 Vector of elements to multiply. Each element must be +/// less than the modulus. +/// @param[in] operand2 Vector of elements to multiply. Each element must be +/// less than the modulus. +/// @param[in] n Number of elements in each vector +/// @param[in] modulus Modulus with which to perform modular reduction +/// @param[in] input_mod_factor Assumes input elements are in [0, +/// input_mod_factor * p) Must be 1, 2 or 4. +/// @details Computes \p result[i] = (\p operand1[i] * \p operand2[i]) mod \p +/// modulus for i=0, ..., \p n - 1 +/// @details Barrett's algorithm for vector-vector modular multiplication +/// (Algorithm 1 from https://hal.archives-ouvertes.fr/hal-01215845/document) +/// using AVX512DQ +template +void EltwiseMultModAVX512DQInt(uint64_t* result, const uint64_t* operand1, + const uint64_t* operand2, uint64_t n, + uint64_t modulus); + +/// @brief Multiplies two vectors elementwise with modular reduction +/// @param[in] result Result of element-wise multiplication +/// @param[in] operand1 Vector of elements to multiply. Each element must be +/// less than the modulus. +/// @param[in] operand2 Vector of elements to multiply. Each element must be +/// less than the modulus. +/// @param[in] n Number of elements in each vector +/// @param[in] modulus Modulus with which to perform modular reduction +/// @param[in] input_mod_factor Assumes input elements are in [0, +/// input_mod_factor * p) Must be 1, 2 or 4. +/// @details Computes \p result[i] = (\p operand1[i] * \p operand2[i]) mod \p +/// modulus for i=0, ..., \p n - 1 +/// @details Function 18 on page 19 of https://arxiv.org/pdf/1407.3383.pdf +/// See also Algorithm 2/3 of +/// https://hal.archives-ouvertes.fr/hal-02552673/document +/// Uses floating-point arithmetic +template +void EltwiseMultModAVX512Float(uint64_t* result, const uint64_t* operand1, + const uint64_t* operand2, uint64_t n, + uint64_t modulus); + +#endif // HEXL_HAS_AVX512DQ + +} // namespace hexl +} // namespace intel diff --git a/hexl_ser/eltwise/eltwise-mult-mod-avx512dq.cpp b/hexl_ser/eltwise/eltwise-mult-mod-avx512dq.cpp new file mode 100644 index 00000000..cbf0ec0f --- /dev/null +++ b/hexl_ser/eltwise/eltwise-mult-mod-avx512dq.cpp @@ -0,0 +1,838 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include +#include + +#include + +#include "eltwise/eltwise-mult-mod-avx512.hpp" +#include "eltwise/eltwise-mult-mod-internal.hpp" +#include "hexl/eltwise/eltwise-mult-mod.hpp" +#include "hexl/logging/logging.hpp" +#include "hexl/number-theory/number-theory.hpp" +#include "hexl/util/check.hpp" +#include "hexl/util/compiler.hpp" +#include "hexl/util/defines.hpp" +#include "util/avx512-util.hpp" + +namespace intel { +namespace hexl { + +#ifdef HEXL_HAS_AVX512DQ + +template void EltwiseMultModAVX512Float<1>(uint64_t* result, + const uint64_t* operand1, + const uint64_t* operand2, uint64_t n, + uint64_t modulus); +template void EltwiseMultModAVX512Float<2>(uint64_t* result, + const uint64_t* operand1, + const uint64_t* operand2, uint64_t n, + uint64_t modulus); +template void EltwiseMultModAVX512Float<4>(uint64_t* result, + const uint64_t* operand1, + const uint64_t* operand2, uint64_t n, + uint64_t modulus); + +template void EltwiseMultModAVX512DQInt<1>(uint64_t* result, + const uint64_t* operand1, + const uint64_t* operand2, uint64_t n, + uint64_t modulus); +template void EltwiseMultModAVX512DQInt<2>(uint64_t* result, + const uint64_t* operand1, + const uint64_t* operand2, uint64_t n, + uint64_t modulus); +template void EltwiseMultModAVX512DQInt<4>(uint64_t* result, + const uint64_t* operand1, + const uint64_t* operand2, uint64_t n, + uint64_t modulus); + +#endif + +#ifdef HEXL_HAS_AVX512DQ + +template +void EltwiseMultModAVX512DQIntLoopUnroll(__m512i* vp_result, + const __m512i* vp_operand1, + const __m512i* vp_operand2, + __m512i v_barr_lo, __m512i v_modulus, + __m512i v_twice_mod) { + constexpr size_t manual_unroll_factor = 16; + constexpr size_t avx512_64bit_count = 8; + constexpr size_t loop_count = + CoeffCount / (manual_unroll_factor * avx512_64bit_count); + + static_assert(loop_count > 0, "loop_count too small for unrolling"); + static_assert(CoeffCount % (manual_unroll_factor * avx512_64bit_count) == 0, + "CoeffCount must be a factor of manual_unroll_factor * " + "avx512_64bit_count"); + + HEXL_UNUSED(v_twice_mod); + HEXL_LOOP_UNROLL_4 + for (size_t i = loop_count; i > 0; --i) { + __m512i x1 = _mm512_loadu_si512(vp_operand1++); + __m512i y1 = _mm512_loadu_si512(vp_operand2++); + __m512i x2 = _mm512_loadu_si512(vp_operand1++); + __m512i y2 = _mm512_loadu_si512(vp_operand2++); + __m512i x3 = _mm512_loadu_si512(vp_operand1++); + __m512i y3 = _mm512_loadu_si512(vp_operand2++); + __m512i x4 = _mm512_loadu_si512(vp_operand1++); + __m512i y4 = _mm512_loadu_si512(vp_operand2++); + __m512i x5 = _mm512_loadu_si512(vp_operand1++); + __m512i y5 = _mm512_loadu_si512(vp_operand2++); + __m512i x6 = _mm512_loadu_si512(vp_operand1++); + __m512i y6 = _mm512_loadu_si512(vp_operand2++); + __m512i x7 = _mm512_loadu_si512(vp_operand1++); + __m512i y7 = _mm512_loadu_si512(vp_operand2++); + __m512i x8 = _mm512_loadu_si512(vp_operand1++); + __m512i y8 = _mm512_loadu_si512(vp_operand2++); + __m512i x9 = _mm512_loadu_si512(vp_operand1++); + __m512i y9 = _mm512_loadu_si512(vp_operand2++); + __m512i x10 = _mm512_loadu_si512(vp_operand1++); + __m512i y10 = _mm512_loadu_si512(vp_operand2++); + __m512i x11 = _mm512_loadu_si512(vp_operand1++); + __m512i y11 = _mm512_loadu_si512(vp_operand2++); + __m512i x12 = _mm512_loadu_si512(vp_operand1++); + __m512i y12 = _mm512_loadu_si512(vp_operand2++); + __m512i x13 = _mm512_loadu_si512(vp_operand1++); + __m512i y13 = _mm512_loadu_si512(vp_operand2++); + __m512i x14 = _mm512_loadu_si512(vp_operand1++); + __m512i y14 = _mm512_loadu_si512(vp_operand2++); + __m512i x15 = _mm512_loadu_si512(vp_operand1++); + __m512i y15 = _mm512_loadu_si512(vp_operand2++); + __m512i x16 = _mm512_loadu_si512(vp_operand1++); + __m512i y16 = _mm512_loadu_si512(vp_operand2++); + + x1 = _mm512_hexl_small_mod_epu64(x1, v_modulus, + &v_twice_mod); + x2 = _mm512_hexl_small_mod_epu64(x2, v_modulus, + &v_twice_mod); + x3 = _mm512_hexl_small_mod_epu64(x3, v_modulus, + &v_twice_mod); + x4 = _mm512_hexl_small_mod_epu64(x4, v_modulus, + &v_twice_mod); + x5 = _mm512_hexl_small_mod_epu64(x5, v_modulus, + &v_twice_mod); + x6 = _mm512_hexl_small_mod_epu64(x6, v_modulus, + &v_twice_mod); + x7 = _mm512_hexl_small_mod_epu64(x7, v_modulus, + &v_twice_mod); + x8 = _mm512_hexl_small_mod_epu64(x8, v_modulus, + &v_twice_mod); + x9 = _mm512_hexl_small_mod_epu64(x9, v_modulus, + &v_twice_mod); + x10 = _mm512_hexl_small_mod_epu64(x10, v_modulus, + &v_twice_mod); + x11 = _mm512_hexl_small_mod_epu64(x11, v_modulus, + &v_twice_mod); + x12 = _mm512_hexl_small_mod_epu64(x12, v_modulus, + &v_twice_mod); + x13 = _mm512_hexl_small_mod_epu64(x13, v_modulus, + &v_twice_mod); + x14 = _mm512_hexl_small_mod_epu64(x14, v_modulus, + &v_twice_mod); + x15 = _mm512_hexl_small_mod_epu64(x15, v_modulus, + &v_twice_mod); + x16 = _mm512_hexl_small_mod_epu64(x16, v_modulus, + &v_twice_mod); + + y1 = _mm512_hexl_small_mod_epu64(y1, v_modulus, + &v_twice_mod); + y2 = _mm512_hexl_small_mod_epu64(y2, v_modulus, + &v_twice_mod); + y3 = _mm512_hexl_small_mod_epu64(y3, v_modulus, + &v_twice_mod); + y4 = _mm512_hexl_small_mod_epu64(y4, v_modulus, + &v_twice_mod); + y5 = _mm512_hexl_small_mod_epu64(y5, v_modulus, + &v_twice_mod); + y6 = _mm512_hexl_small_mod_epu64(y6, v_modulus, + &v_twice_mod); + y7 = _mm512_hexl_small_mod_epu64(y7, v_modulus, + &v_twice_mod); + y8 = _mm512_hexl_small_mod_epu64(y8, v_modulus, + &v_twice_mod); + y9 = _mm512_hexl_small_mod_epu64(y9, v_modulus, + &v_twice_mod); + y10 = _mm512_hexl_small_mod_epu64(y10, v_modulus, + &v_twice_mod); + y11 = _mm512_hexl_small_mod_epu64(y11, v_modulus, + &v_twice_mod); + y12 = _mm512_hexl_small_mod_epu64(y12, v_modulus, + &v_twice_mod); + y13 = _mm512_hexl_small_mod_epu64(y13, v_modulus, + &v_twice_mod); + y14 = _mm512_hexl_small_mod_epu64(y14, v_modulus, + &v_twice_mod); + y15 = _mm512_hexl_small_mod_epu64(y15, v_modulus, + &v_twice_mod); + y16 = _mm512_hexl_small_mod_epu64(y16, v_modulus, + &v_twice_mod); + + __m512i zhi1 = _mm512_hexl_mulhi_epi<64>(x1, y1); + __m512i zhi2 = _mm512_hexl_mulhi_epi<64>(x2, y2); + __m512i zhi3 = _mm512_hexl_mulhi_epi<64>(x3, y3); + __m512i zhi4 = _mm512_hexl_mulhi_epi<64>(x4, y4); + __m512i zhi5 = _mm512_hexl_mulhi_epi<64>(x5, y5); + __m512i zhi6 = _mm512_hexl_mulhi_epi<64>(x6, y6); + __m512i zhi7 = _mm512_hexl_mulhi_epi<64>(x7, y7); + __m512i zhi8 = _mm512_hexl_mulhi_epi<64>(x8, y8); + __m512i zhi9 = _mm512_hexl_mulhi_epi<64>(x9, y9); + __m512i zhi10 = _mm512_hexl_mulhi_epi<64>(x10, y10); + __m512i zhi11 = _mm512_hexl_mulhi_epi<64>(x11, y11); + __m512i zhi12 = _mm512_hexl_mulhi_epi<64>(x12, y12); + __m512i zhi13 = _mm512_hexl_mulhi_epi<64>(x13, y13); + __m512i zhi14 = _mm512_hexl_mulhi_epi<64>(x14, y14); + __m512i zhi15 = _mm512_hexl_mulhi_epi<64>(x15, y15); + __m512i zhi16 = _mm512_hexl_mulhi_epi<64>(x16, y16); + + __m512i zlo1 = _mm512_hexl_mullo_epi<64>(x1, y1); + __m512i zlo2 = _mm512_hexl_mullo_epi<64>(x2, y2); + __m512i zlo3 = _mm512_hexl_mullo_epi<64>(x3, y3); + __m512i zlo4 = _mm512_hexl_mullo_epi<64>(x4, y4); + __m512i zlo5 = _mm512_hexl_mullo_epi<64>(x5, y5); + __m512i zlo6 = _mm512_hexl_mullo_epi<64>(x6, y6); + __m512i zlo7 = _mm512_hexl_mullo_epi<64>(x7, y7); + __m512i zlo8 = _mm512_hexl_mullo_epi<64>(x8, y8); + __m512i zlo9 = _mm512_hexl_mullo_epi<64>(x9, y9); + __m512i zlo10 = _mm512_hexl_mullo_epi<64>(x10, y10); + __m512i zlo11 = _mm512_hexl_mullo_epi<64>(x11, y11); + __m512i zlo12 = _mm512_hexl_mullo_epi<64>(x12, y12); + __m512i zlo13 = _mm512_hexl_mullo_epi<64>(x13, y13); + __m512i zlo14 = _mm512_hexl_mullo_epi<64>(x14, y14); + __m512i zlo15 = _mm512_hexl_mullo_epi<64>(x15, y15); + __m512i zlo16 = _mm512_hexl_mullo_epi<64>(x16, y16); + + __m512i c1 = _mm512_hexl_shrdi_epi64(zlo1, zhi1); + __m512i c2 = _mm512_hexl_shrdi_epi64(zlo2, zhi2); + __m512i c3 = _mm512_hexl_shrdi_epi64(zlo3, zhi3); + __m512i c4 = _mm512_hexl_shrdi_epi64(zlo4, zhi4); + __m512i c5 = _mm512_hexl_shrdi_epi64(zlo5, zhi5); + __m512i c6 = _mm512_hexl_shrdi_epi64(zlo6, zhi6); + __m512i c7 = _mm512_hexl_shrdi_epi64(zlo7, zhi7); + __m512i c8 = _mm512_hexl_shrdi_epi64(zlo8, zhi8); + __m512i c9 = _mm512_hexl_shrdi_epi64(zlo9, zhi9); + __m512i c10 = _mm512_hexl_shrdi_epi64(zlo10, zhi10); + __m512i c11 = _mm512_hexl_shrdi_epi64(zlo11, zhi11); + __m512i c12 = _mm512_hexl_shrdi_epi64(zlo12, zhi12); + __m512i c13 = _mm512_hexl_shrdi_epi64(zlo13, zhi13); + __m512i c14 = _mm512_hexl_shrdi_epi64(zlo14, zhi14); + __m512i c15 = _mm512_hexl_shrdi_epi64(zlo15, zhi15); + __m512i c16 = _mm512_hexl_shrdi_epi64(zlo16, zhi16); + + c1 = _mm512_hexl_mulhi_approx_epi<64>(c1, v_barr_lo); + c2 = _mm512_hexl_mulhi_approx_epi<64>(c2, v_barr_lo); + c3 = _mm512_hexl_mulhi_approx_epi<64>(c3, v_barr_lo); + c4 = _mm512_hexl_mulhi_approx_epi<64>(c4, v_barr_lo); + c5 = _mm512_hexl_mulhi_approx_epi<64>(c5, v_barr_lo); + c6 = _mm512_hexl_mulhi_approx_epi<64>(c6, v_barr_lo); + c7 = _mm512_hexl_mulhi_approx_epi<64>(c7, v_barr_lo); + c8 = _mm512_hexl_mulhi_approx_epi<64>(c8, v_barr_lo); + c9 = _mm512_hexl_mulhi_approx_epi<64>(c9, v_barr_lo); + c10 = _mm512_hexl_mulhi_approx_epi<64>(c10, v_barr_lo); + c11 = _mm512_hexl_mulhi_approx_epi<64>(c11, v_barr_lo); + c12 = _mm512_hexl_mulhi_approx_epi<64>(c12, v_barr_lo); + c13 = _mm512_hexl_mulhi_approx_epi<64>(c13, v_barr_lo); + c14 = _mm512_hexl_mulhi_approx_epi<64>(c14, v_barr_lo); + c15 = _mm512_hexl_mulhi_approx_epi<64>(c15, v_barr_lo); + c16 = _mm512_hexl_mulhi_approx_epi<64>(c16, v_barr_lo); + + __m512i vr1 = _mm512_hexl_mullo_epi<64>(c1, v_modulus); + __m512i vr2 = _mm512_hexl_mullo_epi<64>(c2, v_modulus); + __m512i vr3 = _mm512_hexl_mullo_epi<64>(c3, v_modulus); + __m512i vr4 = _mm512_hexl_mullo_epi<64>(c4, v_modulus); + __m512i vr5 = _mm512_hexl_mullo_epi<64>(c5, v_modulus); + __m512i vr6 = _mm512_hexl_mullo_epi<64>(c6, v_modulus); + __m512i vr7 = _mm512_hexl_mullo_epi<64>(c7, v_modulus); + __m512i vr8 = _mm512_hexl_mullo_epi<64>(c8, v_modulus); + __m512i vr9 = _mm512_hexl_mullo_epi<64>(c9, v_modulus); + __m512i vr10 = _mm512_hexl_mullo_epi<64>(c10, v_modulus); + __m512i vr11 = _mm512_hexl_mullo_epi<64>(c11, v_modulus); + __m512i vr12 = _mm512_hexl_mullo_epi<64>(c12, v_modulus); + __m512i vr13 = _mm512_hexl_mullo_epi<64>(c13, v_modulus); + __m512i vr14 = _mm512_hexl_mullo_epi<64>(c14, v_modulus); + __m512i vr15 = _mm512_hexl_mullo_epi<64>(c15, v_modulus); + __m512i vr16 = _mm512_hexl_mullo_epi<64>(c16, v_modulus); + + vr1 = _mm512_sub_epi64(zlo1, vr1); + vr2 = _mm512_sub_epi64(zlo2, vr2); + vr3 = _mm512_sub_epi64(zlo3, vr3); + vr4 = _mm512_sub_epi64(zlo4, vr4); + vr5 = _mm512_sub_epi64(zlo5, vr5); + vr6 = _mm512_sub_epi64(zlo6, vr6); + vr7 = _mm512_sub_epi64(zlo7, vr7); + vr8 = _mm512_sub_epi64(zlo8, vr8); + vr9 = _mm512_sub_epi64(zlo9, vr9); + vr10 = _mm512_sub_epi64(zlo10, vr10); + vr11 = _mm512_sub_epi64(zlo11, vr11); + vr12 = _mm512_sub_epi64(zlo12, vr12); + vr13 = _mm512_sub_epi64(zlo13, vr13); + vr14 = _mm512_sub_epi64(zlo14, vr14); + vr15 = _mm512_sub_epi64(zlo15, vr15); + vr16 = _mm512_sub_epi64(zlo16, vr16); + + vr1 = _mm512_hexl_small_mod_epu64<4>(vr1, v_modulus, &v_twice_mod); + vr2 = _mm512_hexl_small_mod_epu64<4>(vr2, v_modulus, &v_twice_mod); + vr3 = _mm512_hexl_small_mod_epu64<4>(vr3, v_modulus, &v_twice_mod); + vr4 = _mm512_hexl_small_mod_epu64<4>(vr4, v_modulus, &v_twice_mod); + vr5 = _mm512_hexl_small_mod_epu64<4>(vr5, v_modulus, &v_twice_mod); + vr6 = _mm512_hexl_small_mod_epu64<4>(vr6, v_modulus, &v_twice_mod); + vr7 = _mm512_hexl_small_mod_epu64<4>(vr7, v_modulus, &v_twice_mod); + vr8 = _mm512_hexl_small_mod_epu64<4>(vr8, v_modulus, &v_twice_mod); + vr9 = _mm512_hexl_small_mod_epu64<4>(vr9, v_modulus, &v_twice_mod); + vr10 = _mm512_hexl_small_mod_epu64<4>(vr10, v_modulus, &v_twice_mod); + vr11 = _mm512_hexl_small_mod_epu64<4>(vr11, v_modulus, &v_twice_mod); + vr12 = _mm512_hexl_small_mod_epu64<4>(vr12, v_modulus, &v_twice_mod); + vr13 = _mm512_hexl_small_mod_epu64<4>(vr13, v_modulus, &v_twice_mod); + vr14 = _mm512_hexl_small_mod_epu64<4>(vr14, v_modulus, &v_twice_mod); + vr15 = _mm512_hexl_small_mod_epu64<4>(vr15, v_modulus, &v_twice_mod); + vr16 = _mm512_hexl_small_mod_epu64<4>(vr16, v_modulus, &v_twice_mod); + + _mm512_storeu_si512(vp_result++, vr1); + _mm512_storeu_si512(vp_result++, vr2); + _mm512_storeu_si512(vp_result++, vr3); + _mm512_storeu_si512(vp_result++, vr4); + _mm512_storeu_si512(vp_result++, vr5); + _mm512_storeu_si512(vp_result++, vr6); + _mm512_storeu_si512(vp_result++, vr7); + _mm512_storeu_si512(vp_result++, vr8); + _mm512_storeu_si512(vp_result++, vr9); + _mm512_storeu_si512(vp_result++, vr10); + _mm512_storeu_si512(vp_result++, vr11); + _mm512_storeu_si512(vp_result++, vr12); + _mm512_storeu_si512(vp_result++, vr13); + _mm512_storeu_si512(vp_result++, vr14); + _mm512_storeu_si512(vp_result++, vr15); + _mm512_storeu_si512(vp_result++, vr16); + } +} + +/// @brief Algorithm 2 from +/// https://homes.esat.kuleuven.be/~fvercaut/papers/bar_mont.pdf +template +void EltwiseMultModAVX512DQIntLoopDefault(__m512i* vp_result, + const __m512i* vp_operand1, + const __m512i* vp_operand2, + __m512i v_barr_lo, __m512i v_modulus, + __m512i v_twice_mod, uint64_t n) { + HEXL_UNUSED(v_twice_mod); + + HEXL_LOOP_UNROLL_4 + for (size_t i = n / 8; i > 0; --i) { + __m512i v_op1 = _mm512_loadu_si512(vp_operand1); + __m512i v_op2 = _mm512_loadu_si512(vp_operand2); + + v_op1 = _mm512_hexl_small_mod_epu64(v_op1, v_modulus, + &v_twice_mod); + + v_op2 = _mm512_hexl_small_mod_epu64(v_op2, v_modulus, + &v_twice_mod); + + // Compute product U + __m512i v_prod_hi = _mm512_hexl_mulhi_epi<64>(v_op1, v_op2); + __m512i v_prod_lo = _mm512_hexl_mullo_epi<64>(v_op1, v_op2); + + __m512i c1 = _mm512_hexl_shrdi_epi64(v_prod_lo, v_prod_hi); + // alpha - beta == 64, so we only need high 64 bits + // Perform approximate computation of high bits, as described on page + // 7 of https://arxiv.org/pdf/2003.04510.pdf + __m512i q_hat = _mm512_hexl_mulhi_approx_epi<64>(c1, v_barr_lo); + __m512i v_result = _mm512_hexl_mullo_epi<64>(q_hat, v_modulus); + // Computes result in [0, 4q) + v_result = _mm512_sub_epi64(v_prod_lo, v_result); + + // Reduce result to [0, q) + v_result = + _mm512_hexl_small_mod_epu64<4>(v_result, v_modulus, &v_twice_mod); + _mm512_storeu_si512(vp_result, v_result); + + ++vp_operand1; + ++vp_operand2; + ++vp_result; + } +} + +/// @brief Algorithm 2 from +/// https://homes.esat.kuleuven.be/~fvercaut/papers/bar_mont.pdf +template +void EltwiseMultModAVX512DQIntLoopDefault(__m512i* vp_result, + const __m512i* vp_operand1, + const __m512i* vp_operand2, + __m512i v_barr_lo, __m512i v_modulus, + __m512i v_twice_mod, uint64_t n, + uint64_t prod_right_shift) { + HEXL_UNUSED(v_twice_mod); + + HEXL_LOOP_UNROLL_4 + for (size_t i = n / 8; i > 0; --i) { + __m512i v_op1 = _mm512_loadu_si512(vp_operand1); + __m512i v_op2 = _mm512_loadu_si512(vp_operand2); + + v_op1 = _mm512_hexl_small_mod_epu64(v_op1, v_modulus, + &v_twice_mod); + + v_op2 = _mm512_hexl_small_mod_epu64(v_op2, v_modulus, + &v_twice_mod); + + __m512i v_prod_hi = _mm512_hexl_mulhi_epi<64>(v_op1, v_op2); + __m512i v_prod_lo = _mm512_hexl_mullo_epi<64>(v_op1, v_op2); + + // c1 = floor(U / 2^{n + beta}) + __m512i c1 = _mm512_hexl_shrdi_epi64( + v_prod_lo, v_prod_hi, static_cast(prod_right_shift)); + + // alpha - beta == 64, so we only need high 64 bits + // Perform approximate computation of high bits, as described on page + // 7 of https://arxiv.org/pdf/2003.04510.pdf + __m512i q_hat = _mm512_hexl_mulhi_approx_epi<64>(c1, v_barr_lo); + __m512i v_result = _mm512_hexl_mullo_epi<64>(q_hat, v_modulus); + // Computes result in [0, 4q) + v_result = _mm512_sub_epi64(v_prod_lo, v_result); + + // Reduce result to [0, q) + v_result = + _mm512_hexl_small_mod_epu64<4>(v_result, v_modulus, &v_twice_mod); + _mm512_storeu_si512(vp_result, v_result); + + ++vp_operand1; + ++vp_operand2; + ++vp_result; + } +} + +template +void EltwiseMultModAVX512DQIntLoop(__m512i* vp_result, + const __m512i* vp_operand1, + const __m512i* vp_operand2, + __m512i v_barr_lo, __m512i v_modulus, + __m512i v_twice_mod, uint64_t n) { + switch (n) { + case 1024: + EltwiseMultModAVX512DQIntLoopUnroll( + vp_result, vp_operand1, vp_operand2, v_barr_lo, v_modulus, + v_twice_mod); + break; + + case 2048: + EltwiseMultModAVX512DQIntLoopUnroll( + vp_result, vp_operand1, vp_operand2, v_barr_lo, v_modulus, + v_twice_mod); + break; + + case 4096: + EltwiseMultModAVX512DQIntLoopUnroll( + vp_result, vp_operand1, vp_operand2, v_barr_lo, v_modulus, + v_twice_mod); + break; + + case 8192: + EltwiseMultModAVX512DQIntLoopUnroll( + vp_result, vp_operand1, vp_operand2, v_barr_lo, v_modulus, + v_twice_mod); + break; + + case 16384: + EltwiseMultModAVX512DQIntLoopUnroll(vp_result, vp_operand1, + vp_operand2, v_barr_lo, + v_modulus, v_twice_mod); + break; + + case 32768: + EltwiseMultModAVX512DQIntLoopUnroll(vp_result, vp_operand1, + vp_operand2, v_barr_lo, + v_modulus, v_twice_mod); + break; + + default: + EltwiseMultModAVX512DQIntLoopDefault( + vp_result, vp_operand1, vp_operand2, v_barr_lo, v_modulus, + v_twice_mod, n); + } +} + +#define ELTWISE_MULT_MOD_AVX512_DQ_INT_PROD_RIGHT_SHIFT_CASE(ProdRightShift, \ + InputModFactor) \ + case (ProdRightShift): { \ + EltwiseMultModAVX512DQIntLoop<(ProdRightShift), (InputModFactor)>( \ + vp_result, vp_operand1, vp_operand2, v_barr_lo, v_modulus, \ + v_twice_mod, n); \ + break; \ + } + +// Algorithm 2 from https://homes.esat.kuleuven.be/~fvercaut/papers/bar_mont.pdf +template +void EltwiseMultModAVX512DQInt(uint64_t* result, const uint64_t* operand1, + const uint64_t* operand2, uint64_t n, + uint64_t modulus) { + HEXL_CHECK(InputModFactor == 1 || InputModFactor == 2 || InputModFactor == 4, + "Require InputModFactor = 1, 2, or 4") + HEXL_CHECK(InputModFactor * modulus > (1ULL << 50), + "Require InputModFactor * modulus > (1ULL << 50)") + HEXL_CHECK(InputModFactor * modulus < (1ULL << 63), + "Require InputModFactor * modulus < (1ULL << 63)"); + HEXL_CHECK(modulus < (1ULL << 62), "Require modulus < (1ULL << 62)"); + HEXL_CHECK_BOUNDS(operand1, n, InputModFactor * modulus, + "operand1 exceeds bound " << (InputModFactor * modulus)); + HEXL_CHECK_BOUNDS(operand2, n, InputModFactor * modulus, + "operand2 exceeds bound " << (InputModFactor * modulus)); + HEXL_CHECK(modulus > 1, "Require modulus > 1"); + uint64_t n_mod_8 = n % 8; + if (n_mod_8 != 0) { + EltwiseMultModNative(result, operand1, operand2, n_mod_8, + modulus); + operand1 += n_mod_8; + operand2 += n_mod_8; + result += n_mod_8; + n -= n_mod_8; + } + + constexpr int64_t beta = -2; + HEXL_CHECK(beta <= -2, "beta must be <= -2 for correctness"); + constexpr int64_t alpha = 62; // ensures alpha - beta = 64 + uint64_t gamma = Log2(InputModFactor); + HEXL_UNUSED(gamma); + HEXL_CHECK(alpha >= gamma + 1, "alpha must be >= gamma + 1 for correctness"); + + const uint64_t ceil_log_mod = Log2(modulus) + 1; // "n" from Algorithm 2 + uint64_t prod_right_shift = ceil_log_mod + beta; + + // Barrett factor "mu" + // TODO(fboemer): Allow MultiplyFactor to take bit shifts != 64 + HEXL_CHECK(ceil_log_mod + alpha >= 64, "ceil_log_mod + alpha < 64"); + uint64_t barr_lo = + MultiplyFactor(uint64_t(1) << (ceil_log_mod + alpha - 64), 64, modulus) + .BarrettFactor(); + + __m512i v_barr_lo = _mm512_set1_epi64(static_cast(barr_lo)); + __m512i v_modulus = _mm512_set1_epi64(static_cast(modulus)); + __m512i v_twice_mod = _mm512_set1_epi64(static_cast(2 * modulus)); + const __m512i* vp_operand1 = reinterpret_cast(operand1); + const __m512i* vp_operand2 = reinterpret_cast(operand2); + __m512i* vp_result = reinterpret_cast<__m512i*>(result); + + // Let d be the product operand1 * operand2. + // To ensure d >> prod_right_shift < (1ULL << 64), we need + // (input_mod_factor * modulus)^2 >> (prod_right_shift) < (1ULL << 64) + // This happens when 2*log_2(input_mod_factor) + prod_right_shift - beta < 63 + // If not, we need to reduce the inputs to be less than modulus for + // correctness. This is less efficient, so we avoid it when possible. + bool reduce_mod = 2 * Log2(InputModFactor) + prod_right_shift - beta >= 63; + + if (reduce_mod) { + // Here, we assume beta = -2 + HEXL_CHECK(beta == -2, "beta != -2 may skip some cases"); + // This reduce_mod case happens only when + // prod_right_shift >= 63 - 2 * log2(input_mod_factor) >= 57. + // Additionally, modulus < (1ULL << 62) implies + // prod_right_shift <= 61. So N == 57, 58, 59, 60, 61 are the + // only cases here. + switch (prod_right_shift) { + ELTWISE_MULT_MOD_AVX512_DQ_INT_PROD_RIGHT_SHIFT_CASE(57, InputModFactor) + ELTWISE_MULT_MOD_AVX512_DQ_INT_PROD_RIGHT_SHIFT_CASE(58, InputModFactor) + ELTWISE_MULT_MOD_AVX512_DQ_INT_PROD_RIGHT_SHIFT_CASE(59, InputModFactor) + ELTWISE_MULT_MOD_AVX512_DQ_INT_PROD_RIGHT_SHIFT_CASE(60, InputModFactor) + ELTWISE_MULT_MOD_AVX512_DQ_INT_PROD_RIGHT_SHIFT_CASE(61, InputModFactor) + default: { + HEXL_CHECK(false, + "Bad value for prod_right_shift: " << prod_right_shift); + } + } + } else { // Input mod reduction not required; pass InputModFactor == 1. + // The template arguments are required for use of _mm512_hexl_shrdi_epi64, + // which requires a compile-time constant for the shift. + switch (prod_right_shift) { + // For prod_right_shift < 50, we should prefer EltwiseMultModAVX512Float + // or EltwiseMultModAVX512IFMAInt, so we don't generate those special + // cases here + ELTWISE_MULT_MOD_AVX512_DQ_INT_PROD_RIGHT_SHIFT_CASE(50, 1) + ELTWISE_MULT_MOD_AVX512_DQ_INT_PROD_RIGHT_SHIFT_CASE(51, 1) + ELTWISE_MULT_MOD_AVX512_DQ_INT_PROD_RIGHT_SHIFT_CASE(52, 1) + ELTWISE_MULT_MOD_AVX512_DQ_INT_PROD_RIGHT_SHIFT_CASE(53, 1) + ELTWISE_MULT_MOD_AVX512_DQ_INT_PROD_RIGHT_SHIFT_CASE(54, 1) + ELTWISE_MULT_MOD_AVX512_DQ_INT_PROD_RIGHT_SHIFT_CASE(55, 1) + ELTWISE_MULT_MOD_AVX512_DQ_INT_PROD_RIGHT_SHIFT_CASE(56, 1) + ELTWISE_MULT_MOD_AVX512_DQ_INT_PROD_RIGHT_SHIFT_CASE(57, 1) + ELTWISE_MULT_MOD_AVX512_DQ_INT_PROD_RIGHT_SHIFT_CASE(58, 1) + ELTWISE_MULT_MOD_AVX512_DQ_INT_PROD_RIGHT_SHIFT_CASE(59, 1) + ELTWISE_MULT_MOD_AVX512_DQ_INT_PROD_RIGHT_SHIFT_CASE(60, 1) + ELTWISE_MULT_MOD_AVX512_DQ_INT_PROD_RIGHT_SHIFT_CASE(61, 1) + default: { + HEXL_VLOG(2, "calling EltwiseMultModAVX512DQIntLoopDefault"); + EltwiseMultModAVX512DQIntLoopDefault<1>( + vp_result, vp_operand1, vp_operand2, v_barr_lo, v_modulus, + v_twice_mod, n, prod_right_shift); + } + } + } + HEXL_CHECK_BOUNDS(result, n, modulus, "result exceeds bound " << modulus); +} + +// From Function 18, page 19 of https://arxiv.org/pdf/1407.3383.pdf +// See also Algorithm 2/3 of +// https://hal.archives-ouvertes.fr/hal-02552673/document +template +inline void EltwiseMultModAVX512FloatLoopDefault( + __m512i* vp_result, const __m512i* vp_operand1, const __m512i* vp_operand2, + __m512d v_u, __m512d v_p, __m512i v_modulus, __m512i v_twice_mod, + uint64_t n) { + HEXL_UNUSED(v_twice_mod); + + constexpr int round_mode = (_MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC); + + HEXL_LOOP_UNROLL_4 + for (size_t i = n / 8; i > 0; --i) { + __m512i v_op1 = _mm512_loadu_si512(vp_operand1); + v_op1 = _mm512_hexl_small_mod_epu64(v_op1, v_modulus, + &v_twice_mod); + + __m512i v_op2 = _mm512_loadu_si512(vp_operand2); + v_op2 = _mm512_hexl_small_mod_epu64(v_op2, v_modulus, + &v_twice_mod); + + __m512d v_x = _mm512_cvt_roundepu64_pd(v_op1, round_mode); + __m512d v_y = _mm512_cvt_roundepu64_pd(v_op2, round_mode); + + __m512d v_h = _mm512_mul_pd(v_x, v_y); + __m512d v_l = + _mm512_fmsub_pd(v_x, v_y, v_h); // rounding error; h + l == x * y + __m512d v_b = _mm512_mul_pd(v_h, v_u); // ~ (x * y) / p + __m512d v_c = _mm512_floor_pd(v_b); // ~ floor(x * y / p) + __m512d v_d = _mm512_fnmadd_pd(v_c, v_p, v_h); + __m512d v_g = _mm512_add_pd(v_d, v_l); + __mmask8 m = _mm512_cmp_pd_mask(v_g, _mm512_setzero_pd(), _CMP_LT_OQ); + v_g = _mm512_mask_add_pd(v_g, m, v_g, v_p); + + __m512i v_result = _mm512_cvt_roundpd_epu64(v_g, round_mode); + + _mm512_storeu_si512(vp_result, v_result); + + ++vp_operand1; + ++vp_operand2; + ++vp_result; + } +} + +template +inline void EltwiseMultModAVX512FloatLoopUnroll( + __m512i* vp_result, const __m512i* vp_operand1, const __m512i* vp_operand2, + __m512d v_u, __m512d v_p, __m512i v_modulus, __m512i v_twice_mod) { + constexpr size_t manual_unroll_factor = 4; + constexpr size_t avx512_64bit_count = 8; + constexpr size_t loop_count = + CoeffCount / (manual_unroll_factor * avx512_64bit_count); + + static_assert(loop_count > 0, "loop_count too small for unrolling"); + static_assert(CoeffCount % (manual_unroll_factor * avx512_64bit_count) == 0, + "CoeffCount must be a factor of manual_unroll_factor * " + "avx512_64bit_count"); + + constexpr int round_mode = (_MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC); + + HEXL_LOOP_UNROLL_4 + for (size_t i = loop_count; i > 0; --i) { + __m512i op1_1 = _mm512_loadu_si512(vp_operand1++); + __m512i op1_2 = _mm512_loadu_si512(vp_operand1++); + __m512i op1_3 = _mm512_loadu_si512(vp_operand1++); + __m512i op1_4 = _mm512_loadu_si512(vp_operand1++); + + __m512i op2_1 = _mm512_loadu_si512(vp_operand2++); + __m512i op2_2 = _mm512_loadu_si512(vp_operand2++); + __m512i op2_3 = _mm512_loadu_si512(vp_operand2++); + __m512i op2_4 = _mm512_loadu_si512(vp_operand2++); + + op1_1 = _mm512_hexl_small_mod_epu64(op1_1, v_modulus, + &v_twice_mod); + op1_2 = _mm512_hexl_small_mod_epu64(op1_2, v_modulus, + &v_twice_mod); + op1_3 = _mm512_hexl_small_mod_epu64(op1_3, v_modulus, + &v_twice_mod); + op1_4 = _mm512_hexl_small_mod_epu64(op1_4, v_modulus, + &v_twice_mod); + + op2_1 = _mm512_hexl_small_mod_epu64(op2_1, v_modulus, + &v_twice_mod); + op2_2 = _mm512_hexl_small_mod_epu64(op2_2, v_modulus, + &v_twice_mod); + op2_3 = _mm512_hexl_small_mod_epu64(op2_3, v_modulus, + &v_twice_mod); + op2_4 = _mm512_hexl_small_mod_epu64(op2_4, v_modulus, + &v_twice_mod); + + __m512d v_x_1 = _mm512_cvt_roundepu64_pd(op1_1, round_mode); + __m512d v_x_2 = _mm512_cvt_roundepu64_pd(op1_2, round_mode); + __m512d v_x_3 = _mm512_cvt_roundepu64_pd(op1_3, round_mode); + __m512d v_x_4 = _mm512_cvt_roundepu64_pd(op1_4, round_mode); + + __m512d v_y_1 = _mm512_cvt_roundepu64_pd(op2_1, round_mode); + __m512d v_y_2 = _mm512_cvt_roundepu64_pd(op2_2, round_mode); + __m512d v_y_3 = _mm512_cvt_roundepu64_pd(op2_3, round_mode); + __m512d v_y_4 = _mm512_cvt_roundepu64_pd(op2_4, round_mode); + + __m512d v_h_1 = _mm512_mul_pd(v_x_1, v_y_1); + __m512d v_h_2 = _mm512_mul_pd(v_x_2, v_y_2); + __m512d v_h_3 = _mm512_mul_pd(v_x_3, v_y_3); + __m512d v_h_4 = _mm512_mul_pd(v_x_4, v_y_4); + + // ~ (x * y) / p + __m512d v_b_1 = _mm512_mul_pd(v_h_1, v_u); + __m512d v_b_2 = _mm512_mul_pd(v_h_2, v_u); + __m512d v_b_3 = _mm512_mul_pd(v_h_3, v_u); + __m512d v_b_4 = _mm512_mul_pd(v_h_4, v_u); + + // rounding_ error; h + l == x * y + __m512d v_l_1 = _mm512_fmsub_pd(v_x_1, v_y_1, v_h_1); + __m512d v_l_2 = _mm512_fmsub_pd(v_x_2, v_y_2, v_h_2); + __m512d v_l_3 = _mm512_fmsub_pd(v_x_3, v_y_3, v_h_3); + __m512d v_l_4 = _mm512_fmsub_pd(v_x_4, v_y_4, v_h_4); + + // ~ floor(_x * y / p) + __m512d v_c_1 = _mm512_floor_pd(v_b_1); + __m512d v_c_2 = _mm512_floor_pd(v_b_2); + __m512d v_c_3 = _mm512_floor_pd(v_b_3); + __m512d v_c_4 = _mm512_floor_pd(v_b_4); + + __m512d v_d_1 = _mm512_fnmadd_pd(v_c_1, v_p, v_h_1); + __m512d v_d_2 = _mm512_fnmadd_pd(v_c_2, v_p, v_h_2); + __m512d v_d_3 = _mm512_fnmadd_pd(v_c_3, v_p, v_h_3); + __m512d v_d_4 = _mm512_fnmadd_pd(v_c_4, v_p, v_h_4); + + __m512d v_g_1 = _mm512_add_pd(v_d_1, v_l_1); + __m512d v_g_2 = _mm512_add_pd(v_d_2, v_l_2); + __m512d v_g_3 = _mm512_add_pd(v_d_3, v_l_3); + __m512d v_g_4 = _mm512_add_pd(v_d_4, v_l_4); + + __mmask8 m_1 = _mm512_cmp_pd_mask(v_g_1, _mm512_setzero_pd(), _CMP_LT_OQ); + __mmask8 m_2 = _mm512_cmp_pd_mask(v_g_2, _mm512_setzero_pd(), _CMP_LT_OQ); + __mmask8 m_3 = _mm512_cmp_pd_mask(v_g_3, _mm512_setzero_pd(), _CMP_LT_OQ); + __mmask8 m_4 = _mm512_cmp_pd_mask(v_g_4, _mm512_setzero_pd(), _CMP_LT_OQ); + + v_g_1 = _mm512_mask_add_pd(v_g_1, m_1, v_g_1, v_p); + v_g_2 = _mm512_mask_add_pd(v_g_2, m_2, v_g_2, v_p); + v_g_3 = _mm512_mask_add_pd(v_g_3, m_3, v_g_3, v_p); + v_g_4 = _mm512_mask_add_pd(v_g_4, m_4, v_g_4, v_p); + + __m512i v_out_1 = _mm512_cvt_roundpd_epu64(v_g_1, round_mode); + __m512i v_out_2 = _mm512_cvt_roundpd_epu64(v_g_2, round_mode); + __m512i v_out_3 = _mm512_cvt_roundpd_epu64(v_g_3, round_mode); + __m512i v_out_4 = _mm512_cvt_roundpd_epu64(v_g_4, round_mode); + + _mm512_storeu_si512(vp_result++, v_out_1); + _mm512_storeu_si512(vp_result++, v_out_2); + _mm512_storeu_si512(vp_result++, v_out_3); + _mm512_storeu_si512(vp_result++, v_out_4); + } +} + +template +inline void EltwiseMultModAVX512FloatLoop(__m512i* vp_result, + const __m512i* vp_operand1, + const __m512i* vp_operand2, + __m512d v_u, __m512d v_p, + __m512i v_modulus, + __m512i v_twice_mod, uint64_t n) { + switch (n) { + case 1024: + EltwiseMultModAVX512FloatLoopUnroll( + vp_result, vp_operand1, vp_operand2, v_u, v_p, v_modulus, + v_twice_mod); + break; + + case 2048: + EltwiseMultModAVX512FloatLoopUnroll( + vp_result, vp_operand1, vp_operand2, v_u, v_p, v_modulus, + v_twice_mod); + break; + + case 4096: + EltwiseMultModAVX512FloatLoopUnroll( + vp_result, vp_operand1, vp_operand2, v_u, v_p, v_modulus, + v_twice_mod); + break; + + case 8192: + EltwiseMultModAVX512FloatLoopUnroll( + vp_result, vp_operand1, vp_operand2, v_u, v_p, v_modulus, + v_twice_mod); + break; + + case 16384: + EltwiseMultModAVX512FloatLoopUnroll( + vp_result, vp_operand1, vp_operand2, v_u, v_p, v_modulus, + v_twice_mod); + break; + + case 32768: + EltwiseMultModAVX512FloatLoopUnroll( + vp_result, vp_operand1, vp_operand2, v_u, v_p, v_modulus, + v_twice_mod); + break; + + default: + EltwiseMultModAVX512FloatLoopDefault( + vp_result, vp_operand1, vp_operand2, v_u, v_p, v_modulus, v_twice_mod, + n); + } +} + +// From Function 18, page 19 of https://arxiv.org/pdf/1407.3383.pdf +// See also Algorithm 2/3 of +// https://hal.archives-ouvertes.fr/hal-02552673/document +template +void EltwiseMultModAVX512Float(uint64_t* result, const uint64_t* operand1, + const uint64_t* operand2, uint64_t n, + uint64_t modulus) { + HEXL_CHECK(modulus < MaximumValue(50), + " modulus " << modulus << " exceeds bound " << MaximumValue(50)); + HEXL_CHECK(modulus > 1, "Require modulus > 1"); + + HEXL_CHECK_BOUNDS(operand1, n, InputModFactor * modulus, + "operand1 exceeds bound " << (InputModFactor * modulus)); + HEXL_CHECK_BOUNDS(operand2, n, InputModFactor * modulus, + "operand2 exceeds bound " << (InputModFactor * modulus)); + uint64_t n_mod_8 = n % 8; + if (n_mod_8 != 0) { + EltwiseMultModNative(result, operand1, operand2, n_mod_8, + modulus); + operand1 += n_mod_8; + operand2 += n_mod_8; + result += n_mod_8; + n -= n_mod_8; + } + __m512d v_p = _mm512_set1_pd(static_cast(modulus)); + __m512i v_modulus = _mm512_set1_epi64(static_cast(modulus)); + __m512i v_twice_mod = _mm512_set1_epi64(static_cast(modulus * 2)); + + // Add epsilon to ensure u * p >= 1.0 + // See Proposition 13 of https://arxiv.org/pdf/1407.3383.pdf + double u_bar = (1.0 + std::numeric_limits::epsilon()) / + static_cast(modulus); + __m512d v_u = _mm512_set1_pd(u_bar); + + const __m512i* vp_operand1 = reinterpret_cast(operand1); + const __m512i* vp_operand2 = reinterpret_cast(operand2); + __m512i* vp_result = reinterpret_cast<__m512i*>(result); + + // The implementation without modular reduction of the operands is correct + // as long as (InputModFactor * modulus)^2 < 2^50 * modulus, i.e. + // InputModFactor^2 * modulus < 2^50. + // See function 16 of https://arxiv.org/pdf/1407.3383.pdf. + bool no_input_reduce_mod = + (InputModFactor * InputModFactor * modulus) < (1ULL << 50); + if (no_input_reduce_mod) { + EltwiseMultModAVX512FloatLoop<1>(vp_result, vp_operand1, vp_operand2, v_u, + v_p, v_modulus, v_twice_mod, n); + } else { + EltwiseMultModAVX512FloatLoop(vp_result, vp_operand1, + vp_operand2, v_u, v_p, + v_modulus, v_twice_mod, n); + } + + HEXL_CHECK_BOUNDS(result, n, modulus, "result exceeds bound " << modulus); +} + +#endif // HEXL_HAS_AVX512DQ + +} // namespace hexl +} // namespace intel diff --git a/hexl_ser/eltwise/eltwise-mult-mod-avx512ifma.cpp b/hexl_ser/eltwise/eltwise-mult-mod-avx512ifma.cpp new file mode 100644 index 00000000..54374cff --- /dev/null +++ b/hexl_ser/eltwise/eltwise-mult-mod-avx512ifma.cpp @@ -0,0 +1,615 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include +#include + +#include + +#include "eltwise/eltwise-mult-mod-avx512.hpp" +#include "eltwise/eltwise-mult-mod-internal.hpp" +#include "hexl/eltwise/eltwise-mult-mod.hpp" +#include "hexl/number-theory/number-theory.hpp" +#include "hexl/util/check.hpp" +#include "hexl/util/compiler.hpp" +#include "hexl/util/defines.hpp" +#include "util/avx512-util.hpp" + +namespace intel { +namespace hexl { + +#ifdef HEXL_HAS_AVX512IFMA + +template void EltwiseMultModAVX512IFMAInt<1>(uint64_t* result, + const uint64_t* operand1, + const uint64_t* operand2, + uint64_t n, uint64_t modulus); +template void EltwiseMultModAVX512IFMAInt<2>(uint64_t* result, + const uint64_t* operand1, + const uint64_t* operand2, + uint64_t n, uint64_t modulus); +template void EltwiseMultModAVX512IFMAInt<4>(uint64_t* result, + const uint64_t* operand1, + const uint64_t* operand2, + uint64_t n, uint64_t modulus); + +template +void EltwiseMultModAVX512IFMAIntLoopUnroll(__m512i* vp_result, + const __m512i* vp_operand1, + const __m512i* vp_operand2, + __m512i v_barr_lo, __m512i v_modulus, + __m512i v_neg_mod, + __m512i v_twice_mod) { + constexpr size_t manual_unroll_factor = 16; + constexpr size_t avx512_64bit_count = 8; + constexpr size_t loop_count = + CoeffCount / (manual_unroll_factor * avx512_64bit_count); + + static_assert(loop_count > 0, "loop_count too small for unrolling"); + static_assert(CoeffCount % (manual_unroll_factor * avx512_64bit_count) == 0, + "CoeffCount must be a factor of manual_unroll_factor * " + "avx512_64bit_count"); + + constexpr unsigned int HiShift = + static_cast(52 - ProdRightShift); + + HEXL_UNUSED(v_twice_mod); + HEXL_LOOP_UNROLL_4 + for (size_t i = loop_count; i > 0; --i) { + __m512i v_op1_1 = _mm512_loadu_si512(vp_operand1++); + __m512i v_op2_1 = _mm512_loadu_si512(vp_operand2++); + __m512i v_op1_2 = _mm512_loadu_si512(vp_operand1++); + __m512i v_op2_2 = _mm512_loadu_si512(vp_operand2++); + __m512i v_op1_3 = _mm512_loadu_si512(vp_operand1++); + __m512i v_op2_3 = _mm512_loadu_si512(vp_operand2++); + __m512i v_op1_4 = _mm512_loadu_si512(vp_operand1++); + __m512i v_op2_4 = _mm512_loadu_si512(vp_operand2++); + __m512i v_op1_5 = _mm512_loadu_si512(vp_operand1++); + __m512i v_op2_5 = _mm512_loadu_si512(vp_operand2++); + __m512i v_op1_6 = _mm512_loadu_si512(vp_operand1++); + __m512i v_op2_6 = _mm512_loadu_si512(vp_operand2++); + __m512i v_op1_7 = _mm512_loadu_si512(vp_operand1++); + __m512i v_op2_7 = _mm512_loadu_si512(vp_operand2++); + __m512i v_op1_8 = _mm512_loadu_si512(vp_operand1++); + __m512i v_op2_8 = _mm512_loadu_si512(vp_operand2++); + __m512i v_op1_9 = _mm512_loadu_si512(vp_operand1++); + __m512i v_op2_9 = _mm512_loadu_si512(vp_operand2++); + __m512i v_op1_10 = _mm512_loadu_si512(vp_operand1++); + __m512i v_op2_10 = _mm512_loadu_si512(vp_operand2++); + __m512i v_op1_11 = _mm512_loadu_si512(vp_operand1++); + __m512i v_op2_11 = _mm512_loadu_si512(vp_operand2++); + __m512i v_op1_12 = _mm512_loadu_si512(vp_operand1++); + __m512i v_op2_12 = _mm512_loadu_si512(vp_operand2++); + __m512i v_op1_13 = _mm512_loadu_si512(vp_operand1++); + __m512i v_op2_13 = _mm512_loadu_si512(vp_operand2++); + __m512i v_op1_14 = _mm512_loadu_si512(vp_operand1++); + __m512i v_op2_14 = _mm512_loadu_si512(vp_operand2++); + __m512i v_op1_15 = _mm512_loadu_si512(vp_operand1++); + __m512i v_op2_15 = _mm512_loadu_si512(vp_operand2++); + __m512i v_op1_16 = _mm512_loadu_si512(vp_operand1++); + __m512i v_op2_16 = _mm512_loadu_si512(vp_operand2++); + + v_op1_1 = _mm512_hexl_small_mod_epu64(v_op1_1, v_modulus, + &v_twice_mod); + v_op1_2 = _mm512_hexl_small_mod_epu64(v_op1_2, v_modulus, + &v_twice_mod); + v_op1_3 = _mm512_hexl_small_mod_epu64(v_op1_3, v_modulus, + &v_twice_mod); + v_op1_4 = _mm512_hexl_small_mod_epu64(v_op1_4, v_modulus, + &v_twice_mod); + v_op1_5 = _mm512_hexl_small_mod_epu64(v_op1_5, v_modulus, + &v_twice_mod); + v_op1_6 = _mm512_hexl_small_mod_epu64(v_op1_6, v_modulus, + &v_twice_mod); + v_op1_7 = _mm512_hexl_small_mod_epu64(v_op1_7, v_modulus, + &v_twice_mod); + v_op1_8 = _mm512_hexl_small_mod_epu64(v_op1_8, v_modulus, + &v_twice_mod); + v_op1_9 = _mm512_hexl_small_mod_epu64(v_op1_9, v_modulus, + &v_twice_mod); + v_op1_10 = _mm512_hexl_small_mod_epu64(v_op1_10, v_modulus, + &v_twice_mod); + v_op1_11 = _mm512_hexl_small_mod_epu64(v_op1_11, v_modulus, + &v_twice_mod); + v_op1_12 = _mm512_hexl_small_mod_epu64(v_op1_12, v_modulus, + &v_twice_mod); + v_op1_13 = _mm512_hexl_small_mod_epu64(v_op1_13, v_modulus, + &v_twice_mod); + v_op1_14 = _mm512_hexl_small_mod_epu64(v_op1_14, v_modulus, + &v_twice_mod); + v_op1_15 = _mm512_hexl_small_mod_epu64(v_op1_15, v_modulus, + &v_twice_mod); + v_op1_16 = _mm512_hexl_small_mod_epu64(v_op1_16, v_modulus, + &v_twice_mod); + + v_op2_1 = _mm512_hexl_small_mod_epu64(v_op2_1, v_modulus, + &v_twice_mod); + v_op2_2 = _mm512_hexl_small_mod_epu64(v_op2_2, v_modulus, + &v_twice_mod); + v_op2_3 = _mm512_hexl_small_mod_epu64(v_op2_3, v_modulus, + &v_twice_mod); + v_op2_4 = _mm512_hexl_small_mod_epu64(v_op2_4, v_modulus, + &v_twice_mod); + v_op2_5 = _mm512_hexl_small_mod_epu64(v_op2_5, v_modulus, + &v_twice_mod); + v_op2_6 = _mm512_hexl_small_mod_epu64(v_op2_6, v_modulus, + &v_twice_mod); + v_op2_7 = _mm512_hexl_small_mod_epu64(v_op2_7, v_modulus, + &v_twice_mod); + v_op2_8 = _mm512_hexl_small_mod_epu64(v_op2_8, v_modulus, + &v_twice_mod); + v_op2_9 = _mm512_hexl_small_mod_epu64(v_op2_9, v_modulus, + &v_twice_mod); + v_op2_10 = _mm512_hexl_small_mod_epu64(v_op2_10, v_modulus, + &v_twice_mod); + v_op2_11 = _mm512_hexl_small_mod_epu64(v_op2_11, v_modulus, + &v_twice_mod); + v_op2_12 = _mm512_hexl_small_mod_epu64(v_op2_12, v_modulus, + &v_twice_mod); + v_op2_13 = _mm512_hexl_small_mod_epu64(v_op2_13, v_modulus, + &v_twice_mod); + v_op2_14 = _mm512_hexl_small_mod_epu64(v_op2_14, v_modulus, + &v_twice_mod); + v_op2_15 = _mm512_hexl_small_mod_epu64(v_op2_15, v_modulus, + &v_twice_mod); + v_op2_16 = _mm512_hexl_small_mod_epu64(v_op2_16, v_modulus, + &v_twice_mod); + + __m512i v_prod_hi_1 = _mm512_hexl_mulhi_epi<52>(v_op1_1, v_op2_1); + __m512i v_prod_hi_2 = _mm512_hexl_mulhi_epi<52>(v_op1_2, v_op2_2); + __m512i v_prod_hi_3 = _mm512_hexl_mulhi_epi<52>(v_op1_3, v_op2_3); + __m512i v_prod_hi_4 = _mm512_hexl_mulhi_epi<52>(v_op1_4, v_op2_4); + __m512i v_prod_hi_5 = _mm512_hexl_mulhi_epi<52>(v_op1_5, v_op2_5); + __m512i v_prod_hi_6 = _mm512_hexl_mulhi_epi<52>(v_op1_6, v_op2_6); + __m512i v_prod_hi_7 = _mm512_hexl_mulhi_epi<52>(v_op1_7, v_op2_7); + __m512i v_prod_hi_8 = _mm512_hexl_mulhi_epi<52>(v_op1_8, v_op2_8); + __m512i v_prod_hi_9 = _mm512_hexl_mulhi_epi<52>(v_op1_9, v_op2_9); + __m512i v_prod_hi_10 = _mm512_hexl_mulhi_epi<52>(v_op1_10, v_op2_10); + __m512i v_prod_hi_11 = _mm512_hexl_mulhi_epi<52>(v_op1_11, v_op2_11); + __m512i v_prod_hi_12 = _mm512_hexl_mulhi_epi<52>(v_op1_12, v_op2_12); + __m512i v_prod_hi_13 = _mm512_hexl_mulhi_epi<52>(v_op1_13, v_op2_13); + __m512i v_prod_hi_14 = _mm512_hexl_mulhi_epi<52>(v_op1_14, v_op2_14); + __m512i v_prod_hi_15 = _mm512_hexl_mulhi_epi<52>(v_op1_15, v_op2_15); + __m512i v_prod_hi_16 = _mm512_hexl_mulhi_epi<52>(v_op1_16, v_op2_16); + + __m512i v_prod_lo_1 = _mm512_hexl_mullo_epi<52>(v_op1_1, v_op2_1); + __m512i v_prod_lo_2 = _mm512_hexl_mullo_epi<52>(v_op1_2, v_op2_2); + __m512i v_prod_lo_3 = _mm512_hexl_mullo_epi<52>(v_op1_3, v_op2_3); + __m512i v_prod_lo_4 = _mm512_hexl_mullo_epi<52>(v_op1_4, v_op2_4); + __m512i v_prod_lo_5 = _mm512_hexl_mullo_epi<52>(v_op1_5, v_op2_5); + __m512i v_prod_lo_6 = _mm512_hexl_mullo_epi<52>(v_op1_6, v_op2_6); + __m512i v_prod_lo_7 = _mm512_hexl_mullo_epi<52>(v_op1_7, v_op2_7); + __m512i v_prod_lo_8 = _mm512_hexl_mullo_epi<52>(v_op1_8, v_op2_8); + __m512i v_prod_lo_9 = _mm512_hexl_mullo_epi<52>(v_op1_9, v_op2_9); + __m512i v_prod_lo_10 = _mm512_hexl_mullo_epi<52>(v_op1_10, v_op2_10); + __m512i v_prod_lo_11 = _mm512_hexl_mullo_epi<52>(v_op1_11, v_op2_11); + __m512i v_prod_lo_12 = _mm512_hexl_mullo_epi<52>(v_op1_12, v_op2_12); + __m512i v_prod_lo_13 = _mm512_hexl_mullo_epi<52>(v_op1_13, v_op2_13); + __m512i v_prod_lo_14 = _mm512_hexl_mullo_epi<52>(v_op1_14, v_op2_14); + __m512i v_prod_lo_15 = _mm512_hexl_mullo_epi<52>(v_op1_15, v_op2_15); + __m512i v_prod_lo_16 = _mm512_hexl_mullo_epi<52>(v_op1_16, v_op2_16); + + __m512i c1_lo_1 = _mm512_srli_epi64(v_prod_lo_1, ProdRightShift); + __m512i c1_lo_2 = _mm512_srli_epi64(v_prod_lo_2, ProdRightShift); + __m512i c1_lo_3 = _mm512_srli_epi64(v_prod_lo_3, ProdRightShift); + __m512i c1_lo_4 = _mm512_srli_epi64(v_prod_lo_4, ProdRightShift); + __m512i c1_lo_5 = _mm512_srli_epi64(v_prod_lo_5, ProdRightShift); + __m512i c1_lo_6 = _mm512_srli_epi64(v_prod_lo_6, ProdRightShift); + __m512i c1_lo_7 = _mm512_srli_epi64(v_prod_lo_7, ProdRightShift); + __m512i c1_lo_8 = _mm512_srli_epi64(v_prod_lo_8, ProdRightShift); + __m512i c1_lo_9 = _mm512_srli_epi64(v_prod_lo_9, ProdRightShift); + __m512i c1_lo_10 = _mm512_srli_epi64(v_prod_lo_10, ProdRightShift); + __m512i c1_lo_11 = _mm512_srli_epi64(v_prod_lo_11, ProdRightShift); + __m512i c1_lo_12 = _mm512_srli_epi64(v_prod_lo_12, ProdRightShift); + __m512i c1_lo_13 = _mm512_srli_epi64(v_prod_lo_13, ProdRightShift); + __m512i c1_lo_14 = _mm512_srli_epi64(v_prod_lo_14, ProdRightShift); + __m512i c1_lo_15 = _mm512_srli_epi64(v_prod_lo_15, ProdRightShift); + __m512i c1_lo_16 = _mm512_srli_epi64(v_prod_lo_16, ProdRightShift); + + __m512i c1_hi_1 = _mm512_slli_epi64(v_prod_hi_1, HiShift); + __m512i c1_hi_2 = _mm512_slli_epi64(v_prod_hi_2, HiShift); + __m512i c1_hi_3 = _mm512_slli_epi64(v_prod_hi_3, HiShift); + __m512i c1_hi_4 = _mm512_slli_epi64(v_prod_hi_4, HiShift); + __m512i c1_hi_5 = _mm512_slli_epi64(v_prod_hi_5, HiShift); + __m512i c1_hi_6 = _mm512_slli_epi64(v_prod_hi_6, HiShift); + __m512i c1_hi_7 = _mm512_slli_epi64(v_prod_hi_7, HiShift); + __m512i c1_hi_8 = _mm512_slli_epi64(v_prod_hi_8, HiShift); + __m512i c1_hi_9 = _mm512_slli_epi64(v_prod_hi_9, HiShift); + __m512i c1_hi_10 = _mm512_slli_epi64(v_prod_hi_10, HiShift); + __m512i c1_hi_11 = _mm512_slli_epi64(v_prod_hi_11, HiShift); + __m512i c1_hi_12 = _mm512_slli_epi64(v_prod_hi_12, HiShift); + __m512i c1_hi_13 = _mm512_slli_epi64(v_prod_hi_13, HiShift); + __m512i c1_hi_14 = _mm512_slli_epi64(v_prod_hi_14, HiShift); + __m512i c1_hi_15 = _mm512_slli_epi64(v_prod_hi_15, HiShift); + __m512i c1_hi_16 = _mm512_slli_epi64(v_prod_hi_16, HiShift); + + __m512i c1_1 = _mm512_or_epi64(c1_lo_1, c1_hi_1); + __m512i c1_2 = _mm512_or_epi64(c1_lo_2, c1_hi_2); + __m512i c1_3 = _mm512_or_epi64(c1_lo_3, c1_hi_3); + __m512i c1_4 = _mm512_or_epi64(c1_lo_4, c1_hi_4); + __m512i c1_5 = _mm512_or_epi64(c1_lo_5, c1_hi_5); + __m512i c1_6 = _mm512_or_epi64(c1_lo_6, c1_hi_6); + __m512i c1_7 = _mm512_or_epi64(c1_lo_7, c1_hi_7); + __m512i c1_8 = _mm512_or_epi64(c1_lo_8, c1_hi_8); + __m512i c1_9 = _mm512_or_epi64(c1_lo_9, c1_hi_9); + __m512i c1_10 = _mm512_or_epi64(c1_lo_10, c1_hi_10); + __m512i c1_11 = _mm512_or_epi64(c1_lo_11, c1_hi_11); + __m512i c1_12 = _mm512_or_epi64(c1_lo_12, c1_hi_12); + __m512i c1_13 = _mm512_or_epi64(c1_lo_13, c1_hi_13); + __m512i c1_14 = _mm512_or_epi64(c1_lo_14, c1_hi_14); + __m512i c1_15 = _mm512_or_epi64(c1_lo_15, c1_hi_15); + __m512i c1_16 = _mm512_or_epi64(c1_lo_16, c1_hi_16); + + __m512i q_hat_1 = _mm512_hexl_mulhi_epi<52>(c1_1, v_barr_lo); + __m512i q_hat_2 = _mm512_hexl_mulhi_epi<52>(c1_2, v_barr_lo); + __m512i q_hat_3 = _mm512_hexl_mulhi_epi<52>(c1_3, v_barr_lo); + __m512i q_hat_4 = _mm512_hexl_mulhi_epi<52>(c1_4, v_barr_lo); + __m512i q_hat_5 = _mm512_hexl_mulhi_epi<52>(c1_5, v_barr_lo); + __m512i q_hat_6 = _mm512_hexl_mulhi_epi<52>(c1_6, v_barr_lo); + __m512i q_hat_7 = _mm512_hexl_mulhi_epi<52>(c1_7, v_barr_lo); + __m512i q_hat_8 = _mm512_hexl_mulhi_epi<52>(c1_8, v_barr_lo); + __m512i q_hat_9 = _mm512_hexl_mulhi_epi<52>(c1_9, v_barr_lo); + __m512i q_hat_10 = _mm512_hexl_mulhi_epi<52>(c1_10, v_barr_lo); + __m512i q_hat_11 = _mm512_hexl_mulhi_epi<52>(c1_11, v_barr_lo); + __m512i q_hat_12 = _mm512_hexl_mulhi_epi<52>(c1_12, v_barr_lo); + __m512i q_hat_13 = _mm512_hexl_mulhi_epi<52>(c1_13, v_barr_lo); + __m512i q_hat_14 = _mm512_hexl_mulhi_epi<52>(c1_14, v_barr_lo); + __m512i q_hat_15 = _mm512_hexl_mulhi_epi<52>(c1_15, v_barr_lo); + __m512i q_hat_16 = _mm512_hexl_mulhi_epi<52>(c1_16, v_barr_lo); + + __m512i z_1 = + _mm512_hexl_mullo_add_lo_epi<52>(v_prod_lo_1, q_hat_1, v_neg_mod); + __m512i z_2 = + _mm512_hexl_mullo_add_lo_epi<52>(v_prod_lo_2, q_hat_2, v_neg_mod); + __m512i z_3 = + _mm512_hexl_mullo_add_lo_epi<52>(v_prod_lo_3, q_hat_3, v_neg_mod); + __m512i z_4 = + _mm512_hexl_mullo_add_lo_epi<52>(v_prod_lo_4, q_hat_4, v_neg_mod); + __m512i z_5 = + _mm512_hexl_mullo_add_lo_epi<52>(v_prod_lo_5, q_hat_5, v_neg_mod); + __m512i z_6 = + _mm512_hexl_mullo_add_lo_epi<52>(v_prod_lo_6, q_hat_6, v_neg_mod); + __m512i z_7 = + _mm512_hexl_mullo_add_lo_epi<52>(v_prod_lo_7, q_hat_7, v_neg_mod); + __m512i z_8 = + _mm512_hexl_mullo_add_lo_epi<52>(v_prod_lo_8, q_hat_8, v_neg_mod); + __m512i z_9 = + _mm512_hexl_mullo_add_lo_epi<52>(v_prod_lo_9, q_hat_9, v_neg_mod); + __m512i z_10 = + _mm512_hexl_mullo_add_lo_epi<52>(v_prod_lo_10, q_hat_10, v_neg_mod); + __m512i z_11 = + _mm512_hexl_mullo_add_lo_epi<52>(v_prod_lo_11, q_hat_11, v_neg_mod); + __m512i z_12 = + _mm512_hexl_mullo_add_lo_epi<52>(v_prod_lo_12, q_hat_12, v_neg_mod); + __m512i z_13 = + _mm512_hexl_mullo_add_lo_epi<52>(v_prod_lo_13, q_hat_13, v_neg_mod); + __m512i z_14 = + _mm512_hexl_mullo_add_lo_epi<52>(v_prod_lo_14, q_hat_14, v_neg_mod); + __m512i z_15 = + _mm512_hexl_mullo_add_lo_epi<52>(v_prod_lo_15, q_hat_15, v_neg_mod); + __m512i z_16 = + _mm512_hexl_mullo_add_lo_epi<52>(v_prod_lo_16, q_hat_16, v_neg_mod); + + __m512i v_result_1 = _mm512_hexl_small_mod_epu64<2>(z_1, v_modulus); + __m512i v_result_2 = _mm512_hexl_small_mod_epu64<2>(z_2, v_modulus); + __m512i v_result_3 = _mm512_hexl_small_mod_epu64<2>(z_3, v_modulus); + __m512i v_result_4 = _mm512_hexl_small_mod_epu64<2>(z_4, v_modulus); + __m512i v_result_5 = _mm512_hexl_small_mod_epu64<2>(z_5, v_modulus); + __m512i v_result_6 = _mm512_hexl_small_mod_epu64<2>(z_6, v_modulus); + __m512i v_result_7 = _mm512_hexl_small_mod_epu64<2>(z_7, v_modulus); + __m512i v_result_8 = _mm512_hexl_small_mod_epu64<2>(z_8, v_modulus); + __m512i v_result_9 = _mm512_hexl_small_mod_epu64<2>(z_9, v_modulus); + __m512i v_result_10 = _mm512_hexl_small_mod_epu64<2>(z_10, v_modulus); + __m512i v_result_11 = _mm512_hexl_small_mod_epu64<2>(z_11, v_modulus); + __m512i v_result_12 = _mm512_hexl_small_mod_epu64<2>(z_12, v_modulus); + __m512i v_result_13 = _mm512_hexl_small_mod_epu64<2>(z_13, v_modulus); + __m512i v_result_14 = _mm512_hexl_small_mod_epu64<2>(z_14, v_modulus); + __m512i v_result_15 = _mm512_hexl_small_mod_epu64<2>(z_15, v_modulus); + __m512i v_result_16 = _mm512_hexl_small_mod_epu64<2>(z_16, v_modulus); + + _mm512_storeu_si512(vp_result++, v_result_1); + _mm512_storeu_si512(vp_result++, v_result_2); + _mm512_storeu_si512(vp_result++, v_result_3); + _mm512_storeu_si512(vp_result++, v_result_4); + _mm512_storeu_si512(vp_result++, v_result_5); + _mm512_storeu_si512(vp_result++, v_result_6); + _mm512_storeu_si512(vp_result++, v_result_7); + _mm512_storeu_si512(vp_result++, v_result_8); + _mm512_storeu_si512(vp_result++, v_result_9); + _mm512_storeu_si512(vp_result++, v_result_10); + _mm512_storeu_si512(vp_result++, v_result_11); + _mm512_storeu_si512(vp_result++, v_result_12); + _mm512_storeu_si512(vp_result++, v_result_13); + _mm512_storeu_si512(vp_result++, v_result_14); + _mm512_storeu_si512(vp_result++, v_result_15); + _mm512_storeu_si512(vp_result++, v_result_16); + } +} + +// Algorithm 2 from https://homes.esat.kuleuven.be/~fvercaut/papers/bar_mont.pdf +template +void EltwiseMultModAVX512IFMAIntLoopDefault( + __m512i* vp_result, const __m512i* vp_operand1, const __m512i* vp_operand2, + __m512i v_barr_lo, __m512i v_modulus, __m512i v_neg_mod, + __m512i v_twice_mod, uint64_t n) { + HEXL_UNUSED(v_twice_mod); + HEXL_LOOP_UNROLL_4 + for (size_t i = n / 8; i > 0; --i) { + __m512i v_op1 = _mm512_loadu_si512(vp_operand1); + v_op1 = _mm512_hexl_small_mod_epu64(v_op1, v_modulus, + &v_twice_mod); + + __m512i v_op2 = _mm512_loadu_si512(vp_operand2); + v_op2 = _mm512_hexl_small_mod_epu64(v_op2, v_modulus, + &v_twice_mod); + + // Compute product U + __m512i v_prod_hi = _mm512_hexl_mulhi_epi<52>(v_op1, v_op2); + __m512i v_prod_lo = _mm512_hexl_mullo_epi<52>(v_op1, v_op2); + + // c1 = floor(U / 2^{n + beta}) + __m512i c1_lo = + _mm512_srli_epi64(v_prod_lo, static_cast(ProdRightShift)); + __m512i c1_hi = _mm512_slli_epi64( + v_prod_hi, static_cast(52ULL - (ProdRightShift))); + __m512i c1 = _mm512_or_epi64(c1_lo, c1_hi); + + // alpha - beta == 52, so we only need high 52 bits + __m512i q_hat = _mm512_hexl_mulhi_epi<52>(c1, v_barr_lo); + + // Z = prod_lo - (p * q_hat)_lo + __m512i v_result = + _mm512_hexl_mullo_add_lo_epi<52>(v_prod_lo, q_hat, v_neg_mod); + + // Reduce result to [0, q) + v_result = _mm512_hexl_small_mod_epu64<2>(v_result, v_modulus); + _mm512_storeu_si512(vp_result, v_result); + + ++vp_operand1; + ++vp_operand2; + ++vp_result; + } +} + +// Algorithm 2 from https://homes.esat.kuleuven.be/~fvercaut/papers/bar_mont.pdf +template +void EltwiseMultModAVX512IFMAIntLoopDefault( + __m512i* vp_result, const __m512i* vp_operand1, const __m512i* vp_operand2, + __m512i v_barr_lo, __m512i v_modulus, __m512i v_neg_mod, + __m512i v_twice_mod, uint64_t n, uint64_t prod_right_shift) { + unsigned int low_shift = static_cast(prod_right_shift); + unsigned int high_shift = static_cast(52 - prod_right_shift); + + HEXL_UNUSED(v_twice_mod); + HEXL_LOOP_UNROLL_4 + for (size_t i = n / 8; i > 0; --i) { + __m512i v_op1 = _mm512_loadu_si512(vp_operand1); + v_op1 = _mm512_hexl_small_mod_epu64(v_op1, v_modulus, + &v_twice_mod); + + __m512i v_op2 = _mm512_loadu_si512(vp_operand2); + v_op2 = _mm512_hexl_small_mod_epu64(v_op2, v_modulus, + &v_twice_mod); + + // Compute product + __m512i v_prod_hi = _mm512_hexl_mulhi_epi<52>(v_op1, v_op2); + __m512i v_prod_lo = _mm512_hexl_mullo_epi<52>(v_op1, v_op2); + + __m512i c1_lo = _mm512_srli_epi64(v_prod_lo, low_shift); + __m512i c1_hi = _mm512_slli_epi64(v_prod_hi, high_shift); + __m512i c1 = _mm512_or_epi64(c1_lo, c1_hi); + + // alpha - beta == 52, so we only need high 52 bits + __m512i q_hat = _mm512_hexl_mulhi_epi<52>(c1, v_barr_lo); + + // z = prod_lo - (p * q_hat)_lo + __m512i v_result = + _mm512_hexl_mullo_add_lo_epi<52>(v_prod_lo, q_hat, v_neg_mod); + + // Reduce result to [0, q) + v_result = _mm512_hexl_small_mod_epu64<2>(v_result, v_modulus); + + _mm512_storeu_si512(vp_result, v_result); + + ++vp_operand1; + ++vp_operand2; + ++vp_result; + } +} + +template +void EltwiseMultModAVX512IFMAIntLoop(__m512i* vp_result, + const __m512i* vp_operand1, + const __m512i* vp_operand2, + __m512i v_barr_lo, __m512i v_modulus, + __m512i v_neg_mod, __m512i v_twice_mod, + uint64_t n) { + switch (n) { + case 1024: { + EltwiseMultModAVX512IFMAIntLoopUnroll( + vp_result, vp_operand1, vp_operand2, v_barr_lo, v_modulus, v_neg_mod, + v_twice_mod); + break; + } + case 2048: { + EltwiseMultModAVX512IFMAIntLoopUnroll( + vp_result, vp_operand1, vp_operand2, v_barr_lo, v_modulus, v_neg_mod, + v_twice_mod); + break; + } + case 4096: { + EltwiseMultModAVX512IFMAIntLoopUnroll( + vp_result, vp_operand1, vp_operand2, v_barr_lo, v_modulus, v_neg_mod, + v_twice_mod); + break; + } + case 8192: { + EltwiseMultModAVX512IFMAIntLoopUnroll( + vp_result, vp_operand1, vp_operand2, v_barr_lo, v_modulus, v_neg_mod, + v_twice_mod); + break; + } + case 16384: { + EltwiseMultModAVX512IFMAIntLoopUnroll( + vp_result, vp_operand1, vp_operand2, v_barr_lo, v_modulus, v_neg_mod, + v_twice_mod); + break; + } + case 32768: { + EltwiseMultModAVX512IFMAIntLoopUnroll( + vp_result, vp_operand1, vp_operand2, v_barr_lo, v_modulus, v_neg_mod, + v_twice_mod); + break; + } + default: + EltwiseMultModAVX512IFMAIntLoopDefault( + vp_result, vp_operand1, vp_operand2, v_barr_lo, v_modulus, v_neg_mod, + v_twice_mod, n); + } +} + +#define ELTWISE_MULT_MOD_AVX512_IFMA_INT_PROD_RIGHT_SHIFT_CASE(ProdRightShift, \ + InputModFactor) \ + case (ProdRightShift): { \ + EltwiseMultModAVX512IFMAIntLoop<(ProdRightShift), (InputModFactor)>( \ + vp_result, vp_operand1, vp_operand2, v_barr_lo, v_modulus, v_neg_mod, \ + v_twice_mod, n); \ + break; \ + } + +// Algorithm 2 from https://homes.esat.kuleuven.be/~fvercaut/papers/bar_mont.pdf +template +void EltwiseMultModAVX512IFMAInt(uint64_t* result, const uint64_t* operand1, + const uint64_t* operand2, uint64_t n, + uint64_t modulus) { + HEXL_CHECK(InputModFactor == 1 || InputModFactor == 2 || InputModFactor == 4, + "Require InputModFactor = 1, 2, or 4") + HEXL_CHECK(modulus < (1ULL << 50), "Require modulus < (1ULL << 50)"); + HEXL_CHECK_BOUNDS(operand1, n, InputModFactor * modulus, + "operand1 exceeds bound " << (InputModFactor * modulus)); + HEXL_CHECK_BOUNDS(operand2, n, InputModFactor * modulus, + "operand2 exceeds bound " << (InputModFactor * modulus)); + HEXL_CHECK(modulus > 1, "Require modulus > 1"); + uint64_t n_mod_8 = n % 8; + if (n_mod_8 != 0) { + EltwiseMultModNative(result, operand1, operand2, n_mod_8, + modulus); + operand1 += n_mod_8; + operand2 += n_mod_8; + result += n_mod_8; + n -= n_mod_8; + } + + constexpr int64_t beta = -2; + HEXL_CHECK(beta <= -2, "beta must be <= -2 for correctness"); + constexpr int64_t alpha = 50; // ensures alpha - beta = 52 + uint64_t gamma = Log2(InputModFactor); + HEXL_UNUSED(gamma); + HEXL_CHECK(alpha >= gamma + 1, "alpha must be >= gamma + 1 for correctness"); + + const uint64_t ceil_log_mod = Log2(modulus) + 1; // "n" from Algorithm 2 + uint64_t prod_right_shift = ceil_log_mod + beta; + + // Barrett factor "mu" + // TODO(fboemer): Allow MultiplyFactor to take bit shifts != 52 + HEXL_CHECK(ceil_log_mod + alpha >= 52, "ceil_log_mod + alpha < 52"); + uint64_t barr_lo = + MultiplyFactor((1ULL << (ceil_log_mod + alpha - 52)), 52, modulus) + .BarrettFactor(); + + __m512i v_barr_lo = _mm512_set1_epi64(static_cast(barr_lo)); + __m512i v_modulus = _mm512_set1_epi64(static_cast(modulus)); + __m512i v_twice_mod = _mm512_set1_epi64(static_cast(2 * modulus)); + __m512i v_neg_mod = _mm512_set1_epi64(-static_cast(modulus)); + const __m512i* vp_operand1 = reinterpret_cast(operand1); + const __m512i* vp_operand2 = reinterpret_cast(operand2); + __m512i* vp_result = reinterpret_cast<__m512i*>(result); + + // Let d be the product operand1 * operand2. + // To ensure d >> prod_right_shift < (1ULL << 52), we need + // (input_mod_factor * modulus)^2 >> (prod_right_shift) < (1ULL << 52) + // This happens when 2*log_2(input_mod_factor) + ceil_log_mod - beta < 51 + // If not, we need to reduce the inputs to be less than modulus for + // correctness. This is less efficient, so we avoid it when possible. + bool reduce_mod = 2 * Log2(InputModFactor) + prod_right_shift - beta >= 51; + + if (reduce_mod) { + // Here, we assume beta = -2 + HEXL_CHECK(beta == -2, "beta != -2 may skip some cases"); + // This reduce_mod case happens only when + // prod_right_shift >= 51 - 2 * log2(input_mod_factor) >= 45. + // Additionally, modulus < (1ULL << 50) implies + // prod_right_shift <= 49. So N == 45, 46, 47, 48, 49 are the + // only cases here. + switch (prod_right_shift) { + // The template arguments are required for use of _mm512_hexl_shrdi_epi64, + // which requires a compile-time constant for the shift. + ELTWISE_MULT_MOD_AVX512_IFMA_INT_PROD_RIGHT_SHIFT_CASE(45, InputModFactor) + ELTWISE_MULT_MOD_AVX512_IFMA_INT_PROD_RIGHT_SHIFT_CASE(46, InputModFactor) + ELTWISE_MULT_MOD_AVX512_IFMA_INT_PROD_RIGHT_SHIFT_CASE(47, InputModFactor) + ELTWISE_MULT_MOD_AVX512_IFMA_INT_PROD_RIGHT_SHIFT_CASE(48, InputModFactor) + ELTWISE_MULT_MOD_AVX512_IFMA_INT_PROD_RIGHT_SHIFT_CASE(49, InputModFactor) + default: { + HEXL_CHECK(false, + "Bad value for prod_right_shift: " << prod_right_shift); + } + } + } else { + switch (prod_right_shift) { + // Smaller shifts are uncommon. + // ELTWISE_MULT_MOD_AVX512_IFMA_INT_PROD_RIGHT_SHIFT_CASE(15, 1) + // ELTWISE_MULT_MOD_AVX512_IFMA_INT_PROD_RIGHT_SHIFT_CASE(16, 1) + // ELTWISE_MULT_MOD_AVX512_IFMA_INT_PROD_RIGHT_SHIFT_CASE(17, 1) + // ELTWISE_MULT_MOD_AVX512_IFMA_INT_PROD_RIGHT_SHIFT_CASE(18, 1) + // ELTWISE_MULT_MOD_AVX512_IFMA_INT_PROD_RIGHT_SHIFT_CASE(19, 1) + // ELTWISE_MULT_MOD_AVX512_IFMA_INT_PROD_RIGHT_SHIFT_CASE(20, 1) + // ELTWISE_MULT_MOD_AVX512_IFMA_INT_PROD_RIGHT_SHIFT_CASE(21, 1) + // ELTWISE_MULT_MOD_AVX512_IFMA_INT_PROD_RIGHT_SHIFT_CASE(22, 1) + // ELTWISE_MULT_MOD_AVX512_IFMA_INT_PROD_RIGHT_SHIFT_CASE(23, 1) + // ELTWISE_MULT_MOD_AVX512_IFMA_INT_PROD_RIGHT_SHIFT_CASE(24, 1) + // ELTWISE_MULT_MOD_AVX512_IFMA_INT_PROD_RIGHT_SHIFT_CASE(25, 1) + // ELTWISE_MULT_MOD_AVX512_IFMA_INT_PROD_RIGHT_SHIFT_CASE(26, 1) + // ELTWISE_MULT_MOD_AVX512_IFMA_INT_PROD_RIGHT_SHIFT_CASE(27, 1) + // ELTWISE_MULT_MOD_AVX512_IFMA_INT_PROD_RIGHT_SHIFT_CASE(28, 1) + // ELTWISE_MULT_MOD_AVX512_IFMA_INT_PROD_RIGHT_SHIFT_CASE(29, 1) + // ELTWISE_MULT_MOD_AVX512_IFMA_INT_PROD_RIGHT_SHIFT_CASE(31, 1) + // ELTWISE_MULT_MOD_AVX512_IFMA_INT_PROD_RIGHT_SHIFT_CASE(32, 1) + // ELTWISE_MULT_MOD_AVX512_IFMA_INT_PROD_RIGHT_SHIFT_CASE(33, 1) + // ELTWISE_MULT_MOD_AVX512_IFMA_INT_PROD_RIGHT_SHIFT_CASE(34, 1) + // ELTWISE_MULT_MOD_AVX512_IFMA_INT_PROD_RIGHT_SHIFT_CASE(35, 1) + // ELTWISE_MULT_MOD_AVX512_IFMA_INT_PROD_RIGHT_SHIFT_CASE(36, 1) + // ELTWISE_MULT_MOD_AVX512_IFMA_INT_PROD_RIGHT_SHIFT_CASE(37, 1) + // ELTWISE_MULT_MOD_AVX512_IFMA_INT_PROD_RIGHT_SHIFT_CASE(38, 1) + // ELTWISE_MULT_MOD_AVX512_IFMA_INT_PROD_RIGHT_SHIFT_CASE(39, 1) + // ELTWISE_MULT_MOD_AVX512_IFMA_INT_PROD_RIGHT_SHIFT_CASE(40, 1) + // ELTWISE_MULT_MOD_AVX512_IFMA_INT_PROD_RIGHT_SHIFT_CASE(41, 1) + // ELTWISE_MULT_MOD_AVX512_IFMA_INT_PROD_RIGHT_SHIFT_CASE(42, 1) + // ELTWISE_MULT_MOD_AVX512_IFMA_INT_PROD_RIGHT_SHIFT_CASE(43, 1) + // ELTWISE_MULT_MOD_AVX512_IFMA_INT_PROD_RIGHT_SHIFT_CASE(44, 1) + // ELTWISE_MULT_MOD_AVX512_IFMA_INT_PROD_RIGHT_SHIFT_CASE(45, 1) + // ELTWISE_MULT_MOD_AVX512_IFMA_INT_PROD_RIGHT_SHIFT_CASE(46, 1) + // ELTWISE_MULT_MOD_AVX512_IFMA_INT_PROD_RIGHT_SHIFT_CASE(47, 1) + // ELTWISE_MULT_MOD_AVX512_IFMA_INT_PROD_RIGHT_SHIFT_CASE(48, 1) + // ELTWISE_MULT_MOD_AVX512_IFMA_INT_PROD_RIGHT_SHIFT_CASE(49, 1) + // ELTWISE_MULT_MOD_AVX512_IFMA_INT_PROD_RIGHT_SHIFT_CASE(50, 1) + // ELTWISE_MULT_MOD_AVX512_IFMA_INT_PROD_RIGHT_SHIFT_CASE(51, 1) + default: { + EltwiseMultModAVX512IFMAIntLoopDefault<1>( + vp_result, vp_operand1, vp_operand2, v_barr_lo, v_modulus, + v_neg_mod, v_twice_mod, n, prod_right_shift); + } + } + } + HEXL_CHECK_BOUNDS(result, n, modulus, "result exceeds bound " << modulus); +} + +#endif + +} // namespace hexl +} // namespace intel diff --git a/hexl_ser/eltwise/eltwise-mult-mod-internal.hpp b/hexl_ser/eltwise/eltwise-mult-mod-internal.hpp new file mode 100644 index 00000000..77705223 --- /dev/null +++ b/hexl_ser/eltwise/eltwise-mult-mod-internal.hpp @@ -0,0 +1,104 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +#include + +#include "eltwise/eltwise-mult-mod-internal.hpp" +#include "hexl/eltwise/eltwise-reduce-mod.hpp" +#include "hexl/number-theory/number-theory.hpp" +#include "hexl/util/check.hpp" +#include "hexl/util/compiler.hpp" + +namespace intel { +namespace hexl { + +/// @brief Multiplies two vectors elementwise with modular reduction +/// @param[in] result Result of element-wise multiplication +/// @param[in] operand1 Vector of elements to multiply. Each element must be +/// less than the modulus. +/// @param[in] operand2 Vector of elements to multiply. Each element must be +/// less than the modulus. +/// @param[in] n Number of elements in each vector +/// @param[in] modulus Modulus with which to perform modular reduction +/// @param[in] input_mod_factor Assumes input elements are in [0, +/// input_mod_factor * p) Must be 1, 2 or 4. +/// @details Computes \p result[i] = (\p operand1[i] * \p operand2[i]) mod \p +/// modulus for i=0, ..., \p n - 1 +/// @details Algorithm 2 from +/// https://homes.esat.kuleuven.be/~fvercaut/papers/bar_mont.pdf +template +void EltwiseMultModNative(uint64_t* result, const uint64_t* operand1, + const uint64_t* operand2, uint64_t n, + uint64_t modulus) { + HEXL_CHECK(InputModFactor == 1 || InputModFactor == 2 || InputModFactor == 4, + "Require InputModFactor = 1, 2, or 4") + HEXL_CHECK(result != nullptr, "Require result != nullptr"); + HEXL_CHECK(operand1 != nullptr, "Require operand1 != nullptr"); + HEXL_CHECK(operand2 != nullptr, "Require operand2 != nullptr"); + HEXL_CHECK(n != 0, "Require n != 0"); + HEXL_CHECK(modulus > 1, "Require modulus > 1"); + HEXL_CHECK(modulus < (1ULL << 62), "Require modulus < (1ULL << 62)"); + HEXL_CHECK_BOUNDS(operand1, n, InputModFactor * modulus, + "operand1 exceeds bound " << (InputModFactor * modulus)); + HEXL_CHECK_BOUNDS(operand2, n, InputModFactor * modulus, + "operand2 exceeds bound " << (InputModFactor * modulus)); + + constexpr int64_t beta = -2; + HEXL_CHECK(beta <= -2, "beta must be <= -2 for correctness"); + + constexpr int64_t alpha = 62; // ensures alpha - beta = 64 + + uint64_t gamma = Log2(InputModFactor); + HEXL_UNUSED(gamma); + HEXL_CHECK(alpha >= gamma + 1, "alpha must be >= gamma + 1 for correctness"); + + const uint64_t ceil_log_mod = Log2(modulus) + 1; // "n" from Algorithm 2 + uint64_t prod_right_shift = ceil_log_mod + beta; + + // Barrett factor "mu" + // TODO(fboemer): Allow MultiplyFactor to take bit shifts != 64 + HEXL_CHECK(ceil_log_mod + alpha >= 64, "ceil_log_mod + alpha < 64"); + uint64_t barr_lo = + MultiplyFactor(uint64_t(1) << (ceil_log_mod + alpha - 64), 64, modulus) + .BarrettFactor(); + + const uint64_t twice_modulus = 2 * modulus; + + // HEXL_LOOP_UNROLL_4 + for (size_t i = 0; i < n; ++i) { + uint64_t prod_hi, prod_lo, c2_hi, c2_lo, Z; + + uint64_t x = ReduceMod(*operand1, modulus, &twice_modulus); + uint64_t y = ReduceMod(*operand2, modulus, &twice_modulus); + + // Multiply inputs + MultiplyUInt64(x, y, &prod_hi, &prod_lo); + + // floor(U / 2^{n + beta}) + uint64_t c1 = (prod_lo >> (prod_right_shift)) + + (prod_hi << (64 - (prod_right_shift))); + + // c2 = floor(U / 2^{n + beta}) * mu + MultiplyUInt64(c1, barr_lo, &c2_hi, &c2_lo); + + // alpha - beta == 64, so we only need high 64 bits + uint64_t q_hat = c2_hi; + + // only compute low bits, since we know high bits will be 0 + Z = prod_lo - q_hat * modulus; + + // Conditional subtraction + *result = (Z >= modulus) ? (Z - modulus) : Z; + + ++operand1; + ++operand2; + ++result; + } +} + +} // namespace hexl +} // namespace intel diff --git a/hexl_ser/eltwise/eltwise-mult-mod.cpp b/hexl_ser/eltwise/eltwise-mult-mod.cpp new file mode 100644 index 00000000..c4a423d2 --- /dev/null +++ b/hexl_ser/eltwise/eltwise-mult-mod.cpp @@ -0,0 +1,85 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "hexl/eltwise/eltwise-mult-mod.hpp" + +#include "eltwise/eltwise-mult-mod-avx512.hpp" +#include "eltwise/eltwise-mult-mod-internal.hpp" +#include "hexl/eltwise/eltwise-reduce-mod.hpp" +#include "hexl/logging/logging.hpp" +#include "hexl/number-theory/number-theory.hpp" +#include "hexl/util/aligned-allocator.hpp" +#include "hexl/util/check.hpp" +#include "util/cpu-features.hpp" + +namespace intel { +namespace hexl { + +void EltwiseMultMod(uint64_t* result, const uint64_t* operand1, + const uint64_t* operand2, uint64_t n, uint64_t modulus, + uint64_t input_mod_factor) { + HEXL_CHECK(result != nullptr, "Require result != nullptr"); + HEXL_CHECK(operand1 != nullptr, "Require operand1 != nullptr"); + HEXL_CHECK(operand2 != nullptr, "Require operand2 != nullptr"); + HEXL_CHECK(n != 0, "Require n != 0"); + HEXL_CHECK(modulus > 1, "Require modulus > 1"); + HEXL_CHECK(input_mod_factor * modulus < (1ULL << 63), + "Require input_mod_factor * modulus < (1ULL << 63)"); + HEXL_CHECK( + input_mod_factor == 1 || input_mod_factor == 2 || input_mod_factor == 4, + "Require input_mod_factor = 1, 2, or 4") + HEXL_CHECK_BOUNDS(operand1, n, input_mod_factor * modulus, + "operand1 exceeds bound " << (input_mod_factor * modulus)) + HEXL_CHECK_BOUNDS(operand2, n, input_mod_factor * modulus, + "operand2 exceeds bound " << (input_mod_factor * modulus)) + +#ifdef HEXL_HAS_AVX512DQ + if (has_avx512dq) { + if (modulus < (1ULL << 50)) { + // EltwiseMultModAVX512IFMA has similar performance to + // EltwiseMultModAVX512Float, but requires the AVX512IFMA instruction set, + // so we prefer to use EltwiseMultModAVX512Float. + switch (input_mod_factor) { + case 1: + EltwiseMultModAVX512Float<1>(result, operand1, operand2, n, modulus); + break; + case 2: + EltwiseMultModAVX512Float<2>(result, operand1, operand2, n, modulus); + break; + case 4: + EltwiseMultModAVX512Float<4>(result, operand1, operand2, n, modulus); + break; + } + } else { + switch (input_mod_factor) { + case 1: + EltwiseMultModAVX512DQInt<1>(result, operand1, operand2, n, modulus); + break; + case 2: + EltwiseMultModAVX512DQInt<2>(result, operand1, operand2, n, modulus); + break; + case 4: + EltwiseMultModAVX512DQInt<4>(result, operand1, operand2, n, modulus); + break; + } + } + return; + } +#endif + + HEXL_VLOG(3, "Calling EltwiseMultModNative"); + switch (input_mod_factor) { + case 1: + EltwiseMultModNative<1>(result, operand1, operand2, n, modulus); + break; + case 2: + EltwiseMultModNative<2>(result, operand1, operand2, n, modulus); + break; + case 4: + EltwiseMultModNative<4>(result, operand1, operand2, n, modulus); + break; + } + return; +} +} // namespace hexl +} // namespace intel diff --git a/hexl_ser/eltwise/eltwise-reduce-mod-avx512.cpp b/hexl_ser/eltwise/eltwise-reduce-mod-avx512.cpp new file mode 100644 index 00000000..144e070b --- /dev/null +++ b/hexl_ser/eltwise/eltwise-reduce-mod-avx512.cpp @@ -0,0 +1,27 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "eltwise/eltwise-reduce-mod-avx512.hpp" + +namespace intel { +namespace hexl { + +#ifdef HEXL_HAS_AVX512DQ + +template void EltwiseReduceModAVX512<64>(uint64_t* result, + const uint64_t* operand, uint64_t n, + uint64_t modulus, + uint64_t input_mod_factor, + uint64_t output_mod_factor); +#endif + +#ifdef HEXL_HAS_AVX512IFMA +template void EltwiseReduceModAVX512<52>(uint64_t* result, + const uint64_t* operand, uint64_t n, + uint64_t modulus, + uint64_t input_mod_factor, + uint64_t output_mod_factor); +#endif + +} // namespace hexl +} // namespace intel diff --git a/hexl_ser/eltwise/eltwise-reduce-mod-avx512.hpp b/hexl_ser/eltwise/eltwise-reduce-mod-avx512.hpp new file mode 100644 index 00000000..5374c9c8 --- /dev/null +++ b/hexl_ser/eltwise/eltwise-reduce-mod-avx512.hpp @@ -0,0 +1,378 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include +#include + +#include "eltwise/eltwise-reduce-mod-avx512.hpp" +#include "eltwise/eltwise-reduce-mod-internal.hpp" +#include "hexl/eltwise/eltwise-reduce-mod.hpp" +#include "hexl/logging/logging.hpp" +#include "hexl/number-theory/number-theory.hpp" +#include "hexl/util/check.hpp" +#include "util/avx512-util.hpp" + +namespace intel { +namespace hexl { + +#ifdef HEXL_HAS_AVX512DQ +template +void EltwiseReduceModAVX512(uint64_t* result, const uint64_t* operand, + uint64_t n, uint64_t modulus, + uint64_t input_mod_factor, + uint64_t output_mod_factor) { + HEXL_CHECK(operand != nullptr, "Require operand1 != nullptr"); + HEXL_CHECK(n != 0, "Require n != 0"); + HEXL_CHECK(modulus > 1, "Require modulus > 1"); + HEXL_CHECK(input_mod_factor == modulus || input_mod_factor == 2 || + input_mod_factor == 4, + "input_mod_factor must be modulus or 2 or 4" << input_mod_factor); + HEXL_CHECK(output_mod_factor == 1 || output_mod_factor == 2, + "output_mod_factor must be 1 or 2 " << output_mod_factor); + HEXL_CHECK(input_mod_factor != output_mod_factor, + "input_mod_factor must not be equal to output_mod_factor "); + + uint64_t n_tmp = n; + + // Multi-word Barrett reduction precomputation + constexpr int64_t alpha = BitShift - 2; + constexpr int64_t beta = -2; + const uint64_t ceil_log_mod = Log2(modulus) + 1; // "n" from Algorithm 2 + uint64_t prod_right_shift = ceil_log_mod + beta; + __m512i v_neg_mod = _mm512_set1_epi64(-static_cast(modulus)); + + uint64_t barrett_factor = + MultiplyFactor(uint64_t(1) << (ceil_log_mod + alpha - BitShift), BitShift, + modulus) + .BarrettFactor(); + + uint64_t barrett_factor_52 = MultiplyFactor(1, 52, modulus).BarrettFactor(); + + if (BitShift == 64) { + // Single-worded Barrett reduction. + barrett_factor = MultiplyFactor(1, 64, modulus).BarrettFactor(); + } + + __m512i v_bf = _mm512_set1_epi64(static_cast(barrett_factor)); + __m512i v_bf_52 = _mm512_set1_epi64(static_cast(barrett_factor_52)); + + // Deals with n not divisible by 8 + uint64_t n_mod_8 = n_tmp % 8; + if (n_mod_8 != 0) { + EltwiseReduceModNative(result, operand, n_mod_8, modulus, input_mod_factor, + output_mod_factor); + operand += n_mod_8; + result += n_mod_8; + n_tmp -= n_mod_8; + } + + uint64_t twice_mod = modulus << 1; + const __m512i* v_operand = reinterpret_cast(operand); + __m512i* v_result = reinterpret_cast<__m512i*>(result); + __m512i v_modulus = _mm512_set1_epi64(static_cast(modulus)); + __m512i v_twice_mod = _mm512_set1_epi64(static_cast(twice_mod)); + + if (input_mod_factor == modulus) { + if (output_mod_factor == 2) { + for (size_t i = 0; i < n_tmp; i += 8) { + __m512i v_op = _mm512_loadu_si512(v_operand); + v_op = _mm512_hexl_barrett_reduce64( + v_op, v_modulus, v_bf, v_bf_52, prod_right_shift, v_neg_mod); + HEXL_CHECK_BOUNDS(ExtractValues(v_op).data(), 8, modulus, + "v_op exceeds bound " << modulus); + _mm512_storeu_si512(v_result, v_op); + ++v_operand; + ++v_result; + } + } else { + for (size_t i = 0; i < n_tmp; i += 8) { + __m512i v_op = _mm512_loadu_si512(v_operand); + v_op = _mm512_hexl_barrett_reduce64( + v_op, v_modulus, v_bf, v_bf_52, prod_right_shift, v_neg_mod); + HEXL_CHECK_BOUNDS(ExtractValues(v_op).data(), 8, modulus, + "v_op exceeds bound " << modulus); + _mm512_storeu_si512(v_result, v_op); + ++v_operand; + ++v_result; + } + } + } + + if (input_mod_factor == 2) { + for (size_t i = 0; i < n_tmp; i += 8) { + __m512i v_op = _mm512_loadu_si512(v_operand); + v_op = _mm512_hexl_small_mod_epu64(v_op, v_modulus); + HEXL_CHECK_BOUNDS(ExtractValues(v_op).data(), 8, modulus, + "v_op exceeds bound " << modulus); + _mm512_storeu_si512(v_result, v_op); + ++v_operand; + ++v_result; + } + } + + if (input_mod_factor == 4) { + if (output_mod_factor == 1) { + for (size_t i = 0; i < n_tmp; i += 8) { + __m512i v_op = _mm512_loadu_si512(v_operand); + v_op = _mm512_hexl_small_mod_epu64(v_op, v_twice_mod); + v_op = _mm512_hexl_small_mod_epu64(v_op, v_modulus); + HEXL_CHECK_BOUNDS(ExtractValues(v_op).data(), 8, modulus, + "v_op exceeds bound " << modulus); + _mm512_storeu_si512(v_result, v_op); + ++v_operand; + ++v_result; + } + } + if (output_mod_factor == 2) { + for (size_t i = 0; i < n_tmp; i += 8) { + __m512i v_op = _mm512_loadu_si512(v_operand); + v_op = _mm512_hexl_small_mod_epu64(v_op, v_twice_mod); + HEXL_CHECK_BOUNDS(ExtractValues(v_op).data(), 8, twice_mod, + "v_op exceeds bound " << twice_mod); + _mm512_storeu_si512(v_result, v_op); + ++v_operand; + ++v_result; + } + } + } +} + +/// @brief Returns Montgomery form of modular product ab mod q, computed via the +/// REDC algorithm, also known as Montgomery reduction. +/// @tparam BitShift denotes the operational length, in bits, of the operands +/// and result values. +/// @tparam r defines the value of R, being R = 2^r. R > modulus. +/// @param[in] a input vector. T = ab in the range [0, Rq − 1]. +/// @param[in] b input vector. +/// @param[in] modulus such that gcd(R, modulus) = 1. +/// @param[in] neg_inv_mod in [0, R − 1] such that q*neg_inv_mod ≡ −1 mod R, +/// @param[in] n number of elements in input vector. +/// @param[out] result unsigned long int vector in the range [0, q − 1] such +/// that S ≡ TR^−1 mod q +template +void EltwiseMontReduceModAVX512(uint64_t* result, const uint64_t* a, + const uint64_t* b, uint64_t n, uint64_t modulus, + uint64_t neg_inv_mod) { + HEXL_CHECK(result != nullptr, "Require result != nullptr"); + HEXL_CHECK(a != nullptr, "Require operand a != nullptr"); + HEXL_CHECK(b != nullptr, "Require operand b != nullptr"); + HEXL_CHECK(n != 0, "Require n != 0"); + HEXL_CHECK(modulus > 1, "Require modulus > 1"); + + uint64_t R = (1ULL << r); + HEXL_CHECK(std::gcd(modulus, R) == 1, "gcd(modulus, R) != 1"); + HEXL_CHECK(R > modulus, "Needs R bigger than q."); + + // mod_R_mask[63:r] all zeros & mod_R_mask[r-1:0] all ones + uint64_t mod_R_mask = R - 1; + uint64_t prod_rs; + if (BitShift == 64) { + HEXL_CHECK(r <= 62, "With r > 62 internal ops might overflow"); + prod_rs = (1ULL << 63) - 1; + } else { + prod_rs = (1ULL << (52 - r)); + } + uint64_t n_tmp = n; + + // Deals with n not divisible by 8 + uint64_t n_mod_8 = n_tmp % 8; + if (n_mod_8 != 0) { + for (size_t i = 0; i < n_mod_8; ++i) { + uint64_t T_hi; + uint64_t T_lo; + MultiplyUInt64(a[i], b[i], &T_hi, &T_lo); + result[i] = MontgomeryReduce(T_hi, T_lo, modulus, r, mod_R_mask, + neg_inv_mod); + } + a += n_mod_8; + b += n_mod_8; + result += n_mod_8; + n_tmp -= n_mod_8; + } + + const __m512i* v_a = reinterpret_cast(a); + const __m512i* v_b = reinterpret_cast(b); + __m512i* v_result = reinterpret_cast<__m512i*>(result); + __m512i v_modulus = _mm512_set1_epi64(modulus); + __m512i v_neg_inv_mod = _mm512_set1_epi64(neg_inv_mod); + __m512i v_prod_rs = _mm512_set1_epi64(prod_rs); + + for (size_t i = 0; i < n_tmp; i += 8) { + __m512i v_a_op = _mm512_loadu_si512(v_a); + __m512i v_b_op = _mm512_loadu_si512(v_b); + __m512i v_T_hi = _mm512_hexl_mulhi_epi(v_a_op, v_b_op); + __m512i v_T_lo = _mm512_hexl_mullo_epi(v_a_op, v_b_op); + + // Convert to 63 bits to save intermediate carry + if (BitShift == 64) { + v_T_hi = _mm512_slli_epi64(v_T_hi, 1); + __m512i tmp = _mm512_srli_epi64(v_T_lo, 63); + v_T_hi = _mm512_add_epi64(v_T_hi, tmp); + v_T_lo = _mm512_and_epi64(v_T_lo, v_prod_rs); + } + + __m512i v_c = _mm512_hexl_montgomery_reduce( + v_T_hi, v_T_lo, v_modulus, v_neg_inv_mod, v_prod_rs); + HEXL_CHECK_BOUNDS(ExtractValues(v_c).data(), 8, modulus, + "v_op exceeds bound " << modulus); + _mm512_storeu_si512(v_result, v_c); + ++v_a; + ++v_b; + ++v_result; + } +} + +/// @brief Returns Montgomery form of a mod q, computed via the REDC algorithm, +/// also known as Montgomery reduction. +/// @tparam BitShift denotes the operational length, in bits, of the operands +/// and result values. +/// @tparam r defines the value of R, being R = 2^r. R > modulus. +/// @param[in] a input vector. T = a(R^2 mod q) in the range [0, Rq − 1]. +/// @param[in] R2_mod_q R^2 mod q. +/// @param[in] modulus such that gcd(R, modulus) = 1. +/// @param[in] neg_inv_mod in [0, R − 1] such that q*neg_inv_mod ≡ −1 mod R, +/// @param[in] n number of elements in input vector. +/// @param[out] result unsigned long int vector in the range [0, q − 1] such +/// that S ≡ TR^−1 mod q +template +void EltwiseMontgomeryFormInAVX512(uint64_t* result, const uint64_t* a, + uint64_t R2_mod_q, uint64_t n, + uint64_t modulus, uint64_t neg_inv_mod) { + HEXL_CHECK(result != nullptr, "Require result != nullptr"); + HEXL_CHECK(a != nullptr, "Require operand a != nullptr"); + HEXL_CHECK(n != 0, "Require n != 0"); + HEXL_CHECK(modulus > 1, "Require modulus > 1"); + + uint64_t R = (1ULL << r); + HEXL_CHECK(std::gcd(modulus, R) == 1, "gcd(modulus, R) != 1"); + HEXL_CHECK(R > modulus, "Needs R bigger than q."); + + // mod_R_mask[63:r] all zeros & mod_R_mask[r-1:0] all ones + uint64_t mod_R_mask = R - 1; + uint64_t prod_rs; + if (BitShift == 64) { + HEXL_CHECK(r <= 62, "With r > 62 internal ops might overflow"); + prod_rs = (1ULL << 63) - 1; + } else { + prod_rs = (1ULL << (52 - r)); + } + uint64_t n_tmp = n; + + // Deals with n not divisible by 8 + uint64_t n_mod_8 = n_tmp % 8; + if (n_mod_8 != 0) { + for (size_t i = 0; i < n_mod_8; ++i) { + uint64_t T_hi; + uint64_t T_lo; + MultiplyUInt64(a[i], R2_mod_q, &T_hi, &T_lo); + result[i] = MontgomeryReduce(T_hi, T_lo, modulus, r, mod_R_mask, + neg_inv_mod); + } + a += n_mod_8; + result += n_mod_8; + n_tmp -= n_mod_8; + } + + const __m512i* v_a = reinterpret_cast(a); + __m512i* v_result = reinterpret_cast<__m512i*>(result); + __m512i v_b = _mm512_set1_epi64(R2_mod_q); + __m512i v_modulus = _mm512_set1_epi64(modulus); + __m512i v_neg_inv_mod = _mm512_set1_epi64(neg_inv_mod); + __m512i v_prod_rs = _mm512_set1_epi64(prod_rs); + + for (size_t i = 0; i < n_tmp; i += 8) { + __m512i v_a_op = _mm512_loadu_si512(v_a); + __m512i v_T_hi = _mm512_hexl_mulhi_epi(v_a_op, v_b); + __m512i v_T_lo = _mm512_hexl_mullo_epi(v_a_op, v_b); + + // Convert to 63 bits to save intermediate carry + if (BitShift == 64) { + v_T_hi = _mm512_slli_epi64(v_T_hi, 1); + __m512i tmp = _mm512_srli_epi64(v_T_lo, 63); + v_T_hi = _mm512_add_epi64(v_T_hi, tmp); + v_T_lo = _mm512_and_epi64(v_T_lo, v_prod_rs); + } + + __m512i v_c = _mm512_hexl_montgomery_reduce( + v_T_hi, v_T_lo, v_modulus, v_neg_inv_mod, v_prod_rs); + HEXL_CHECK_BOUNDS(ExtractValues(v_c).data(), 8, modulus, + "v_op exceeds bound " << modulus); + _mm512_storeu_si512(v_result, v_c); + ++v_a; + ++v_result; + } +} + +/// @brief Convert out of the Montgomery Form computed via the REDC algorithm, +/// also known as Montgomery reduction. +/// @tparam BitShift denotes the operational length, in bits, of the operands +/// and result values. +/// @tparam r defines the value of R, being R = 2^r. R > modulus. +/// @param[in] a input vector in Montgomery Form. +/// @param[in] modulus such that gcd(R, modulus) = 1. +/// @param[in] neg_inv_mod in [0, R − 1] such that q*neg_inv_mod ≡ −1 mod R, +/// @param[in] n number of elements in input vector. +/// @param[out] result unsigned long int vector in the range [0, q − 1] such +/// that S ≡ TR^−1 mod q +template +void EltwiseMontgomeryFormOutAVX512(uint64_t* result, const uint64_t* a, + uint64_t n, uint64_t modulus, + uint64_t neg_inv_mod) { + HEXL_CHECK(result != nullptr, "Require result != nullptr"); + HEXL_CHECK(a != nullptr, "Require operand a != nullptr"); + HEXL_CHECK(n != 0, "Require n != 0"); + HEXL_CHECK(modulus > 1, "Require modulus > 1"); + + uint64_t R = (1ULL << r); + HEXL_CHECK(std::gcd(modulus, R) == 1, "gcd(modulus, R) != 1"); + HEXL_CHECK(R > modulus, "Needs R bigger than q."); + + // mod_R_mask[63:r] all zeros & mod_R_mask[r-1:0] all ones + uint64_t mod_R_mask = R - 1; + uint64_t prod_rs; + if (BitShift == 64) { + HEXL_CHECK(r <= 62, "With r > 62 internal ops might overflow"); + prod_rs = (1ULL << 63) - 1; + } else { + prod_rs = (1ULL << (52 - r)); + } + uint64_t n_tmp = n; + + // Deals with n not divisible by 8 + uint64_t n_mod_8 = n_tmp % 8; + if (n_mod_8 != 0) { + for (size_t i = 0; i < n_mod_8; ++i) { + result[i] = MontgomeryReduce(0, a[i], modulus, r, mod_R_mask, + neg_inv_mod); + } + a += n_mod_8; + result += n_mod_8; + n_tmp -= n_mod_8; + } + + const __m512i* v_a = reinterpret_cast(a); + __m512i* v_result = reinterpret_cast<__m512i*>(result); + __m512i v_modulus = _mm512_set1_epi64(modulus); + __m512i v_neg_inv_mod = _mm512_set1_epi64(neg_inv_mod); + __m512i v_prod_rs = _mm512_set1_epi64(prod_rs); + __m512i v_T_hi = _mm512_set1_epi64(0); + + for (size_t i = 0; i < n_tmp; i += 8) { + __m512i v_T_lo = _mm512_loadu_si512(v_a); + __m512i v_c = _mm512_hexl_montgomery_reduce( + v_T_hi, v_T_lo, v_modulus, v_neg_inv_mod, v_prod_rs); + HEXL_CHECK_BOUNDS(ExtractValues(v_c).data(), 8, modulus, + "v_op exceeds bound " << modulus); + _mm512_storeu_si512(v_result, v_c); + ++v_a; + ++v_result; + } +} + +#endif + +} // namespace hexl +} // namespace intel diff --git a/hexl_ser/eltwise/eltwise-reduce-mod-internal.hpp b/hexl_ser/eltwise/eltwise-reduce-mod-internal.hpp new file mode 100644 index 00000000..ce50f5e8 --- /dev/null +++ b/hexl_ser/eltwise/eltwise-reduce-mod-internal.hpp @@ -0,0 +1,28 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +namespace intel { +namespace hexl { + +// @brief Performs elementwise modular reduction +// @param[out] result Stores result +// @param[in] operand Vector of elements +// @param[in] n Number of elements in operand +// @param[in] modulus Modulus with which to perform modular reduction +// @param[in] input_mod_factor Assumes input elements are in [0, +// input_mod_factor * p) Must be modulus, 2 or 4. input_mod_factor=modulus +// means, input range is [0, p * p]. Barrett reduction will be used in this case +// input_mod_factor > output_mod_factor +// @param[in] output_mod_factor output elements will be in [0, output_mod_factor +// * p) Must be 1 or 2. for input_mod_factor=0, output_mod_factor will be set +// to 1. +void EltwiseReduceModNative(uint64_t* result, const uint64_t* operand, + uint64_t n, uint64_t modulus, + uint64_t input_mod_factor, + uint64_t output_mod_factor); +} // namespace hexl +} // namespace intel diff --git a/hexl_ser/eltwise/eltwise-reduce-mod.cpp b/hexl_ser/eltwise/eltwise-reduce-mod.cpp new file mode 100644 index 00000000..accfe938 --- /dev/null +++ b/hexl_ser/eltwise/eltwise-reduce-mod.cpp @@ -0,0 +1,125 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "hexl/eltwise/eltwise-reduce-mod.hpp" + +#include "eltwise/eltwise-reduce-mod-avx512.hpp" +#include "eltwise/eltwise-reduce-mod-internal.hpp" +#include "hexl/logging/logging.hpp" +#include "hexl/number-theory/number-theory.hpp" +#include "hexl/util/check.hpp" +#include "util/cpu-features.hpp" + +namespace intel { +namespace hexl { + +void EltwiseReduceModNative(uint64_t* result, const uint64_t* operand, + uint64_t n, uint64_t modulus, + uint64_t input_mod_factor, + uint64_t output_mod_factor) { + HEXL_CHECK(operand != nullptr, "Require operand1 != nullptr"); + HEXL_CHECK(result != nullptr, "Require result != nullptr"); + HEXL_CHECK(n != 0, "Require n != 0"); + HEXL_CHECK(modulus > 1, "Require modulus > 1"); + HEXL_CHECK(input_mod_factor == modulus || input_mod_factor == 2 || + input_mod_factor == 4, + "input_mod_factor must be modulus or 2 or 4" << input_mod_factor); + HEXL_CHECK(output_mod_factor == 1 || output_mod_factor == 2, + "output_mod_factor must be 1 or 2 " << output_mod_factor); + HEXL_CHECK(input_mod_factor != output_mod_factor, + "input_mod_factor must not be equal to output_mod_factor "); + + uint64_t barrett_factor = MultiplyFactor(1, 64, modulus).BarrettFactor(); + + uint64_t twice_modulus = modulus << 1; + if (input_mod_factor == modulus) { + if (output_mod_factor == 2) { + for (size_t i = 0; i < n; ++i) { + if (operand[i] >= modulus) { + result[i] = BarrettReduce64<2>(operand[i], modulus, barrett_factor); + } else { + result[i] = operand[i]; + } + } + } else { + for (size_t i = 0; i < n; ++i) { + if (operand[i] >= modulus) { + result[i] = BarrettReduce64<1>(operand[i], modulus, barrett_factor); + } else { + result[i] = operand[i]; + } + } + + HEXL_CHECK_BOUNDS(result, n, modulus, "result exceeds bound " << modulus); + } + } + + if (input_mod_factor == 2) { + for (size_t i = 0; i < n; ++i) { + result[i] = ReduceMod<2>(operand[i], modulus); + } + HEXL_CHECK_BOUNDS(result, n, modulus, "result exceeds bound " << modulus); + } + + if (input_mod_factor == 4) { + if (output_mod_factor == 1) { + for (size_t i = 0; i < n; ++i) { + result[i] = ReduceMod<4>(operand[i], modulus, &twice_modulus); + } + HEXL_CHECK_BOUNDS(result, n, modulus, "result exceeds bound " << modulus); + } + if (output_mod_factor == 2) { + for (size_t i = 0; i < n; ++i) { + result[i] = ReduceMod<2>(operand[i], twice_modulus); + } + HEXL_CHECK_BOUNDS(result, n, twice_modulus, + "result exceeds bound " << twice_modulus); + } + } +} + +void EltwiseReduceMod(uint64_t* result, const uint64_t* operand, uint64_t n, + uint64_t modulus, uint64_t input_mod_factor, + uint64_t output_mod_factor) { + HEXL_CHECK(operand != nullptr, "Require operand1 != nullptr"); + HEXL_CHECK(result != nullptr, "Require result != nullptr"); + HEXL_CHECK(n != 0, "Require n != 0"); + HEXL_CHECK(modulus > 1, "Require modulus > 1"); + HEXL_CHECK(input_mod_factor == modulus || input_mod_factor == 2 || + input_mod_factor == 4, + "input_mod_factor must be modulus or 2 or 4" << input_mod_factor); + HEXL_CHECK(output_mod_factor == 1 || output_mod_factor == 2, + "output_mod_factor must be 1 or 2 " << output_mod_factor); + + if (input_mod_factor == output_mod_factor && (operand != result)) { + for (size_t i = 0; i < n; ++i) { + result[i] = operand[i]; + } + return; + } + +#ifdef HEXL_HAS_AVX512IFMA + // Modulus can be 52 bits only if input mod factors <= 4 + // otherwise modulus should be 51 bits max to give correct results + if ((has_avx512ifma && modulus < (1ULL << 51)) || + (modulus < (1ULL << 52) && input_mod_factor <= 4)) { + EltwiseReduceModAVX512<52>(result, operand, n, modulus, input_mod_factor, + output_mod_factor); + return; + } +#endif + +#ifdef HEXL_HAS_AVX512DQ + if (has_avx512dq) { + EltwiseReduceModAVX512<64>(result, operand, n, modulus, input_mod_factor, + output_mod_factor); + return; + } +#endif + + HEXL_VLOG(3, "Calling EltwiseReduceModNative"); + EltwiseReduceModNative(result, operand, n, modulus, input_mod_factor, + output_mod_factor); +} +} // namespace hexl +} // namespace intel diff --git a/hexl_ser/eltwise/eltwise-sub-mod-avx512.cpp b/hexl_ser/eltwise/eltwise-sub-mod-avx512.cpp new file mode 100644 index 00000000..2039c917 --- /dev/null +++ b/hexl_ser/eltwise/eltwise-sub-mod-avx512.cpp @@ -0,0 +1,108 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "eltwise/eltwise-sub-mod-avx512.hpp" + +#include +#include + +#include "eltwise/eltwise-sub-mod-internal.hpp" +#include "hexl/eltwise/eltwise-sub-mod.hpp" +#include "hexl/util/check.hpp" +#include "util/avx512-util.hpp" + +#ifdef HEXL_HAS_AVX512DQ + +namespace intel { +namespace hexl { + +void EltwiseSubModAVX512(uint64_t* result, const uint64_t* operand1, + const uint64_t* operand2, uint64_t n, + uint64_t modulus) { + HEXL_CHECK(result != nullptr, "Require result != nullptr"); + HEXL_CHECK(operand1 != nullptr, "Require operand1 != nullptr"); + HEXL_CHECK(operand2 != nullptr, "Require operand2 != nullptr"); + HEXL_CHECK(n != 0, "Require n != 0"); + HEXL_CHECK(modulus > 1, "Require modulus > 1"); + HEXL_CHECK(modulus < (1ULL << 63), "Require modulus < 2**63"); + HEXL_CHECK_BOUNDS(operand1, n, modulus, + "pre-sub value in operand1 exceeds bound " << modulus); + HEXL_CHECK_BOUNDS(operand2, n, modulus, + "pre-sub value in operand2 exceeds bound " << modulus); + + uint64_t n_mod_8 = n % 8; + if (n_mod_8 != 0) { + EltwiseSubModNative(result, operand1, operand2, n_mod_8, modulus); + operand1 += n_mod_8; + operand2 += n_mod_8; + result += n_mod_8; + n -= n_mod_8; + } + + __m512i v_modulus = _mm512_set1_epi64(static_cast(modulus)); + __m512i* vp_result = reinterpret_cast<__m512i*>(result); + const __m512i* vp_operand1 = reinterpret_cast(operand1); + const __m512i* vp_operand2 = reinterpret_cast(operand2); + + HEXL_LOOP_UNROLL_4 + for (size_t i = n / 8; i > 0; --i) { + __m512i v_operand1 = _mm512_loadu_si512(vp_operand1); + __m512i v_operand2 = _mm512_loadu_si512(vp_operand2); + + __m512i v_result = + _mm512_hexl_small_sub_mod_epi64(v_operand1, v_operand2, v_modulus); + + _mm512_storeu_si512(vp_result, v_result); + + ++vp_result; + ++vp_operand1; + ++vp_operand2; + } + + HEXL_CHECK_BOUNDS(result, n, modulus, "result exceeds bound " << modulus); +} + +void EltwiseSubModAVX512(uint64_t* result, const uint64_t* operand1, + uint64_t operand2, uint64_t n, uint64_t modulus) { + HEXL_CHECK(result != nullptr, "Require result != nullptr"); + HEXL_CHECK(operand1 != nullptr, "Require operand1 != nullptr"); + HEXL_CHECK(n != 0, "Require n != 0"); + HEXL_CHECK(modulus > 1, "Require modulus > 1"); + HEXL_CHECK(modulus < (1ULL << 63), "Require modulus < 2**63"); + HEXL_CHECK_BOUNDS(operand1, n, modulus, + "pre-sub value in operand1 exceeds bound " << modulus); + HEXL_CHECK(operand2 < modulus, "Require operand2 < modulus"); + + uint64_t n_mod_8 = n % 8; + if (n_mod_8 != 0) { + EltwiseSubModNative(result, operand1, operand2, n_mod_8, modulus); + operand1 += n_mod_8; + result += n_mod_8; + n -= n_mod_8; + } + + __m512i v_modulus = _mm512_set1_epi64(static_cast(modulus)); + __m512i* vp_result = reinterpret_cast<__m512i*>(result); + const __m512i* vp_operand1 = reinterpret_cast(operand1); + __m512i v_operand2 = _mm512_set1_epi64(static_cast(operand2)); + + HEXL_LOOP_UNROLL_4 + for (size_t i = n / 8; i > 0; --i) { + __m512i v_operand1 = _mm512_loadu_si512(vp_operand1); + + __m512i v_result = + _mm512_hexl_small_sub_mod_epi64(v_operand1, v_operand2, v_modulus); + + _mm512_storeu_si512(vp_result, v_result); + + ++vp_result; + ++vp_operand1; + } + + HEXL_CHECK_BOUNDS(result, n, modulus, "result exceeds bound " << modulus); +} + +} // namespace hexl +} // namespace intel + +#endif diff --git a/hexl_ser/eltwise/eltwise-sub-mod-avx512.hpp b/hexl_ser/eltwise/eltwise-sub-mod-avx512.hpp new file mode 100644 index 00000000..eab9772e --- /dev/null +++ b/hexl_ser/eltwise/eltwise-sub-mod-avx512.hpp @@ -0,0 +1,19 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +namespace intel { +namespace hexl { + +void EltwiseSubModAVX512(uint64_t* result, const uint64_t* operand1, + const uint64_t* operand2, uint64_t n, + uint64_t modulus); + +void EltwiseSubModAVX512(uint64_t* result, const uint64_t* operand1, + uint64_t operand2, uint64_t n, uint64_t modulus); + +} // namespace hexl +} // namespace intel diff --git a/hexl_ser/eltwise/eltwise-sub-mod-internal.hpp b/hexl_ser/eltwise/eltwise-sub-mod-internal.hpp new file mode 100644 index 00000000..7c05dfe9 --- /dev/null +++ b/hexl_ser/eltwise/eltwise-sub-mod-internal.hpp @@ -0,0 +1,36 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +namespace intel { +namespace hexl { + +/// @brief Subtracts two vectors elementwise with modular reduction +/// @param[out] result Stores result +/// @param[in] operand1 Vector of elements to subtract from +/// @param[in] operand2 Vector of elements to subtract +/// @param[in] n Number of elements in each vector +/// @param[in] modulus Modulus with which to perform modular reduction +/// @details Computes \f$ operand1[i] = (operand1[i] - operand2[i]) \mod modulus +/// \f$ for \f$ i=0, ..., n-1\f$. +void EltwiseSubModNative(uint64_t* result, const uint64_t* operand1, + const uint64_t* operand2, uint64_t n, + uint64_t modulus); + +/// @brief Subtracts a scalar from a vector elementwise with modular reduction +/// @param[out] result Stores result +/// @param[in] operand1 Vector of elements to subtract from. Each element must +/// be less than the modulus +/// @param[in] operand2 Elements to subtract. Each element must be +/// less than the modulus +/// @param[in] n Number of elements in each vector +/// @param[in] modulus Modulus with which to perform modular reduction. Must be +/// in the range \f$[2, 2^{63} - 1]\f$ +/// @details Computes \f$ operand1[i] = (operand1[i] - operand2) \mod modulus +/// \f$ for \f$ i=0, ..., n-1\f$. +void EltwiseSubModNative(uint64_t* result, const uint64_t* operand1, + uint64_t operand2, uint64_t n, uint64_t modulus); + +} // namespace hexl +} // namespace intel diff --git a/hexl_ser/eltwise/eltwise-sub-mod.cpp b/hexl_ser/eltwise/eltwise-sub-mod.cpp new file mode 100644 index 00000000..45ad1472 --- /dev/null +++ b/hexl_ser/eltwise/eltwise-sub-mod.cpp @@ -0,0 +1,112 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "eltwise/eltwise-sub-mod-avx512.hpp" +#include "eltwise/eltwise-sub-mod-internal.hpp" +#include "hexl/eltwise/eltwise-add-mod.hpp" +#include "hexl/logging/logging.hpp" +#include "hexl/number-theory/number-theory.hpp" +#include "hexl/util/check.hpp" +#include "util/cpu-features.hpp" + +namespace intel { +namespace hexl { + +void EltwiseSubModNative(uint64_t* result, const uint64_t* operand1, + const uint64_t* operand2, uint64_t n, + uint64_t modulus) { + HEXL_CHECK(result != nullptr, "Require result != nullptr"); + HEXL_CHECK(operand1 != nullptr, "Require operand1 != nullptr"); + HEXL_CHECK(operand2 != nullptr, "Require operand2 != nullptr"); + HEXL_CHECK(n != 0, "Require n != 0"); + HEXL_CHECK(modulus > 1, "Require modulus > 1"); + HEXL_CHECK(modulus < (1ULL << 63), "Require modulus < 2**63"); + HEXL_CHECK_BOUNDS(operand1, n, modulus, + "pre-sub value in operand1 exceeds bound " << modulus); + HEXL_CHECK_BOUNDS(operand2, n, modulus, + "pre-sub value in operand2 exceeds bound " << modulus); + + // HEXL_LOOP_UNROLL_4 + for (size_t i = 0; i < n; ++i) { + if (*operand1 >= *operand2) { + *result = *operand1 - *operand2; + } else { + *result = *operand1 + modulus - *operand2; + } + + ++operand1; + ++operand2; + ++result; + } +} + +void EltwiseSubModNative(uint64_t* result, const uint64_t* operand1, + uint64_t operand2, uint64_t n, uint64_t modulus) { + HEXL_CHECK(result != nullptr, "Require result != nullptr"); + HEXL_CHECK(operand1 != nullptr, "Require operand1 != nullptr"); + HEXL_CHECK(n != 0, "Require n != 0"); + HEXL_CHECK(modulus > 1, "Require modulus > 1"); + HEXL_CHECK(modulus < (1ULL << 63), "Require modulus < 2**63"); + HEXL_CHECK_BOUNDS(operand1, n, modulus, + "pre-sub value in operand1 exceeds bound " << modulus); + HEXL_CHECK(operand2 < modulus, "Require operand2 < modulus"); + + // HEXL_LOOP_UNROLL_4 + for (size_t i = 0; i < n; ++i) { + if (*operand1 >= operand2) { + *result = *operand1 - operand2; + } else { + *result = *operand1 + modulus - operand2; + } + + ++operand1; + ++result; + } +} + +void EltwiseSubMod(uint64_t* result, const uint64_t* operand1, + const uint64_t* operand2, uint64_t n, uint64_t modulus) { + HEXL_CHECK(operand1 != nullptr, "Require operand1 != nullptr"); + HEXL_CHECK(operand2 != nullptr, "Require operand2 != nullptr"); + HEXL_CHECK(n != 0, "Require n != 0"); + HEXL_CHECK(modulus > 1, "Require modulus > 1"); + HEXL_CHECK(modulus < (1ULL << 63), "Require modulus < 2**63"); + HEXL_CHECK_BOUNDS(operand1, n, modulus, + "pre-sub value in operand1 exceeds bound " << modulus); + HEXL_CHECK_BOUNDS(operand2, n, modulus, + "pre-sub value in operand2 exceeds bound " << modulus); + +#ifdef HEXL_HAS_AVX512DQ + if (has_avx512dq) { + EltwiseSubModAVX512(result, operand1, operand2, n, modulus); + return; + } +#endif + + HEXL_VLOG(3, "Calling EltwiseSubModNative"); + EltwiseSubModNative(result, operand1, operand2, n, modulus); +} + +void EltwiseSubMod(uint64_t* result, const uint64_t* operand1, + uint64_t operand2, uint64_t n, uint64_t modulus) { + HEXL_CHECK(operand1 != nullptr, "Require operand1 != nullptr"); + HEXL_CHECK(n != 0, "Require n != 0"); + HEXL_CHECK(modulus > 1, "Require modulus > 1"); + HEXL_CHECK(modulus < (1ULL << 63), "Require modulus < 2**63"); + HEXL_CHECK_BOUNDS(operand1, n, modulus, + "pre-sub value in operand1 exceeds bound " << modulus); + HEXL_CHECK(operand2 < modulus, "Require operand2 < modulus"); + +#ifdef HEXL_HAS_AVX512DQ + if (has_avx512dq) { + EltwiseSubModAVX512(result, operand1, operand2, n, modulus); + return; + } +#endif + + HEXL_VLOG(3, "Calling EltwiseSubModNative"); + EltwiseSubModNative(result, operand1, operand2, n, modulus); +} + +} // namespace hexl +} // namespace intel diff --git a/hexl_ser/experimental/fft-like/fft-like-native.cpp b/hexl_ser/experimental/fft-like/fft-like-native.cpp new file mode 100644 index 00000000..0c84699d --- /dev/null +++ b/hexl_ser/experimental/fft-like/fft-like-native.cpp @@ -0,0 +1,423 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "hexl/experimental/fft-like/fft-like-native.hpp" + +#include + +#include "hexl/logging/logging.hpp" +#include "util/util-internal.hpp" + +namespace intel { +namespace hexl { + +inline void ComplexFwdButterflyRadix2(std::complex* X_r, + std::complex* Y_r, + const std::complex* X_op, + const std::complex* Y_op, + const std::complex W) { + HEXL_VLOG(5, "ComplexFwdButterflyRadix2"); + HEXL_VLOG(5, "Inputs: X_op " << *X_op << ", Y_op " << *Y_op << ", W " << W); + std::complex U = *X_op; + std::complex V = *Y_op * W; + *X_r = U + V; + *Y_r = U - V; + HEXL_VLOG(5, "Output X " << *X_r << ", Y " << *Y_r); +} + +inline void ComplexInvButterflyRadix2(std::complex* X_r, + std::complex* Y_r, + const std::complex* X_op, + const std::complex* Y_op, + const std::complex W) { + HEXL_VLOG(5, "ComplexInvButterflyRadix2"); + HEXL_VLOG(5, "Inputs: X_op " << *X_op << ", Y_op " << *Y_op << ", W " << W); + std::complex U = *X_op; + *X_r = U + *Y_op; + *Y_r = (U - *Y_op) * W; + HEXL_VLOG(5, "Output X " << *X_r << ", Y " << *Y_r); +} + +inline void ScaledComplexInvButterflyRadix2(std::complex* X_r, + std::complex* Y_r, + const std::complex* X_op, + const std::complex* Y_op, + const std::complex W, + const double* scalar) { + HEXL_VLOG(5, "ScaledComplexInvButterflyRadix2"); + HEXL_VLOG(5, "Inputs: X_op " << *X_op << ", Y_op " << *Y_op << ", W " << W + << ", scalar " << *scalar); + std::complex U = *X_op; + *X_r = (U + *Y_op) * (*scalar); + *Y_r = (U - *Y_op) * W; + HEXL_VLOG(5, "Output X " << *X_r << ", Y " << *Y_r); +} + +void Forward_FFTLike_ToBitReverseRadix2( + std::complex* result, const std::complex* operand, + const std::complex* root_of_unity_powers, const uint64_t n, + const double* scalar) { + HEXL_CHECK(IsPowerOfTwo(n), "degree " << n << " is not a power of 2"); + HEXL_CHECK(root_of_unity_powers != nullptr, + "root_of_unity_powers == nullptr"); + HEXL_CHECK(operand != nullptr, "operand == nullptr"); + HEXL_CHECK(result != nullptr, "result == nullptr"); + + size_t gap = (n >> 1); + + // In case of out-of-place operation do first pass and convert to in-place + { + const std::complex W = root_of_unity_powers[1]; + std::complex* X_r = result; + std::complex* Y_r = X_r + gap; + const std::complex* X_op = operand; + const std::complex* Y_op = X_op + gap; + + // First pass for out-of-order case + switch (gap) { + case 8: { + ComplexFwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + ComplexFwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + ComplexFwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + ComplexFwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + ComplexFwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + ComplexFwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + ComplexFwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + ComplexFwdButterflyRadix2(X_r, Y_r, X_op, Y_op, W); + break; + } + case 4: { + ComplexFwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + ComplexFwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + ComplexFwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + ComplexFwdButterflyRadix2(X_r, Y_r, X_op, Y_op, W); + break; + } + case 2: { + ComplexFwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + ComplexFwdButterflyRadix2(X_r, Y_r, X_op, Y_op, W); + break; + } + case 1: { + std::complex scaled_W = W; + if (scalar != nullptr) scaled_W = W * *scalar; + ComplexFwdButterflyRadix2(X_r, Y_r, X_op, Y_op, W); + break; + } + default: { + HEXL_LOOP_UNROLL_8 + for (size_t j = 0; j < gap; j += 8) { + ComplexFwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + ComplexFwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + ComplexFwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + ComplexFwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + ComplexFwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + ComplexFwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + ComplexFwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + ComplexFwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + } + } + } + gap >>= 1; + } + + // Continue with in-place operation + for (size_t m = 2; m < n; m <<= 1) { + size_t offset = 0; + switch (gap) { + case 8: { + for (size_t i = 0; i < m; i++) { + if (i != 0) { + offset += (gap << 1); + } + const std::complex W = root_of_unity_powers[m + i]; + std::complex* X_r = result + offset; + std::complex* Y_r = X_r + gap; + const std::complex* X_op = X_r; + const std::complex* Y_op = Y_r; + ComplexFwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + ComplexFwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + ComplexFwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + ComplexFwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + ComplexFwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + ComplexFwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + ComplexFwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + ComplexFwdButterflyRadix2(X_r, Y_r, X_op, Y_op, W); + } + break; + } + case 4: { + for (size_t i = 0; i < m; i++) { + if (i != 0) { + offset += (gap << 1); + } + const std::complex W = root_of_unity_powers[m + i]; + std::complex* X_r = result + offset; + std::complex* Y_r = X_r + gap; + const std::complex* X_op = X_r; + const std::complex* Y_op = Y_r; + ComplexFwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + ComplexFwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + ComplexFwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + ComplexFwdButterflyRadix2(X_r, Y_r, X_op, Y_op, W); + } + break; + } + case 2: { + for (size_t i = 0; i < m; i++) { + if (i != 0) { + offset += (gap << 1); + } + const std::complex W = root_of_unity_powers[m + i]; + std::complex* X_r = result + offset; + std::complex* Y_r = X_r + gap; + const std::complex* X_op = X_r; + const std::complex* Y_op = Y_r; + ComplexFwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + ComplexFwdButterflyRadix2(X_r, Y_r, X_op, Y_op, W); + } + break; + } + case 1: { + if (scalar == nullptr) { + for (size_t i = 0; i < m; i++) { + if (i != 0) { + offset += (gap << 1); + } + const std::complex W = root_of_unity_powers[m + i]; + std::complex* X_r = result + offset; + std::complex* Y_r = X_r + gap; + const std::complex* X_op = X_r; + const std::complex* Y_op = Y_r; + ComplexFwdButterflyRadix2(X_r, Y_r, X_op, Y_op, W); + } + } else { + for (size_t i = 0; i < m; i++) { + if (i != 0) { + offset += (gap << 1); + } + const std::complex W = + *scalar * root_of_unity_powers[m + i]; + std::complex* X_r = result + offset; + std::complex* Y_r = X_r + gap; + *X_r = (*scalar) * (*X_r); + const std::complex* X_op = X_r; + const std::complex* Y_op = Y_r; + ComplexFwdButterflyRadix2(X_r, Y_r, X_op, Y_op, W); + } + } + break; + } + default: { + for (size_t i = 0; i < m; i++) { + if (i != 0) { + offset += (gap << 1); + } + const std::complex W = root_of_unity_powers[m + i]; + std::complex* X_r = result + offset; + std::complex* Y_r = X_r + gap; + const std::complex* X_op = X_r; + const std::complex* Y_op = Y_r; + HEXL_LOOP_UNROLL_8 + for (size_t j = 0; j < gap; j += 8) { + ComplexFwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + ComplexFwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + ComplexFwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + ComplexFwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + ComplexFwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + ComplexFwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + ComplexFwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + ComplexFwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + } + } + } + } + gap >>= 1; + } +} + +void Inverse_FFTLike_FromBitReverseRadix2( + std::complex* result, const std::complex* operand, + const std::complex* inv_root_of_unity_powers, const uint64_t n, + const double* scalar) { + HEXL_CHECK(IsPowerOfTwo(n), "degree " << n << " is not a power of 2"); + HEXL_CHECK(inv_root_of_unity_powers != nullptr, + "inv_root_of_unity_powers == nullptr"); + HEXL_CHECK(operand != nullptr, "operand == nullptr"); + HEXL_CHECK(result != nullptr, "result == nullptr"); + + uint64_t n_div_2 = (n >> 1); + size_t gap = 1; + size_t root_index = 1; + + size_t stop_loop = (scalar == nullptr) ? 0 : 1; + size_t m = n_div_2; + for (; m > stop_loop; m >>= 1) { + size_t offset = 0; + + switch (gap) { + case 1: { + for (size_t i = 0; i < m; i++, root_index++) { + if (i != 0) { + offset += (gap << 1); + } + const std::complex W = inv_root_of_unity_powers[root_index]; + + std::complex* X_r = result + offset; + std::complex* Y_r = X_r + gap; + const std::complex* X_op = operand + offset; + const std::complex* Y_op = X_op + gap; + ComplexInvButterflyRadix2(X_r, Y_r, X_op, Y_op, W); + } + break; + } + case 2: { + for (size_t i = 0; i < m; i++, root_index++) { + if (i != 0) { + offset += (gap << 1); + } + const std::complex W = inv_root_of_unity_powers[root_index]; + std::complex* X_r = result + offset; + std::complex* Y_r = X_r + gap; + const std::complex* X_op = X_r; + const std::complex* Y_op = Y_r; + ComplexInvButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + ComplexInvButterflyRadix2(X_r, Y_r, X_op, Y_op, W); + } + break; + } + case 4: { + for (size_t i = 0; i < m; i++, root_index++) { + if (i != 0) { + offset += (gap << 1); + } + const std::complex W = inv_root_of_unity_powers[root_index]; + std::complex* X_r = result + offset; + std::complex* Y_r = X_r + gap; + const std::complex* X_op = X_r; + const std::complex* Y_op = Y_r; + ComplexInvButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + ComplexInvButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + ComplexInvButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + ComplexInvButterflyRadix2(X_r, Y_r, X_op, Y_op, W); + } + break; + } + case 8: { + for (size_t i = 0; i < m; i++, root_index++) { + if (i != 0) { + offset += (gap << 1); + } + const std::complex W = inv_root_of_unity_powers[root_index]; + std::complex* X_r = result + offset; + std::complex* Y_r = X_r + gap; + const std::complex* X_op = X_r; + const std::complex* Y_op = Y_r; + ComplexInvButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + ComplexInvButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + ComplexInvButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + ComplexInvButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + ComplexInvButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + ComplexInvButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + ComplexInvButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + ComplexInvButterflyRadix2(X_r, Y_r, X_op, Y_op, W); + } + break; + } + default: { + for (size_t i = 0; i < m; i++, root_index++) { + if (i != 0) { + offset += (gap << 1); + } + const std::complex W = inv_root_of_unity_powers[root_index]; + std::complex* X_r = result + offset; + std::complex* Y_r = X_r + gap; + const std::complex* X_op = X_r; + const std::complex* Y_op = Y_r; + + HEXL_LOOP_UNROLL_8 + for (size_t j = 0; j < gap; j += 8) { + ComplexInvButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + ComplexInvButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + ComplexInvButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + ComplexInvButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + ComplexInvButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + ComplexInvButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + ComplexInvButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + ComplexInvButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W); + } + } + } + } + gap <<= 1; + } + + if (m > 0) { + const std::complex W = + *scalar * inv_root_of_unity_powers[root_index]; + std::complex* X_r = result; + std::complex* Y_r = X_r + gap; + const std::complex* X_o = X_r; + const std::complex* Y_o = Y_r; + + switch (gap) { + case 1: { + ScaledComplexInvButterflyRadix2(X_r, Y_r, X_o, Y_o, W, scalar); + break; + } + case 2: { + ScaledComplexInvButterflyRadix2(X_r++, Y_r++, X_o++, Y_o++, W, scalar); + ScaledComplexInvButterflyRadix2(X_r, Y_r, X_o, Y_o, W, scalar); + break; + } + case 4: { + ScaledComplexInvButterflyRadix2(X_r++, Y_r++, X_o++, Y_o++, W, scalar); + ScaledComplexInvButterflyRadix2(X_r++, Y_r++, X_o++, Y_o++, W, scalar); + ScaledComplexInvButterflyRadix2(X_r++, Y_r++, X_o++, Y_o++, W, scalar); + ScaledComplexInvButterflyRadix2(X_r, Y_r, X_o, Y_o, W, scalar); + break; + } + case 8: { + ScaledComplexInvButterflyRadix2(X_r++, Y_r++, X_o++, Y_o++, W, scalar); + ScaledComplexInvButterflyRadix2(X_r++, Y_r++, X_o++, Y_o++, W, scalar); + ScaledComplexInvButterflyRadix2(X_r++, Y_r++, X_o++, Y_o++, W, scalar); + ScaledComplexInvButterflyRadix2(X_r++, Y_r++, X_o++, Y_o++, W, scalar); + ScaledComplexInvButterflyRadix2(X_r++, Y_r++, X_o++, Y_o++, W, scalar); + ScaledComplexInvButterflyRadix2(X_r++, Y_r++, X_o++, Y_o++, W, scalar); + ScaledComplexInvButterflyRadix2(X_r++, Y_r++, X_o++, Y_o++, W, scalar); + ScaledComplexInvButterflyRadix2(X_r, Y_r, X_o, Y_o, W, scalar); + break; + } + default: { + HEXL_LOOP_UNROLL_8 + for (size_t j = 0; j < gap; j += 8) { + ScaledComplexInvButterflyRadix2(X_r++, Y_r++, X_o++, Y_o++, W, + scalar); + ScaledComplexInvButterflyRadix2(X_r++, Y_r++, X_o++, Y_o++, W, + scalar); + ScaledComplexInvButterflyRadix2(X_r++, Y_r++, X_o++, Y_o++, W, + scalar); + ScaledComplexInvButterflyRadix2(X_r++, Y_r++, X_o++, Y_o++, W, + scalar); + ScaledComplexInvButterflyRadix2(X_r++, Y_r++, X_o++, Y_o++, W, + scalar); + ScaledComplexInvButterflyRadix2(X_r++, Y_r++, X_o++, Y_o++, W, + scalar); + ScaledComplexInvButterflyRadix2(X_r++, Y_r++, X_o++, Y_o++, W, + scalar); + ScaledComplexInvButterflyRadix2(X_r++, Y_r++, X_o++, Y_o++, W, + scalar); + } + } + } + } + + // When M is too short it only needs the final stage butterfly. Copying here + // in the case of out-of-place. + if (result != operand && n == 2) { + std::memcpy(result, operand, n * sizeof(std::complex)); + } +} + +} // namespace hexl +} // namespace intel diff --git a/hexl_ser/experimental/fft-like/fft-like.cpp b/hexl_ser/experimental/fft-like/fft-like.cpp new file mode 100644 index 00000000..b6a7759d --- /dev/null +++ b/hexl_ser/experimental/fft-like/fft-like.cpp @@ -0,0 +1,161 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "hexl/experimental/fft-like/fft-like.hpp" + +#include "hexl/experimental/fft-like/fft-like-native.hpp" +#include "hexl/logging/logging.hpp" + +namespace intel { +namespace hexl { + +FFTLike::FFTLike(uint64_t degree, double* in_scalar, + std::shared_ptr alloc_ptr) + : m_degree(degree), + scalar(in_scalar), + m_alloc(alloc_ptr), + m_aligned_alloc(AlignedAllocator(m_alloc)), + m_complex_roots_of_unity(m_aligned_alloc) { + HEXL_CHECK(IsPowerOfTwo(degree), + "degree " << degree << " is not a power of 2"); + HEXL_CHECK(degree > 8, "degree should be bigger than 8"); + + m_degree_bits = Log2(m_degree); + ComputeComplexRootsOfUnity(); + + if (scalar != nullptr) { + scale = *scalar / static_cast(degree); + inv_scale = 1.0 / *scalar; + } +} + +inline std::complex swap_real_imag(std::complex c) { + return std::complex(c.imag(), c.real()); +} + +void FFTLike::ComputeComplexRootsOfUnity() { + AlignedVector64> roots_of_unity(m_degree, 0, + m_aligned_alloc); + AlignedVector64> roots_in_bit_reverse(m_degree, 0, + m_aligned_alloc); + AlignedVector64> inv_roots_in_bit_reverse( + m_degree, 0, m_aligned_alloc); + uint64_t roots_degree = static_cast(m_degree) << 1; // degree > 2 + + // PI value used to calculate the roots of unity + static constexpr double PI_ = 3.1415926535897932384626433832795028842; + + // Generate 1/8 of all roots first. + size_t i = 0; + for (; i <= roots_degree / 8; i++) { + roots_of_unity[i] = + std::polar(1.0, 2 * PI_ * static_cast(i) / + static_cast(roots_degree)); + } + // Complete first 4th + for (; i <= roots_degree / 4; i++) { + roots_of_unity[i] = swap_real_imag(roots_of_unity[roots_degree / 4 - i]); + } + // Get second 4th + for (; i < roots_degree / 2; i++) { + roots_of_unity[i] = -std::conj(roots_of_unity[roots_degree / 2 - i]); + } + // Put in bit reverse and get inv roots + for (i = 1; i < m_degree; i++) { + roots_in_bit_reverse[i] = roots_of_unity[ReverseBits(i, m_degree_bits)]; + inv_roots_in_bit_reverse[i] = + std::conj(roots_of_unity[ReverseBits(i - 1, m_degree_bits) + 1]); + } + m_complex_roots_of_unity = roots_in_bit_reverse; + m_inv_complex_roots_of_unity = inv_roots_in_bit_reverse; +} + +void FFTLike::ComputeForwardFFTLike(std::complex* result, + const std::complex* operand, + const double* in_scale) { + HEXL_CHECK(result != nullptr, "result == nullptr"); + HEXL_CHECK(operand != nullptr, "operand == nullptr"); + + const double* out_scale = nullptr; + if (scalar != nullptr) { + out_scale = &inv_scale; + } else if (in_scale != nullptr) { + out_scale = in_scale; + } + +#ifdef HEXL_HAS_AVX512DQ + HEXL_VLOG(3, "Calling 64-bit AVX512-DQ FwdFFTLike"); + + Forward_FFTLike_ToBitReverseAVX512( + &(reinterpret_cast(result[0]))[0], + &(reinterpret_cast(operand[0]))[0], + &(reinterpret_cast(m_complex_roots_of_unity[0]))[0], + m_degree, out_scale); + return; +#else + HEXL_VLOG(3, "Calling Native FwdFFTLike"); + Forward_FFTLike_ToBitReverseRadix2( + result, operand, m_complex_roots_of_unity.data(), m_degree, out_scale); + return; +#endif +} + +void FFTLike::ComputeInverseFFTLike(std::complex* result, + const std::complex* operand, + const double* in_scale) { + HEXL_CHECK(result != nullptr, "result==nullptr"); + HEXL_CHECK(operand != nullptr, "operand==nullptr"); + + const double* out_scale = nullptr; + if (scalar != nullptr) { + out_scale = &scale; + } else if (in_scale != nullptr) { + out_scale = in_scale; + } + +#ifdef HEXL_HAS_AVX512DQ + HEXL_VLOG(3, "Calling 64-bit AVX512-DQ InvFFTLike"); + + Inverse_FFTLike_FromBitReverseAVX512( + &(reinterpret_cast(result[0]))[0], + &(reinterpret_cast(operand[0]))[0], + &(reinterpret_cast( + m_inv_complex_roots_of_unity[0]))[0], + m_degree, out_scale); + + return; +#else + HEXL_VLOG(3, "Calling Native InvFFTLike"); + Inverse_FFTLike_FromBitReverseRadix2(result, operand, + m_inv_complex_roots_of_unity.data(), + m_degree, out_scale); + return; +#endif +} + +void FFTLike::BuildFloatingPoints(std::complex* res, + const uint64_t* plain, + const uint64_t* threshold, + const uint64_t* decryption_modulus, + const double in_inv_scale, size_t mod_size, + size_t coeff_count) { + HEXL_UNUSED(res); + HEXL_UNUSED(plain); + HEXL_UNUSED(threshold); + HEXL_UNUSED(decryption_modulus); + HEXL_UNUSED(in_inv_scale); + HEXL_UNUSED(mod_size); + HEXL_UNUSED(coeff_count); + +#ifdef HEXL_HAS_AVX512DQ + HEXL_VLOG(3, "Calling 64-bit AVX512-DQ BuildFloatingPoints"); + + BuildFloatingPointsAVX512(&(reinterpret_cast(res[0]))[0], plain, + threshold, decryption_modulus, in_inv_scale, + mod_size, coeff_count); + return; +#endif +} + +} // namespace hexl +} // namespace intel diff --git a/hexl_ser/experimental/fft-like/fwd-fft-like-avx512.cpp b/hexl_ser/experimental/fft-like/fwd-fft-like-avx512.cpp new file mode 100644 index 00000000..50db1147 --- /dev/null +++ b/hexl_ser/experimental/fft-like/fwd-fft-like-avx512.cpp @@ -0,0 +1,482 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "hexl/experimental/fft-like/fwd-fft-like-avx512.hpp" + +#include "hexl/experimental/fft-like/fft-like-avx512-util.hpp" +#include "hexl/logging/logging.hpp" + +namespace intel { +namespace hexl { + +#ifdef HEXL_HAS_AVX512DQ + +/// @brief Final butterfly step for the Forward FFT like. +/// @param[in,out] X_real Double precision (DP) values in SIMD form representing +/// the real part of 8 complex numbers. +/// @param[in,out] X_imag DP values in SIMD form representing the +/// imaginary part of the forementioned complex numbers. +/// @param[in,out] Y_real DP values in SIMD form representing the +/// real part of 8 complex numbers. +/// @param[in,out] Y_imag DP values in SIMD form representing the +/// imaginary part of the forementioned complex numbers. +/// @param[in] W_real DP values in SIMD form representing the real part of the +/// Complex Roots of unity. +/// @param[in] W_imag DP values in SIMD form representing the imaginary part of +/// the Complex Roots of unity. +void ComplexFwdButterfly(__m512d* X_real, __m512d* X_imag, __m512d* Y_real, + __m512d* Y_imag, __m512d W_real, __m512d W_imag) { + // U = X + __m512d U_real = *X_real; + __m512d U_imag = *X_imag; + + // V = Y*W. Complex multiplication: + // (y_r + iy_b)*(w_a + iw_b) = (y_a*w_a - y_b*w_b) + i(y_a*w_b + y_b*w_a) + __m512d V_real = _mm512_mul_pd(*Y_real, W_real); + __m512d tmp = _mm512_mul_pd(*Y_imag, W_imag); + V_real = _mm512_sub_pd(V_real, tmp); + + __m512d V_imag = _mm512_mul_pd(*Y_real, W_imag); + tmp = _mm512_mul_pd(*Y_imag, W_real); + V_imag = _mm512_add_pd(V_imag, tmp); + + // X = U + V + *X_real = _mm512_add_pd(U_real, V_real); + *X_imag = _mm512_add_pd(U_imag, V_imag); + // Y = U - V + *Y_real = _mm512_sub_pd(U_real, V_real); + *Y_imag = _mm512_sub_pd(U_imag, V_imag); +} + +// Takes operand as 8 complex interleaved: This is 8 real parts followed by +// its 8 imaginary parts. +// Returns operand as 1 complex interleaved: One real part followed by its +// imaginary part. +void ComplexFwdT1(double* operand_8C_intrlvd, const double* W_1C_intrlvd, + uint64_t m, const double* scalar = nullptr) { + size_t offset = 0; + + __m512d v_scalar; + if (scalar != nullptr) { + v_scalar = _mm512_set1_pd(*scalar); + } + + // 8 | m guaranteed by n >= 16 + HEXL_LOOP_UNROLL_8 + for (size_t i = 0; i < (m >> 1); i += 8) { + double* X_real = operand_8C_intrlvd + offset; + double* X_imag = X_real + 8; + __m512d* v_out_pt = reinterpret_cast<__m512d*>(X_real); + + __m512d v_X_real; + __m512d v_X_imag; + __m512d v_Y_real; + __m512d v_Y_imag; + + ComplexLoadFwdInterleavedT1(X_real, &v_X_real, &v_Y_real); + ComplexLoadFwdInterleavedT1(X_imag, &v_X_imag, &v_Y_imag); + + // Weights + __m512d v_W_real = _mm512_set_pd( + W_1C_intrlvd[14], W_1C_intrlvd[12], W_1C_intrlvd[10], W_1C_intrlvd[8], + W_1C_intrlvd[6], W_1C_intrlvd[4], W_1C_intrlvd[2], W_1C_intrlvd[0]); + __m512d v_W_imag = _mm512_set_pd( + W_1C_intrlvd[15], W_1C_intrlvd[13], W_1C_intrlvd[11], W_1C_intrlvd[9], + W_1C_intrlvd[7], W_1C_intrlvd[5], W_1C_intrlvd[3], W_1C_intrlvd[1]); + W_1C_intrlvd += 16; + + if (scalar != nullptr) { + v_W_real = _mm512_mul_pd(v_W_real, v_scalar); + v_W_imag = _mm512_mul_pd(v_W_imag, v_scalar); + v_X_real = _mm512_mul_pd(v_X_real, v_scalar); + v_X_imag = _mm512_mul_pd(v_X_imag, v_scalar); + } + + ComplexFwdButterfly(&v_X_real, &v_X_imag, &v_Y_real, &v_Y_imag, v_W_real, + v_W_imag); + + ComplexWriteFwdInterleavedT1(v_X_real, v_Y_real, v_X_imag, v_Y_imag, + v_out_pt); + + offset += 32; + } +} + +void ComplexFwdT2(double* operand_8C_intrlvd, const double* W_1C_intrlvd, + uint64_t m) { + size_t offset = 0; + + // 4 | m guaranteed by n >= 16 + HEXL_LOOP_UNROLL_4 + for (size_t i = 0; i < (m >> 1); i += 4) { + double* X_real = operand_8C_intrlvd + offset; + double* X_imag = X_real + 8; + + __m512d* v_X_pt_real = reinterpret_cast<__m512d*>(X_real); + __m512d* v_X_pt_imag = reinterpret_cast<__m512d*>(X_imag); + + __m512d v_X_real; + __m512d v_X_imag; + __m512d v_Y_real; + __m512d v_Y_imag; + + ComplexLoadFwdInterleavedT2(X_real, &v_X_real, &v_Y_real); + ComplexLoadFwdInterleavedT2(X_imag, &v_X_imag, &v_Y_imag); + + // Weights + __m512d v_W_real = _mm512_set_pd( + W_1C_intrlvd[6], W_1C_intrlvd[6], W_1C_intrlvd[4], W_1C_intrlvd[4], + W_1C_intrlvd[2], W_1C_intrlvd[2], W_1C_intrlvd[0], W_1C_intrlvd[0]); + __m512d v_W_imag = _mm512_set_pd( + W_1C_intrlvd[7], W_1C_intrlvd[7], W_1C_intrlvd[5], W_1C_intrlvd[5], + W_1C_intrlvd[3], W_1C_intrlvd[3], W_1C_intrlvd[1], W_1C_intrlvd[1]); + W_1C_intrlvd += 8; + + ComplexFwdButterfly(&v_X_real, &v_X_imag, &v_Y_real, &v_Y_imag, v_W_real, + v_W_imag); + + _mm512_storeu_pd(v_X_pt_real, v_X_real); + _mm512_storeu_pd(v_X_pt_imag, v_X_imag); + v_X_pt_real += 2; + v_X_pt_imag += 2; + _mm512_storeu_pd(v_X_pt_real, v_Y_real); + _mm512_storeu_pd(v_X_pt_imag, v_Y_imag); + + offset += 32; + } +} + +void ComplexFwdT4(double* operand_8C_intrlvd, const double* W_1C_intrlvd, + uint64_t m) { + size_t offset = 0; + + // 2 | m guaranteed by n >= 16 + HEXL_LOOP_UNROLL_4 + for (size_t i = 0; i < (m >> 1); i += 2) { + double* X_real = operand_8C_intrlvd + offset; + double* X_imag = X_real + 8; + + __m512d* v_X_pt_real = reinterpret_cast<__m512d*>(X_real); + __m512d* v_X_pt_imag = reinterpret_cast<__m512d*>(X_imag); + + __m512d v_X_real; + __m512d v_X_imag; + __m512d v_Y_real; + __m512d v_Y_imag; + + ComplexLoadFwdInterleavedT4(X_real, &v_X_real, &v_Y_real); + ComplexLoadFwdInterleavedT4(X_imag, &v_X_imag, &v_Y_imag); + + // Weights + // x = (11, 10, 9, 8, 3, 2, 1, 0) + // y = (15, 14, 13, 12, 7, 6, 5, 4) + __m512d v_W_real = _mm512_set_pd( + W_1C_intrlvd[2], W_1C_intrlvd[2], W_1C_intrlvd[2], W_1C_intrlvd[2], + W_1C_intrlvd[0], W_1C_intrlvd[0], W_1C_intrlvd[0], W_1C_intrlvd[0]); + __m512d v_W_imag = _mm512_set_pd( + W_1C_intrlvd[3], W_1C_intrlvd[3], W_1C_intrlvd[3], W_1C_intrlvd[3], + W_1C_intrlvd[1], W_1C_intrlvd[1], W_1C_intrlvd[1], W_1C_intrlvd[1]); + + W_1C_intrlvd += 4; + + ComplexFwdButterfly(&v_X_real, &v_X_imag, &v_Y_real, &v_Y_imag, v_W_real, + v_W_imag); + + _mm512_storeu_pd(v_X_pt_real, v_X_real); + _mm512_storeu_pd(v_X_pt_imag, v_X_imag); + v_X_pt_real += 2; + v_X_pt_imag += 2; + _mm512_storeu_pd(v_X_pt_real, v_Y_real); + _mm512_storeu_pd(v_X_pt_imag, v_Y_imag); + + offset += 32; + } +} + +void ComplexFwdT8(double* operand_8C_intrlvd, const double* W_1C_intrlvd, + uint64_t gap, uint64_t m) { + size_t offset = 0; + + HEXL_LOOP_UNROLL_4 + for (size_t i = 0; i < (m >> 1); i++) { + // Referencing operand + double* X_real = operand_8C_intrlvd + offset; + double* X_imag = X_real + 8; + + double* Y_real = X_real + gap; + double* Y_imag = X_imag + gap; + + __m512d* v_X_pt_real = reinterpret_cast<__m512d*>(X_real); + __m512d* v_X_pt_imag = reinterpret_cast<__m512d*>(X_imag); + + __m512d* v_Y_pt_real = reinterpret_cast<__m512d*>(Y_real); + __m512d* v_Y_pt_imag = reinterpret_cast<__m512d*>(Y_imag); + + // Weights + __m512d v_W_real = _mm512_set1_pd(*W_1C_intrlvd++); + __m512d v_W_imag = _mm512_set1_pd(*W_1C_intrlvd++); + + // assume 8 | t + for (size_t j = 0; j < gap; j += 16) { + __m512d v_X_real = _mm512_loadu_pd(v_X_pt_real); + __m512d v_X_imag = _mm512_loadu_pd(v_X_pt_imag); + + __m512d v_Y_real = _mm512_loadu_pd(v_Y_pt_real); + __m512d v_Y_imag = _mm512_loadu_pd(v_Y_pt_imag); + + ComplexFwdButterfly(&v_X_real, &v_X_imag, &v_Y_real, &v_Y_imag, v_W_real, + v_W_imag); + + _mm512_storeu_pd(v_X_pt_real, v_X_real); + _mm512_storeu_pd(v_X_pt_imag, v_X_imag); + + _mm512_storeu_pd(v_Y_pt_real, v_Y_real); + _mm512_storeu_pd(v_Y_pt_imag, v_Y_imag); + + // Increase pointers + v_X_pt_real += 2; + v_X_pt_imag += 2; + v_Y_pt_real += 2; + v_Y_pt_imag += 2; + } + offset += (gap << 1); + } +} + +void ComplexStartFwdT8(double* result_8C_intrlvd, + const double* operand_1C_intrlvd, + const double* W_1C_intrlvd, uint64_t gap, uint64_t m) { + size_t offset = 0; + + HEXL_LOOP_UNROLL_4 + for (size_t i = 0; i < (m >> 1); i++) { + // Referencing operand + const double* X_op = operand_1C_intrlvd + offset; + const double* Y_op = X_op + gap; + const __m512d* v_X_op_pt = reinterpret_cast(X_op); + const __m512d* v_Y_op_pt = reinterpret_cast(Y_op); + + // Referencing result + double* X_r_real = result_8C_intrlvd + offset; + double* X_r_imag = X_r_real + 8; + double* Y_r_real = X_r_real + gap; + double* Y_r_imag = X_r_imag + gap; + __m512d* v_X_r_pt_real = reinterpret_cast<__m512d*>(X_r_real); + __m512d* v_X_r_pt_imag = reinterpret_cast<__m512d*>(X_r_imag); + __m512d* v_Y_r_pt_real = reinterpret_cast<__m512d*>(Y_r_real); + __m512d* v_Y_r_pt_imag = reinterpret_cast<__m512d*>(Y_r_imag); + + // Weights + __m512d v_W_real = _mm512_set1_pd(*W_1C_intrlvd++); + __m512d v_W_imag = _mm512_set1_pd(*W_1C_intrlvd++); + + // assume 8 | t + for (size_t j = 0; j < gap; j += 16) { + __m512d v_X_real; + __m512d v_X_imag; + __m512d v_Y_real; + __m512d v_Y_imag; + + ComplexLoadFwdInterleavedT8(v_X_op_pt, v_Y_op_pt, &v_X_real, &v_X_imag, + &v_Y_real, &v_Y_imag); + + ComplexFwdButterfly(&v_X_real, &v_X_imag, &v_Y_real, &v_Y_imag, v_W_real, + v_W_imag); + + _mm512_storeu_pd(v_X_r_pt_real, v_X_real); + _mm512_storeu_pd(v_X_r_pt_imag, v_X_imag); + + _mm512_storeu_pd(v_Y_r_pt_real, v_Y_real); + _mm512_storeu_pd(v_Y_r_pt_imag, v_Y_imag); + + // Increase operand & result pointers + v_X_op_pt += 2; + v_Y_op_pt += 2; + v_X_r_pt_real += 2; + v_X_r_pt_imag += 2; + v_Y_r_pt_real += 2; + v_Y_r_pt_imag += 2; + } + offset += (gap << 1); + } +} + +void Forward_FFTLike_ToBitReverseAVX512( + double* result_cmplx_intrlvd, const double* operand_cmplx_intrlvd, + const double* root_of_unity_powers_cmplx_intrlvd, const uint64_t n, + const double* scale, uint64_t recursion_depth, uint64_t recursion_half) { + HEXL_CHECK(IsPowerOfTwo(n), "n " << n << " is not a power of 2"); + HEXL_CHECK(n >= 16, + "Don't support small transforms. Need n >= 16, got n = " << n); + HEXL_VLOG(5, "root_of_unity_powers_cmplx_intrlvd " + << std::vector>( + root_of_unity_powers_cmplx_intrlvd, + root_of_unity_powers_cmplx_intrlvd + 2 * n)); + HEXL_VLOG(5, "operand_cmplx_intrlvd " << std::vector>( + operand_cmplx_intrlvd, operand_cmplx_intrlvd + 2 * n)); + + static const size_t base_fft_like_size = 1024; + + if (n <= base_fft_like_size) { // Perform breadth-first FFT like + size_t gap = n; // (2*n >> 1) Interleaved complex numbers + size_t m = 2; // require twice the size + size_t W_idx = (m << recursion_depth) + (recursion_half * m); + + // First pass in case of out of place + if (recursion_depth == 0 && gap >= 16) { + const double* W_cmplx_intrlvd = + &root_of_unity_powers_cmplx_intrlvd[W_idx]; + ComplexStartFwdT8(result_cmplx_intrlvd, operand_cmplx_intrlvd, + W_cmplx_intrlvd, gap, m); + m <<= 1; + W_idx <<= 1; + gap >>= 1; + } + + for (; gap >= 16; gap >>= 1) { + const double* W_cmplx_intrlvd = + &root_of_unity_powers_cmplx_intrlvd[W_idx]; + ComplexFwdT8(result_cmplx_intrlvd, W_cmplx_intrlvd, gap, m); + m <<= 1; + W_idx <<= 1; + } + + { + // T4 + const double* W_cmplx_intrlvd = + &root_of_unity_powers_cmplx_intrlvd[W_idx]; + ComplexFwdT4(result_cmplx_intrlvd, W_cmplx_intrlvd, m); + m <<= 1; + W_idx <<= 1; + + // T2 + W_cmplx_intrlvd = &root_of_unity_powers_cmplx_intrlvd[W_idx]; + ComplexFwdT2(result_cmplx_intrlvd, W_cmplx_intrlvd, m); + m <<= 1; + W_idx <<= 1; + + // T1 + W_cmplx_intrlvd = &root_of_unity_powers_cmplx_intrlvd[W_idx]; + ComplexFwdT1(result_cmplx_intrlvd, W_cmplx_intrlvd, m, scale); + m <<= 1; + W_idx <<= 1; + } + } else { + // Perform depth-first FFT like via recursive call + size_t gap = n; + size_t W_idx = (2ULL << recursion_depth) + (recursion_half << 1); + const double* W_cmplx_intrlvd = &root_of_unity_powers_cmplx_intrlvd[W_idx]; + + if (recursion_depth == 0) { + ComplexStartFwdT8(result_cmplx_intrlvd, operand_cmplx_intrlvd, + W_cmplx_intrlvd, gap, 2); + } else { + ComplexFwdT8(result_cmplx_intrlvd, W_cmplx_intrlvd, gap, 2); + } + + Forward_FFTLike_ToBitReverseAVX512( + result_cmplx_intrlvd, result_cmplx_intrlvd, + root_of_unity_powers_cmplx_intrlvd, n / 2, scale, recursion_depth + 1, + recursion_half * 2); + + Forward_FFTLike_ToBitReverseAVX512( + &result_cmplx_intrlvd[n], &result_cmplx_intrlvd[n], + root_of_unity_powers_cmplx_intrlvd, n / 2, scale, recursion_depth + 1, + recursion_half * 2 + 1); + } + if (recursion_depth == 0) { + HEXL_VLOG(5, + "AVX512 returning FWD FFT like result " + << std::vector>( + result_cmplx_intrlvd, result_cmplx_intrlvd + 2 * n)); + } +} + +void BuildFloatingPointsAVX512(double* res_cmplx_intrlvd, const uint64_t* plain, + const uint64_t* threshold, + const uint64_t* decryption_modulus, + const double inv_scale, const size_t mod_size, + const size_t coeff_count) { + const __m512i v_perm = _mm512_set_epi64(7, 3, 6, 2, 5, 1, 4, 0); + __m512d v_res_imag = _mm512_setzero_pd(); + __m512d* v_res_pt = reinterpret_cast<__m512d*>(res_cmplx_intrlvd); + double two_pow_64 = std::pow(2.0, 64); + + for (size_t i = 0; i < coeff_count; i += 8) { + __mmask8 zeros = 0xff; + __mmask8 cond_lt_thr = 0; + + for (int32_t j = static_cast(mod_size) - 1; zeros && (j >= 0); + j--) { + const uint64_t* base = plain + j; + __m512i v_thrld = _mm512_set1_epi64(*(threshold + j)); + __m512i v_plain = _mm512_set_epi64( + *(base + (i + 7) * mod_size), *(base + (i + 6) * mod_size), + *(base + (i + 5) * mod_size), *(base + (i + 4) * mod_size), + *(base + (i + 3) * mod_size), *(base + (i + 2) * mod_size), + *(base + (i + 1) * mod_size), *(base + (i + 0) * mod_size)); + + cond_lt_thr = static_cast(cond_lt_thr) | + static_cast( + _mm512_mask_cmplt_epu64_mask(zeros, v_plain, v_thrld)); + zeros = _mm512_mask_cmpeq_epu64_mask(zeros, v_plain, v_thrld); + } + + __mmask8 cond_ge_thr = static_cast(~cond_lt_thr); + double scaled_two_pow_64 = inv_scale; + __m512d v_zeros = _mm512_setzero_pd(); + __m512d v_res_real = _mm512_setzero_pd(); + HEXL_LOOP_UNROLL_8 + for (size_t j = 0; j < mod_size; j++, scaled_two_pow_64 *= two_pow_64) { + const uint64_t* base = plain + j; + __m512d v_scaled_p64 = _mm512_set1_pd(scaled_two_pow_64); + __m512i v_dec_moduli = _mm512_set1_epi64(*(decryption_modulus + j)); + __m512i v_curr_coeff = _mm512_set_epi64( + *(base + (i + 7) * mod_size), *(base + (i + 6) * mod_size), + *(base + (i + 5) * mod_size), *(base + (i + 4) * mod_size), + *(base + (i + 3) * mod_size), *(base + (i + 2) * mod_size), + *(base + (i + 1) * mod_size), *(base + (i + 0) * mod_size)); + + __mmask8 cond_gt_dec_mod = + _mm512_mask_cmpgt_epu64_mask(cond_ge_thr, v_curr_coeff, v_dec_moduli); + __mmask8 cond_le_dec_mod = cond_gt_dec_mod ^ cond_ge_thr; + + __m512i v_diff = _mm512_mask_sub_epi64(v_curr_coeff, cond_gt_dec_mod, + v_curr_coeff, v_dec_moduli); + v_diff = _mm512_mask_sub_epi64(v_diff, cond_le_dec_mod, v_dec_moduli, + v_curr_coeff); + + // __m512d v_scaled_diff = _mm512_castsi512_pd(v_diff); does not work + uint64_t tmp_v_ui[8]; + __m512i* tmp_v_ui_pt = reinterpret_cast<__m512i*>(tmp_v_ui); + double tmp_v_pd[8]; + _mm512_storeu_si512(tmp_v_ui_pt, v_diff); + HEXL_LOOP_UNROLL_8 + for (size_t t = 0; t < 8; t++) { + tmp_v_pd[t] = static_cast(tmp_v_ui[t]); + } + + __m512d v_casted_diff = _mm512_loadu_pd(tmp_v_pd); + // This mask avoids multiplying by inf when diff is already zero + __mmask8 cond_no_zero = _mm512_cmpneq_pd_mask(v_casted_diff, v_zeros); + __m512d v_scaled_diff = _mm512_mask_mul_pd(v_casted_diff, cond_no_zero, + v_casted_diff, v_scaled_p64); + v_res_real = _mm512_mask_add_pd(v_res_real, cond_gt_dec_mod | cond_lt_thr, + v_res_real, v_scaled_diff); + v_res_real = _mm512_mask_sub_pd(v_res_real, cond_le_dec_mod, v_res_real, + v_scaled_diff); + } + + // Make res 1 complex interleaved + v_res_real = _mm512_permutexvar_pd(v_perm, v_res_real); + __m512d v_res1 = _mm512_shuffle_pd(v_res_real, v_res_imag, 0x00); + __m512d v_res2 = _mm512_shuffle_pd(v_res_real, v_res_imag, 0xff); + _mm512_storeu_pd(v_res_pt++, v_res1); + _mm512_storeu_pd(v_res_pt++, v_res2); + } +} + +#endif // HEXL_HAS_AVX512DQ + +} // namespace hexl +} // namespace intel diff --git a/hexl_ser/experimental/fft-like/inv-fft-like-avx512.cpp b/hexl_ser/experimental/fft-like/inv-fft-like-avx512.cpp new file mode 100644 index 00000000..feb353a2 --- /dev/null +++ b/hexl_ser/experimental/fft-like/inv-fft-like-avx512.cpp @@ -0,0 +1,411 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +#include "hexl/experimental/fft-like/inv-fft-like-avx512.hpp" + +#include "hexl/experimental/fft-like/fft-like-avx512-util.hpp" +#include "hexl/logging/logging.hpp" + +namespace intel { +namespace hexl { + +#ifdef HEXL_HAS_AVX512DQ + +/// @brief Final butterfly step for the Inverse FFT like. +/// @param[in,out] X_real Double precision (DP) values in SIMD form representing +/// the real part of 8 complex numbers. +/// @param[in,out] X_imag DP values in SIMD form representing the +/// imaginary part of the forementioned complex numbers. +/// @param[in,out] Y_real DP values in SIMD form representing the +/// real part of 8 complex numbers. +/// @param[in,out] Y_imag DP values in SIMD form representing the +/// imaginary part of the forementioned complex numbers. +/// @param[in] W_real DP values in SIMD form representing the real part of the +/// Inverse Complex Roots of unity. +/// @param[in] W_imag DP values in SIMD form representing the imaginary part of +/// the Inverse Complex Roots of unity. +void ComplexInvButterfly(__m512d* X_real, __m512d* X_imag, __m512d* Y_real, + __m512d* Y_imag, __m512d W_real, __m512d W_imag, + const double* scalar = nullptr) { + // U = X, + __m512d U_real = *X_real; + __m512d U_imag = *X_imag; + + // X = U + Y + *X_real = _mm512_add_pd(U_real, *Y_real); + *X_imag = _mm512_add_pd(U_imag, *Y_imag); + + if (scalar != nullptr) { + __m512d v_scalar = _mm512_set1_pd(*scalar); + *X_real = _mm512_mul_pd(*X_real, v_scalar); + *X_imag = _mm512_mul_pd(*X_imag, v_scalar); + } + + // V = U - Y + __m512d V_real = _mm512_sub_pd(U_real, *Y_real); + __m512d V_imag = _mm512_sub_pd(U_imag, *Y_imag); + + // Y = V*W. Complex multiplication: + // (v_r + iv_b)*(w_a + iw_b) = (v_a*w_a - v_b*w_b) + i(v_a*w_b + v_b*w_a) + *Y_real = _mm512_mul_pd(V_real, W_real); + __m512d tmp = _mm512_mul_pd(V_imag, W_imag); + *Y_real = _mm512_sub_pd(*Y_real, tmp); + + *Y_imag = _mm512_mul_pd(V_real, W_imag); + tmp = _mm512_mul_pd(V_imag, W_real); + *Y_imag = _mm512_add_pd(*Y_imag, tmp); +} + +void ComplexInvT1(double* result_8C_intrlvd, const double* operand_1C_intrlvd, + const double* W_1C_intrlvd, uint64_t m) { + size_t offset = 0; + + // 8 | m guaranteed by n >= 16 + HEXL_LOOP_UNROLL_8 + for (size_t i = 0; i < (m >> 1); i += 8) { + // Referencing operand + const double* X_op_real = operand_1C_intrlvd + offset; + + // Referencing result + double* X_r_real = result_8C_intrlvd + offset; + double* X_r_imag = X_r_real + 8; + __m512d* v_X_r_pt_real = reinterpret_cast<__m512d*>(X_r_real); + __m512d* v_X_r_pt_imag = reinterpret_cast<__m512d*>(X_r_imag); + + __m512d v_X_real; + __m512d v_X_imag; + __m512d v_Y_real; + __m512d v_Y_imag; + + ComplexLoadInvInterleavedT1(X_op_real, &v_X_real, &v_X_imag, &v_Y_real, + &v_Y_imag); + + // Weights + // x = (14r, 10r, 6r, 2r, 12r, 8r, 4r, 0r); + // y = (15r, 11r, 7r, 3r, 13r, 9r, 5r, 1r); + __m512d v_W_real = _mm512_set_pd( + W_1C_intrlvd[14], W_1C_intrlvd[10], W_1C_intrlvd[6], W_1C_intrlvd[2], + W_1C_intrlvd[12], W_1C_intrlvd[8], W_1C_intrlvd[4], W_1C_intrlvd[0]); + __m512d v_W_imag = _mm512_set_pd( + W_1C_intrlvd[15], W_1C_intrlvd[11], W_1C_intrlvd[7], W_1C_intrlvd[3], + W_1C_intrlvd[13], W_1C_intrlvd[9], W_1C_intrlvd[5], W_1C_intrlvd[1]); + W_1C_intrlvd += 16; + + ComplexInvButterfly(&v_X_real, &v_X_imag, &v_Y_real, &v_Y_imag, v_W_real, + v_W_imag); + + _mm512_storeu_pd(v_X_r_pt_real, v_X_real); + _mm512_storeu_pd(v_X_r_pt_imag, v_X_imag); + v_X_r_pt_real += 2; + v_X_r_pt_imag += 2; + _mm512_storeu_pd(v_X_r_pt_real, v_Y_real); + _mm512_storeu_pd(v_X_r_pt_imag, v_Y_imag); + + offset += 32; + } +} + +void ComplexInvT2(double* operand_8C_intrlvd, const double* W_1C_intrlvd, + uint64_t m) { + size_t offset = 0; + + // 4 | m guaranteed by n >= 16 + HEXL_LOOP_UNROLL_4 + for (size_t i = 0; i < (m >> 1); i += 4) { + double* X_real = operand_8C_intrlvd + offset; + double* X_imag = X_real + 8; + + __m512d* v_X_pt_real = reinterpret_cast<__m512d*>(X_real); + __m512d* v_X_pt_imag = reinterpret_cast<__m512d*>(X_imag); + + __m512d v_X_real; + __m512d v_X_imag; + + __m512d v_Y_real; + __m512d v_Y_imag; + + ComplexLoadInvInterleavedT2(X_real, &v_X_real, &v_Y_real); + ComplexLoadInvInterleavedT2(X_imag, &v_X_imag, &v_Y_imag); + + // Weights + // x = (13, 9, 5, 1, 12, 8, 4, 0) + // y = (15, 11, 7, 3, 14, 10, 6, 2) + __m512d v_W_real = _mm512_set_pd( + W_1C_intrlvd[6], W_1C_intrlvd[4], W_1C_intrlvd[2], W_1C_intrlvd[0], + W_1C_intrlvd[6], W_1C_intrlvd[4], W_1C_intrlvd[2], W_1C_intrlvd[0]); + __m512d v_W_imag = _mm512_set_pd( + W_1C_intrlvd[7], W_1C_intrlvd[5], W_1C_intrlvd[3], W_1C_intrlvd[1], + W_1C_intrlvd[7], W_1C_intrlvd[5], W_1C_intrlvd[3], W_1C_intrlvd[1]); + W_1C_intrlvd += 8; + + ComplexInvButterfly(&v_X_real, &v_X_imag, &v_Y_real, &v_Y_imag, v_W_real, + v_W_imag); + + _mm512_storeu_pd(v_X_pt_real, v_X_real); + _mm512_storeu_pd(v_X_pt_imag, v_X_imag); + v_X_pt_real += 2; + v_X_pt_imag += 2; + _mm512_storeu_pd(v_X_pt_real, v_Y_real); + _mm512_storeu_pd(v_X_pt_imag, v_Y_imag); + + offset += 32; + } +} + +void ComplexInvT4(double* operand_8C_intrlvd, const double* W_1C_intrlvd, + uint64_t m) { + size_t offset = 0; + + // 2 | m guaranteed by n >= 16 + HEXL_LOOP_UNROLL_4 + for (size_t i = 0; i < (m >> 1); i += 2) { + double* X_real = operand_8C_intrlvd + offset; + double* X_imag = X_real + 8; + + __m512d* v_X_pt_real = reinterpret_cast<__m512d*>(X_real); + __m512d* v_X_pt_imag = reinterpret_cast<__m512d*>(X_imag); + + __m512d v_X_real; + __m512d v_X_imag; + __m512d v_Y_real; + __m512d v_Y_imag; + + ComplexLoadInvInterleavedT4(X_real, &v_X_real, &v_Y_real); + ComplexLoadInvInterleavedT4(X_imag, &v_X_imag, &v_Y_imag); + + // Weights + // x = (11, 9, 3, 1, 10, 8, 2, 0) + // y = (15, 13, 7, 5, 14, 12, 6, 4) + __m512d v_W_real = _mm512_set_pd( + W_1C_intrlvd[2], W_1C_intrlvd[2], W_1C_intrlvd[0], W_1C_intrlvd[0], + W_1C_intrlvd[2], W_1C_intrlvd[2], W_1C_intrlvd[0], W_1C_intrlvd[0]); + __m512d v_W_imag = _mm512_set_pd( + W_1C_intrlvd[3], W_1C_intrlvd[3], W_1C_intrlvd[1], W_1C_intrlvd[1], + W_1C_intrlvd[3], W_1C_intrlvd[3], W_1C_intrlvd[1], W_1C_intrlvd[1]); + + W_1C_intrlvd += 4; + + ComplexInvButterfly(&v_X_real, &v_X_imag, &v_Y_real, &v_Y_imag, v_W_real, + v_W_imag); + + ComplexWriteInvInterleavedT4(v_X_real, v_Y_real, v_X_pt_real); + ComplexWriteInvInterleavedT4(v_X_imag, v_Y_imag, v_X_pt_imag); + + offset += 32; + } +} + +void ComplexInvT8(double* operand_8C_intrlvd, const double* W_1C_intrlvd, + uint64_t gap, uint64_t m) { + size_t offset = 0; + + HEXL_LOOP_UNROLL_4 + for (size_t i = 0; i < (m >> 1); i++) { + // Referencing operand + double* X_real = operand_8C_intrlvd + offset; + double* X_imag = X_real + 8; + + double* Y_real = X_real + gap; + double* Y_imag = X_imag + gap; + + __m512d* v_X_pt_real = reinterpret_cast<__m512d*>(X_real); + __m512d* v_X_pt_imag = reinterpret_cast<__m512d*>(X_imag); + + __m512d* v_Y_pt_real = reinterpret_cast<__m512d*>(Y_real); + __m512d* v_Y_pt_imag = reinterpret_cast<__m512d*>(Y_imag); + + // Weights + __m512d v_W_real = _mm512_set1_pd(*W_1C_intrlvd++); + __m512d v_W_imag = _mm512_set1_pd(*W_1C_intrlvd++); + + // assume 8 | t + for (size_t j = 0; j < gap; j += 16) { + __m512d v_X_real = _mm512_loadu_pd(v_X_pt_real); + __m512d v_X_imag = _mm512_loadu_pd(v_X_pt_imag); + + __m512d v_Y_real = _mm512_loadu_pd(v_Y_pt_real); + __m512d v_Y_imag = _mm512_loadu_pd(v_Y_pt_imag); + + ComplexInvButterfly(&v_X_real, &v_X_imag, &v_Y_real, &v_Y_imag, v_W_real, + v_W_imag); + + _mm512_storeu_pd(v_X_pt_real, v_X_real); + _mm512_storeu_pd(v_X_pt_imag, v_X_imag); + + _mm512_storeu_pd(v_Y_pt_real, v_Y_real); + _mm512_storeu_pd(v_Y_pt_imag, v_Y_imag); + + // Increase operand & result pointers + v_X_pt_real += 2; + v_X_pt_imag += 2; + v_Y_pt_real += 2; + v_Y_pt_imag += 2; + } + offset += (gap << 1); + } +} + +// Takes operand as 8 complex interleaved: This is 8 real parts followed by +// its 8 imaginary parts. +// Returns operand as 1 complex interleaved: One real part followed by its +// imaginary part. +void ComplexFinalInvT8(double* operand_8C_intrlvd, const double* W_1C_intrlvd, + uint64_t gap, uint64_t m, + const double* scalar = nullptr) { + size_t offset = 0; + + __m512d v_scalar; + if (scalar != nullptr) { + v_scalar = _mm512_set1_pd(*scalar); + } + + HEXL_LOOP_UNROLL_4 + for (size_t i = 0; i < (m >> 1); i++, offset += (gap << 1)) { + // Referencing operand + double* X_real = operand_8C_intrlvd + offset; + double* X_imag = X_real + 8; + + double* Y_real = X_real + gap; + double* Y_imag = X_imag + gap; + + __m512d* v_X_pt_real = reinterpret_cast<__m512d*>(X_real); + __m512d* v_X_pt_imag = reinterpret_cast<__m512d*>(X_imag); + + __m512d* v_Y_pt_real = reinterpret_cast<__m512d*>(Y_real); + __m512d* v_Y_pt_imag = reinterpret_cast<__m512d*>(Y_imag); + + // Weights + __m512d v_W_real = _mm512_set1_pd(*W_1C_intrlvd++); + __m512d v_W_imag = _mm512_set1_pd(*W_1C_intrlvd++); + + if (scalar != nullptr) { + v_W_real = _mm512_mul_pd(v_W_real, v_scalar); + v_W_imag = _mm512_mul_pd(v_W_imag, v_scalar); + } + + // assume 8 | t + for (size_t j = 0; j < gap; j += 16) { + __m512d v_X_real = _mm512_loadu_pd(v_X_pt_real); + __m512d v_X_imag = _mm512_loadu_pd(v_X_pt_imag); + __m512d v_Y_real = _mm512_loadu_pd(v_Y_pt_real); + __m512d v_Y_imag = _mm512_loadu_pd(v_Y_pt_imag); + + ComplexInvButterfly(&v_X_real, &v_X_imag, &v_Y_real, &v_Y_imag, v_W_real, + v_W_imag, scalar); + + ComplexWriteInvInterleavedT8(&v_X_real, &v_X_imag, &v_Y_real, &v_Y_imag, + v_X_pt_real, v_Y_pt_real); + + // Increase operand & result pointers + v_X_pt_real += 2; + v_X_pt_imag += 2; + v_Y_pt_real += 2; + v_Y_pt_imag += 2; + } + } +} + +void Inverse_FFTLike_FromBitReverseAVX512( + double* result_cmplx_intrlvd, const double* operand_cmplx_intrlvd, + const double* inv_root_of_unity_cmplx_intrlvd, const uint64_t n, + const double* scale, uint64_t recursion_depth, uint64_t recursion_half) { + HEXL_CHECK(IsPowerOfTwo(n), "n " << n << " is not a power of 2"); + HEXL_CHECK(n >= 16, + "Don't support small transforms. Need n >= 16, got n = " << n); + HEXL_VLOG(5, "inv_root_of_unity_cmplx_intrlvd " + << std::vector>( + inv_root_of_unity_cmplx_intrlvd, + inv_root_of_unity_cmplx_intrlvd + 2 * n)); + HEXL_VLOG(5, "operand_cmplx_intrlvd " << std::vector>( + operand_cmplx_intrlvd, operand_cmplx_intrlvd + 2 * n)); + size_t gap = 2; // Interleaved complex values requires twice the size + size_t m = n; // (2*n >> 1); + size_t W_idx = 2 + m * recursion_half; // 2*1 + + static const size_t base_fft_like_size = 1024; + + if (n <= base_fft_like_size) { // Perform breadth-first InvFFT like + // T1 + const double* W_cmplx_intrlvd = &inv_root_of_unity_cmplx_intrlvd[W_idx]; + ComplexInvT1(result_cmplx_intrlvd, operand_cmplx_intrlvd, W_cmplx_intrlvd, + m); + gap <<= 1; + m >>= 1; + uint64_t W_idx_delta = + m * ((1ULL << (recursion_depth + 1)) - recursion_half); + W_idx += W_idx_delta; + + // T2 + W_cmplx_intrlvd = &inv_root_of_unity_cmplx_intrlvd[W_idx]; + ComplexInvT2(result_cmplx_intrlvd, W_cmplx_intrlvd, m); + gap <<= 1; + m >>= 1; + W_idx_delta = m * ((1ULL << (recursion_depth + 1)) - recursion_half); + W_idx += W_idx_delta; + + // T4 + W_cmplx_intrlvd = &inv_root_of_unity_cmplx_intrlvd[W_idx]; + ComplexInvT4(result_cmplx_intrlvd, W_cmplx_intrlvd, m); + gap <<= 1; + m >>= 1; + W_idx_delta = m * ((1ULL << (recursion_depth + 1)) - recursion_half); + W_idx += W_idx_delta; + + while (m > 2) { + W_cmplx_intrlvd = &inv_root_of_unity_cmplx_intrlvd[W_idx]; + ComplexInvT8(result_cmplx_intrlvd, W_cmplx_intrlvd, gap, m); + gap <<= 1; + m >>= 1; + W_idx_delta = m * ((1ULL << (recursion_depth + 1)) - recursion_half); + W_idx += W_idx_delta; + } + + W_cmplx_intrlvd = &inv_root_of_unity_cmplx_intrlvd[W_idx]; + if (recursion_depth == 0) { + ComplexFinalInvT8(result_cmplx_intrlvd, W_cmplx_intrlvd, gap, m, scale); + HEXL_VLOG(5, + "AVX512 returning INV FFT like result " + << std::vector>( + result_cmplx_intrlvd, result_cmplx_intrlvd + 2 * n)); + } else { + ComplexInvT8(result_cmplx_intrlvd, W_cmplx_intrlvd, gap, m); + } + gap <<= 1; + m >>= 1; + W_idx_delta = m * ((1ULL << (recursion_depth + 1)) - recursion_half); + W_idx += W_idx_delta; + } else { + Inverse_FFTLike_FromBitReverseAVX512( + result_cmplx_intrlvd, operand_cmplx_intrlvd, + inv_root_of_unity_cmplx_intrlvd, n / 2, scale, recursion_depth + 1, + 2 * recursion_half); + Inverse_FFTLike_FromBitReverseAVX512( + &result_cmplx_intrlvd[n], &operand_cmplx_intrlvd[n], + inv_root_of_unity_cmplx_intrlvd, n / 2, scale, recursion_depth + 1, + 2 * recursion_half + 1); + uint64_t W_delta = m * ((1ULL << (recursion_depth + 1)) - recursion_half); + for (; m > 2; m >>= 1) { + gap <<= 1; + W_delta >>= 1; + W_idx += W_delta; + } + const double* W_cmplx_intrlvd = &inv_root_of_unity_cmplx_intrlvd[W_idx]; + if (recursion_depth == 0) { + ComplexFinalInvT8(result_cmplx_intrlvd, W_cmplx_intrlvd, gap, m, scale); + HEXL_VLOG(5, + "AVX512 returning INV FFT like result " + << std::vector>( + result_cmplx_intrlvd, result_cmplx_intrlvd + 2 * n)); + } else { + ComplexInvT8(result_cmplx_intrlvd, W_cmplx_intrlvd, gap, m); + } + gap <<= 1; + m >>= 1; + W_delta = m * ((1ULL << (recursion_depth + 1)) - recursion_half); + W_idx += W_delta; + } +} + +#endif // HEXL_HAS_AVX512DQ + +} // namespace hexl +} // namespace intel diff --git a/hexl_ser/experimental/misc/lr-mat-vec-mult.cpp b/hexl_ser/experimental/misc/lr-mat-vec-mult.cpp new file mode 100644 index 00000000..729ae160 --- /dev/null +++ b/hexl_ser/experimental/misc/lr-mat-vec-mult.cpp @@ -0,0 +1,149 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "hexl/experimental/misc/lr-mat-vec-mult.hpp" + +#include + +#include "hexl/eltwise/eltwise-add-mod.hpp" +#include "hexl/eltwise/eltwise-mult-mod.hpp" +#include "hexl/number-theory/number-theory.hpp" +#include "hexl/util/aligned-allocator.hpp" +#include "hexl/util/check.hpp" +#include "util/cpu-features.hpp" + +namespace intel { +namespace hexl { + +// operand1: num_weights x 2 x n x num_moduli +// operand2: num_weights x 2 x n x num_moduli +// +// results: num_weights x 3 x n x num_moduli +// [num_weights x {x[0].*y[0], x[0].*y[1]+x[1].*y[0], x[1].*y[1]} x num_moduli]. +// TODO(@fdiasmor): Ideally, the size of results can be optimized to [3 x n x +// num_moduli]. +void LinRegMatrixVectorMultiply(uint64_t* result, const uint64_t* operand1, + const uint64_t* operand2, uint64_t n, + const uint64_t* moduli, uint64_t num_moduli, + uint64_t num_weights) { + HEXL_CHECK(result != nullptr, "Require result != nullptr"); + HEXL_CHECK(operand1 != nullptr, "Require operand1 != nullptr"); + HEXL_CHECK(operand2 != nullptr, "Require operand2 != nullptr"); + HEXL_CHECK(moduli != nullptr, "Require moduli != nullptr"); + HEXL_CHECK(n != 0, "Require n != 0"); + HEXL_CHECK(num_weights != 0, "Require n != 0"); + + // pointer increment to switch to a next polynomial + size_t poly_size = n * num_moduli; + + // ciphertext increment to switch to the next ciphertext + size_t cipher_size = 2 * poly_size; + + // ciphertext output increment to switch to the next output + size_t output_size = 3 * poly_size; + + AlignedVector64 temp(n, 0); + + for (size_t r = 0; r < num_weights; r++) { + size_t next_output = r * output_size; + size_t next_poly_pair = r * cipher_size; + uint64_t* cipher2 = result + next_output; + const uint64_t* cipher0 = operand1 + next_poly_pair; + const uint64_t* cipher1 = operand2 + next_poly_pair; + + for (size_t i = 0; i < num_moduli; i++) { + size_t i_times_n = i * n; + size_t poly0_offset = i_times_n; + size_t poly1_offset = poly0_offset + poly_size; + size_t poly2_offset = poly0_offset + 2 * poly_size; + + // Output ciphertext has 3 polynomials, where x, y are the input + // ciphertexts: (x[0] * y[0], x[0] * y[1] + x[1] * y[0], x[1] * y[1]) + + // Compute third output polynomial + // Output written directly to result rather than temporary buffer + // result[2] = x[1] * y[1] + intel::hexl::EltwiseMultMod(cipher2 + poly2_offset, + cipher0 + poly1_offset, + cipher1 + poly1_offset, n, moduli[i], 1); + + // Compute second output polynomial + // result[1] = x[1] * y[0] + intel::hexl::EltwiseMultMod(cipher2 + poly1_offset, + cipher0 + poly1_offset, + cipher1 + poly0_offset, n, moduli[i], 1); + + // result[1] = x[0] * y[1] + intel::hexl::EltwiseMultMod(temp.data(), cipher0 + poly0_offset, + cipher1 + poly1_offset, n, moduli[i], 1); + // result[1] += temp_poly + intel::hexl::EltwiseAddMod(cipher2 + poly1_offset, cipher2 + poly1_offset, + temp.data(), n, moduli[i]); + + // Compute first output polynomial + // result[0] = x[0] * y[0] + intel::hexl::EltwiseMultMod(cipher2 + poly0_offset, + cipher0 + poly0_offset, + cipher1 + poly0_offset, n, moduli[i], 1); + } + } + + const bool USE_ADDER_TREE = true; + if (USE_ADDER_TREE) { + // Accumulate with the adder-tree algorithm in O(logn) + for (size_t dist = 1; dist < num_weights; dist += dist) { + size_t step = dist * 2; + size_t neighbor_cipher_incr = dist * output_size; + // This loop can leverage parallelism using #pragma unroll + for (size_t s = 0; s < num_weights; s += step) { + size_t next_cipher_pair_incr = s * output_size; + uint64_t* left_cipher = result + next_cipher_pair_incr; + uint64_t* right_cipher = left_cipher + neighbor_cipher_incr; + + // This loop can leverage parallelism using #pragma unroll + for (size_t i = 0; i < num_moduli; i++) { + size_t i_times_n = i * n; + size_t poly0_offset = i_times_n; + size_t poly1_offset = poly0_offset + poly_size; + size_t poly2_offset = poly0_offset + 2 * poly_size; + + // All EltwiseAddMod below can run in parallel + intel::hexl::EltwiseAddMod(left_cipher + poly0_offset, + right_cipher + poly0_offset, + left_cipher + poly0_offset, n, moduli[i]); + intel::hexl::EltwiseAddMod(left_cipher + poly1_offset, + right_cipher + poly1_offset, + left_cipher + poly1_offset, n, moduli[i]); + intel::hexl::EltwiseAddMod(left_cipher + poly2_offset, + right_cipher + poly2_offset, + left_cipher + poly2_offset, n, moduli[i]); + } + } + } + } else { + // Accumulate all rows in sequence + uint64_t* acc = result; + for (size_t r = 1; r < num_weights; r++) { + size_t next_cipher = r * output_size; + acc += next_cipher; + for (size_t i = 0; i < num_moduli; i++) { + size_t i_times_n = i * n; + size_t poly0_offset = i_times_n; + size_t poly1_offset = poly0_offset + poly_size; + size_t poly2_offset = poly0_offset + 2 * poly_size; + + // All EltwiseAddMod below can run in parallel + + intel::hexl::EltwiseAddMod(result + poly0_offset, result + poly0_offset, + acc + poly0_offset, n, moduli[i]); + intel::hexl::EltwiseAddMod(result + poly1_offset, result + poly1_offset, + acc + poly1_offset, n, moduli[i]); + intel::hexl::EltwiseAddMod(result + poly2_offset, result + poly2_offset, + acc + poly2_offset, n, moduli[i]); + } + } + } +} + +} // namespace hexl +} // namespace intel diff --git a/hexl_ser/experimental/seal/dyadic-multiply-internal.cpp b/hexl_ser/experimental/seal/dyadic-multiply-internal.cpp new file mode 100644 index 00000000..e321ec4a --- /dev/null +++ b/hexl_ser/experimental/seal/dyadic-multiply-internal.cpp @@ -0,0 +1,78 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "hexl/experimental/seal/dyadic-multiply-internal.hpp" + +#include "hexl/eltwise/eltwise-add-mod.hpp" +#include "hexl/eltwise/eltwise-mult-mod.hpp" +#include "hexl/number-theory/number-theory.hpp" +#include "hexl/util/aligned-allocator.hpp" +#include "hexl/util/check.hpp" +#include "util/cpu-features.hpp" + +namespace intel { +namespace hexl { +namespace internal { + +void DyadicMultiply(uint64_t* result, const uint64_t* operand1, + const uint64_t* operand2, uint64_t n, + const uint64_t* moduli, uint64_t num_moduli) { + HEXL_CHECK(result != nullptr, "Require result != nullptr"); + HEXL_CHECK(operand1 != nullptr, "Require operand1 != nullptr"); + HEXL_CHECK(operand2 != nullptr, "Require operand2 != nullptr"); + HEXL_CHECK(moduli != nullptr, "Require moduli != nullptr"); + HEXL_CHECK(n != 0, "Require n != 0"); + + // pointer increment to switch to a next polynomial + size_t poly_size = n * num_moduli; + + // Output ciphertext has 3 polynomials, where x, y are the input + // ciphertexts: (x[0] * y[0], x[0] * y[1] + x[1] * y[0], x[1] * y[1]) + + // TODO(fboemer): Determine based on cpu cache size + size_t tile_size = std::min(n, uint64_t(512)); + size_t num_tiles = n / tile_size; + + AlignedVector64 temp(tile_size, 0); + + // Modulus by modulus + for (size_t i = 0; i < num_moduli; i++) { + // Split by tiles for better caching + size_t i_times_n = i * n; + for (size_t tile = 0; tile < num_tiles; ++tile) { + size_t poly0_offset = i_times_n + tile_size * tile; + size_t poly1_offset = poly0_offset + poly_size; + size_t poly2_offset = poly0_offset + 2 * poly_size; + + // Compute third output polynomial + // Output written directly to result rather than temporary buffer + // result[2] = x[1] * y[1] + intel::hexl::EltwiseMultMod( + &result[poly2_offset], operand1 + poly1_offset, + operand2 + poly1_offset, tile_size, moduli[i], 1); + + // Compute second output polynomial + // result[1] = x[1] * y[0] + intel::hexl::EltwiseMultMod(temp.data(), operand1 + poly1_offset, + operand2 + poly0_offset, tile_size, moduli[i], + 1); + // result[1] = x[0] * y[1] + intel::hexl::EltwiseMultMod( + &result[poly1_offset], operand1 + poly0_offset, + operand2 + poly1_offset, tile_size, moduli[i], 1); + // result[1] += temp_poly + intel::hexl::EltwiseAddMod(&result[poly1_offset], temp.data(), + &result[poly1_offset], tile_size, moduli[i]); + + // Compute first output polynomial + // result[0] = x[0] * y[0] + intel::hexl::EltwiseMultMod( + &result[poly0_offset], operand1 + poly0_offset, + operand2 + poly0_offset, tile_size, moduli[i], 1); + } + } +} + +} // namespace internal +} // namespace hexl +} // namespace intel diff --git a/hexl_ser/experimental/seal/dyadic-multiply.cpp b/hexl_ser/experimental/seal/dyadic-multiply.cpp new file mode 100644 index 00000000..e3306530 --- /dev/null +++ b/hexl_ser/experimental/seal/dyadic-multiply.cpp @@ -0,0 +1,22 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#ifndef HEXL_FPGA_COMPATIBLE_DYADIC_MULTIPLY + +#include "hexl/experimental/seal/dyadic-multiply.hpp" + +#include "hexl/experimental/seal/dyadic-multiply-internal.hpp" + +namespace intel { +namespace hexl { + +void DyadicMultiply(uint64_t* result, const uint64_t* operand1, + const uint64_t* operand2, uint64_t n, + const uint64_t* moduli, uint64_t num_moduli) { + intel::hexl::internal::DyadicMultiply(result, operand1, operand2, n, moduli, + num_moduli); +} + +} // namespace hexl +} // namespace intel +#endif diff --git a/hexl_ser/experimental/seal/key-switch-internal.cpp b/hexl_ser/experimental/seal/key-switch-internal.cpp new file mode 100644 index 00000000..15edb9a8 --- /dev/null +++ b/hexl_ser/experimental/seal/key-switch-internal.cpp @@ -0,0 +1,205 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "hexl/experimental/seal/key-switch-internal.hpp" + +#include +#include +#include + +#include "hexl/eltwise/eltwise-add-mod.hpp" +#include "hexl/eltwise/eltwise-fma-mod.hpp" +#include "hexl/eltwise/eltwise-mult-mod.hpp" +#include "hexl/eltwise/eltwise-reduce-mod.hpp" +#include "hexl/experimental/seal/ntt-cache.hpp" +#include "hexl/logging/logging.hpp" +#include "hexl/number-theory/number-theory.hpp" +#include "hexl/util/aligned-allocator.hpp" +#include "hexl/util/check.hpp" +#include "util/cpu-features.hpp" + +namespace intel { +namespace hexl { +namespace internal { + +void KeySwitch(uint64_t* result, const uint64_t* t_target_iter_ptr, uint64_t n, + uint64_t decomp_modulus_size, uint64_t key_modulus_size, + uint64_t rns_modulus_size, uint64_t key_component_count, + const uint64_t* moduli, const uint64_t** k_switch_keys, + const uint64_t* modswitch_factors, + const uint64_t* root_of_unity_powers_ptr) { + if (root_of_unity_powers_ptr != nullptr) { + throw std::invalid_argument( + "Parameter root_of_unity_powers_ptr is not supported yet."); + } + + uint64_t coeff_count = n; + + // Create a copy of target_iter + std::vector t_target( + t_target_iter_ptr, + t_target_iter_ptr + (coeff_count * decomp_modulus_size)); + uint64_t* t_target_ptr = t_target.data(); + + // Simplified implementation, where we assume no modular reduction is required + // for intermediate additions + std::vector t_ntt(coeff_count, 0); + uint64_t* t_ntt_ptr = t_ntt.data(); + + // In CKKS t_target is in NTT form; switch + // back to normal form + for (size_t j = 0; j < decomp_modulus_size; ++j) { + GetNTT(n, moduli[j]) + .ComputeInverse(&t_target_ptr[j * coeff_count], + &t_target_ptr[j * coeff_count], 2, 1); + } + + std::vector t_poly_prod( + key_component_count * coeff_count * rns_modulus_size, 0); + + for (size_t i = 0; i < rns_modulus_size; ++i) { + size_t key_index = (i == decomp_modulus_size ? key_modulus_size - 1 : i); + + // Allocate memory for a lazy accumulator (128-bit coefficients) + std::vector t_poly_lazy(key_component_count * coeff_count * 2, 0); + uint64_t* t_poly_lazy_ptr = &t_poly_lazy[0]; + uint64_t* accumulator_ptr = &t_poly_lazy[0]; + + for (size_t j = 0; j < decomp_modulus_size; ++j) { + const uint64_t* t_operand; + // assume scheme == scheme_type::ckks + if (i == j) { + t_operand = &t_target_iter_ptr[j * coeff_count]; + } else { + // Perform RNS-NTT conversion + // No need to perform RNS conversion (modular reduction) + if (moduli[j] <= moduli[key_index]) { + for (size_t l = 0; l < coeff_count; ++l) { + t_ntt_ptr[l] = t_target_ptr[j * coeff_count + l]; + } + } else { + // Perform RNS conversion (modular reduction) + intel::hexl::EltwiseReduceMod( + t_ntt_ptr, &t_target_ptr[j * coeff_count], coeff_count, + moduli[key_index], moduli[key_index], 1); + } + + // NTT conversion lazy outputs in [0, 4q) + GetNTT(n, moduli[key_index]).ComputeForward(t_ntt_ptr, t_ntt_ptr, 4, 4); + t_operand = t_ntt_ptr; + } + + // Multiply with keys and modular accumulate products in a lazy fashion + for (size_t k = 0; k < key_component_count; ++k) { + // No reduction used; assume intermediate results don't overflow + for (size_t l = 0; l < coeff_count; ++l) { + uint64_t t_poly_idx = 2 * (k * coeff_count + l); + + uint64_t mult_op2_idx = + coeff_count * key_index + k * key_modulus_size * coeff_count + l; + + uint128_t prod = + MultiplyUInt64(t_operand[l], k_switch_keys[j][mult_op2_idx]); + + // TODO(fboemer): add uint128 + uint128_t low = t_poly_lazy_ptr[t_poly_idx]; + uint128_t hi = t_poly_lazy_ptr[t_poly_idx + 1]; + uint128_t x = (hi << 64) + low; + uint128_t sum = prod + x; + uint64_t sum_hi = static_cast(sum >> 64); + uint64_t sum_lo = static_cast(sum); + t_poly_lazy_ptr[t_poly_idx] = sum_lo; + t_poly_lazy_ptr[t_poly_idx + 1] = sum_hi; + } + } + } + + // PolyIter pointing to the destination t_poly_prod, shifted to the + // appropriate modulus + uint64_t* t_poly_prod_iter_ptr = &t_poly_prod[i * coeff_count]; + + // Final modular reduction + for (size_t k = 0; k < key_component_count; ++k) { + for (size_t l = 0; l < coeff_count; ++l) { + uint64_t accumulator_idx = 2 * coeff_count * k + 2 * l; + uint64_t poly_iter_idx = coeff_count * rns_modulus_size * k + l; + + t_poly_prod_iter_ptr[poly_iter_idx] = BarrettReduce128( + accumulator_ptr[accumulator_idx + 1], + accumulator_ptr[accumulator_idx], moduli[key_index]); + } + } + } + + uint64_t* data_array = result; + for (size_t key_component = 0; key_component < key_component_count; + ++key_component) { + uint64_t* t_poly_prod_it = + &t_poly_prod[key_component * coeff_count * rns_modulus_size]; + uint64_t* t_last = &t_poly_prod_it[decomp_modulus_size * coeff_count]; + + GetNTT(n, moduli[key_modulus_size - 1]) + .ComputeInverse(t_last, t_last, 2, 2); + + uint64_t qk = moduli[key_modulus_size - 1]; + uint64_t qk_half = qk >> 1; + + for (size_t i = 0; i < coeff_count; ++i) { + uint64_t barrett_factor = + MultiplyFactor(1, 64, moduli[key_modulus_size - 1]).BarrettFactor(); + t_last[i] = BarrettReduce64(t_last[i] + qk_half, + moduli[key_modulus_size - 1], barrett_factor); + } + + for (size_t i = 0; i < decomp_modulus_size; ++i) { + // (ct mod 4qk) mod qi + uint64_t qi = moduli[i]; + + // TODO(fboemer): Use input_mod_factor != 0 when qk / qi < 4 + // TODO(fboemer): Use output_mod_factor == 4? + uint64_t input_mod_factor = (qk > qi) ? moduli[i] : 2; + if (qk > qi) { + intel::hexl::EltwiseReduceMod(t_ntt_ptr, t_last, coeff_count, moduli[i], + input_mod_factor, 1); + } else { + for (size_t coeff_idx = 0; coeff_idx < coeff_count; ++coeff_idx) { + t_ntt_ptr[coeff_idx] = t_last[coeff_idx]; + } + } + + // Lazy subtraction, results in [0, 2*qi), since fix is in [0, qi]. + uint64_t barrett_factor = + MultiplyFactor(1, 64, moduli[i]).BarrettFactor(); + uint64_t fix = qi - BarrettReduce64(qk_half, moduli[i], barrett_factor); + for (size_t l = 0; l < coeff_count; ++l) { + t_ntt_ptr[l] += fix; + } + + uint64_t qi_lazy = qi << 1; // some multiples of qi + GetNTT(n, moduli[i]).ComputeForward(t_ntt_ptr, t_ntt_ptr, 4, 4); + // Since SEAL uses at most 60bit moduli, 8*qi < 2^63. + qi_lazy = qi << 2; + + // ((ct mod qi) - (ct mod qk)) mod qi + uint64_t* t_ith_poly = &t_poly_prod_it[i * coeff_count]; + for (size_t k = 0; k < coeff_count; ++k) { + t_ith_poly[k] = t_ith_poly[k] + qi_lazy - t_ntt[k]; + } + + // qk^(-1) * ((ct mod qi) - (ct mod qk)) mod qi + intel::hexl::EltwiseFMAMod(t_ith_poly, t_ith_poly, modswitch_factors[i], + nullptr, coeff_count, moduli[i], 8); + uint64_t data_ptr_offset = + coeff_count * (decomp_modulus_size * key_component + i); + + uint64_t* data_ptr = &data_array[data_ptr_offset]; + intel::hexl::EltwiseAddMod(data_ptr, data_ptr, t_ith_poly, coeff_count, + moduli[i]); + } + } + return; +} + +} // namespace internal +} // namespace hexl +} // namespace intel diff --git a/hexl_ser/experimental/seal/key-switch.cpp b/hexl_ser/experimental/seal/key-switch.cpp new file mode 100644 index 00000000..f006a47a --- /dev/null +++ b/hexl_ser/experimental/seal/key-switch.cpp @@ -0,0 +1,27 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#ifndef HEXL_FPGA_COMPATIBLE_KEYSWITCH + +#include "hexl/experimental/seal/key-switch.hpp" + +#include "hexl/experimental/seal/key-switch-internal.hpp" + +namespace intel { +namespace hexl { + +void KeySwitch(uint64_t* result, const uint64_t* t_target_iter_ptr, uint64_t n, + uint64_t decomp_modulus_size, uint64_t key_modulus_size, + uint64_t rns_modulus_size, uint64_t key_component_count, + const uint64_t* moduli, const uint64_t** k_switch_keys, + const uint64_t* modswitch_factors, + const uint64_t* root_of_unity_powers_ptr) { + intel::hexl::internal::KeySwitch( + result, t_target_iter_ptr, n, decomp_modulus_size, key_modulus_size, + rns_modulus_size, key_component_count, moduli, k_switch_keys, + modswitch_factors, root_of_unity_powers_ptr); +} + +} // namespace hexl +} // namespace intel +#endif diff --git a/hexl_ser/include/hexl/eltwise/eltwise-add-mod.hpp b/hexl_ser/include/hexl/eltwise/eltwise-add-mod.hpp new file mode 100644 index 00000000..cb2df110 --- /dev/null +++ b/hexl_ser/include/hexl/eltwise/eltwise-add-mod.hpp @@ -0,0 +1,40 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +namespace intel { +namespace hexl { + +/// @brief Adds two vectors elementwise with modular reduction +/// @param[out] result Stores result +/// @param[in] operand1 Vector of elements to add. Each element must be less +/// than the modulus +/// @param[in] operand2 Vector of elements to add. Each element must be less +/// than the modulus +/// @param[in] n Number of elements in each vector +/// @param[in] modulus Modulus with which to perform modular reduction. Must be +/// in the range \f$[2, 2^{63} - 1]\f$ +/// @details Computes \f$ operand1[i] = (operand1[i] + operand2[i]) \mod modulus +/// \f$ for \f$ i=0, ..., n-1\f$. +void EltwiseAddMod(uint64_t* result, const uint64_t* operand1, + const uint64_t* operand2, uint64_t n, uint64_t modulus); + +/// @brief Adds a vector and scalar elementwise with modular reduction +/// @param[out] result Stores result +/// @param[in] operand1 Vector of elements to add. Each element must be less +/// than the modulus +/// @param[in] operand2 Scalar to add. Must be less +/// than the modulus +/// @param[in] n Number of elements in each vector +/// @param[in] modulus Modulus with which to perform modular reduction. Must be +/// in the range \f$[2, 2^{63} - 1]\f$ +/// @details Computes \f$ operand1[i] = (operand1[i] + operand2) \mod modulus +/// \f$ for \f$ i=0, ..., n-1\f$. +void EltwiseAddMod(uint64_t* result, const uint64_t* operand1, + uint64_t operand2, uint64_t n, uint64_t modulus); + +} // namespace hexl +} // namespace intel diff --git a/hexl_ser/include/hexl/eltwise/eltwise-cmp-add.hpp b/hexl_ser/include/hexl/eltwise/eltwise-cmp-add.hpp new file mode 100644 index 00000000..27e514ff --- /dev/null +++ b/hexl_ser/include/hexl/eltwise/eltwise-cmp-add.hpp @@ -0,0 +1,26 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +#include "hexl/util/util.hpp" + +namespace intel { +namespace hexl { + +/// @brief Computes element-wise conditional addition. +/// @param[out] result Stores the result +/// @param[in] operand1 Vector of elements to compare; stores result +/// @param[in] n Number of elements in \p operand1 +/// @param[in] cmp Comparison operation +/// @param[in] bound Scalar to compare against +/// @param[in] diff Scalar to conditionally add +/// @details Computes result[i] = cmp(operand1[i], bound) ? operand1[i] + +/// diff : operand1[i] for all \f$i=0, ..., n-1\f$. +void EltwiseCmpAdd(uint64_t* result, const uint64_t* operand1, uint64_t n, + CMPINT cmp, uint64_t bound, uint64_t diff); + +} // namespace hexl +} // namespace intel diff --git a/hexl_ser/include/hexl/eltwise/eltwise-cmp-sub-mod.hpp b/hexl_ser/include/hexl/eltwise/eltwise-cmp-sub-mod.hpp new file mode 100644 index 00000000..07ba3d23 --- /dev/null +++ b/hexl_ser/include/hexl/eltwise/eltwise-cmp-sub-mod.hpp @@ -0,0 +1,29 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +#include "hexl/util/util.hpp" + +namespace intel { +namespace hexl { + +/// @brief Computes element-wise conditional modular subtraction. +/// @param[out] result Stores the result +/// @param[in] operand1 Vector of elements to compare +/// @param[in] n Number of elements in \p operand1 +/// @param[in] modulus Modulus to reduce by +/// @param[in] cmp Comparison function +/// @param[in] bound Scalar to compare against +/// @param[in] diff Scalar to subtract by +/// @details Computes \p operand1[i] = (\p cmp(\p operand1, \p bound)) ? (\p +/// operand1 - \p diff) mod \p modulus : \p operand1 mod \p modulus for all i=0, +/// ..., n-1 +void EltwiseCmpSubMod(uint64_t* result, const uint64_t* operand1, uint64_t n, + uint64_t modulus, CMPINT cmp, uint64_t bound, + uint64_t diff); + +} // namespace hexl +} // namespace intel diff --git a/hexl_ser/include/hexl/eltwise/eltwise-fma-mod.hpp b/hexl_ser/include/hexl/eltwise/eltwise-fma-mod.hpp new file mode 100644 index 00000000..03651a42 --- /dev/null +++ b/hexl_ser/include/hexl/eltwise/eltwise-fma-mod.hpp @@ -0,0 +1,27 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +namespace intel { +namespace hexl { + +/// @brief Computes fused multiply-add (\p arg1 * \p arg2 + \p arg3) mod \p +/// modulus element-wise, broadcasting scalars to vectors. +/// @param[out] result Stores the result +/// @param[in] arg1 Vector to multiply +/// @param[in] arg2 Scalar to multiply +/// @param[in] arg3 Vector to add. Will not add if \p arg3 == nullptr +/// @param[in] n Number of elements in each vector +/// @param[in] modulus Modulus with which to perform modular reduction. Must be +/// in the range \f$ [2, 2^{61} - 1]\f$ +/// @param[in] input_mod_factor Assumes input elements are in [0, +/// input_mod_factor * modulus). Must be 1, 2, 4, or 8. +void EltwiseFMAMod(uint64_t* result, const uint64_t* arg1, uint64_t arg2, + const uint64_t* arg3, uint64_t n, uint64_t modulus, + uint64_t input_mod_factor); + +} // namespace hexl +} // namespace intel diff --git a/hexl_ser/include/hexl/eltwise/eltwise-mult-mod.hpp b/hexl_ser/include/hexl/eltwise/eltwise-mult-mod.hpp new file mode 100644 index 00000000..e4d2dbd7 --- /dev/null +++ b/hexl_ser/include/hexl/eltwise/eltwise-mult-mod.hpp @@ -0,0 +1,28 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +namespace intel { +namespace hexl { + +/// @brief Multiplies two vectors elementwise with modular reduction +/// @param[in] result Result of element-wise multiplication +/// @param[in] operand1 Vector of elements to multiply. Each element must be +/// less than the modulus. +/// @param[in] operand2 Vector of elements to multiply. Each element must be +/// less than the modulus. +/// @param[in] n Number of elements in each vector +/// @param[in] modulus Modulus with which to perform modular reduction +/// @param[in] input_mod_factor Assumes input elements are in [0, +/// input_mod_factor * p) Must be 1, 2 or 4. +/// @details Computes \p result[i] = (\p operand1[i] * \p operand2[i]) mod \p +/// modulus for i=0, ..., \p n - 1 +void EltwiseMultMod(uint64_t* result, const uint64_t* operand1, + const uint64_t* operand2, uint64_t n, uint64_t modulus, + uint64_t input_mod_factor); + +} // namespace hexl +} // namespace intel diff --git a/hexl_ser/include/hexl/eltwise/eltwise-reduce-mod.hpp b/hexl_ser/include/hexl/eltwise/eltwise-reduce-mod.hpp new file mode 100644 index 00000000..c23abde2 --- /dev/null +++ b/hexl_ser/include/hexl/eltwise/eltwise-reduce-mod.hpp @@ -0,0 +1,29 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +namespace intel { +namespace hexl { + +/// @brief Performs elementwise modular reduction +/// @param[out] result Stores the result +/// @param[in] operand Data on which to compute the elementwise modular +/// reduction +/// @param[in] n Number of elements in operand +/// @param[in] modulus Modulus with which to perform modular reduction +/// @param[in] input_mod_factor Assumes input elements are in [0, +/// input_mod_factor * p) Must be modulus, 1, 2 or 4. input_mod_factor=modulus +/// means, input range is [0, p * p]. Barrett reduction will be used in this +/// case. input_mod_factor > output_mod_factor +/// @param[in] output_mod_factor output elements will be in [0, +/// output_mod_factor * modulus) Must be 1 or 2. For input_mod_factor=0, +/// output_mod_factor will be set to 1. +void EltwiseReduceMod(uint64_t* result, const uint64_t* operand, uint64_t n, + uint64_t modulus, uint64_t input_mod_factor, + uint64_t output_mod_factor); + +} // namespace hexl +} // namespace intel diff --git a/hexl_ser/include/hexl/eltwise/eltwise-sub-mod.hpp b/hexl_ser/include/hexl/eltwise/eltwise-sub-mod.hpp new file mode 100644 index 00000000..bd286e47 --- /dev/null +++ b/hexl_ser/include/hexl/eltwise/eltwise-sub-mod.hpp @@ -0,0 +1,40 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +namespace intel { +namespace hexl { + +/// @brief Subtracts two vectors elementwise with modular reduction +/// @param[out] result Stores result +/// @param[in] operand1 Vector of elements to subtract from. Each element must +/// be less than the modulus +/// @param[in] operand2 Vector of elements to subtract. Each element must be +/// less than the modulus +/// @param[in] n Number of elements in each vector +/// @param[in] modulus Modulus with which to perform modular reduction. Must be +/// in the range \f$[2, 2^{63} - 1]\f$ +/// @details Computes \f$ operand1[i] = (operand1[i] - operand2[i]) \mod modulus +/// \f$ for \f$ i=0, ..., n-1\f$. +void EltwiseSubMod(uint64_t* result, const uint64_t* operand1, + const uint64_t* operand2, uint64_t n, uint64_t modulus); + +/// @brief Subtracts a scalar from a vector elementwise with modular reduction +/// @param[out] result Stores result +/// @param[in] operand1 Vector of elements to subtract from. Each element must +/// be less than the modulus +/// @param[in] operand2 Elements to subtract. Each element must be +/// less than the modulus +/// @param[in] n Number of elements in each vector +/// @param[in] modulus Modulus with which to perform modular reduction. Must be +/// in the range \f$[2, 2^{63} - 1]\f$ +/// @details Computes \f$ operand1[i] = (operand1[i] - operand2) \mod modulus +/// \f$ for \f$ i=0, ..., n-1\f$. +void EltwiseSubMod(uint64_t* result, const uint64_t* operand1, + uint64_t operand2, uint64_t n, uint64_t modulus); + +} // namespace hexl +} // namespace intel diff --git a/hexl_ser/include/hexl/experimental/fft-like/fft-like-avx512-util.hpp b/hexl_ser/include/hexl/experimental/fft-like/fft-like-avx512-util.hpp new file mode 100644 index 00000000..28a2dddf --- /dev/null +++ b/hexl_ser/include/hexl/experimental/fft-like/fft-like-avx512-util.hpp @@ -0,0 +1,402 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "util/avx512-util.hpp" + +namespace intel { +namespace hexl { + +#ifdef HEXL_HAS_AVX512DQ + +// ************************************ T1 ************************************ + +// ComplexLoadFwdInterleavedT1: +// Assumes ComplexLoadFwdInterleavedT2 was used before. +// Given input: 15, 14, 11, 10, 7, 6, 3, 2, 13, 12, 9, 8, 5, 4, 1, 0 +// Returns +// *out1 = (14, 12, 10, 8, 6, 4, 2, 0); +// *out2 = (15, 13, 11, 9, 7, 5, 3, 1); +// +// Given output: 15, 13, 11, 9, 7, 5, 3, 1, 14, 12, 10, 8, 6, 4, 2, 0 +inline void ComplexLoadFwdInterleavedT1(const double* arg, __m512d* out1, + __m512d* out2) { + const __m512i vperm_idx = _mm512_set_epi64(6, 7, 4, 5, 2, 3, 0, 1); + + const __m512d* arg_512 = reinterpret_cast(arg); + + // 13 12 9 8 5 4 1 0 + __m512d v_7to0 = _mm512_loadu_pd(arg_512); + arg_512 += 2; + // 15 14 11 10 7 6 3 2 + __m512d v_15to8 = _mm512_loadu_pd(arg_512); + + // 12, 13, 8, 9, 4, 5, 0, 1 + __m512d perm_1 = _mm512_permutexvar_pd(vperm_idx, v_7to0); + // 14, 15, 10, 11, 6, 7, 2, 3 + __m512d perm_2 = _mm512_permutexvar_pd(vperm_idx, v_15to8); + + // 14, 12, 10, 8, 6, 4, 2, 0 + *out1 = _mm512_mask_blend_pd(0xaa, v_7to0, perm_2); + // 15, 13, 11, 9, 7, 5, 3, 1 + *out2 = _mm512_mask_blend_pd(0x55, v_15to8, perm_1); +} + +// ComplexWriteFwdInterleavedT1: +// Assumes ComplexLoadFwdInterleavedT1 was used before. +// Given inputs: +// 15i, 13i, 11i, 9i, 7i, 5i, 3i, 1i, 15r, 13r, 11r, 9r, 7r, 5r, 3r, 1r, +// 14i, 12i, 10i, 8i, 6i, 4i, 2i, 0i, 14r, 12r, 10r, 8r, 6r, 4r, 2r, 0r +// As seen with internal indexes: +// @param arg_yr = (15r, 14r, 13r, 12r, 11r, 10r, 9r, 8r); +// @param arg_xr = ( 7r, 6r, 5r, 4r, 3r, 2r, 1r, 0r); +// @param arg_yi = (15i, 14i, 13i, 12i, 11i, 10i, 9i, 8i); +// @param arg_xi = ( 7i, 6i, 5i, 4i, 3i, 2i, 1i, 0i); +// Writes out = +// {15i, 15r, 7i, 7r, 14i, 14r, 6i, 6r, 13i, 13r, 5i, 5r, 12i, 12r, 4i, 4r, +// 11i, 11r, 3i, 3r, 10i, 10r, 2i, 2r, 9i, 9r, 1i, 1r, 8i, 8r, 0i, 0r} +// +// Given output: +// 15i, 15r, 14i, 14r, 13i, 13r, 12i, 12r, 11i, 11r, 10i, 10r, 9i, 9r, 8i, 8r, +// 7i, 7r, 6i, 6r, 5i, 5r, 4i, 4r, 3i, 3r, 2i, 2r, 1i, 1r, 0i, 0r +inline void ComplexWriteFwdInterleavedT1(__m512d arg_xr, __m512d arg_yr, + __m512d arg_xi, __m512d arg_yi, + __m512d* out) { + const __m512i vperm_4hi_4lo_idx = _mm512_set_epi64(3, 2, 1, 0, 7, 6, 5, 4); + const __m512i v_X_out_idx = _mm512_set_epi64(3, 1, 7, 5, 2, 0, 6, 4); + const __m512i v_Y_out_idx = _mm512_set_epi64(7, 5, 3, 1, 6, 4, 2, 0); + + // Real part + // in: 14r, 12r, 10r, 8r, 6r, 4r, 2r, 0r + // -> 6r, 4r, 2r, 0r, 14r, 12r, 10r, 8r + arg_xr = _mm512_permutexvar_pd(vperm_4hi_4lo_idx, arg_xr); + + // arg_yr: 15r, 13r, 11r, 9r, 7r, 5r, 3r, 1r + // -> 6r, 4r, 2r, 0r, 7r, 5r, 3r, 1r + __m512d perm_1 = _mm512_mask_blend_pd(0x0f, arg_xr, arg_yr); + // -> 15r, 13r, 11r, 9r, 14r, 12r, 10r, 8r + __m512d perm_2 = _mm512_mask_blend_pd(0xf0, arg_xr, arg_yr); + + // 7r, 3r, 6r, 2r, 5r, 1r, 4r, 0r + arg_xr = _mm512_permutexvar_pd(v_X_out_idx, perm_1); + // 15r, 11r, 14r, 10r, 13r, 9r, 12r, 8r + arg_yr = _mm512_permutexvar_pd(v_Y_out_idx, perm_2); + + // Imaginary part + // in: 14i, 12i, 10i, 8i, 6i, 4i, 2i, 0i + // -> 6i, 4i, 2i, 0i, 14i, 12i, 10i, 8i + arg_xi = _mm512_permutexvar_pd(vperm_4hi_4lo_idx, arg_xi); + + // arg_yr: 15i, 13i, 11i, 9i, 7i, 5i, 3i, 1i + // -> 6i, 4i, 2i, 0i, 7i, 5i, 3i, 1i + perm_1 = _mm512_mask_blend_pd(0x0f, arg_xi, arg_yi); + // -> 15i, 13i, 11i, 9i, 14i, 12i, 10i, 8i + perm_2 = _mm512_mask_blend_pd(0xf0, arg_xi, arg_yi); + + // 7i, 3i, 6i, 2i, 5i, 1i, 4i, 0i + arg_xi = _mm512_permutexvar_pd(v_X_out_idx, perm_1); + // 15i, 11i, 14i, 10i, 13i, 9i, 12i, 8i + arg_yi = _mm512_permutexvar_pd(v_Y_out_idx, perm_2); + + // Merge + // 00000000 > 3i 3r 2i 2r 1i 1r 0i 0r + __m512d out1 = _mm512_shuffle_pd(arg_xr, arg_xi, 0x00); + // 11111111 > 7i 7r 6i 6r 5i 5r 4i 4r + __m512d out2 = _mm512_shuffle_pd(arg_xr, arg_xi, 0xff); + + // 00000000 > 11i 11r 10i 10r 9i 9r 8i 8r + __m512d out3 = _mm512_shuffle_pd(arg_yr, arg_yi, 0x00); + // 11111111 > 15i 15r 14i 14r 13i 13r 12i 12r + __m512d out4 = _mm512_shuffle_pd(arg_yr, arg_yi, 0xff); + + _mm512_storeu_pd(out++, out1); + _mm512_storeu_pd(out++, out2); + _mm512_storeu_pd(out++, out3); + _mm512_storeu_pd(out++, out4); +} + +// ComplexLoadInvInterleavedT1: +// Given input: 15i 15r 14i 14r 13i 13r 12i 12r 11i 11r 10i 10r 9i 9r 8i 8r +// 7i 7r 6i 6r 5i 5r 4i 4r 3i 3r 2i 2r 1i 1r 0i 0r +// Returns +// *out1_r = (14r, 10r, 6r, 2r, 12r, 8r, 4r, 0r); +// *out1_i = (14i, 10i, 6i, 2i, 12i, 8i, 4i, 0i); +// *out2_r = (15r, 11r, 7r, 3r, 13r, 9r, 5r, 1r); +// *out2_i = (15i, 11i, 7i, 3i, 13i, 9i, 5i, 1i); +// +// Given output: +// 15i, 11i, 7i, 3i, 13i, 9i, 5i, 1i, 15r, 11r, 7r, 3r, 13r, 9r, 5r, 1r, +// 14i, 10i, 6i, 2i, 12i, 8i, 4i, 0i, 14r, 10r, 6r, 2r, 12r, 8r, 4r, 0r +inline void ComplexLoadInvInterleavedT1(const double* arg, __m512d* out1_r, + __m512d* out1_i, __m512d* out2_r, + __m512d* out2_i) { + const __m512d* arg_512 = reinterpret_cast(arg); + + // 3i 3r 2i 2r 1i 1r 0i 0r + __m512d v_3to0 = _mm512_loadu_pd(arg_512++); + // 7i 7r 6i 6r 5i 5r 4i 4r + __m512d v_7to4 = _mm512_loadu_pd(arg_512++); + // 11i 11r 10i 10r 9i 9r 8i 8r + __m512d v_11to8 = _mm512_loadu_pd(arg_512++); + // 15i 15r 14i 14r 13i 13r 12i 12r + __m512d v_15to12 = _mm512_loadu_pd(arg_512++); + + // 00000000 > 7r 3r 6r 2r 5r 1r 4r 0r + __m512d v_7to0_r = _mm512_shuffle_pd(v_3to0, v_7to4, 0x00); + // 11111111 > 7i 3i 6i 2i 5i 1i 4i 0i + __m512d v_7to0_i = _mm512_shuffle_pd(v_3to0, v_7to4, 0xff); + // 00000000 > 15r 11r 14r 10r 13r 9r 12r 8r + __m512d v_15to8_r = _mm512_shuffle_pd(v_11to8, v_15to12, 0x00); + // 11111111 > 15i 11i 14i 10i 13i 9i 12i 8i + __m512d v_15to8_i = _mm512_shuffle_pd(v_11to8, v_15to12, 0xff); + + // real + const __m512i v1_perm_idx = _mm512_set_epi64(5, 4, 7, 6, 1, 0, 3, 2); + // 6 2 7 3 4 0 5 1 + __m512d v1r = _mm512_permutexvar_pd(v1_perm_idx, v_7to0_r); + // 14 10 15 11 12 8 13 9 + __m512d v2r = _mm512_permutexvar_pd(v1_perm_idx, v_15to8_r); + // 11001100 > 14 10 6 2 12 8 4 0 + *out1_r = _mm512_mask_blend_pd(0xcc, v_7to0_r, v2r); + // 11001100 > 15 11 7 3 13 9 5 1 + *out2_r = _mm512_mask_blend_pd(0xcc, v1r, v_15to8_r); + + // imag + // 6 2 7 3 4 0 5 1 + __m512d v1i = _mm512_permutexvar_pd(v1_perm_idx, v_7to0_i); + // 14 10 15 11 12 8 13 9 + __m512d v2i = _mm512_permutexvar_pd(v1_perm_idx, v_15to8_i); + // 11001100 > 14 10 6 2 12 8 4 0 + *out1_i = _mm512_mask_blend_pd(0xcc, v_7to0_i, v2i); + // 11001100 > 15 11 7 3 13 9 5 1 + *out2_i = _mm512_mask_blend_pd(0xcc, v1i, v_15to8_i); +} + +// ************************************ T2 ************************************ + +// ComplexLoadFwdInterleavedT2: +// Assumes ComplexLoadFwdInterleavedT4 was used before. +// Given input: 15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0 +// Returns +// *out1 = (13, 12, 9, 8, 5, 4, 1, 0) +// *out2 = (15, 14, 11, 10, 7, 6, 3, 2) +// +// Given output: 15, 14, 11, 10, 7, 6, 3, 2, 13, 12, 9, 8, 5, 4, 1, 0 +inline void ComplexLoadFwdInterleavedT2(const double* arg, __m512d* out1, + __m512d* out2) { + const __m512d* arg_512 = reinterpret_cast(arg); + + // Values were swapped in T4 + // 11, 10, 9, 8, 3, 2, 1, 0 + __m512d v1 = _mm512_loadu_pd(arg_512); + arg_512 += 2; + // 15, 14, 13, 12, 7, 6, 5, 4 + __m512d v2 = _mm512_loadu_pd(arg_512); + + const __m512i v1_perm_idx = _mm512_set_epi64(5, 4, 7, 6, 1, 0, 3, 2); + + __m512d v1_perm = _mm512_permutexvar_pd(v1_perm_idx, v1); + __m512d v2_perm = _mm512_permutexvar_pd(v1_perm_idx, v2); + + *out1 = _mm512_mask_blend_pd(0xcc, v1, v2_perm); + *out2 = _mm512_mask_blend_pd(0xcc, v1_perm, v2); +} + +// ComplexLoadInvInterleavedT2: +// Assumes ComplexLoadInvInterleavedT1 was used before. +// Given input: 15, 11, 7, 3, 13, 9, 5, 1, 14, 10, 6, 2, 12, 8, 4, 0 +// Returns +// *out1 = (13, 9, 5, 1, 12, 8, 4, 0) +// *out2 = (15, 11, 7, 3, 14, 10, 6, 2) +// +// Given output: 15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0 +inline void ComplexLoadInvInterleavedT2(const double* arg, __m512d* out1, + __m512d* out2) { + const __m512d* arg_512 = reinterpret_cast(arg); + + // 14 10 6 2 12 8 4 0 + __m512d v1 = _mm512_loadu_pd(arg_512); + arg_512 += 2; + // 15 11 7 3 13 9 5 1 + __m512d v2 = _mm512_loadu_pd(arg_512); + + const __m512i v1_perm_idx = _mm512_set_epi64(3, 2, 1, 0, 7, 6, 5, 4); + + // 12 8 4 0 14 10 6 2 + __m512d v1_perm = _mm512_permutexvar_pd(v1_perm_idx, v1); + // 13 9 5 1 15 11 7 3 + __m512d v2_perm = _mm512_permutexvar_pd(v1_perm_idx, v2); + + // 11110000 > 13 9 5 1 12 8 4 0 + *out1 = _mm512_mask_blend_pd(0xf0, v1, v2_perm); + // 11110000 > 15 11 7 3 14 10 6 2 + *out2 = _mm512_mask_blend_pd(0xf0, v1_perm, v2); +} + +// ************************************ T4 ************************************ + +// Complex LoadFwdInterleavedT4: +// Given input: 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0 +// Returns +// *out1 = (11, 10, 9, 8, 3, 2, 1, 0) +// *out2 = (15, 14, 13, 12, 7, 6, 5, 4) +// +// Given output: 15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0 +inline void ComplexLoadFwdInterleavedT4(const double* arg, __m512d* out1, + __m512d* out2) { + const __m512d* arg_512 = reinterpret_cast(arg); + + const __m512i vperm2_idx = _mm512_set_epi64(3, 2, 1, 0, 7, 6, 5, 4); + __m512d v_7to0 = _mm512_loadu_pd(arg_512); + arg_512 += 2; + __m512d v_15to8 = _mm512_loadu_pd(arg_512); + __m512d perm_hi = _mm512_permutexvar_pd(vperm2_idx, v_15to8); + *out1 = _mm512_mask_blend_pd(0x0f, perm_hi, v_7to0); + *out2 = _mm512_mask_blend_pd(0xf0, perm_hi, v_7to0); + *out2 = _mm512_permutexvar_pd(vperm2_idx, *out2); +} + +// ComplexLoadInvInterleavedT4: +// Assumes ComplexLoadInvInterleavedT2 was used before. +// Given input: 15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0 +// Returns +// *out1 = (11, 9, 3, 1, 10, 8, 2, 0) +// *out2 = (15, 13, 7, 5, 14, 12, 6, 4) +// +// Given output: 15, 13, 7, 5, 14, 12, 6, 4, 11, 9, 3, 1, 10, 8, 2, 0 + +inline void ComplexLoadInvInterleavedT4(const double* arg, __m512d* out1, + __m512d* out2) { + const __m512d* arg_512 = reinterpret_cast(arg); + + // 13, 9, 5, 1, 12, 8, 4, 0 + __m512d v1 = _mm512_loadu_pd(arg_512); + arg_512 += 2; + // 15, 11, 7, 3, 14, 10, 6, 2 + __m512d v2 = _mm512_loadu_pd(arg_512); + + // 00000000 > 11 9 3 1 10 8 2 0 + *out1 = _mm512_shuffle_pd(v1, v2, 0x00); + // 11111111 > 15 13 7 5 14 12 6 4 + *out2 = _mm512_shuffle_pd(v1, v2, 0xff); +} + +// ComplexWriteInvInterleavedT4: +// Assuming ComplexLoadInvInterleavedT4 was used before. +// Given inputs: 15, 13, 7, 5, 14, 12, 6, 4, 11, 9, 3, 1, 10, 8, 2, 0 +// Seen Internally: +// @param arg1 = ( 7, 6, 5, 4, 3, 2, 1, 0); +// @param arg2 = (15, 14, 13, 12, 11, 10, 9, 8); +// Writes out = {15, 11, 14, 10, 7, 3, 6, 2, +// 13, 9, 12, 8, 5, 1, 4, 0} +// +// Given output: 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0 +inline void ComplexWriteInvInterleavedT4(__m512d arg1, __m512d arg2, + __m512d* out) { + const __m512i vperm_4hi_4lo_idx = _mm512_set_epi64(3, 2, 1, 0, 7, 6, 5, 4); + const __m512i vperm1 = _mm512_set_epi64(7, 3, 6, 2, 5, 1, 4, 0); + const __m512i vperm2 = _mm512_set_epi64(5, 1, 4, 0, 7, 3, 6, 2); + + // in: 11 9 3 1 10 8 2 0 + // -> 11 10 9 8 3 2 1 0 + arg1 = _mm512_permutexvar_pd(vperm1, arg1); + // in: 15 13 7 5 14 12 6 4 + // -> 7 6 5 4 15 14 13 12 + arg2 = _mm512_permutexvar_pd(vperm2, arg2); + + // 7 6 5 4 3 2 1 0 + __m512d out1 = _mm512_mask_blend_pd(0xf0, arg1, arg2); + // 11 10 9 8 15 14 13 12 + __m512d out2 = _mm512_mask_blend_pd(0x0f, arg1, arg2); + // 15 14 13 12 11 10 9 8 + out2 = _mm512_permutexvar_pd(vperm_4hi_4lo_idx, out2); + + _mm512_storeu_pd(out, out1); + out += 2; + _mm512_storeu_pd(out, out2); +} + +// ************************************ T8 ************************************ + +// ComplexLoadFwdInterleavedT8: +// Given inputs: 7i, 7r, 6i, 6r, 5i, 5r, 4i, 4r, 3i, 3r, 2i, 2r, 1i, 1r, 0i, 0r +// Seen Internally: +// v_X1 = ( 7, 6, 5, 4, 3, 2, 1, 0); +// v_X2 = (15, 14, 13, 12, 11, 10, 9, 8); +// Writes out = {15, 13, 11, 9, 7, 5, 3, 1, +// 14, 12, 10, 8, 6, 4, 2, 0} +// +// Given output: 7i, 6i, 5i, 4i, 3i, 2i, 1i, 0i, 7r, 6r, 5r, 4r, 3r, 2r, 1r, 0r +inline void ComplexLoadFwdInterleavedT8(const __m512d* arg_x, + const __m512d* arg_y, __m512d* out1_r, + __m512d* out1_i, __m512d* out2_r, + __m512d* out2_i) { + const __m512i v_perm_idx = _mm512_set_epi64(7, 5, 3, 1, 6, 4, 2, 0); + + // 3i, 3r, 2i, 2r, 1i, 1r, 0i, 0r + __m512d v_X1 = _mm512_loadu_pd(arg_x++); + // 7i, 7r, 6i, 6r, 5i, 5r, 4i, 4r + __m512d v_X2 = _mm512_loadu_pd(arg_x); + // 7r, 3r, 6r, 2r, 5r, 1r, 4r, 0r + *out1_r = _mm512_shuffle_pd(v_X1, v_X2, 0x00); + // 7i, 3i, 6i, 2i, 5i, 1i, 4i, 0i + *out1_i = _mm512_shuffle_pd(v_X1, v_X2, 0xff); + // 7r, 6r, 5r, 4r, 3r, 2r, 1r, 0r + *out1_r = _mm512_permutexvar_pd(v_perm_idx, *out1_r); + // 7i, 6i, 5i, 4i, 3i, 2i, 1i, 0i + *out1_i = _mm512_permutexvar_pd(v_perm_idx, *out1_i); + + __m512d v_Y1 = _mm512_loadu_pd(arg_y++); + __m512d v_Y2 = _mm512_loadu_pd(arg_y); + *out2_r = _mm512_shuffle_pd(v_Y1, v_Y2, 0x00); + *out2_i = _mm512_shuffle_pd(v_Y1, v_Y2, 0xff); + *out2_r = _mm512_permutexvar_pd(v_perm_idx, *out2_r); + *out2_i = _mm512_permutexvar_pd(v_perm_idx, *out2_i); +} + +// ComplexWriteInvInterleavedT8: +// Assuming ComplexLoadInvInterleavedT4 was used before. +// Given inputs: 7i, 6i, 5i, 4i, 3i, 2i, 1i, 0i, 7r, 6r, 5r, 4r, 3r, 2r, 1r, 0r +// Seen Internally: +// @param arg1 = ( 7, 6, 5, 4, 3, 2, 1, 0); +// @param arg2 = (15, 14, 13, 12, 11, 10, 9, 8); +// Writes out = {15, 7, 14, 6, 13, 5, 12, 4, +// 11, 3, 10, 2, 9, 1, 8, 0} +// +// Given output: 7i, 7r, 6i, 6r, 5i, 5r, 4i, 4r, 3i, 3r, 2i, 2r, 1i, 1r, 0i, 0r +inline void ComplexWriteInvInterleavedT8(__m512d* v_X_real, __m512d* v_X_imag, + __m512d* v_Y_real, __m512d* v_Y_imag, + __m512d* v_X_pt, __m512d* v_Y_pt) { + const __m512i vperm = _mm512_set_epi64(7, 3, 6, 2, 5, 1, 4, 0); + // in: 7r 6r 5r 4r 3r 2r 1r 0r + // -> 7r 3r 6r 2r 5r 1r 4r 0r + *v_X_real = _mm512_permutexvar_pd(vperm, *v_X_real); + // in: 7i 6i 5i 4i 3i 2i 1i 0i + // -> 7i 3i 6i 2i 5i 1i 4i 0i + *v_X_imag = _mm512_permutexvar_pd(vperm, *v_X_imag); + // in: 15r 14r 13r 12r 11r 10r 9r 8r + // -> 15r 11r 14r 10r 13r 9r 12r 8r + *v_Y_real = _mm512_permutexvar_pd(vperm, *v_Y_real); + // in: 15i 14i 13i 12i 11i 10i 9i 8i + // -> 15i 11i 14i 10i 13i 9i 12i 8i + *v_Y_imag = _mm512_permutexvar_pd(vperm, *v_Y_imag); + + // 00000000 > 3i 3r 2i 2r 1i 1r 0i 0r + __m512d v_X1 = _mm512_shuffle_pd(*v_X_real, *v_X_imag, 0x00); + // 11111111 > 7i 7r 6i 6r 5i 5r 4i 4r + __m512d v_X2 = _mm512_shuffle_pd(*v_X_real, *v_X_imag, 0xff); + // 00000000 > 11i 11r 10i 10r 9i 9r 8i 8r + __m512d v_Y1 = _mm512_shuffle_pd(*v_Y_real, *v_Y_imag, 0x00); + // 11111111 > 15i 15r 14i 14r 13i 13r 12i 12r + __m512d v_Y2 = _mm512_shuffle_pd(*v_Y_real, *v_Y_imag, 0xff); + + _mm512_storeu_pd(v_X_pt++, v_X1); + _mm512_storeu_pd(v_X_pt, v_X2); + _mm512_storeu_pd(v_Y_pt++, v_Y1); + _mm512_storeu_pd(v_Y_pt, v_Y2); +} +#endif // HEXL_HAS_AVX512DQ + +} // namespace hexl +} // namespace intel diff --git a/hexl_ser/include/hexl/experimental/fft-like/fft-like-native.hpp b/hexl_ser/include/hexl/experimental/fft-like/fft-like-native.hpp new file mode 100644 index 00000000..7e02492d --- /dev/null +++ b/hexl_ser/include/hexl/experimental/fft-like/fft-like-native.hpp @@ -0,0 +1,38 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +namespace intel { +namespace hexl { + +/// @brief Radix-2 native C++ FFT like implementation of the forward FFT like +/// @param[out] result Output data. Overwritten with FFT like output +/// @param[in] operand Input data. +/// @param[in] n Size of the transform, i.e. the polynomial degree. Must be a +/// power of two. +/// @param[in] root_of_unity_powers Powers of 2n'th root of unity. In +/// bit-reversed order +/// @param[in] scale Scale applied to output data +void Forward_FFTLike_ToBitReverseRadix2( + std::complex* result, const std::complex* operand, + const std::complex* root_of_unity_powers, const uint64_t n, + const double* scale = nullptr); + +/// @brief Radix-2 native C++ FFT like implementation of the inverse FFT like +/// @param[out] result Output data. Overwritten with FFT like output +/// @param[in] operand Input data. +/// @param[in] n Size of the transform, i.e. the polynomial degree. Must be a +/// power of two. +/// @param[in] inv_root_of_unity_powers Powers of inverse 2n'th root of unity. +/// In bit-reversed order. +/// @param[in] scale Scale applied to output data +void Inverse_FFTLike_FromBitReverseRadix2( + std::complex* result, const std::complex* operand, + const std::complex* inv_root_of_unity_powers, const uint64_t n, + const double* scale = nullptr); + +} // namespace hexl +} // namespace intel diff --git a/hexl_ser/include/hexl/experimental/fft-like/fft-like.hpp b/hexl_ser/include/hexl/experimental/fft-like/fft-like.hpp new file mode 100644 index 00000000..334de246 --- /dev/null +++ b/hexl_ser/include/hexl/experimental/fft-like/fft-like.hpp @@ -0,0 +1,147 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +#include "hexl/experimental/fft-like/fwd-fft-like-avx512.hpp" +#include "hexl/experimental/fft-like/inv-fft-like-avx512.hpp" +#include "hexl/util/aligned-allocator.hpp" +#include "hexl/util/allocator.hpp" + +namespace intel { +namespace hexl { + +/// @brief Performs linear forward and inverse FFT like transform +/// for CKKS encoding and decoding. +class FFTLike { + public: + /// @brief Helper class for custom memory allocation + template + struct AllocatorAdapter + : public AllocatorInterface> { + explicit AllocatorAdapter(Adaptee&& _a, Args&&... args); + AllocatorAdapter(const Adaptee& _a, Args&... args); + + // interface implementation + void* allocate_impl(size_t bytes_count); + void deallocate_impl(void* p, size_t n); + + private: + Adaptee alloc; + }; + + /// @brief Initializes an empty CKKS_FTT object + FFTLike() = default; + + /// @brief Destructs the CKKS_FTT object + ~FFTLike() = default; + + /// @brief Initializes an FFTLike object with degree \p degree and scalar + /// \p in_scalar. + /// @param[in] degree also known as N. Size of the FFT like transform. Must be + /// a power of 2 + /// @param[in] in_scalar Scalar value to calculate scale and inv scale + /// @param[in] alloc_ptr Custom memory allocator used for intermediate + /// calculations + /// @details Performs pre-computation necessary for forward and inverse + /// transforms + FFTLike(uint64_t degree, double* in_scalar, + std::shared_ptr alloc_ptr = {}); + + template + FFTLike(uint64_t degree, double* in_scalar, Allocator&& a, + AllocatorArgs&&... args) + : FFTLike( + degree, in_scalar, + std::static_pointer_cast( + std::make_shared>( + std::move(a), std::forward(args)...))) {} + + /// @brief Compute forward FFT like. Results are bit-reversed. + /// @param[out] result Stores the result + /// @param[in] operand Data on which to compute the FFT like + /// @param[in] in_scale Scale applied to output values + void ComputeForwardFFTLike(std::complex* result, + const std::complex* operand, + const double* in_scale = nullptr); + + /// @brief Compute inverse FFT like. Results are bit-reversed. + /// @param[out] result Stores the result + /// @param[in] operand Data on which to compute the FFT like + /// @param[in] in_scale Scale applied to output values + void ComputeInverseFFTLike(std::complex* result, + const std::complex* operand, + const double* in_scale = nullptr); + + /// @brief Construct floating-point values from CRT-composed polynomial with + /// integer coefficients. + /// @param[out] res Stores the result + /// @param[in] plain Plaintext + /// @param[in] threshold Upper half threshold with respect to the total + /// coefficient modulus + /// @param[in] decryption_modulus Product of all primes in the coefficient + /// modulus + /// @param[in] inv_scale Scale applied to output values + /// @param[in] mod_size Size of coefficient modulus parameter + /// @param[in] coeff_count Degree of the polynomial modulus parameter + void BuildFloatingPoints(std::complex* res, const uint64_t* plain, + const uint64_t* threshold, + const uint64_t* decryption_modulus, + const double inv_scale, size_t mod_size, + size_t coeff_count); + + /// @brief Returns the root of unity power at bit-reversed index i. + /// @param[in] i Index + std::complex GetComplexRootOfUnity(size_t i) { + return GetComplexRootsOfUnity()[i]; + } + + /// @brief Returns the root of unity in bit-reversed order + const AlignedVector64>& GetComplexRootsOfUnity() const { + return m_complex_roots_of_unity; + } + + /// @brief Returns the root of unity power at bit-reversed index i. + /// @param[in] i Index + std::complex GetInvComplexRootOfUnity(size_t i) { + return GetInvComplexRootsOfUnity()[i]; + } + + /// @brief Returns the inverse root of unity in bit-reversed order + const AlignedVector64>& GetInvComplexRootsOfUnity() + const { + return m_inv_complex_roots_of_unity; + } + + /// @brief Returns the degree N + uint64_t GetDegree() const { return m_degree; } + + private: + // Computes 1~(n-1)-th powers and inv powers of the primitive 2n-th root + void ComputeComplexRootsOfUnity(); + + uint64_t m_degree; // N: size of FFT like transform, should be power of 2 + + double* scalar; // Pointer to scalar used for scale/inv_scale calculation + + double scale; // Scale value use for encoding (inv fft-like) + + double inv_scale; // Scale value use in decoding (fwd fft-like) + + std::shared_ptr m_alloc; + + AlignedAllocator m_aligned_alloc; + + uint64_t m_degree_bits; // log_2(m_degree) + + // Contains 0~(n-1)-th powers of the 2n-th primitive root. + AlignedVector64> m_complex_roots_of_unity; + + // Contains 0~(n-1)-th inv powers of the 2n-th primitive inv root. + AlignedVector64> m_inv_complex_roots_of_unity; +}; + +} // namespace hexl +} // namespace intel diff --git a/hexl_ser/include/hexl/experimental/fft-like/fwd-fft-like-avx512.hpp b/hexl_ser/include/hexl/experimental/fft-like/fwd-fft-like-avx512.hpp new file mode 100644 index 00000000..aba4ca4d --- /dev/null +++ b/hexl_ser/include/hexl/experimental/fft-like/fwd-fft-like-avx512.hpp @@ -0,0 +1,59 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "hexl/experimental/fft-like/fft-like.hpp" + +namespace intel { +namespace hexl { + +#ifdef HEXL_HAS_AVX512DQ + +/// @brief AVX512 implementation of the forward FFT like +/// @param[out] result_cmplx_intrlvd Output data. Overwritten with FFT like +/// output. Result is a vector of double with interleaved real and imaginary +/// numbers. +/// @param[in] operand_cmplx_intrlvd Input data. A vector of double with +/// interleaved real and imaginary numbers. +/// @param[in] roots_of_unity_cmplx_intrlvd Powers of 2n'th root of unity. In +/// bit-reversed order. +/// @param[in] n Size of the transform, i.e. the polynomial degree. Must be a +/// power of two. +/// @param[in] scale Scale applied to output values +/// @param[in] recursion_depth Depth of recursive call +/// @param[in] recursion_half Helper for indexing roots of unity +/// @details The implementation is recursive. The base case is a breadth-first +/// FFT like, where all the butterflies in a given stage are processed before +/// any butterflies in the next stage. The base case is small enough to fit in +/// the smallest cache. Larger FFTs are processed recursively in a depth-first +/// manner, such that an entire subtransform is completed before moving to the +/// next subtransform. This reduces the number of cache misses, improving +/// performance on larger transform sizes. +void Forward_FFTLike_ToBitReverseAVX512( + double* result_cmplx_intrlvd, const double* operand_cmplx_intrlvd, + const double* roots_of_unity_cmplx_intrlvd, const uint64_t n, + const double* scale = nullptr, uint64_t recursion_depth = 0, + uint64_t recursion_half = 0); + +/// @brief Construct floating-point values from CRT-composed polynomial with +/// integer coefficients in AVX512. +/// @param[out] res_cmplx_intrlvd Stores the result +/// @param[in] plain Plaintext +/// @param[in] threshold Upper half threshold with respect to the total +/// coefficient modulus +/// @param[in] decryption_modulus Product of all primes in the coefficient +/// modulus +/// @param[in] inv_scale Scale applied to output values +/// @param[in] mod_size Size of coefficient modulus parameter +/// @param[in] coeff_count Degree of the polynomial modulus parameter +void BuildFloatingPointsAVX512(double* res_cmplx_intrlvd, const uint64_t* plain, + const uint64_t* threshold, + const uint64_t* decryption_modulus, + const double inv_scale, const size_t mod_size, + const size_t coeff_count); + +#endif // HEXL_HAS_AVX512DQ + +} // namespace hexl +} // namespace intel diff --git a/hexl_ser/include/hexl/experimental/fft-like/inv-fft-like-avx512.hpp b/hexl_ser/include/hexl/experimental/fft-like/inv-fft-like-avx512.hpp new file mode 100644 index 00000000..487e2828 --- /dev/null +++ b/hexl_ser/include/hexl/experimental/fft-like/inv-fft-like-avx512.hpp @@ -0,0 +1,42 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "hexl/experimental/fft-like/fft-like.hpp" + +namespace intel { +namespace hexl { + +#ifdef HEXL_HAS_AVX512DQ + +/// @brief AVX512 implementation of the inverse FFT like +/// @param[out] result_cmplx_intrlvd Output data. Overwritten with FFT like +/// output. Result is a vector of double with interleaved real and imaginary +/// numbers. +/// @param[in] operand_cmplx_intrlvd Input data. A vector of double with +/// interleaved real and imaginary numbers. +/// @param[in] inv_roots_of_unity_cmplx_intrlvd Powers of 2n'th root of unity. +/// In bit-reversed order. +/// @param[in] n Size of the transform, i.e. the polynomial degree. Must be a +/// power of two. +/// @param[in] scale Scale applied to output values +/// @param[in] recursion_depth Depth of recursive call +/// @param[in] recursion_half Helper for indexing roots of unity +/// @details The implementation is recursive. The base case is a breadth-first +/// FFT like, where all the butterflies in a given stage are processed before +/// any butterflies in the next stage. The base case is small enough to fit in +/// the smallest cache. Larger FFTs are processed recursively in a depth-first +/// manner, such that an entire subtransform is completed before moving to the +/// next subtransform. This reduces the number of cache misses, improving +/// performance on larger transform sizes. +void Inverse_FFTLike_FromBitReverseAVX512( + double* result_cmplx_intrlvd, const double* operand_cmplx_intrlvd, + const double* inv_root_of_unity_cmplxintrlvd, const uint64_t n, + const double* scale = nullptr, uint64_t recursion_depth = 0, + uint64_t recursion_half = 0); + +#endif // HEXL_HAS_AVX512DQ + +} // namespace hexl +} // namespace intel diff --git a/hexl_ser/include/hexl/experimental/misc/lr-mat-vec-mult.hpp b/hexl_ser/include/hexl/experimental/misc/lr-mat-vec-mult.hpp new file mode 100644 index 00000000..df03df92 --- /dev/null +++ b/hexl_ser/include/hexl/experimental/misc/lr-mat-vec-mult.hpp @@ -0,0 +1,31 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +namespace intel { +namespace hexl { + +/// @brief Computes transposed linear regression +/// @param[in,out] result Ciphertext data. Will be over-written with result. Has +/// (3 * n * num_moduli) elements +/// @param[in] operand1 Vector of ciphertext representing a matrix that encodes +/// a transposed logistic regression model. Has (num_weights * 2 * n * +/// num_moduli) elements. +/// @param[in] operand2 Vector of ciphertext representing a matrix that encodes +/// at most n/2 input samples with feature size num_weights. Has (num_weights * +/// 2 * n * num_moduli) elements. +/// @param[in] n Number of coefficients in each polynomial +/// @param[in] moduli Pointer to contiguous array of num_moduli word-sized +/// coefficient moduli +/// @param[in] num_moduli Number of word-sized coefficient moduli +/// @param[in] num_weights Feature size of the linear/logistic regression model +void LinRegMatrixVectorMultiply(uint64_t* result, const uint64_t* operand1, + const uint64_t* operand2, uint64_t n, + const uint64_t* moduli, uint64_t num_moduli, + uint64_t num_weights); + +} // namespace hexl +} // namespace intel diff --git a/hexl_ser/include/hexl/experimental/seal/dyadic-multiply-internal.hpp b/hexl_ser/include/hexl/experimental/seal/dyadic-multiply-internal.hpp new file mode 100644 index 00000000..310a46b0 --- /dev/null +++ b/hexl_ser/include/hexl/experimental/seal/dyadic-multiply-internal.hpp @@ -0,0 +1,29 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +namespace intel { +namespace hexl { +namespace internal { + +/// @brief Computes dyadic multiplication +/// @param[in,out] result Ciphertext data. Will be over-written with result. Has +/// (2 * n * num_moduli) elements +/// @param[in] operand1 First ciphertext argument. Has (2 * n * num_moduli) +/// elements. +/// @param[in] operand2 Second ciphertext argument. Has (2 * n * num_moduli) +/// elements. +/// @param[in] n Number of coefficients in each polynomial +/// @param[in] moduli Pointer to contiguous array of num_moduli word-sized +/// coefficient moduli +/// @param[in] num_moduli Number of word-sized coefficient moduli +void DyadicMultiply(uint64_t* result, const uint64_t* operand1, + const uint64_t* operand2, uint64_t n, + const uint64_t* moduli, uint64_t num_moduli); + +} // namespace internal +} // namespace hexl +} // namespace intel diff --git a/hexl_ser/include/hexl/experimental/seal/dyadic-multiply.hpp b/hexl_ser/include/hexl/experimental/seal/dyadic-multiply.hpp new file mode 100644 index 00000000..f7eacfdf --- /dev/null +++ b/hexl_ser/include/hexl/experimental/seal/dyadic-multiply.hpp @@ -0,0 +1,27 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +namespace intel { +namespace hexl { + +/// @brief Computes dyadic multiplication +/// @param[in,out] result Ciphertext data. Will be over-written with result. Has +/// (2 * n * num_moduli) elements +/// @param[in] operand1 First ciphertext argument. Has (2 * n * num_moduli) +/// elements. +/// @param[in] operand2 Second ciphertext argument. Has (2 * n * num_moduli) +/// elements. +/// @param[in] n Number of coefficients in each polynomial +/// @param[in] moduli Pointer to contiguous array of num_moduli word-sized +/// coefficient moduli +/// @param[in] num_moduli Number of word-sized coefficient moduli +void DyadicMultiply(uint64_t* result, const uint64_t* operand1, + const uint64_t* operand2, uint64_t n, + const uint64_t* moduli, uint64_t num_moduli); + +} // namespace hexl +} // namespace intel diff --git a/hexl_ser/include/hexl/experimental/seal/key-switch-internal.hpp b/hexl_ser/include/hexl/experimental/seal/key-switch-internal.hpp new file mode 100644 index 00000000..8fc9d53e --- /dev/null +++ b/hexl_ser/include/hexl/experimental/seal/key-switch-internal.hpp @@ -0,0 +1,44 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +namespace intel { +namespace hexl { +namespace internal { + +/// @brief Computes key switching in-place +/// @param[in,out] result Ciphertext data. Will be over-written with result. Has +/// (n * decomp_modulus_size * key_component_count) elements +/// @param[in] t_target_iter_ptr Pointer to the last component of the input +/// ciphertext +/// @param[in] n Number of coefficients in each polynomial +/// @param[in] decomp_modulus_size Number of moduli in the ciphertext at its +/// current level, excluding one auxiliary prime. +/// @param[in] key_modulus_size Number of moduli in the ciphertext at its top +/// level, including one auxiliary prime. +/// @param[in] rns_modulus_size Number of moduli in the ciphertext at its +/// current level, including one auxiliary prime. rns_modulus_size == +/// decomp_modulus_size + 1 +/// @param[in] key_component_count Number of components in the resulting +/// ciphertext, e.g. key_component_count == 2. +/// @param[in] moduli Array of word-sized coefficient moduli. There must be +/// key_modulus_size moduli in the array +/// @param[in] k_switch_keys Array of evaluation key data. Has +/// decomp_modulus_size entries, each with +/// coeff_count * ((key_modulus_size - 1)+ (key_component_count - 1) * +/// (key_modulus_size) + 1) entries +/// @param[in] modswitch_factors Array of modulus switch factors +/// @param[in] root_of_unity_powers_ptr Array of root of unity powers +void KeySwitch(uint64_t* result, const uint64_t* t_target_iter_ptr, uint64_t n, + uint64_t decomp_modulus_size, uint64_t key_modulus_size, + uint64_t rns_modulus_size, uint64_t key_component_count, + const uint64_t* moduli, const uint64_t** k_switch_keys, + const uint64_t* modswitch_factors, + const uint64_t* root_of_unity_powers_ptr = nullptr); + +} // namespace internal +} // namespace hexl +} // namespace intel diff --git a/hexl_ser/include/hexl/experimental/seal/key-switch.hpp b/hexl_ser/include/hexl/experimental/seal/key-switch.hpp new file mode 100644 index 00000000..9eda159c --- /dev/null +++ b/hexl_ser/include/hexl/experimental/seal/key-switch.hpp @@ -0,0 +1,42 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +namespace intel { +namespace hexl { + +/// @brief Computes key switching in-place +/// @param[in,out] result Ciphertext data. Will be over-written with result. Has +/// (n * decomp_modulus_size * key_component_count) elements +/// @param[in] t_target_iter_ptr Pointer to the last component of the input +/// ciphertext +/// @param[in] n Number of coefficients in each polynomial +/// @param[in] decomp_modulus_size Number of moduli in the ciphertext at its +/// current level, excluding one auxiliary prime. +/// @param[in] key_modulus_size Number of moduli in the ciphertext at its top +/// level, including one auxiliary prime. +/// @param[in] rns_modulus_size Number of moduli in the ciphertext at its +/// current level, including one auxiliary prime. rns_modulus_size == +/// decomp_modulus_size + 1 +/// @param[in] key_component_count Number of components in the resulting +/// ciphertext, e.g. key_component_count == 2. +/// @param[in] moduli Array of word-sized coefficient moduli. There must be +/// key_modulus_size moduli in the array +/// @param[in] k_switch_keys Array of evaluation key data. Has +/// decomp_modulus_size entries, each with +/// coeff_count * ((key_modulus_size - 1)+ (key_component_count - 1) * +/// (key_modulus_size) + 1) entries +/// @param[in] modswitch_factors Array of modulus switch factors +/// @param[in] root_of_unity_powers_ptr Array of root of unity powers +void KeySwitch(uint64_t* result, const uint64_t* t_target_iter_ptr, uint64_t n, + uint64_t decomp_modulus_size, uint64_t key_modulus_size, + uint64_t rns_modulus_size, uint64_t key_component_count, + const uint64_t* moduli, const uint64_t** k_switch_keys, + const uint64_t* modswitch_factors, + const uint64_t* root_of_unity_powers_ptr = nullptr); + +} // namespace hexl +} // namespace intel diff --git a/hexl_ser/include/hexl/experimental/seal/locks.hpp b/hexl_ser/include/hexl/experimental/seal/locks.hpp new file mode 100644 index 00000000..4595f4e5 --- /dev/null +++ b/hexl_ser/include/hexl/experimental/seal/locks.hpp @@ -0,0 +1,35 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include + +namespace intel { +namespace hexl { + +using Lock = std::shared_mutex; +using WriteLock = std::unique_lock; +using ReadLock = std::shared_lock; + +class RWLock { + public: + RWLock() = default; + inline ReadLock AcquireRead() { return ReadLock(rw_mutex); } + inline WriteLock AcquireWrite() { return WriteLock(rw_mutex); } + inline ReadLock TryAcquireRead() noexcept { + return ReadLock(rw_mutex, std::try_to_lock); + } + inline WriteLock TryAcquireWrite() noexcept { + return WriteLock(rw_mutex, std::try_to_lock); + } + + private: + RWLock(const RWLock& copy) = delete; + RWLock& operator=(const RWLock& assign) = delete; + Lock rw_mutex{}; +}; + +} // namespace hexl +} // namespace intel diff --git a/hexl_ser/include/hexl/experimental/seal/ntt-cache.hpp b/hexl_ser/include/hexl/experimental/seal/ntt-cache.hpp new file mode 100644 index 00000000..8f6c1046 --- /dev/null +++ b/hexl_ser/include/hexl/experimental/seal/ntt-cache.hpp @@ -0,0 +1,56 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "hexl/experimental/seal/locks.hpp" +#include "ntt/ntt-internal.hpp" + +namespace intel { +namespace hexl { + +struct HashPair { + template + std::size_t operator()(const std::pair& p) const { + auto hash1 = std::hash{}(p.first); + auto hash2 = std::hash{}(p.second); + return hash_combine(hash1, hash2); + } + + // Golden Ratio Hashing with seeds + static std::size_t hash_combine(std::size_t lhs, std::size_t rhs) { + lhs ^= rhs + 0x9e3779b9 + (lhs << 6) + (lhs >> 2); + return lhs; + } +}; + +NTT& GetNTT(size_t N, uint64_t modulus) { + static std::unordered_map, NTT, HashPair> + ntt_cache; + static RWLock ntt_cache_locker; + + std::pair key{N, modulus}; + + // Enable shared access to NTT already present + { + ReadLock reader_lock(ntt_cache_locker.AcquireRead()); + auto ntt_it = ntt_cache.find(key); + if (ntt_it != ntt_cache.end()) { + return ntt_it->second; + } + } + + // Deal with NTT not yet present + WriteLock write_lock(ntt_cache_locker.AcquireWrite()); + + // Check ntt_cache for value (may be added by another thread) + auto ntt_it = ntt_cache.find(key); + if (ntt_it == ntt_cache.end()) { + NTT ntt(N, modulus); + ntt_it = ntt_cache.emplace(std::move(key), std::move(ntt)).first; + } + return ntt_it->second; +} + +} // namespace hexl +} // namespace intel diff --git a/hexl_ser/include/hexl/hexl.hpp b/hexl_ser/include/hexl/hexl.hpp new file mode 100644 index 00000000..6f07ae57 --- /dev/null +++ b/hexl_ser/include/hexl/hexl.hpp @@ -0,0 +1,26 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "hexl/eltwise/eltwise-add-mod.hpp" +#include "hexl/eltwise/eltwise-cmp-add.hpp" +#include "hexl/eltwise/eltwise-cmp-sub-mod.hpp" +#include "hexl/eltwise/eltwise-fma-mod.hpp" +#include "hexl/eltwise/eltwise-mult-mod.hpp" +#include "hexl/eltwise/eltwise-reduce-mod.hpp" +#include "hexl/eltwise/eltwise-sub-mod.hpp" +#include "hexl/experimental/fft-like/fft-like.hpp" +#include "hexl/experimental/misc/lr-mat-vec-mult.hpp" +#include "hexl/experimental/seal/dyadic-multiply-internal.hpp" +#include "hexl/experimental/seal/dyadic-multiply.hpp" +#include "hexl/experimental/seal/key-switch-internal.hpp" +#include "hexl/experimental/seal/key-switch.hpp" +#include "hexl/logging/logging.hpp" +#include "hexl/ntt/ntt.hpp" +#include "hexl/number-theory/number-theory.hpp" +#include "hexl/util/check.hpp" +#include "hexl/util/compiler.hpp" +#include "hexl/util/defines.hpp" +#include "hexl/util/types.hpp" +#include "hexl/util/util.hpp" diff --git a/hexl_ser/include/hexl/logging/logging.hpp b/hexl_ser/include/hexl/logging/logging.hpp new file mode 100644 index 00000000..af5bfcd8 --- /dev/null +++ b/hexl_ser/include/hexl/logging/logging.hpp @@ -0,0 +1,43 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include + +#include "hexl/util/defines.hpp" + +// Wrap HEXL_VLOG with HEXL_DEBUG; this ensures no logging overhead in +// release mode +#ifdef HEXL_DEBUG + +// TODO(fboemer) Enable if needed +// #define ELPP_THREAD_SAFE +#define ELPP_CUSTOM_COUT std::cerr +#define ELPP_STL_LOGGING +#define ELPP_LOG_STD_ARRAY +#define ELPP_LOG_UNORDERED_MAP +#define ELPP_LOG_UNORDERED_SET +#define ELPP_NO_LOG_TO_FILE +#define ELPP_DISABLE_DEFAULT_CRASH_HANDLING +#define ELPP_WINSOCK2 + +#include + +#define HEXL_VLOG(N, rest) \ + do { \ + if (VLOG_IS_ON(N)) { \ + VLOG(N) << rest; \ + } \ + } while (0); + +#else + +#define HEXL_VLOG(N, rest) \ + {} + +#define START_EASYLOGGINGPP(X, Y) \ + {} + +#endif diff --git a/hexl_ser/include/hexl/ntt/ntt.hpp b/hexl_ser/include/hexl/ntt/ntt.hpp new file mode 100644 index 00000000..93ccba72 --- /dev/null +++ b/hexl_ser/include/hexl/ntt/ntt.hpp @@ -0,0 +1,296 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +#include +#include + +#include "hexl/util/aligned-allocator.hpp" +#include "hexl/util/allocator.hpp" + +namespace intel { +namespace hexl { + +/// @brief Performs negacyclic forward and inverse number-theoretic transform +/// (NTT), commonly used in RLWE cryptography. +/// @details The number-theoretic transform (NTT) specializes the discrete +/// Fourier transform (DFT) to the finite field \f$ \mathbb{Z}_q[X] / (X^N + 1) +/// \f$. +class NTT { + public: + /// @brief Helper class for custom memory allocation + template + struct AllocatorAdapter + : public AllocatorInterface> { + explicit AllocatorAdapter(Adaptee&& _a, Args&&... args); + AllocatorAdapter(const Adaptee& _a, Args&... args); + + // interface implementation + void* allocate_impl(size_t bytes_count); + void deallocate_impl(void* p, size_t n); + + private: + Adaptee alloc; + }; + + /// @brief Initializes an empty NTT object + NTT() = default; + + /// @brief Destructs the NTT object + ~NTT() = default; + + /// @brief Initializes an NTT object with degree \p degree and modulus \p q. + /// @param[in] degree also known as N. Size of the NTT transform. Must be a + /// power of + /// 2 + /// @param[in] q Prime modulus. Must satisfy \f$ q == 1 \mod 2N \f$ + /// @param[in] alloc_ptr Custom memory allocator used for intermediate + /// calculations + /// @brief Performs pre-computation necessary for forward and inverse + /// transforms + NTT(uint64_t degree, uint64_t q, + std::shared_ptr alloc_ptr = {}); + + template + NTT(uint64_t degree, uint64_t q, Allocator&& a, AllocatorArgs&&... args) + : NTT(degree, q, + std::static_pointer_cast( + std::make_shared>( + std::move(a), std::forward(args)...))) {} + + /// @brief Initializes an NTT object with degree \p degree and modulus + /// \p q. + /// @param[in] degree also known as N. Size of the NTT transform. Must be a + /// power of + /// 2 + /// @param[in] q Prime modulus. Must satisfy \f$ q == 1 \mod 2N \f$ + /// @param[in] root_of_unity 2N'th root of unity in \f$ \mathbb{Z_q} \f$. + /// @param[in] alloc_ptr Custom memory allocator used for intermediate + /// calculations + /// @details Performs pre-computation necessary for forward and inverse + /// transforms + NTT(uint64_t degree, uint64_t q, uint64_t root_of_unity, + std::shared_ptr alloc_ptr = {}); + + template + NTT(uint64_t degree, uint64_t q, uint64_t root_of_unity, Allocator&& a, + AllocatorArgs&&... args) + : NTT(degree, q, root_of_unity, + std::static_pointer_cast( + std::make_shared>( + std::move(a), std::forward(args)...))) {} + + /// @brief Returns true if arguments satisfy constraints for negacyclic NTT + /// @param[in] degree N. Size of the transform, i.e. the polynomial degree. + /// Must be a power of two. + /// @param[in] modulus Prime modulus q. Must satisfy q mod 2N = 1 + static bool CheckArguments(uint64_t degree, uint64_t modulus); + + /// @brief Compute forward NTT. Results are bit-reversed. + /// @param[out] result Stores the result + /// @param[in] operand Data on which to compute the NTT + /// @param[in] input_mod_factor Assume input \p operand are in [0, + /// input_mod_factor * q). Must be 1, 2 or 4. + /// @param[in] output_mod_factor Returns output \p result in [0, + /// output_mod_factor * q). Must be 1 or 4. + void ComputeForward(uint64_t* result, const uint64_t* operand, + uint64_t input_mod_factor, uint64_t output_mod_factor); + + /// Compute inverse NTT. Results are bit-reversed. + /// @param[out] result Stores the result + /// @param[in] operand Data on which to compute the NTT + /// @param[in] input_mod_factor Assume input \p operand are in [0, + /// input_mod_factor * q). Must be 1 or 2. + /// @param[in] output_mod_factor Returns output \p result in [0, + /// output_mod_factor * q). Must be 1 or 2. + void ComputeInverse(uint64_t* result, const uint64_t* operand, + uint64_t input_mod_factor, uint64_t output_mod_factor); + + /// @brief Returns the minimal 2N'th root of unity + uint64_t GetMinimalRootOfUnity() const { return m_w; } + + /// @brief Returns the degree N + uint64_t GetDegree() const { return m_degree; } + + /// @brief Returns the word-sized prime modulus + uint64_t GetModulus() const { return m_q; } + + /// @brief Returns the root of unity powers in bit-reversed order + const AlignedVector64& GetRootOfUnityPowers() const { + return m_root_of_unity_powers; + } + + /// @brief Returns the root of unity power at bit-reversed index i. + uint64_t GetRootOfUnityPower(size_t i) { return GetRootOfUnityPowers()[i]; } + + /// @brief Returns 32-bit pre-conditioned root of unity powers in + /// bit-reversed order + const AlignedVector64& GetPrecon32RootOfUnityPowers() const { + return m_precon32_root_of_unity_powers; + } + + /// @brief Returns 64-bit pre-conditioned root of unity powers in + /// bit-reversed order + const AlignedVector64& GetPrecon64RootOfUnityPowers() const { + return m_precon64_root_of_unity_powers; + } + + /// @brief Returns the root of unity powers in bit-reversed order with + /// modifications for use by AVX512 implementation + const AlignedVector64& GetAVX512RootOfUnityPowers() const { + return m_avx512_root_of_unity_powers; + } + + /// @brief Returns 32-bit pre-conditioned AVX512 root of unity powers in + /// bit-reversed order + const AlignedVector64& GetAVX512Precon32RootOfUnityPowers() const { + return m_avx512_precon32_root_of_unity_powers; + } + + /// @brief Returns 52-bit pre-conditioned AVX512 root of unity powers in + /// bit-reversed order + const AlignedVector64& GetAVX512Precon52RootOfUnityPowers() const { + return m_avx512_precon52_root_of_unity_powers; + } + + /// @brief Returns 64-bit pre-conditioned AVX512 root of unity powers in + /// bit-reversed order + const AlignedVector64& GetAVX512Precon64RootOfUnityPowers() const { + return m_avx512_precon64_root_of_unity_powers; + } + + /// @brief Returns the inverse root of unity powers in bit-reversed order + const AlignedVector64& GetInvRootOfUnityPowers() const { + return m_inv_root_of_unity_powers; + } + + /// @brief Returns the inverse root of unity power at bit-reversed index i. + uint64_t GetInvRootOfUnityPower(size_t i) { + return GetInvRootOfUnityPowers()[i]; + } + + /// @brief Returns the vector of 32-bit pre-conditioned pre-computed root of + /// unity + // powers for the modulus and root of unity. + const AlignedVector64& GetPrecon32InvRootOfUnityPowers() const { + return m_precon32_inv_root_of_unity_powers; + } + + /// @brief Returns the vector of 52-bit pre-conditioned pre-computed root of + /// unity + // powers for the modulus and root of unity. + const AlignedVector64& GetPrecon52InvRootOfUnityPowers() const { + return m_precon52_inv_root_of_unity_powers; + } + + /// @brief Returns the vector of 64-bit pre-conditioned pre-computed root of + /// unity + // powers for the modulus and root of unity. + const AlignedVector64& GetPrecon64InvRootOfUnityPowers() const { + return m_precon64_inv_root_of_unity_powers; + } + + /// @brief Maximum power of 2 in degree + static size_t MaxDegreeBits() { return 20; } + + /// @brief Maximum number of bits in modulus; + static size_t MaxModulusBits() { return 62; } + + /// @brief Default bit shift used in Barrett precomputation + static const size_t s_default_shift_bits{64}; + + /// @brief Bit shift used in Barrett precomputation when AVX512-IFMA + /// acceleration is enabled + static const size_t s_ifma_shift_bits{52}; + + /// @brief Maximum modulus to use 32-bit AVX512-DQ acceleration for the + /// forward transform + static const size_t s_max_fwd_32_modulus{1ULL << (32 - 2)}; + + /// @brief Maximum modulus to use 32-bit AVX512-DQ acceleration for the + /// inverse transform + static const size_t s_max_inv_32_modulus{1ULL << (32 - 2)}; + + /// @brief Maximum modulus to use AVX512-IFMA acceleration for the forward + /// transform + static const size_t s_max_fwd_ifma_modulus{1ULL << (s_ifma_shift_bits - 2)}; + + /// @brief Maximum modulus to use AVX512-IFMA acceleration for the inverse + /// transform + static const size_t s_max_inv_ifma_modulus{1ULL << (s_ifma_shift_bits - 2)}; + + /// @brief Maximum modulus to use AVX512-DQ acceleration for the inverse + /// transform + static const size_t s_max_inv_dq_modulus{1ULL << (s_default_shift_bits - 2)}; + + static size_t s_max_fwd_modulus(int bit_shift) { + if (bit_shift == 32) { + return s_max_fwd_32_modulus; + } else if (bit_shift == 52) { + return s_max_fwd_ifma_modulus; + } else if (bit_shift == 64) { + return 1ULL << MaxModulusBits(); + } + HEXL_CHECK(false, "Invalid bit_shift " << bit_shift); + return 0; + } + + static size_t s_max_inv_modulus(int bit_shift) { + if (bit_shift == 32) { + return s_max_inv_32_modulus; + } else if (bit_shift == 52) { + return s_max_inv_ifma_modulus; + } else if (bit_shift == 64) { + return 1ULL << MaxModulusBits(); + } + HEXL_CHECK(false, "Invalid bit_shift " << bit_shift); + return 0; + } + + private: + void ComputeRootOfUnityPowers(); + + uint64_t m_degree; // N: size of NTT transform, should be power of 2 + uint64_t m_q; // prime modulus. Must satisfy q == 1 mod 2n + + uint64_t m_degree_bits; // log_2(m_degree) + + uint64_t m_w_inv; // Inverse of minimal root of unity + uint64_t m_w; // A 2N'th root of unity + + std::shared_ptr m_alloc; + + AlignedAllocator m_aligned_alloc; + + // powers of the minimal root of unity + AlignedVector64 m_root_of_unity_powers; + // vector of floor(W * 2**32 / m_q), with W the root of unity powers + AlignedVector64 m_precon32_root_of_unity_powers; + // vector of floor(W * 2**64 / m_q), with W the root of unity powers + AlignedVector64 m_precon64_root_of_unity_powers; + + // powers of the minimal root of unity adjusted for use in AVX512 + // implementations + AlignedVector64 m_avx512_root_of_unity_powers; + // vector of floor(W * 2**32 / m_q), with W the AVX512 root of unity powers + AlignedVector64 m_avx512_precon32_root_of_unity_powers; + // vector of floor(W * 2**52 / m_q), with W the AVX512 root of unity powers + AlignedVector64 m_avx512_precon52_root_of_unity_powers; + // vector of floor(W * 2**64 / m_q), with W the AVX512 root of unity powers + AlignedVector64 m_avx512_precon64_root_of_unity_powers; + + // vector of floor(W * 2**32 / m_q), with W the inverse root of unity powers + AlignedVector64 m_precon32_inv_root_of_unity_powers; + // vector of floor(W * 2**52 / m_q), with W the inverse root of unity powers + AlignedVector64 m_precon52_inv_root_of_unity_powers; + // vector of floor(W * 2**64 / m_q), with W the inverse root of unity powers + AlignedVector64 m_precon64_inv_root_of_unity_powers; + + AlignedVector64 m_inv_root_of_unity_powers; +}; + +} // namespace hexl +} // namespace intel diff --git a/hexl_ser/include/hexl/number-theory/number-theory.hpp b/hexl_ser/include/hexl/number-theory/number-theory.hpp new file mode 100644 index 00000000..da8d1d2a --- /dev/null +++ b/hexl_ser/include/hexl/number-theory/number-theory.hpp @@ -0,0 +1,342 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +#include +#include + +#include "hexl/util/check.hpp" +#include "hexl/util/compiler.hpp" + +namespace intel { +namespace hexl { + +/// @brief Pre-computes a Barrett factor with which modular multiplication can +/// be performed more efficiently +class MultiplyFactor { + public: + MultiplyFactor() = default; + + /// @brief Computes and stores the Barrett factor floor((operand << bit_shift) + /// / modulus). This is useful when modular multiplication of the form + /// (x * operand) mod modulus is performed with same modulus and operand + /// several times. Note, passing operand=1 can be used to pre-compute a + /// Barrett factor for multiplications of the form (x * y) mod modulus, where + /// only the modulus is re-used across calls to modular multiplication. + MultiplyFactor(uint64_t operand, uint64_t bit_shift, uint64_t modulus) + : m_operand(operand) { + HEXL_CHECK(operand <= modulus, "operand " << operand + << " must be less than modulus " + << modulus); + HEXL_CHECK(bit_shift == 32 || bit_shift == 52 || bit_shift == 64, + "Unsupported BitShift " << bit_shift); + uint64_t op_hi = operand >> (64 - bit_shift); + uint64_t op_lo = (bit_shift == 64) ? 0 : (operand << bit_shift); + + m_barrett_factor = DivideUInt128UInt64Lo(op_hi, op_lo, modulus); + } + + /// @brief Returns the pre-computed Barrett factor + inline uint64_t BarrettFactor() const { return m_barrett_factor; } + + /// @brief Returns the operand corresponding to the Barrett factor + inline uint64_t Operand() const { return m_operand; } + + private: + uint64_t m_operand; + uint64_t m_barrett_factor; +}; + +/// @brief Returns whether or not num is a power of two +inline bool IsPowerOfTwo(uint64_t num) { return num && !(num & (num - 1)); } + +/// @brief Returns floor(log2(x)) +inline uint64_t Log2(uint64_t x) { return MSB(x); } + +inline bool IsPowerOfFour(uint64_t num) { + return IsPowerOfTwo(num) && (Log2(num) % 2 == 0); +} + +/// @brief Returns the maximum value that can be represented using \p bits bits +inline uint64_t MaximumValue(uint64_t bits) { + HEXL_CHECK(bits <= 64, "MaximumValue requires bits <= 64; got " << bits); + if (bits == 64) { + return (std::numeric_limits::max)(); + } + return (1ULL << bits) - 1; +} + +/// @brief Reverses the bits +/// @param[in] x Input to reverse +/// @param[in] bit_width Number of bits in the input; must be >= MSB(x) +/// @return The bit-reversed representation of \p x using \p bit_width bits +uint64_t ReverseBits(uint64_t x, uint64_t bit_width); + +/// @brief Returns x^{-1} mod modulus +/// @details Requires x % modulus != 0 +uint64_t InverseMod(uint64_t x, uint64_t modulus); + +/// @brief Returns (x * y) mod modulus +/// @details Assumes x, y < modulus +uint64_t MultiplyMod(uint64_t x, uint64_t y, uint64_t modulus); + +/// @brief Returns (x * y) mod modulus +/// @param[in] x +/// @param[in] y +/// @param[in] y_precon 64-bit precondition factor floor(2**64 / modulus) +/// @param[in] modulus +uint64_t MultiplyMod(uint64_t x, uint64_t y, uint64_t y_precon, + uint64_t modulus); + +/// @brief Returns (x + y) mod modulus +/// @details Assumes x, y < modulus +uint64_t AddUIntMod(uint64_t x, uint64_t y, uint64_t modulus); + +/// @brief Returns (x - y) mod modulus +/// @details Assumes x, y < modulus +uint64_t SubUIntMod(uint64_t x, uint64_t y, uint64_t modulus); + +/// @brief Returns base^exp mod modulus +uint64_t PowMod(uint64_t base, uint64_t exp, uint64_t modulus); + +/// @brief Returns whether or not root is a degree-th root of unity mod modulus +/// @param[in] root Root of unity to check +/// @param[in] degree Degree of root of unity; must be a power of two +/// @param[in] modulus Modulus of finite field +bool IsPrimitiveRoot(uint64_t root, uint64_t degree, uint64_t modulus); + +/// @brief Tries to return a primitive degree-th root of unity +/// @details Returns 0 or throws an error if no root is found +uint64_t GeneratePrimitiveRoot(uint64_t degree, uint64_t modulus); + +/// @brief Returns whether or not root is a degree-th root of unity +/// @param[in] degree Must be a power of two +/// @param[in] modulus Modulus of finite field +uint64_t MinimalPrimitiveRoot(uint64_t degree, uint64_t modulus); + +/// @brief Computes (x * y) mod modulus, except that the output is in [0, 2 * +/// modulus] +/// @param[in] x +/// @param[in] y_operand also denoted y +/// @param[in] modulus +/// @param[in] y_barrett_factor Pre-computed Barrett reduction factor floor((y +/// << BitShift) / modulus) +template +inline uint64_t MultiplyModLazy(uint64_t x, uint64_t y_operand, + uint64_t y_barrett_factor, uint64_t modulus) { + HEXL_CHECK(y_operand < modulus, "y_operand " << y_operand + << " must be less than modulus " + << modulus); + HEXL_CHECK( + modulus <= MaximumValue(BitShift), + "Modulus " << modulus << " exceeds bound " << MaximumValue(BitShift)); + HEXL_CHECK(x <= MaximumValue(BitShift), + "Operand " << x << " exceeds bound " << MaximumValue(BitShift)); + + uint64_t Q = MultiplyUInt64Hi(x, y_barrett_factor); + return y_operand * x - Q * modulus; +} + +/// @brief Computes (x * y) mod modulus, except that the output is in [0, 2 * +/// modulus] +/// @param[in] x +/// @param[in] y +/// @param[in] modulus +template +inline uint64_t MultiplyModLazy(uint64_t x, uint64_t y, uint64_t modulus) { + HEXL_CHECK(BitShift == 64 || BitShift == 52, + "Unsupported BitShift " << BitShift); + HEXL_CHECK(x <= MaximumValue(BitShift), + "Operand " << x << " exceeds bound " << MaximumValue(BitShift)); + HEXL_CHECK(y < modulus, + "y " << y << " must be less than modulus " << modulus); + HEXL_CHECK( + modulus <= MaximumValue(BitShift), + "Modulus " << modulus << " exceeds bound " << MaximumValue(BitShift)); + + uint64_t y_barrett = MultiplyFactor(y, BitShift, modulus).BarrettFactor(); + return MultiplyModLazy(x, y, y_barrett, modulus); +} + +/// @brief Adds two unsigned 64-bit integers +/// @param operand1 Number to add +/// @param operand2 Number to add +/// @param result Stores the sum +/// @return The carry bit +inline unsigned char AddUInt64(uint64_t operand1, uint64_t operand2, + uint64_t* result) { + *result = operand1 + operand2; + return static_cast(*result < operand1); +} + +/// @brief Returns whether or not the input is prime +bool IsPrime(uint64_t n); + +/// @brief Generates a list of num_primes primes in the range [2^(bit_size), +// 2^(bit_size+1)]. Ensures each prime q satisfies +// q % (2*ntt_size+1)) == 1 +/// @param[in] num_primes Number of primes to generate +/// @param[in] bit_size Bit size of each prime +/// @param[in] prefer_small_primes When true, returns primes starting from +/// 2^(bit_size); when false, returns primes starting from 2^(bit_size+1) +/// @param[in] ntt_size N such that each prime q satisfies q % (2N) == 1. N must +/// be a power of two less than 2^bit_size. +std::vector GeneratePrimes(size_t num_primes, size_t bit_size, + bool prefer_small_primes, + size_t ntt_size = 1); + +/// @brief Returns input mod modulus, computed via 64-bit Barrett reduction +/// @param[in] input +/// @param[in] modulus +/// @param[in] q_barr floor(2^64 / modulus) +template +uint64_t BarrettReduce64(uint64_t input, uint64_t modulus, uint64_t q_barr) { + HEXL_CHECK(modulus != 0, "modulus == 0"); + uint64_t q = MultiplyUInt64Hi<64>(input, q_barr); + uint64_t q_times_input = input - q * modulus; + if (OutputModFactor == 2) { + return q_times_input; + } else { + return (q_times_input >= modulus) ? q_times_input - modulus : q_times_input; + } +} + +/// @brief Returns x mod modulus, assuming x < InputModFactor * modulus +/// @param[in] x +/// @param[in] modulus also denoted q +/// @param[in] twice_modulus 2 * q; must not be nullptr if InputModFactor == 4 +/// or 8 +/// @param[in] four_times_modulus 4 * q; must not be nullptr if InputModFactor +/// == 8 +template +uint64_t ReduceMod(uint64_t x, uint64_t modulus, + const uint64_t* twice_modulus = nullptr, + const uint64_t* four_times_modulus = nullptr) { + HEXL_CHECK(InputModFactor == 1 || InputModFactor == 2 || + InputModFactor == 4 || InputModFactor == 8, + "InputModFactor should be 1, 2, 4, or 8"); + if (InputModFactor == 1) { + return x; + } + if (InputModFactor == 2) { + if (x >= modulus) { + x -= modulus; + } + return x; + } + if (InputModFactor == 4) { + HEXL_CHECK(twice_modulus != nullptr, "twice_modulus should not be nullptr"); + if (x >= *twice_modulus) { + x -= *twice_modulus; + } + if (x >= modulus) { + x -= modulus; + } + return x; + } + if (InputModFactor == 8) { + HEXL_CHECK(twice_modulus != nullptr, "twice_modulus should not be nullptr"); + HEXL_CHECK(four_times_modulus != nullptr, + "four_times_modulus should not be nullptr"); + + if (x >= *four_times_modulus) { + x -= *four_times_modulus; + } + if (x >= *twice_modulus) { + x -= *twice_modulus; + } + if (x >= modulus) { + x -= modulus; + } + return x; + } + HEXL_CHECK(false, "Should be unreachable"); + return x; +} + +/// @brief Returns Montgomery form of ab mod q, computed via the REDC algorithm, +/// also known as Montgomery reduction. +/// @param[in] r +/// @param[in] q with R = 2^r such that gcd(R, q) = 1. R > q. +/// @param[in] inv_mod in [0, R − 1] such that q*v_inv_mod ≡ −1 mod R. +/// @param[in] mod_R_msk take r last bits to apply mod R. +/// @param[in] T_hi of T = ab in the range [0, Rq − 1]. +/// @param[in] T_lo of T. +/// @return Unsigned long int in the range [0, q − 1] such that S ≡ TR^−1 mod q +template +inline uint64_t MontgomeryReduce(uint64_t T_hi, uint64_t T_lo, uint64_t q, + int r, uint64_t mod_R_msk, uint64_t inv_mod) { + HEXL_CHECK(BitShift == 64 || BitShift == 52, + "Unsupported BitShift " << BitShift); + HEXL_CHECK((1ULL << r) > static_cast(q), + "R value should be greater than q = " << static_cast(q)); + + uint64_t mq_hi; + uint64_t mq_lo; + + uint64_t m = ((T_lo & mod_R_msk) * inv_mod) & mod_R_msk; + MultiplyUInt64(m, q, &mq_hi, &mq_lo); + + if (BitShift == 52) { + mq_hi = (mq_hi << 12) | (mq_lo >> 52); + mq_lo &= (1ULL << 52) - 1; + } + + uint64_t t_hi; + uint64_t t_lo; + + // first 64bit block + t_lo = T_lo + mq_lo; + unsigned int carry = static_cast(t_lo < T_lo); + t_hi = T_hi + mq_hi + carry; + + t_hi = t_hi << (BitShift - r); + t_lo = t_lo >> r; + t_lo = t_hi + t_lo; + + return (t_lo >= q) ? (t_lo - q) : t_lo; +} + +/// @brief Hensel's Lemma for 2-adic numbers +/// Find solution for qX + 1 = 0 mod 2^r +/// @param[in] r +/// @param[in] q such that gcd(2, q) = 1 +/// @return Unsigned long int in [0, 2^r − 1] such that q*x ≡ −1 mod 2^r +inline uint64_t HenselLemma2adicRoot(uint32_t r, uint64_t q) { + uint64_t a_prev = 1; + uint64_t c = 2; + uint64_t mod_mask = 3; + + // Root: + // f(x) = qX + 1 and a_(0) = 1 then f(1) ≡ 0 mod 2 + // General Case: + // - a_(n) ≡ a_(n-1) mod 2^(n) + // => a_(n) = a_(n-1) + 2^(n)*t + // - Find 't' such that f(a_(n)) = 0 mod 2^(n+1) + // First case in for: + // - a_(1) ≡ 1 mod 2 or a_(1) = 1 + 2t + // - Find 't' so f(a_(1)) ≡ 0 mod 4 => q(1 + 2t) + 1 ≡ 0 mod 4 + for (uint64_t k = 2; k <= r; k++) { + uint64_t f = 0; + uint64_t t = 0; + uint64_t a = 0; + + do { + a = a_prev + c * t++; + f = q * a + 1ULL; + } while (f & mod_mask); // f(a) ≡ 0 mod 2^(k) + + // Update vars + mod_mask = mod_mask * 2 + 1ULL; + c *= 2; + a_prev = a; + } + + return a_prev; +} + +} // namespace hexl +} // namespace intel diff --git a/hexl_ser/include/hexl/util/aligned-allocator.hpp b/hexl_ser/include/hexl/util/aligned-allocator.hpp new file mode 100644 index 00000000..d175c734 --- /dev/null +++ b/hexl_ser/include/hexl/util/aligned-allocator.hpp @@ -0,0 +1,110 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include +#include + +#include "hexl/number-theory/number-theory.hpp" +#include "hexl/util/allocator.hpp" +#include "hexl/util/defines.hpp" + +namespace intel { +namespace hexl { + +/// @brief Allocater implementation using malloc and free +struct MallocStrategy : AllocatorBase { + void* allocate(size_t bytes_count) final { return std::malloc(bytes_count); } + + void deallocate(void* p, size_t n) final { + HEXL_UNUSED(n); + std::free(p); + } +}; + +using AllocatorStrategyPtr = std::shared_ptr; +extern AllocatorStrategyPtr mallocStrategy; + +/// @brief Allocates memory aligned to Alignment-byte sized boundaries +/// @details Alignment must be a power of two +template +class AlignedAllocator { + public: + template + friend class AlignedAllocator; + + using value_type = T; + + explicit AlignedAllocator(AllocatorStrategyPtr strategy = nullptr) noexcept + : m_alloc_impl((strategy != nullptr) ? strategy : mallocStrategy) {} + + AlignedAllocator(const AlignedAllocator& src) = default; + AlignedAllocator& operator=(const AlignedAllocator& src) = default; + + template + AlignedAllocator(const AlignedAllocator& src) + : m_alloc_impl(src.m_alloc_impl) {} + + ~AlignedAllocator() {} + + template + struct rebind { + using other = AlignedAllocator; + }; + + bool operator==(const AlignedAllocator&) { return true; } + + bool operator!=(const AlignedAllocator&) { return false; } + + /// @brief Allocates \p n elements aligned to Alignment-byte boundaries + /// @return Pointer to the aligned allocated memory + T* allocate(size_t n) { + if (!IsPowerOfTwo(Alignment)) { + return nullptr; + } + // Allocate enough space to ensure the alignment can be satisfied + size_t buffer_size = sizeof(T) * n + Alignment; + // Additionally, allocate a prefix to store the memory location of the + // unaligned buffer + size_t alloc_size = buffer_size + sizeof(void*); + void* buffer = m_alloc_impl->allocate(alloc_size); + if (!buffer) { + return nullptr; + } + + // Reserve first location for pointer to originally-allocated space + void* aligned_buffer = static_cast(buffer) + sizeof(void*); + std::align(Alignment, sizeof(T) * n, aligned_buffer, buffer_size); + if (!aligned_buffer) { + return nullptr; + } + + // Store allocated buffer address at aligned_buffer - sizeof(void*). + void* store_buffer_addr = + static_cast(aligned_buffer) - sizeof(void*); + *(static_cast(store_buffer_addr)) = buffer; + + return static_cast(aligned_buffer); + } + + void deallocate(T* p, size_t n) { + if (!p) { + return; + } + void* store_buffer_addr = (reinterpret_cast(p) - sizeof(void*)); + void* free_address = *(static_cast(store_buffer_addr)); + m_alloc_impl->deallocate(free_address, n); + } + + private: + AllocatorStrategyPtr m_alloc_impl; +}; + +/// @brief 64-byte aligned memory allocator +template +using AlignedVector64 = std::vector >; + +} // namespace hexl +} // namespace intel diff --git a/hexl_ser/include/hexl/util/allocator.hpp b/hexl_ser/include/hexl/util/allocator.hpp new file mode 100644 index 00000000..5f4a7a31 --- /dev/null +++ b/hexl_ser/include/hexl/util/allocator.hpp @@ -0,0 +1,53 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +namespace intel { +namespace hexl { + +/// @brief Base class for custom memory allocator +struct AllocatorBase { + virtual ~AllocatorBase() noexcept {} + + /// @brief Allocates byte_count bytes of memory + /// @param[in] bytes_count Number of bytes to allocate + /// @return A pointer to the allocated memory + virtual void* allocate(size_t bytes_count) = 0; + + /// @brief Deallocate memory + /// @param[in] p Pointer to memory to deallocate + /// @param[in] n Number of bytes to deallocate + virtual void deallocate(void* p, size_t n) = 0; +}; + +/// @brief Helper memory allocation struct which delegates implementation to +/// AllocatorImpl +template +struct AllocatorInterface : public AllocatorBase { + /// @brief Override interface and delegate implementation to AllocatorImpl + void* allocate(size_t bytes_count) override { + return static_cast(this)->allocate_impl(bytes_count); + } + + /// @brief Override interface and delegate implementation to AllocatorImpl + void deallocate(void* p, size_t n) override { + static_cast(this)->deallocate_impl(p, n); + } + + private: + // in case AllocatorImpl doesn't provide implementations, use default null + // behavior + void* allocate_impl(size_t bytes_count) { + HEXL_UNUSED(bytes_count); + return nullptr; + } + void deallocate_impl(void* p, size_t n) { + HEXL_UNUSED(p); + HEXL_UNUSED(n); + } +}; +} // namespace hexl +} // namespace intel diff --git a/hexl_ser/include/hexl/util/check.hpp b/hexl_ser/include/hexl/util/check.hpp new file mode 100644 index 00000000..386eba89 --- /dev/null +++ b/hexl_ser/include/hexl/util/check.hpp @@ -0,0 +1,44 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +#include "hexl/util/types.hpp" + +// Create logging/debug macros with no run-time overhead unless HEXL_DEBUG is +// enabled +#ifdef HEXL_DEBUG +#include "hexl/logging/logging.hpp" + +/// @brief If input condition is not true, logs the expression and throws an +/// error +/// @param[in] cond A boolean indication the condition +/// @param[in] expr The expression to be logged +#define HEXL_CHECK(cond, expr) \ + if (!(cond)) { \ + LOG(ERROR) << expr << " in function: " << __FUNCTION__ \ + << " in file: " __FILE__ << ":" << __LINE__; \ + throw std::runtime_error("Error. Check log output"); \ + } + +/// @brief If input has an element >= bound, logs the expression and throws an +/// error +/// @param[in] arg Input container which supports the [] operator. +/// @param[in] n Size of input +/// @param[in] bound Upper bound on the input +/// @param[in] expr The expression to be logged +#define HEXL_CHECK_BOUNDS(arg, n, bound, expr) \ + for (size_t hexl_check_idx = 0; hexl_check_idx < n; ++hexl_check_idx) { \ + HEXL_CHECK((arg)[hexl_check_idx] < bound, expr); \ + } + +#else // HEXL_DEBUG=OFF + +#define HEXL_CHECK(cond, expr) \ + {} +#define HEXL_CHECK_BOUNDS(...) \ + {} + +#endif // HEXL_DEBUG diff --git a/hexl_ser/include/hexl/util/clang.hpp b/hexl_ser/include/hexl/util/clang.hpp new file mode 100644 index 00000000..958bea7b --- /dev/null +++ b/hexl_ser/include/hexl/util/clang.hpp @@ -0,0 +1,67 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once +#include + +#include "hexl/util/check.hpp" +#include "hexl/util/types.hpp" + +namespace intel { +namespace hexl { + +#ifdef HEXL_USE_CLANG +// Return x * y as 128-bit integer +// Correctness if x * y < 128 bits +inline uint128_t MultiplyUInt64(uint64_t x, uint64_t y) { + return uint128_t(x) * uint128_t(y); +} + +inline uint64_t BarrettReduce128(uint64_t input_hi, uint64_t input_lo, + uint64_t modulus) { + HEXL_CHECK(modulus != 0, "modulus == 0") + uint128_t n = (static_cast(input_hi) << 64) | + (static_cast(input_lo)); + + return n % modulus; + // TODO(fboemer): actually use barrett reduction if performance-critical +} + +// Returns low 64bit of 128b/64b where x1=high 64b, x0=low 64b +inline uint64_t DivideUInt128UInt64Lo(uint64_t x1, uint64_t x0, uint64_t y) { + uint128_t n = + (static_cast(x1) << 64) | (static_cast(x0)); + uint128_t q = n / y; + + return static_cast(q); +} + +// Multiplies x * y as 128-bit integer. +// @param prod_hi Stores high 64 bits of product +// @param prod_lo Stores low 64 bits of product +inline void MultiplyUInt64(uint64_t x, uint64_t y, uint64_t* prod_hi, + uint64_t* prod_lo) { + uint128_t prod = MultiplyUInt64(x, y); + *prod_hi = static_cast(prod >> 64); + *prod_lo = static_cast(prod); +} + +// Return the high 128 minus BitShift bits of the 128-bit product x * y +template +inline uint64_t MultiplyUInt64Hi(uint64_t x, uint64_t y) { + uint128_t product = static_cast(x) * y; + return static_cast(product >> BitShift); +} + +// Returns most-significant bit of the input +inline uint64_t MSB(uint64_t input) { + return static_cast(std::log2l(input)); +} + +#define HEXL_LOOP_UNROLL_4 _Pragma("clang loop unroll_count(4)") +#define HEXL_LOOP_UNROLL_8 _Pragma("clang loop unroll_count(8)") + +#endif + +} // namespace hexl +} // namespace intel diff --git a/hexl_ser/include/hexl/util/compiler.hpp b/hexl_ser/include/hexl/util/compiler.hpp new file mode 100644 index 00000000..7dd077df --- /dev/null +++ b/hexl_ser/include/hexl/util/compiler.hpp @@ -0,0 +1,14 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "hexl/util/defines.hpp" + +#ifdef HEXL_USE_MSVC +#include "hexl/util/msvc.hpp" +#elif defined HEXL_USE_GNU +#include "hexl/util/gcc.hpp" +#elif defined HEXL_USE_CLANG +#include "hexl/util/clang.hpp" +#endif diff --git a/hexl_ser/include/hexl/util/defines.hpp b/hexl_ser/include/hexl/util/defines.hpp new file mode 100644 index 00000000..b92dd24e --- /dev/null +++ b/hexl_ser/include/hexl/util/defines.hpp @@ -0,0 +1,13 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +/* #undef HEXL_USE_MSVC */ +#define HEXL_USE_GNU +/* #undef HEXL_USE_CLANG */ + +#define HEXL_DEBUG + +// Avoid unused variable warnings +#define HEXL_UNUSED(x) (void)(x) diff --git a/hexl_ser/include/hexl/util/defines.hpp.in b/hexl_ser/include/hexl/util/defines.hpp.in new file mode 100644 index 00000000..0f146c26 --- /dev/null +++ b/hexl_ser/include/hexl/util/defines.hpp.in @@ -0,0 +1,13 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#cmakedefine HEXL_USE_MSVC +#cmakedefine HEXL_USE_GNU +#cmakedefine HEXL_USE_CLANG + +#cmakedefine HEXL_DEBUG + +// Avoid unused variable warnings +#define HEXL_UNUSED(x) (void)(x) diff --git a/hexl_ser/include/hexl/util/gcc.hpp b/hexl_ser/include/hexl/util/gcc.hpp new file mode 100644 index 00000000..828e3836 --- /dev/null +++ b/hexl_ser/include/hexl/util/gcc.hpp @@ -0,0 +1,67 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once +#include + +#include "hexl/util/check.hpp" +#include "hexl/util/types.hpp" + +namespace intel { +namespace hexl { + +#ifdef HEXL_USE_GNU +// Return x * y as 128-bit integer +// Correctness if x * y < 128 bits +inline uint128_t MultiplyUInt64(uint64_t x, uint64_t y) { + return uint128_t(x) * uint128_t(y); +} + +inline uint64_t BarrettReduce128(uint64_t input_hi, uint64_t input_lo, + uint64_t modulus) { + HEXL_CHECK(modulus != 0, "modulus == 0") + uint128_t n = (static_cast(input_hi) << 64) | + (static_cast(input_lo)); + + return static_cast(n % modulus); + // TODO(fboemer): actually use barrett reduction if performance-critical +} + +// Returns low 64bit of 128b/64b where x1=high 64b, x0=low 64b +inline uint64_t DivideUInt128UInt64Lo(uint64_t x1, uint64_t x0, uint64_t y) { + uint128_t n = + (static_cast(x1) << 64) | (static_cast(x0)); + uint128_t q = n / y; + + return static_cast(q); +} + +// Multiplies x * y as 128-bit integer. +// @param prod_hi Stores high 64 bits of product +// @param prod_lo Stores low 64 bits of product +inline void MultiplyUInt64(uint64_t x, uint64_t y, uint64_t* prod_hi, + uint64_t* prod_lo) { + uint128_t prod = MultiplyUInt64(x, y); + *prod_hi = static_cast(prod >> 64); + *prod_lo = static_cast(prod); +} + +// Return the high 128 minus BitShift bits of the 128-bit product x * y +template +inline uint64_t MultiplyUInt64Hi(uint64_t x, uint64_t y) { + uint128_t product = MultiplyUInt64(x, y); + return static_cast(product >> BitShift); +} + +// Returns most-significant bit of the input +inline uint64_t MSB(uint64_t input) { + return static_cast(std::log2l(input)); +} + +#define HEXL_LOOP_UNROLL_4 _Pragma("GCC unroll 4") +#define HEXL_LOOP_UNROLL_8 _Pragma("GCC unroll 8") + +#endif + +} // namespace hexl +} // namespace intel diff --git a/hexl_ser/include/hexl/util/msvc.hpp b/hexl_ser/include/hexl/util/msvc.hpp new file mode 100644 index 00000000..0ada2d45 --- /dev/null +++ b/hexl_ser/include/hexl/util/msvc.hpp @@ -0,0 +1,289 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#ifdef HEXL_USE_MSVC + +#define NOMINMAX // Avoid errors with std::min/std::max +#undef min +#undef max + +#include +#include +#include + +#include + +#include "hexl/util/check.hpp" + +#pragma intrinsic(_addcarry_u64, _BitScanReverse64, _subborrow_u64, _udiv128, \ + _umul128) + +#undef TRUE +#undef FALSE + +namespace intel { +namespace hexl { + +inline uint64_t BarrettReduce128(uint64_t input_hi, uint64_t input_lo, + uint64_t modulus) { + HEXL_CHECK(modulus != 0, "modulus == 0") + uint64_t remainder; + _udiv128(input_hi, input_lo, modulus, &remainder); + + return remainder; +} + +// Multiplies x * y as 128-bit integer. +// @param prod_hi Stores high 64 bits of product +// @param prod_lo Stores low 64 bits of product +inline void MultiplyUInt64(uint64_t x, uint64_t y, uint64_t* prod_hi, + uint64_t* prod_lo) { + *prod_lo = _umul128(x, y, prod_hi); +} + +// Return the high 128 minus BitShift bits of the 128-bit product x * y +template +inline uint64_t MultiplyUInt64Hi(uint64_t x, uint64_t y) { + HEXL_CHECK(BitShift == 52 || BitShift == 64, + "Invalid BitShift " << BitShift << "; expected 52 or 64"); + uint64_t prod_hi; + uint64_t prod_lo = _umul128(x, y, &prod_hi); + uint64_t result_hi; + uint64_t result_lo; + RightShift128(&result_hi, &result_lo, prod_hi, prod_lo, BitShift); + return result_lo; +} + +/// @brief Computes Left Shift op as 128-bit unsigned integer +/// @param[out] result_hi Stores high 64 bits of result +/// @param[out] result_lo Stores low 64 bits of result +/// @param[in] op_hi Stores high 64 bits of input +/// @param[in] op_lo Stores low 64 bits of input +inline void LeftShift128(uint64_t* result_hi, uint64_t* result_lo, + const uint64_t op_hi, const uint64_t op_lo, + const uint64_t shift_value) { + HEXL_CHECK(result_hi != nullptr, "Require result_hi != nullptr"); + HEXL_CHECK(result_lo != nullptr, "Require result_lo != nullptr"); + HEXL_CHECK(shift_value <= 128, + "shift_value cannot be greater than 128 " << shift_value); + + if (shift_value == 0) { + *result_hi = op_hi; + *result_lo = op_lo; + } else if (shift_value == 64) { + *result_hi = op_lo; + *result_lo = 0ULL; + } else if (shift_value == 128) { + *result_hi = 0ULL; + *result_lo = 0ULL; + } else if (shift_value >= 1 && shift_value <= 63) { + *result_hi = (op_hi << shift_value) | (op_lo >> (64 - shift_value)); + *result_lo = op_lo << shift_value; + } else if (shift_value >= 65 && shift_value < 128) { + *result_hi = op_lo << (shift_value - 64); + *result_lo = 0ULL; + } +} + +/// @brief Computes Right Shift op as 128-bit unsigned integer +/// @param[out] result_hi Stores high 64 bits of result +/// @param[out] result_lo Stores low 64 bits of result +/// @param[in] op_hi Stores high 64 bits of input +/// @param[in] op_lo Stores low 64 bits of input +inline void RightShift128(uint64_t* result_hi, uint64_t* result_lo, + const uint64_t op_hi, const uint64_t op_lo, + const uint64_t shift_value) { + HEXL_CHECK(result_hi != nullptr, "Require result_hi != nullptr"); + HEXL_CHECK(result_lo != nullptr, "Require result_lo != nullptr"); + HEXL_CHECK(shift_value <= 128, + "shift_value cannot be greater than 128 " << shift_value); + + if (shift_value == 0) { + *result_hi = op_hi; + *result_lo = op_lo; + } else if (shift_value == 64) { + *result_hi = 0ULL; + *result_lo = op_hi; + } else if (shift_value == 128) { + *result_hi = 0ULL; + *result_lo = 0ULL; + } else if (shift_value >= 1 && shift_value <= 63) { + *result_hi = op_hi >> shift_value; + *result_lo = (op_hi << (64 - shift_value)) | (op_lo >> shift_value); + } else if (shift_value >= 65 && shift_value < 128) { + *result_hi = 0ULL; + *result_lo = op_hi >> (shift_value - 64); + } +} + +/// @brief Adds op1 + op2 as 128-bit integer +/// @param[out] result_hi Stores high 64 bits of result +/// @param[out] result_lo Stores low 64 bits of result +inline void AddWithCarry128(uint64_t* result_hi, uint64_t* result_lo, + const uint64_t op1_hi, const uint64_t op1_lo, + const uint64_t op2_hi, const uint64_t op2_lo) { + HEXL_CHECK(result_hi != nullptr, "Require result_hi != nullptr"); + HEXL_CHECK(result_lo != nullptr, "Require result_lo != nullptr"); + + // first 64bit block + *result_lo = op1_lo + op2_lo; + unsigned char carry = static_cast(*result_lo < op1_lo); + + // second 64bit block + _addcarry_u64(carry, op1_hi, op2_hi, result_hi); +} + +/// @brief Subtracts op1 - op2 as 128-bit integer +/// @param[out] result_hi Stores high 64 bits of result +/// @param[out] result_lo Stores low 64 bits of result +inline void SubWithCarry128(uint64_t* result_hi, uint64_t* result_lo, + const uint64_t op1_hi, const uint64_t op1_lo, + const uint64_t op2_hi, const uint64_t op2_lo) { + HEXL_CHECK(result_hi != nullptr, "Require result_hi != nullptr"); + HEXL_CHECK(result_lo != nullptr, "Require result_lo != nullptr"); + + unsigned char borrow; + + // first 64bit block + *result_lo = op1_lo - op2_lo; + borrow = static_cast(op2_lo > op1_lo); + + // second 64bit block + _subborrow_u64(borrow, op1_hi, op2_hi, result_hi); +} + +/// @brief Computes and returns significant bit count +/// @param[in] value Input element at most 128 bits long +inline uint64_t SignificantBitLength(const uint64_t* value) { + HEXL_CHECK(value != nullptr, "Require value != nullptr"); + + unsigned long count = 0; // NOLINT(runtime/int) + + // second 64bit block + _BitScanReverse64(&count, *(value + 1)); + if (count >= 0 && *(value + 1) > 0) { + return static_cast(count) + 1 + 64; + } + + // first 64bit block + _BitScanReverse64(&count, *value); + if (count >= 0 && *(value) > 0) { + return static_cast(count) + 1; + } + return 0; +} + +/// @brief Checks if input is negative number +/// @param[in] input Input element to check for sign +inline bool CheckSign(const uint64_t* input) { + HEXL_CHECK(input != nullptr, "Require input != nullptr"); + + uint64_t input_temp[2]{0, 0}; + RightShift128(&input_temp[1], &input_temp[0], input[1], input[0], 127); + return (input_temp[0] == 1); +} + +/// @brief Divides numerator by denominator +/// @param[out] quotient Stores quotient as two 64-bit blocks after division +/// @param[in] numerator +/// @param[in] denominator +inline void DivideUInt128UInt64(uint64_t* quotient, const uint64_t* numerator, + const uint64_t denominator) { + HEXL_CHECK(quotient != nullptr, "Require quotient != nullptr"); + HEXL_CHECK(numerator != nullptr, "Require numerator != nullptr"); + HEXL_CHECK(denominator != 0, "denominator cannot be 0 " << denominator); + + // get bit count of divisor + uint64_t numerator_bits = SignificantBitLength(numerator); + const uint64_t numerator_bits_const = numerator_bits; + const uint64_t uint_128_bit = 128ULL; + + uint64_t MASK[2]{0x0000000000000001, 0x0000000000000000}; + uint64_t remainder[2]{0, 0}; + uint64_t quotient_temp[2]{0, 0}; + uint64_t denominator_temp[2]{denominator, 0}; + + quotient[0] = numerator[0]; + quotient[1] = numerator[1]; + + // align numerator + LeftShift128("ient[1], "ient[0], quotient[1], quotient[0], + (uint_128_bit - numerator_bits_const)); + + while (numerator_bits) { + // if remainder is negative + if (CheckSign(remainder)) { + LeftShift128(&remainder[1], &remainder[0], remainder[1], remainder[0], 1); + RightShift128("ient_temp[1], "ient_temp[0], quotient[1], + quotient[0], (uint_128_bit - 1)); + remainder[0] = remainder[0] | quotient_temp[0]; + LeftShift128("ient[1], "ient[0], quotient[1], quotient[0], 1); + // remainder=remainder+denominator_temp + AddWithCarry128(&remainder[1], &remainder[0], remainder[1], remainder[0], + denominator_temp[1], denominator_temp[0]); + } else { // if remainder is positive + LeftShift128(&remainder[1], &remainder[0], remainder[1], remainder[0], 1); + RightShift128("ient_temp[1], "ient_temp[0], quotient[1], + quotient[0], (uint_128_bit - 1)); + remainder[0] = remainder[0] | quotient_temp[0]; + LeftShift128("ient[1], "ient[0], quotient[1], quotient[0], 1); + // remainder=remainder-denominator_temp + SubWithCarry128(&remainder[1], &remainder[0], remainder[1], remainder[0], + denominator_temp[1], denominator_temp[0]); + } + + // if remainder is positive set MSB of quotient[0]=1 + if (!CheckSign(remainder)) { + MASK[0] = 0x0000000000000001; + MASK[1] = 0x0000000000000000; + LeftShift128(&MASK[1], &MASK[0], MASK[1], MASK[0], + (uint_128_bit - numerator_bits_const)); + quotient[0] = quotient[0] | MASK[0]; + quotient[1] = quotient[1] | MASK[1]; + } + quotient_temp[0] = 0; + quotient_temp[1] = 0; + numerator_bits--; + } + + if (CheckSign(remainder)) { + // remainder=remainder+denominator_temp + AddWithCarry128(&remainder[1], &remainder[0], remainder[1], remainder[0], + denominator_temp[1], denominator_temp[0]); + } + RightShift128("ient[1], "ient[0], quotient[1], quotient[0], + (uint_128_bit - numerator_bits_const)); +} + +/// @brief Returns low of dividing numerator by denominator +/// @param[in] numerator_hi Stores high 64 bit of numerator +/// @param[in] numerator_lo Stores low 64 bit of numerator +/// @param[in] denominator Stores denominator +inline uint64_t DivideUInt128UInt64Lo(const uint64_t numerator_hi, + const uint64_t numerator_lo, + const uint64_t denominator) { + uint64_t numerator[2]{numerator_lo, numerator_hi}; + uint64_t quotient[2]{0, 0}; + + DivideUInt128UInt64(quotient, numerator, denominator); + return quotient[0]; +} + +// Returns most-significant bit of the input +inline uint64_t MSB(uint64_t input) { + unsigned long index{0}; // NOLINT(runtime/int) + _BitScanReverse64(&index, input); + return index; +} + +#define HEXL_LOOP_UNROLL_4 \ + {} +#define HEXL_LOOP_UNROLL_8 \ + {} + +#endif + +} // namespace hexl +} // namespace intel diff --git a/hexl_ser/include/hexl/util/types.hpp b/hexl_ser/include/hexl/util/types.hpp new file mode 100644 index 00000000..2d2d8551 --- /dev/null +++ b/hexl_ser/include/hexl/util/types.hpp @@ -0,0 +1,13 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +#include "hexl/util/defines.hpp" + +#if defined(HEXL_USE_GNU) || defined(HEXL_USE_CLANG) +__extension__ typedef __int128 int128_t; +__extension__ typedef unsigned __int128 uint128_t; +#endif diff --git a/hexl_ser/include/hexl/util/util.hpp b/hexl_ser/include/hexl/util/util.hpp new file mode 100644 index 00000000..bf878a98 --- /dev/null +++ b/hexl_ser/include/hexl/util/util.hpp @@ -0,0 +1,53 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +namespace intel { +namespace hexl { + +#undef TRUE // MSVC defines TRUE +#undef FALSE // MSVC defines FALSE + +/// @enum CMPINT +/// @brief Represents binary operations between two boolean values +enum class CMPINT { + EQ = 0, ///< Equal + LT = 1, ///< Less than + LE = 2, ///< Less than or equal + FALSE = 3, ///< False + NE = 4, ///< Not equal + NLT = 5, ///< Not less than + NLE = 6, ///< Not less than or equal + TRUE = 7 ///< True +}; + +/// @brief Returns the logical negation of a binary operation +/// @param[in] cmp The binary operation to negate +inline CMPINT Not(CMPINT cmp) { + switch (cmp) { + case CMPINT::EQ: + return CMPINT::NE; + case CMPINT::LT: + return CMPINT::NLT; + case CMPINT::LE: + return CMPINT::NLE; + case CMPINT::FALSE: + return CMPINT::TRUE; + case CMPINT::NE: + return CMPINT::EQ; + case CMPINT::NLT: + return CMPINT::LT; + case CMPINT::NLE: + return CMPINT::LE; + case CMPINT::TRUE: + return CMPINT::FALSE; + default: + return CMPINT::FALSE; + } +} + +} // namespace hexl +} // namespace intel diff --git a/hexl_ser/logging/logging.cpp b/hexl_ser/logging/logging.cpp new file mode 100644 index 00000000..c491b43e --- /dev/null +++ b/hexl_ser/logging/logging.cpp @@ -0,0 +1,8 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "hexl/logging/logging.hpp" + +#ifdef HEXL_DEBUG +INITIALIZE_EASYLOGGINGPP +#endif diff --git a/hexl_ser/ntt/fwd-ntt-avx512.cpp b/hexl_ser/ntt/fwd-ntt-avx512.cpp new file mode 100644 index 00000000..8ed0ac7e --- /dev/null +++ b/hexl_ser/ntt/fwd-ntt-avx512.cpp @@ -0,0 +1,409 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "ntt/fwd-ntt-avx512.hpp" + +#include +#include +#include + +#include "hexl/logging/logging.hpp" +#include "hexl/ntt/ntt.hpp" +#include "hexl/number-theory/number-theory.hpp" +#include "ntt/ntt-avx512-util.hpp" +#include "ntt/ntt-internal.hpp" +#include "util/avx512-util.hpp" + +namespace intel { +namespace hexl { + +#ifdef HEXL_HAS_AVX512IFMA +template void ForwardTransformToBitReverseAVX512( + uint64_t* result, const uint64_t* operand, uint64_t degree, uint64_t mod, + const uint64_t* root_of_unity_powers, + const uint64_t* precon_root_of_unity_powers, uint64_t input_mod_factor, + uint64_t output_mod_factor, uint64_t recursion_depth, + uint64_t recursion_half); +#endif + +#ifdef HEXL_HAS_AVX512DQ +template void ForwardTransformToBitReverseAVX512<32>( + uint64_t* result, const uint64_t* operand, uint64_t degree, uint64_t mod, + const uint64_t* root_of_unity_powers, + const uint64_t* precon_root_of_unity_powers, uint64_t input_mod_factor, + uint64_t output_mod_factor, uint64_t recursion_depth, + uint64_t recursion_half); + +template void ForwardTransformToBitReverseAVX512( + uint64_t* result, const uint64_t* operand, uint64_t degree, uint64_t mod, + const uint64_t* root_of_unity_powers, + const uint64_t* precon_root_of_unity_powers, uint64_t input_mod_factor, + uint64_t output_mod_factor, uint64_t recursion_depth, + uint64_t recursion_half); +#endif + +#ifdef HEXL_HAS_AVX512DQ + +/// @brief The Harvey butterfly: assume \p X, \p Y in [0, 4q), and return X', Y' +/// in [0, 4q) such that X', Y' = X + WY, X - WY (mod q). +/// @param[in,out] X Input representing 8 64-bit signed integers in SIMD form +/// @param[in,out] Y Input representing 8 64-bit signed integers in SIMD form +/// @param[in] W Root of unity represented as 8 64-bit signed integers in +/// SIMD form +/// @param[in] W_precon Preconditioned \p W for BitShift-bit Barrett +/// reduction +/// @param[in] neg_modulus Negative modulus, i.e. (-q) represented as 8 64-bit +/// signed integers in SIMD form +/// @param[in] twice_modulus Twice the modulus, i.e. 2*q represented as 8 64-bit +/// signed integers in SIMD form +/// @param InputLessThanMod If true, assumes \p X, \p Y < \p q. Otherwise, +/// assumes \p X, \p Y < 4*\p q +/// @details See Algorithm 4 of https://arxiv.org/pdf/1205.2926.pdf +template +void FwdButterfly(__m512i* X, __m512i* Y, __m512i W, __m512i W_precon, + __m512i neg_modulus, __m512i twice_modulus) { + if (!InputLessThanMod) { + *X = _mm512_hexl_small_mod_epu64(*X, twice_modulus); + } + + __m512i T; + if (BitShift == 32) { + __m512i Q = _mm512_hexl_mullo_epi<64>(W_precon, *Y); + Q = _mm512_srli_epi64(Q, 32); + __m512i W_Y = _mm512_hexl_mullo_epi<64>(W, *Y); + T = _mm512_hexl_mullo_add_lo_epi<64>(W_Y, Q, neg_modulus); + } else if (BitShift == 52) { + __m512i Q = _mm512_hexl_mulhi_epi(W_precon, *Y); + __m512i W_Y = _mm512_hexl_mullo_epi(W, *Y); + T = _mm512_hexl_mullo_add_lo_epi(W_Y, Q, neg_modulus); + } else if (BitShift == 64) { + // Perform approximate computation of Q, as described in page 7 of + // https://arxiv.org/pdf/2003.04510.pdf + __m512i Q = _mm512_hexl_mulhi_approx_epi(W_precon, *Y); + __m512i W_Y = _mm512_hexl_mullo_epi(W, *Y); + // Compute T in range [0, 4q) + T = _mm512_hexl_mullo_add_lo_epi(W_Y, Q, neg_modulus); + // Reduce T to range [0, 2q) + T = _mm512_hexl_small_mod_epu64<2>(T, twice_modulus); + } else { + HEXL_CHECK(false, "Invalid BitShift " << BitShift); + } + + __m512i twice_mod_minus_T = _mm512_sub_epi64(twice_modulus, T); + *Y = _mm512_add_epi64(*X, twice_mod_minus_T); + *X = _mm512_add_epi64(*X, T); +} + +template +void FwdT1(uint64_t* operand, __m512i v_neg_modulus, __m512i v_twice_mod, + uint64_t m, const uint64_t* W, const uint64_t* W_precon) { + const __m512i* v_W_pt = reinterpret_cast(W); + const __m512i* v_W_precon_pt = reinterpret_cast(W_precon); + size_t j1 = 0; + + // 8 | m guaranteed by n >= 16 + HEXL_LOOP_UNROLL_8 + for (size_t i = m / 8; i > 0; --i) { + uint64_t* X = operand + j1; + __m512i* v_X_pt = reinterpret_cast<__m512i*>(X); + + __m512i v_X; + __m512i v_Y; + LoadFwdInterleavedT1(X, &v_X, &v_Y); + __m512i v_W = _mm512_loadu_si512(v_W_pt++); + __m512i v_W_precon = _mm512_loadu_si512(v_W_precon_pt++); + + FwdButterfly(&v_X, &v_Y, v_W, v_W_precon, v_neg_modulus, + v_twice_mod); + WriteFwdInterleavedT1(v_X, v_Y, v_X_pt); + + j1 += 16; + } +} + +template +void FwdT2(uint64_t* operand, __m512i v_neg_modulus, __m512i v_twice_mod, + uint64_t m, const uint64_t* W, const uint64_t* W_precon) { + const __m512i* v_W_pt = reinterpret_cast(W); + const __m512i* v_W_precon_pt = reinterpret_cast(W_precon); + + size_t j1 = 0; + // 4 | m guaranteed by n >= 16 + HEXL_LOOP_UNROLL_4 + for (size_t i = m / 4; i > 0; --i) { + uint64_t* X = operand + j1; + __m512i* v_X_pt = reinterpret_cast<__m512i*>(X); + + __m512i v_X; + __m512i v_Y; + LoadFwdInterleavedT2(X, &v_X, &v_Y); + + __m512i v_W = _mm512_loadu_si512(v_W_pt++); + __m512i v_W_precon = _mm512_loadu_si512(v_W_precon_pt++); + + HEXL_CHECK(ExtractValues(v_W)[0] == ExtractValues(v_W)[1], + "bad v_W " << ExtractValues(v_W)); + HEXL_CHECK(ExtractValues(v_W_precon)[0] == ExtractValues(v_W_precon)[1], + "bad v_W_precon " << ExtractValues(v_W_precon)); + FwdButterfly(&v_X, &v_Y, v_W, v_W_precon, v_neg_modulus, + v_twice_mod); + + _mm512_storeu_si512(v_X_pt++, v_X); + _mm512_storeu_si512(v_X_pt, v_Y); + + j1 += 16; + } +} + +template +void FwdT4(uint64_t* operand, __m512i v_neg_modulus, __m512i v_twice_mod, + uint64_t m, const uint64_t* W, const uint64_t* W_precon) { + size_t j1 = 0; + const __m512i* v_W_pt = reinterpret_cast(W); + const __m512i* v_W_precon_pt = reinterpret_cast(W_precon); + + // 2 | m guaranteed by n >= 16 + HEXL_LOOP_UNROLL_4 + for (size_t i = m / 2; i > 0; --i) { + uint64_t* X = operand + j1; + __m512i* v_X_pt = reinterpret_cast<__m512i*>(X); + + __m512i v_X; + __m512i v_Y; + LoadFwdInterleavedT4(X, &v_X, &v_Y); + + __m512i v_W = _mm512_loadu_si512(v_W_pt++); + __m512i v_W_precon = _mm512_loadu_si512(v_W_precon_pt++); + FwdButterfly(&v_X, &v_Y, v_W, v_W_precon, v_neg_modulus, + v_twice_mod); + + _mm512_storeu_si512(v_X_pt++, v_X); + _mm512_storeu_si512(v_X_pt, v_Y); + + j1 += 16; + } +} + +// Out-of-place implementation +template +void FwdT8(uint64_t* result, const uint64_t* operand, __m512i v_neg_modulus, + __m512i v_twice_mod, uint64_t t, uint64_t m, const uint64_t* W, + const uint64_t* W_precon) { + size_t j1 = 0; + + HEXL_LOOP_UNROLL_4 + for (size_t i = 0; i < m; i++) { + // Referencing operand + const uint64_t* X_op = operand + j1; + const uint64_t* Y_op = X_op + t; + + const __m512i* v_X_op_pt = reinterpret_cast(X_op); + const __m512i* v_Y_op_pt = reinterpret_cast(Y_op); + + // Referencing result + uint64_t* X_r = result + j1; + uint64_t* Y_r = X_r + t; + + __m512i* v_X_r_pt = reinterpret_cast<__m512i*>(X_r); + __m512i* v_Y_r_pt = reinterpret_cast<__m512i*>(Y_r); + + // Weights and weights' preconditions + __m512i v_W = _mm512_set1_epi64(static_cast(*W++)); + __m512i v_W_precon = _mm512_set1_epi64(static_cast(*W_precon++)); + + // assume 8 | t + for (size_t j = t / 8; j > 0; --j) { + __m512i v_X = _mm512_loadu_si512(v_X_op_pt); + __m512i v_Y = _mm512_loadu_si512(v_Y_op_pt); + + FwdButterfly(&v_X, &v_Y, v_W, v_W_precon, + v_neg_modulus, v_twice_mod); + + _mm512_storeu_si512(v_X_r_pt++, v_X); + _mm512_storeu_si512(v_Y_r_pt++, v_Y); + + // Increase operand pointers as well + v_X_op_pt++; + v_Y_op_pt++; + } + j1 += (t << 1); + } +} + +template +void ForwardTransformToBitReverseAVX512( + uint64_t* result, const uint64_t* operand, uint64_t n, uint64_t modulus, + const uint64_t* root_of_unity_powers, + const uint64_t* precon_root_of_unity_powers, uint64_t input_mod_factor, + uint64_t output_mod_factor, uint64_t recursion_depth, + uint64_t recursion_half) { + HEXL_CHECK(NTT::CheckArguments(n, modulus), ""); + HEXL_CHECK(modulus < NTT::s_max_fwd_modulus(BitShift), + "modulus " << modulus << " too large for BitShift " << BitShift + << " => maximum value " + << NTT::s_max_fwd_modulus(BitShift)); + HEXL_CHECK_BOUNDS(precon_root_of_unity_powers, n, MaximumValue(BitShift), + "precon_root_of_unity_powers too large"); + HEXL_CHECK_BOUNDS(operand, n, MaximumValue(BitShift), "operand too large"); + // Skip input bound checking for recursive steps + HEXL_CHECK_BOUNDS(operand, (recursion_depth == 0) ? n : 0, + input_mod_factor * modulus, + "operand larger than input_mod_factor * modulus (" + << input_mod_factor << " * " << modulus << ")"); + HEXL_CHECK(n >= 16, + "Don't support small transforms. Need n >= 16, got n = " << n); + HEXL_CHECK( + input_mod_factor == 1 || input_mod_factor == 2 || input_mod_factor == 4, + "input_mod_factor must be 1, 2, or 4; got " << input_mod_factor); + HEXL_CHECK(output_mod_factor == 1 || output_mod_factor == 4, + "output_mod_factor must be 1 or 4; got " << output_mod_factor); + + uint64_t twice_mod = modulus << 1; + + __m512i v_modulus = _mm512_set1_epi64(static_cast(modulus)); + __m512i v_neg_modulus = _mm512_set1_epi64(-static_cast(modulus)); + __m512i v_twice_mod = _mm512_set1_epi64(static_cast(twice_mod)); + + HEXL_VLOG(5, "root_of_unity_powers " << std::vector( + root_of_unity_powers, root_of_unity_powers + n)) + HEXL_VLOG(5, + "precon_root_of_unity_powers " << std::vector( + precon_root_of_unity_powers, precon_root_of_unity_powers + n)); + HEXL_VLOG(5, "operand " << std::vector(operand, operand + n)); + + static const size_t base_ntt_size = 1024; + + if (n <= base_ntt_size) { // Perform breadth-first NTT + size_t t = (n >> 1); + size_t m = 1; + size_t W_idx = (m << recursion_depth) + (recursion_half * m); + + // Copy for out-of-place in case m is <= base_ntt_size from start + if (result != operand) { + std::memcpy(result, operand, n * sizeof(uint64_t)); + } + + // First iteration assumes input in [0,p) + if (m < (n >> 3)) { + const uint64_t* W = &root_of_unity_powers[W_idx]; + const uint64_t* W_precon = &precon_root_of_unity_powers[W_idx]; + + if ((input_mod_factor <= 2) && (recursion_depth == 0)) { + FwdT8(result, result, v_neg_modulus, v_twice_mod, t, m, + W, W_precon); + } else { + FwdT8(result, result, v_neg_modulus, v_twice_mod, t, m, + W, W_precon); + } + + t >>= 1; + m <<= 1; + W_idx <<= 1; + } + for (; m < (n >> 3); m <<= 1) { + const uint64_t* W = &root_of_unity_powers[W_idx]; + const uint64_t* W_precon = &precon_root_of_unity_powers[W_idx]; + FwdT8(result, result, v_neg_modulus, v_twice_mod, t, m, + W, W_precon); + t >>= 1; + W_idx <<= 1; + } + + // Do T=4, T=2, T=1 separately + { + // Correction step needed due to extra copies of roots of unity in the + // AVX512 vectors loaded for FwdT2 and FwdT4 + auto compute_new_W_idx = [&](size_t idx) { + // Originally, from root of unity vector index to loop: + // [0, N/8) => FwdT8 + // [N/8, N/4) => FwdT4 + // [N/4, N/2) => FwdT2 + // [N/2, N) => FwdT1 + // The new mapping from AVX512 root of unity vector index to loop: + // [0, N/8) => FwdT8 + // [N/8, 5N/8) => FwdT4 + // [5N/8, 9N/8) => FwdT2 + // [9N/8, 13N/8) => FwdT1 + size_t N = n << recursion_depth; + + // FwdT8 range + if (idx <= N / 8) { + return idx; + } + // FwdT4 range + if (idx <= N / 4) { + return (idx - N / 8) * 4 + (N / 8); + } + // FwdT2 range + if (idx <= N / 2) { + return (idx - N / 4) * 2 + (5 * N / 8); + } + // FwdT1 range + return idx + (5 * N / 8); + }; + + size_t new_W_idx = compute_new_W_idx(W_idx); + const uint64_t* W = &root_of_unity_powers[new_W_idx]; + const uint64_t* W_precon = &precon_root_of_unity_powers[new_W_idx]; + FwdT4(result, v_neg_modulus, v_twice_mod, m, W, W_precon); + + m <<= 1; + W_idx <<= 1; + new_W_idx = compute_new_W_idx(W_idx); + W = &root_of_unity_powers[new_W_idx]; + W_precon = &precon_root_of_unity_powers[new_W_idx]; + FwdT2(result, v_neg_modulus, v_twice_mod, m, W, W_precon); + + m <<= 1; + W_idx <<= 1; + new_W_idx = compute_new_W_idx(W_idx); + W = &root_of_unity_powers[new_W_idx]; + W_precon = &precon_root_of_unity_powers[new_W_idx]; + FwdT1(result, v_neg_modulus, v_twice_mod, m, W, W_precon); + } + + if (output_mod_factor == 1) { + // n power of two at least 8 => n divisible by 8 + HEXL_CHECK(n % 8 == 0, "n " << n << " not a power of 2"); + __m512i* v_X_pt = reinterpret_cast<__m512i*>(result); + for (size_t i = 0; i < n; i += 8) { + __m512i v_X = _mm512_loadu_si512(v_X_pt); + + // Reduce from [0, 4q) to [0, q) + v_X = _mm512_hexl_small_mod_epu64(v_X, v_twice_mod); + v_X = _mm512_hexl_small_mod_epu64(v_X, v_modulus); + + HEXL_CHECK_BOUNDS(ExtractValues(v_X).data(), 8, modulus, + "v_X exceeds bound " << modulus); + + _mm512_storeu_si512(v_X_pt, v_X); + + ++v_X_pt; + } + } + } else { + // Perform depth-first NTT via recursive call + size_t t = (n >> 1); + size_t W_idx = (1ULL << recursion_depth) + recursion_half; + const uint64_t* W = &root_of_unity_powers[W_idx]; + const uint64_t* W_precon = &precon_root_of_unity_powers[W_idx]; + + FwdT8(result, operand, v_neg_modulus, v_twice_mod, t, 1, W, + W_precon); + + ForwardTransformToBitReverseAVX512( + result, result, n / 2, modulus, root_of_unity_powers, + precon_root_of_unity_powers, input_mod_factor, output_mod_factor, + recursion_depth + 1, recursion_half * 2); + + ForwardTransformToBitReverseAVX512( + &result[n / 2], &result[n / 2], n / 2, modulus, root_of_unity_powers, + precon_root_of_unity_powers, input_mod_factor, output_mod_factor, + recursion_depth + 1, recursion_half * 2 + 1); + } +} + +#endif // HEXL_HAS_AVX512DQ + +} // namespace hexl +} // namespace intel diff --git a/hexl_ser/ntt/fwd-ntt-avx512.hpp b/hexl_ser/ntt/fwd-ntt-avx512.hpp new file mode 100644 index 00000000..b3e4cdff --- /dev/null +++ b/hexl_ser/ntt/fwd-ntt-avx512.hpp @@ -0,0 +1,46 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "hexl/ntt/ntt.hpp" + +namespace intel { +namespace hexl { + +#ifdef HEXL_HAS_AVX512DQ + +/// @brief AVX512 implementation of the forward NTT +/// @param[in, out] operand Input data. Overwritten with NTT output +/// @param[in] n Size of the transform, i.e. the polynomial degree. Must be a +/// power of two. +/// @param[in] modulus Prime modulus q. Must satisfy q == 1 mod 2n +/// @param[in] root_of_unity_powers Powers of 2n'th root of unity in F_q. In +/// bit-reversed order. +/// @param[in] precon_root_of_unity_powers Pre-conditioned Powers of 2n'th root +/// of unity in F_q. In bit-reversed order. +/// @param[in] input_mod_factor Upper bound for inputs; inputs must be in [0, +/// input_mod_factor * q) +/// @param[in] output_mod_factor Upper bound for result; result must be in [0, +/// output_mod_factor * q) +/// @param[in] recursion_depth Depth of recursive call +/// @param[in] recursion_half Helper for indexing roots of unity +/// @details The implementation is recursive. The base case is a breadth-first +/// NTT, where all the butterflies in a given stage are processed before any +/// butterflies in the next stage. The base case is small enough to fit in the +/// smallest cache. Larger NTTs are processed recursively in a depth-first +/// manner, such that an entire subtransform is completed before moving to the +/// next subtransform. This reduces the number of cache misses, improving +/// performance on larger transform sizes. +template +void ForwardTransformToBitReverseAVX512( + uint64_t* result, const uint64_t* operand, uint64_t n, uint64_t modulus, + const uint64_t* root_of_unity_powers, + const uint64_t* precon_root_of_unity_powers, uint64_t input_mod_factor, + uint64_t output_mod_factor, uint64_t recursion_depth = 0, + uint64_t recursion_half = 0); + +#endif // HEXL_HAS_AVX512DQ + +} // namespace hexl +} // namespace intel diff --git a/hexl_ser/ntt/inv-ntt-avx512.cpp b/hexl_ser/ntt/inv-ntt-avx512.cpp new file mode 100644 index 00000000..8d340fe1 --- /dev/null +++ b/hexl_ser/ntt/inv-ntt-avx512.cpp @@ -0,0 +1,438 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "ntt/inv-ntt-avx512.hpp" + +#include + +#include +#include +#include + +#include "hexl/logging/logging.hpp" +#include "hexl/ntt/ntt.hpp" +#include "hexl/number-theory/number-theory.hpp" +#include "ntt/ntt-avx512-util.hpp" +#include "ntt/ntt-internal.hpp" +#include "util/avx512-util.hpp" + +namespace intel { +namespace hexl { + +#ifdef HEXL_HAS_AVX512IFMA +template void InverseTransformFromBitReverseAVX512( + uint64_t* result, const uint64_t* operand, uint64_t degree, + uint64_t modulus, const uint64_t* inv_root_of_unity_powers, + const uint64_t* precon_inv_root_of_unity_powers, uint64_t input_mod_factor, + uint64_t output_mod_factor, uint64_t recursion_depth, + uint64_t recursion_half); +#endif + +#ifdef HEXL_HAS_AVX512DQ +template void InverseTransformFromBitReverseAVX512<32>( + uint64_t* result, const uint64_t* operand, uint64_t degree, + uint64_t modulus, const uint64_t* inv_root_of_unity_powers, + const uint64_t* precon_inv_root_of_unity_powers, uint64_t input_mod_factor, + uint64_t output_mod_factor, uint64_t recursion_depth, + uint64_t recursion_half); + +template void InverseTransformFromBitReverseAVX512( + uint64_t* result, const uint64_t* operand, uint64_t degree, + uint64_t modulus, const uint64_t* inv_root_of_unity_powers, + const uint64_t* precon_inv_root_of_unity_powers, uint64_t input_mod_factor, + uint64_t output_mod_factor, uint64_t recursion_depth, + uint64_t recursion_half); +#endif + +#ifdef HEXL_HAS_AVX512DQ + +/// @brief The Harvey butterfly: assume X, Y in [0, 2q), and return X', Y' in +/// [0, 2q). such that X', Y' = X + Y (mod q), W(X - Y) (mod q). +/// @param[in,out] X Input representing 8 64-bit signed integers in SIMD form +/// @param[in,out] Y Input representing 8 64-bit signed integers in SIMD form +/// @param[in] W Root of unity representing 8 64-bit signed integers in SIMD +/// form +/// @param[in] W_precon Preconditioned \p W for BitShift-bit Barrett +/// reduction +/// @param[in] neg_modulus Negative modulus, i.e. (-q) represented as 8 64-bit +/// signed integers in SIMD form +/// @param[in] twice_modulus Twice the modulus, i.e. 2*q represented as 8 64-bit +/// signed integers in SIMD form +/// @param InputLessThanMod If true, assumes \p X, \p Y < \p q. Otherwise, +/// assumes \p X, \p Y < 2*\p q +/// @details See Algorithm 3 of https://arxiv.org/pdf/1205.2926.pdf +template +inline void InvButterfly(__m512i* X, __m512i* Y, __m512i W, __m512i W_precon, + __m512i neg_modulus, __m512i twice_modulus) { + // Compute T first to allow in-place update of X + __m512i Y_minus_2q = _mm512_sub_epi64(*Y, twice_modulus); + __m512i T = _mm512_sub_epi64(*X, Y_minus_2q); + + if (InputLessThanMod) { + // No need for modulus reduction, since inputs are in [0, q) + *X = _mm512_add_epi64(*X, *Y); + } else { + // Algorithm 3 computes (X >= 2q) ? (X - 2q) : X + // We instead compute (X - 2q >= 0) ? (X - 2q) : X + // This allows us to use the faster _mm512_movepi64_mask rather than + // _mm512_cmp_epu64_mask to create the mask. + *X = _mm512_add_epi64(*X, Y_minus_2q); + __mmask8 sign_bits = _mm512_movepi64_mask(*X); + *X = _mm512_mask_add_epi64(*X, sign_bits, *X, twice_modulus); + } + + if (BitShift == 32) { + __m512i Q = _mm512_hexl_mullo_epi<64>(W_precon, T); + Q = _mm512_srli_epi64(Q, 32); + __m512i Q_p = _mm512_hexl_mullo_epi<64>(Q, neg_modulus); + *Y = _mm512_hexl_mullo_add_lo_epi<64>(Q_p, W, T); + } else if (BitShift == 52) { + __m512i Q = _mm512_hexl_mulhi_epi(W_precon, T); + __m512i Q_p = _mm512_hexl_mullo_epi(Q, neg_modulus); + *Y = _mm512_hexl_mullo_add_lo_epi(Q_p, W, T); + } else if (BitShift == 64) { + // Perform approximate computation of Q, as described in page 7 of + // https://arxiv.org/pdf/2003.04510.pdf + __m512i Q = _mm512_hexl_mulhi_approx_epi(W_precon, T); + __m512i Q_p = _mm512_hexl_mullo_epi(Q, neg_modulus); + // Compute Y in range [0, 4q) + *Y = _mm512_hexl_mullo_add_lo_epi(Q_p, W, T); + // Reduce Y to range [0, 2q) + *Y = _mm512_hexl_small_mod_epu64<2>(*Y, twice_modulus); + } else { + HEXL_CHECK(false, "Invalid BitShift " << BitShift); + } +} + +template +void InvT1(uint64_t* operand, __m512i v_neg_modulus, __m512i v_twice_mod, + uint64_t m, const uint64_t* W, const uint64_t* W_precon) { + const __m512i* v_W_pt = reinterpret_cast(W); + const __m512i* v_W_precon_pt = reinterpret_cast(W_precon); + size_t j1 = 0; + + // 8 | m guaranteed by n >= 16 + HEXL_LOOP_UNROLL_8 + for (size_t i = m / 8; i > 0; --i) { + uint64_t* X = operand + j1; + __m512i* v_X_pt = reinterpret_cast<__m512i*>(X); + + __m512i v_X; + __m512i v_Y; + LoadInvInterleavedT1(X, &v_X, &v_Y); + + __m512i v_W = _mm512_loadu_si512(v_W_pt++); + __m512i v_W_precon = _mm512_loadu_si512(v_W_precon_pt++); + + InvButterfly(&v_X, &v_Y, v_W, v_W_precon, + v_neg_modulus, v_twice_mod); + + _mm512_storeu_si512(v_X_pt++, v_X); + _mm512_storeu_si512(v_X_pt, v_Y); + + j1 += 16; + } +} + +template +void InvT2(uint64_t* X, __m512i v_neg_modulus, __m512i v_twice_mod, uint64_t m, + const uint64_t* W, const uint64_t* W_precon) { + // 4 | m guaranteed by n >= 16 + HEXL_LOOP_UNROLL_4 + for (size_t i = m / 4; i > 0; --i) { + __m512i* v_X_pt = reinterpret_cast<__m512i*>(X); + + __m512i v_X; + __m512i v_Y; + LoadInvInterleavedT2(X, &v_X, &v_Y); + + __m512i v_W = LoadWOpT2(static_cast(W)); + __m512i v_W_precon = LoadWOpT2(static_cast(W_precon)); + + InvButterfly(&v_X, &v_Y, v_W, v_W_precon, v_neg_modulus, + v_twice_mod); + + _mm512_storeu_si512(v_X_pt++, v_X); + _mm512_storeu_si512(v_X_pt, v_Y); + X += 16; + + W += 4; + W_precon += 4; + } +} + +template +void InvT4(uint64_t* operand, __m512i v_neg_modulus, __m512i v_twice_mod, + uint64_t m, const uint64_t* W, const uint64_t* W_precon) { + uint64_t* X = operand; + + // 2 | m guaranteed by n >= 16 + HEXL_LOOP_UNROLL_4 + for (size_t i = m / 2; i > 0; --i) { + __m512i* v_X_pt = reinterpret_cast<__m512i*>(X); + + __m512i v_X; + __m512i v_Y; + LoadInvInterleavedT4(X, &v_X, &v_Y); + + __m512i v_W = LoadWOpT4(static_cast(W)); + __m512i v_W_precon = LoadWOpT4(static_cast(W_precon)); + + InvButterfly(&v_X, &v_Y, v_W, v_W_precon, v_neg_modulus, + v_twice_mod); + + WriteInvInterleavedT4(v_X, v_Y, v_X_pt); + X += 16; + + W += 2; + W_precon += 2; + } +} + +template +void InvT8(uint64_t* operand, __m512i v_neg_modulus, __m512i v_twice_mod, + uint64_t t, uint64_t m, const uint64_t* W, + const uint64_t* W_precon) { + size_t j1 = 0; + + HEXL_LOOP_UNROLL_4 + for (size_t i = 0; i < m; i++) { + uint64_t* X = operand + j1; + uint64_t* Y = X + t; + + __m512i v_W = _mm512_set1_epi64(static_cast(*W++)); + __m512i v_W_precon = _mm512_set1_epi64(static_cast(*W_precon++)); + + __m512i* v_X_pt = reinterpret_cast<__m512i*>(X); + __m512i* v_Y_pt = reinterpret_cast<__m512i*>(Y); + + // assume 8 | t + for (size_t j = t / 8; j > 0; --j) { + __m512i v_X = _mm512_loadu_si512(v_X_pt); + __m512i v_Y = _mm512_loadu_si512(v_Y_pt); + + InvButterfly(&v_X, &v_Y, v_W, v_W_precon, v_neg_modulus, + v_twice_mod); + + _mm512_storeu_si512(v_X_pt++, v_X); + _mm512_storeu_si512(v_Y_pt++, v_Y); + } + j1 += (t << 1); + } +} + +template +void InverseTransformFromBitReverseAVX512( + uint64_t* result, const uint64_t* operand, uint64_t n, uint64_t modulus, + const uint64_t* inv_root_of_unity_powers, + const uint64_t* precon_inv_root_of_unity_powers, uint64_t input_mod_factor, + uint64_t output_mod_factor, uint64_t recursion_depth, + uint64_t recursion_half) { + HEXL_CHECK(NTT::CheckArguments(n, modulus), ""); + HEXL_CHECK(n >= 16, + "InverseTransformFromBitReverseAVX512 doesn't support small " + "transforms. Need n >= 16, got n = " + << n); + HEXL_CHECK(modulus < NTT::s_max_inv_modulus(BitShift), + "modulus " << modulus << " too large for BitShift " << BitShift + << " => maximum value " + << NTT::s_max_inv_modulus(BitShift)); + HEXL_CHECK_BOUNDS(precon_inv_root_of_unity_powers, n, MaximumValue(BitShift), + "precon_inv_root_of_unity_powers too large"); + HEXL_CHECK_BOUNDS(operand, n, MaximumValue(BitShift), "operand too large"); + // Skip input bound checking for recursive steps + HEXL_CHECK_BOUNDS(operand, (recursion_depth == 0) ? n : 0, + input_mod_factor * modulus, + "operand larger than input_mod_factor * modulus (" + << input_mod_factor << " * " << modulus << ")"); + HEXL_CHECK(input_mod_factor == 1 || input_mod_factor == 2, + "input_mod_factor must be 1 or 2; got " << input_mod_factor); + HEXL_CHECK(output_mod_factor == 1 || output_mod_factor == 2, + "output_mod_factor must be 1 or 2; got " << output_mod_factor); + + uint64_t twice_mod = modulus << 1; + __m512i v_modulus = _mm512_set1_epi64(static_cast(modulus)); + __m512i v_neg_modulus = _mm512_set1_epi64(-static_cast(modulus)); + __m512i v_twice_mod = _mm512_set1_epi64(static_cast(twice_mod)); + + size_t t = 1; + size_t m = (n >> 1); + size_t W_idx = 1 + m * recursion_half; + + static const size_t base_ntt_size = 1024; + + if (n <= base_ntt_size) { // Perform breadth-first InvNTT + if (operand != result) { + std::memcpy(result, operand, n * sizeof(uint64_t)); + } + + // Extract t=1, t=2, t=4 loops separately + { + // t = 1 + const uint64_t* W = &inv_root_of_unity_powers[W_idx]; + const uint64_t* W_precon = &precon_inv_root_of_unity_powers[W_idx]; + if ((input_mod_factor == 1) && (recursion_depth == 0)) { + InvT1(result, v_neg_modulus, v_twice_mod, m, W, + W_precon); + } else { + InvT1(result, v_neg_modulus, v_twice_mod, m, W, + W_precon); + } + + t <<= 1; + m >>= 1; + uint64_t W_idx_delta = + m * ((1ULL << (recursion_depth + 1)) - recursion_half); + W_idx += W_idx_delta; + + // t = 2 + W = &inv_root_of_unity_powers[W_idx]; + W_precon = &precon_inv_root_of_unity_powers[W_idx]; + InvT2(result, v_neg_modulus, v_twice_mod, m, W, W_precon); + + t <<= 1; + m >>= 1; + W_idx_delta >>= 1; + W_idx += W_idx_delta; + + // t = 4 + W = &inv_root_of_unity_powers[W_idx]; + W_precon = &precon_inv_root_of_unity_powers[W_idx]; + InvT4(result, v_neg_modulus, v_twice_mod, m, W, W_precon); + t <<= 1; + m >>= 1; + W_idx_delta >>= 1; + W_idx += W_idx_delta; + + // t >= 8 + for (; m > 1;) { + W = &inv_root_of_unity_powers[W_idx]; + W_precon = &precon_inv_root_of_unity_powers[W_idx]; + InvT8(result, v_neg_modulus, v_twice_mod, t, m, W, W_precon); + t <<= 1; + m >>= 1; + W_idx_delta >>= 1; + W_idx += W_idx_delta; + } + } + } else { + InverseTransformFromBitReverseAVX512( + result, operand, n / 2, modulus, inv_root_of_unity_powers, + precon_inv_root_of_unity_powers, input_mod_factor, output_mod_factor, + recursion_depth + 1, 2 * recursion_half); + InverseTransformFromBitReverseAVX512( + &result[n / 2], &operand[n / 2], n / 2, modulus, + inv_root_of_unity_powers, precon_inv_root_of_unity_powers, + input_mod_factor, output_mod_factor, recursion_depth + 1, + 2 * recursion_half + 1); + + uint64_t W_idx_delta = + m * ((1ULL << (recursion_depth + 1)) - recursion_half); + for (; m > 2; m >>= 1) { + t <<= 1; + W_idx_delta >>= 1; + W_idx += W_idx_delta; + } + if (m == 2) { + const uint64_t* W = &inv_root_of_unity_powers[W_idx]; + const uint64_t* W_precon = &precon_inv_root_of_unity_powers[W_idx]; + InvT8(result, v_neg_modulus, v_twice_mod, t, m, W, W_precon); + t <<= 1; + m >>= 1; + W_idx_delta >>= 1; + W_idx += W_idx_delta; + } + } + + // Final loop through data + if (recursion_depth == 0) { + HEXL_VLOG(4, "AVX512 intermediate result " + << std::vector(result, result + n)); + + const uint64_t W = inv_root_of_unity_powers[W_idx]; + MultiplyFactor mf_inv_n(InverseMod(n, modulus), BitShift, modulus); + const uint64_t inv_n = mf_inv_n.Operand(); + const uint64_t inv_n_prime = mf_inv_n.BarrettFactor(); + + MultiplyFactor mf_inv_n_w(MultiplyMod(inv_n, W, modulus), BitShift, + modulus); + const uint64_t inv_n_w = mf_inv_n_w.Operand(); + const uint64_t inv_n_w_prime = mf_inv_n_w.BarrettFactor(); + + HEXL_VLOG(4, "inv_n_w " << inv_n_w); + + uint64_t* X = result; + uint64_t* Y = X + (n >> 1); + + __m512i v_inv_n = _mm512_set1_epi64(static_cast(inv_n)); + __m512i v_inv_n_prime = + _mm512_set1_epi64(static_cast(inv_n_prime)); + __m512i v_inv_n_w = _mm512_set1_epi64(static_cast(inv_n_w)); + __m512i v_inv_n_w_prime = + _mm512_set1_epi64(static_cast(inv_n_w_prime)); + + __m512i* v_X_pt = reinterpret_cast<__m512i*>(X); + __m512i* v_Y_pt = reinterpret_cast<__m512i*>(Y); + + // Merge final InvNTT loop with modulus reduction baked-in + HEXL_LOOP_UNROLL_4 + for (size_t j = n / 16; j > 0; --j) { + __m512i v_X = _mm512_loadu_si512(v_X_pt); + __m512i v_Y = _mm512_loadu_si512(v_Y_pt); + + // Slightly different from regular InvButterfly because different W is + // used for X and Y + __m512i Y_minus_2q = _mm512_sub_epi64(v_Y, v_twice_mod); + __m512i X_plus_Y_mod2q = + _mm512_hexl_small_add_mod_epi64(v_X, v_Y, v_twice_mod); + // T = *X + twice_mod - *Y + __m512i T = _mm512_sub_epi64(v_X, Y_minus_2q); + + if (BitShift == 32) { + __m512i Q1 = _mm512_hexl_mullo_epi<64>(v_inv_n_prime, X_plus_Y_mod2q); + Q1 = _mm512_srli_epi64(Q1, 32); + // X = inv_N * X_plus_Y_mod2q - Q1 * modulus; + __m512i inv_N_tx = _mm512_hexl_mullo_epi<64>(v_inv_n, X_plus_Y_mod2q); + v_X = _mm512_hexl_mullo_add_lo_epi<64>(inv_N_tx, Q1, v_neg_modulus); + + __m512i Q2 = _mm512_hexl_mullo_epi<64>(v_inv_n_w_prime, T); + Q2 = _mm512_srli_epi64(Q2, 32); + + // Y = inv_N_W * T - Q2 * modulus; + __m512i inv_N_W_T = _mm512_hexl_mullo_epi<64>(v_inv_n_w, T); + v_Y = _mm512_hexl_mullo_add_lo_epi<64>(inv_N_W_T, Q2, v_neg_modulus); + } else { + __m512i Q1 = + _mm512_hexl_mulhi_epi(v_inv_n_prime, X_plus_Y_mod2q); + // X = inv_N * X_plus_Y_mod2q - Q1 * modulus; + __m512i inv_N_tx = + _mm512_hexl_mullo_epi(v_inv_n, X_plus_Y_mod2q); + v_X = + _mm512_hexl_mullo_add_lo_epi(inv_N_tx, Q1, v_neg_modulus); + + __m512i Q2 = _mm512_hexl_mulhi_epi(v_inv_n_w_prime, T); + // Y = inv_N_W * T - Q2 * modulus; + __m512i inv_N_W_T = _mm512_hexl_mullo_epi(v_inv_n_w, T); + v_Y = _mm512_hexl_mullo_add_lo_epi(inv_N_W_T, Q2, + v_neg_modulus); + } + + if (output_mod_factor == 1) { + // Modulus reduction from [0, 2q), to [0, q) + v_X = _mm512_hexl_small_mod_epu64(v_X, v_modulus); + v_Y = _mm512_hexl_small_mod_epu64(v_Y, v_modulus); + } + + _mm512_storeu_si512(v_X_pt++, v_X); + _mm512_storeu_si512(v_Y_pt++, v_Y); + } + + HEXL_VLOG(5, "AVX512 returning result " + << std::vector(result, result + n)); + } +} + +#endif // HEXL_HAS_AVX512DQ + +} // namespace hexl +} // namespace intel diff --git a/hexl_ser/ntt/inv-ntt-avx512.hpp b/hexl_ser/ntt/inv-ntt-avx512.hpp new file mode 100644 index 00000000..143b9476 --- /dev/null +++ b/hexl_ser/ntt/inv-ntt-avx512.hpp @@ -0,0 +1,46 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "ntt/ntt-internal.hpp" + +namespace intel { +namespace hexl { + +#ifdef HEXL_HAS_AVX512DQ + +/// @brief AVX512 implementation of the inverse NTT +/// @param[in, out] operand Input data. Overwritten with NTT output +/// @param[in] n Size of the transform, i.e. the polynomial degree. Must be a +/// power of two. +/// @param[in] modulus Prime modulus q. Must satisfy q == 1 mod 2n +/// @param[in] inv_root_of_unity_powers Powers of inverse 2n'th root of unity in +/// F_q. In bit-reversed order. +/// @param[in] precon_root_of_unity_powers Pre-conditioned powers of inverse +/// 2n'th root of unity in F_q. In bit-reversed order. +/// @param[in] input_mod_factor Upper bound for inputs; inputs must be in [0, +/// input_mod_factor * q) +/// @param[in] output_mod_factor Upper bound for result; result must be in [0, +/// output_mod_factor * q) +/// @param[in] recursion_depth Depth of recursive call +/// @param[in] recursion_half Helper for indexing roots of unity +/// @details The implementation is recursive. The base case is a breadth-first +/// NTT, where all the butterflies in a given stage are processed before any +/// butterflies in the next stage. The base case is small enough to fit in the +/// smallest cache. Larger NTTs are processed recursively in a depth-first +/// manner, such that an entire subtransform is completed before moving to the +/// next subtransform. This reduces the number of cache misses, improving +/// performance on larger transform sizes. +template +void InverseTransformFromBitReverseAVX512( + uint64_t* result, const uint64_t* operand, uint64_t n, uint64_t modulus, + const uint64_t* inv_root_of_unity_powers, + const uint64_t* precon_inv_root_of_unity_powers, uint64_t input_mod_factor, + uint64_t output_mod_factor, uint64_t recursion_depth = 0, + uint64_t recursion_half = 0); + +#endif // HEXL_HAS_AVX512DQ + +} // namespace hexl +} // namespace intel diff --git a/hexl_ser/ntt/ntt-avx512-util.hpp b/hexl_ser/ntt/ntt-avx512-util.hpp new file mode 100644 index 00000000..c2342c01 --- /dev/null +++ b/hexl_ser/ntt/ntt-avx512-util.hpp @@ -0,0 +1,218 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +#include +#include + +#include "hexl/ntt/ntt.hpp" +#include "hexl/number-theory/number-theory.hpp" +#include "util/avx512-util.hpp" + +namespace intel { +namespace hexl { + +#ifdef HEXL_HAS_AVX512DQ + +// Given input: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 +// Returns +// *out1 = _mm512_set_epi64(14, 6, 12, 4, 10, 2, 8, 0); +// *out2 = _mm512_set_epi64(15, 7, 13, 5, 11, 3, 9, 1); +inline void LoadFwdInterleavedT1(const uint64_t* arg, __m512i* out1, + __m512i* out2) { + const __m512i* arg_512 = reinterpret_cast(arg); + + // 0, 1, 2, 3, 4, 5, 6, 7 + __m512i v1 = _mm512_loadu_si512(arg_512++); + // 8, 9, 10, 11, 12, 13, 14, 15 + __m512i v2 = _mm512_loadu_si512(arg_512); + + const __m512i perm_idx = _mm512_set_epi64(6, 7, 4, 5, 2, 3, 0, 1); + + // 1, 0, 3, 2, 5, 4, 7, 6 + __m512i v1_perm = _mm512_permutexvar_epi64(perm_idx, v1); + // 9, 8, 11, 10, 13, 12, 15, 14 + __m512i v2_perm = _mm512_permutexvar_epi64(perm_idx, v2); + + *out1 = _mm512_mask_blend_epi64(0xaa, v1, v2_perm); + *out2 = _mm512_mask_blend_epi64(0xaa, v1_perm, v2); +} + +// Given input: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 +// Returns +// *out1 = _mm512_set_epi64(14, 12, 10, 8, 6, 4, 2, 0); +// *out2 = _mm512_set_epi64(15, 13, 11, 9, 7, 5, 3, 1); +inline void LoadInvInterleavedT1(const uint64_t* arg, __m512i* out1, + __m512i* out2) { + const __m512i vperm_hi_idx = _mm512_set_epi64(6, 4, 2, 0, 7, 5, 3, 1); + const __m512i vperm_lo_idx = _mm512_set_epi64(7, 5, 3, 1, 6, 4, 2, 0); + const __m512i vperm2_idx = _mm512_set_epi64(3, 2, 1, 0, 7, 6, 5, 4); + + const __m512i* arg_512 = reinterpret_cast(arg); + + // 7, 6, 5, 4, 3, 2, 1, 0 + __m512i v_7to0 = _mm512_loadu_si512(arg_512++); + // 15, 14, 13, 12, 11, 10, 9, 8 + __m512i v_15to8 = _mm512_loadu_si512(arg_512); + // 7, 5, 3, 1, 6, 4, 2, 0 + __m512i perm_lo = _mm512_permutexvar_epi64(vperm_lo_idx, v_7to0); + // 14, 12, 10, 8, 15, 13, 11, 9 + __m512i perm_hi = _mm512_permutexvar_epi64(vperm_hi_idx, v_15to8); + + *out1 = _mm512_mask_blend_epi64(0x0f, perm_hi, perm_lo); + *out2 = _mm512_mask_blend_epi64(0xf0, perm_hi, perm_lo); + *out2 = _mm512_permutexvar_epi64(vperm2_idx, *out2); +} + +// Given input: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 +// Returns +// *out1 = _mm512_set_epi64(13, 12, 9, 8, 5, 4, 1, 0); +// *out2 = _mm512_set_epi64(15, 14, 11, 10, 7, 6, 3, 2) +inline void LoadFwdInterleavedT2(const uint64_t* arg, __m512i* out1, + __m512i* out2) { + const __m512i* arg_512 = reinterpret_cast(arg); + + // 11, 10, 9, 8, 3, 2, 1, 0 + __m512i v1 = _mm512_loadu_si512(arg_512++); + // 15, 14, 13, 12, 7, 6, 5, 4 + __m512i v2 = _mm512_loadu_si512(arg_512); + + const __m512i v1_perm_idx = _mm512_set_epi64(5, 4, 7, 6, 1, 0, 3, 2); + + __m512i v1_perm = _mm512_permutexvar_epi64(v1_perm_idx, v1); + __m512i v2_perm = _mm512_permutexvar_epi64(v1_perm_idx, v2); + + *out1 = _mm512_mask_blend_epi64(0xcc, v1, v2_perm); + *out2 = _mm512_mask_blend_epi64(0xcc, v1_perm, v2); +} + +// Given input: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 +// Returns +// *out1 = _mm512_set_epi64(14, 12, 10, 8, 6, 4, 2, 0); +// *out2 = _mm512_set_epi64(15, 13, 11, 9, 7, 5, 3, 1); +inline void LoadInvInterleavedT2(const uint64_t* arg, __m512i* out1, + __m512i* out2) { + const __m512i* arg_512 = reinterpret_cast(arg); + + __m512i v1 = _mm512_loadu_si512(arg_512++); + __m512i v2 = _mm512_loadu_si512(arg_512); + + const __m512i v1_perm_idx = _mm512_set_epi64(6, 7, 4, 5, 2, 3, 0, 1); + + __m512i v1_perm = _mm512_permutexvar_epi64(v1_perm_idx, v1); + __m512i v2_perm = _mm512_permutexvar_epi64(v1_perm_idx, v2); + + *out1 = _mm512_mask_blend_epi64(0xaa, v1, v2_perm); + *out2 = _mm512_mask_blend_epi64(0xaa, v1_perm, v2); +} + +// Returns +// *out1 = _mm512_set_epi64(arg[11], arg[10], arg[9], arg[8], +// arg[3], arg[2], arg[1], arg[0]); +// *out2 = _mm512_set_epi64(arg[15], arg[14], arg[13], arg[12], +// arg[7], arg[6], arg[5], arg[4]); +inline void LoadFwdInterleavedT4(const uint64_t* arg, __m512i* out1, + __m512i* out2) { + const __m512i* arg_512 = reinterpret_cast(arg); + + const __m512i vperm2_idx = _mm512_set_epi64(3, 2, 1, 0, 7, 6, 5, 4); + __m512i v_7to0 = _mm512_loadu_si512(arg_512++); + __m512i v_15to8 = _mm512_loadu_si512(arg_512); + __m512i perm_hi = _mm512_permutexvar_epi64(vperm2_idx, v_15to8); + *out1 = _mm512_mask_blend_epi64(0x0f, perm_hi, v_7to0); + *out2 = _mm512_mask_blend_epi64(0xf0, perm_hi, v_7to0); + *out2 = _mm512_permutexvar_epi64(vperm2_idx, *out2); +} + +inline void LoadInvInterleavedT4(const uint64_t* arg, __m512i* out1, + __m512i* out2) { + const __m512i* arg_512 = reinterpret_cast(arg); + + // 0, 1, 2, 3, 4, 5, 6, 7 + __m512i v1 = _mm512_loadu_si512(arg_512++); + // 8, 9, 10, 11, 12, 13, 14, 15 + __m512i v2 = _mm512_loadu_si512(arg_512); + const __m512i perm_idx = _mm512_set_epi64(5, 4, 7, 6, 1, 0, 3, 2); + + // 1, 0, 3, 2, 5, 4, 7, 6 + __m512i v1_perm = _mm512_permutexvar_epi64(perm_idx, v1); + // 9, 8, 11, 10, 13, 12, 15, 14 + __m512i v2_perm = _mm512_permutexvar_epi64(perm_idx, v2); + + *out1 = _mm512_mask_blend_epi64(0xcc, v1, v2_perm); + *out2 = _mm512_mask_blend_epi64(0xcc, v1_perm, v2); +} + +// Given inputs +// @param arg1 = _mm512_set_epi64(15, 14, 13, 12, 11, 10, 9, 8); +// @param arg2 = _mm512_set_epi64(7, 6, 5, 4, 3, 2, 1, 0); +// Writes out = {8, 0, 9, 1, 10, 2, 11, 3, +// 12, 4, 13, 5, 14, 6, 15, 7} +inline void WriteFwdInterleavedT1(__m512i arg1, __m512i arg2, __m512i* out) { + const __m512i vperm2_idx = _mm512_set_epi64(3, 2, 1, 0, 7, 6, 5, 4); + const __m512i v_X_out_idx = _mm512_set_epi64(7, 3, 6, 2, 5, 1, 4, 0); + const __m512i v_Y_out_idx = _mm512_set_epi64(3, 7, 2, 6, 1, 5, 0, 4); + + // v_Y => (4, 5, 6, 7, 0, 1, 2, 3) + arg2 = _mm512_permutexvar_epi64(vperm2_idx, arg2); + // 4, 5, 6, 7, 12, 13, 14, 15 + __m512i perm_lo = _mm512_mask_blend_epi64(0x0f, arg1, arg2); + + // 8, 9, 10, 11, 0, 1, 2, 3 + __m512i perm_hi = _mm512_mask_blend_epi64(0xf0, arg1, arg2); + + arg1 = _mm512_permutexvar_epi64(v_X_out_idx, perm_hi); + arg2 = _mm512_permutexvar_epi64(v_Y_out_idx, perm_lo); + + _mm512_storeu_si512(out++, arg1); + _mm512_storeu_si512(out, arg2); +} + +// Given inputs +// @param arg1 = _mm512_set_epi64(15, 14, 13, 12, 11, 10, 9, 8); +// @param arg2 = _mm512_set_epi64(7, 6, 5, 4, 3, 2, 1, 0); +// Writes out = {8, 9, 10, 11, 0, 1, 2, 3, +// 12, 13, 14, 15, 4, 5, 6, 7} +inline void WriteInvInterleavedT4(__m512i arg1, __m512i arg2, __m512i* out) { + __m256i x0 = _mm512_extracti64x4_epi64(arg1, 0); + __m256i x1 = _mm512_extracti64x4_epi64(arg1, 1); + __m256i y0 = _mm512_extracti64x4_epi64(arg2, 0); + __m256i y1 = _mm512_extracti64x4_epi64(arg2, 1); + __m256i* out_256 = reinterpret_cast<__m256i*>(out); + _mm256_storeu_si256(out_256++, x0); + _mm256_storeu_si256(out_256++, y0); + _mm256_storeu_si256(out_256++, x1); + _mm256_storeu_si256(out_256++, y1); +} + +// Returns _mm512_set_epi64(arg[3], arg[3], arg[2], arg[2], +// arg[1], arg[1], arg[0], arg[0]); +inline __m512i LoadWOpT2(const void* arg) { + const __m512i vperm_w_idx = _mm512_set_epi64(3, 3, 2, 2, 1, 1, 0, 0); + + __m256i v_W_256 = _mm256_loadu_si256(reinterpret_cast(arg)); + __m512i v_W = _mm512_broadcast_i64x4(v_W_256); + v_W = _mm512_permutexvar_epi64(vperm_w_idx, v_W); + + return v_W; +} + +// Returns _mm512_set_epi64(arg[1], arg[1], arg[1], arg[1], +// arg[0], arg[0], arg[0], arg[0]); +inline __m512i LoadWOpT4(const void* arg) { + const __m512i vperm_w_idx = _mm512_set_epi64(1, 1, 1, 1, 0, 0, 0, 0); + + __m128i v_W_128 = _mm_loadu_si128(reinterpret_cast(arg)); + __m512i v_W = _mm512_broadcast_i64x2(v_W_128); + v_W = _mm512_permutexvar_epi64(vperm_w_idx, v_W); + + return v_W; +} + +#endif // HEXL_HAS_AVX512DQ + +} // namespace hexl +} // namespace intel diff --git a/hexl_ser/ntt/ntt-default.hpp b/hexl_ser/ntt/ntt-default.hpp new file mode 100644 index 00000000..7fe79db4 --- /dev/null +++ b/hexl_ser/ntt/ntt-default.hpp @@ -0,0 +1,159 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "hexl/logging/logging.hpp" +#include "hexl/number-theory/number-theory.hpp" +#include "hexl/util/check.hpp" + +namespace intel { +namespace hexl { + +/// @brief Out of place Harvey butterfly: assume \p X_op, \p Y_op in [0, 4q), +/// and return X_r, Y_r in [0, 4q) such that X_r = X_op + WY_op, Y_r = X_op - +/// WY_op (mod q). +/// @param[out] X_r Butterfly data +/// @param[out] Y_r Butterfly data +/// @param[in] X_op Butterfly data +/// @param[in] Y_op Butterfly data +/// @param[in] W Root of unity +/// @param[in] W_precon Preconditioned \p W for BitShift-bit Barrett +/// reduction +/// @param[in] modulus Negative modulus, i.e. (-q) represented as 8 64-bit +/// signed integers in SIMD form +/// @param[in] twice_modulus Twice the modulus, i.e. 2*q represented as 8 64-bit +/// signed integers in SIMD form +/// @details See Algorithm 4 of https://arxiv.org/pdf/1205.2926.pdf +inline void FwdButterflyRadix2(uint64_t* X_r, uint64_t* Y_r, + const uint64_t* X_op, const uint64_t* Y_op, + uint64_t W, uint64_t W_precon, uint64_t modulus, + uint64_t twice_modulus) { + HEXL_VLOG(5, "FwdButterflyRadix2"); + HEXL_VLOG(5, "Inputs: X_op " << *X_op << ", Y_op " << *Y_op << ", W " << W + << ", modulus " << modulus); + uint64_t tx = ReduceMod<2>(*X_op, twice_modulus); + uint64_t T = MultiplyModLazy<64>(*Y_op, W, W_precon, modulus); + HEXL_VLOG(5, "T " << T); + *X_r = tx + T; + *Y_r = tx + twice_modulus - T; + + HEXL_VLOG(5, "Output X " << *X_r << ", Y " << *Y_r); +} + +// Assume X, Y in [0, n*q) and return X_r, Y_r in [0, (n+2)*q) +// such that X_r = X_op + WY_op mod q and Y_r = X_op - WY_op mod q +inline void FwdButterflyRadix4Lazy(uint64_t* X_r, uint64_t* Y_r, + const uint64_t X_op, const uint64_t Y_op, + uint64_t W, uint64_t W_precon, + uint64_t modulus, uint64_t twice_modulus) { + HEXL_VLOG(3, "FwdButterflyRadix4Lazy"); + HEXL_VLOG(3, "Inputs: X_op " << X_op << ", Y_op " << Y_op << ", W " << W + << ", modulus " << modulus); + + uint64_t T = MultiplyModLazy<64>(Y_op, W, W_precon, modulus); + HEXL_VLOG(3, "T " << T); + *X_r = X_op + T; + *Y_r = X_op + twice_modulus - T; + + HEXL_VLOG(3, "Outputs: X_r " << *X_r << ", Y_r " << *Y_r); +} + +// Assume X0, X1, X2, X3 in [0, 4q) and return X0, X1, X2, X3 in [0, 4q) +inline void FwdButterflyRadix4( + uint64_t* X_r0, uint64_t* X_r1, uint64_t* X_r2, uint64_t* X_r3, + const uint64_t* X_op0, const uint64_t* X_op1, const uint64_t* X_op2, + const uint64_t* X_op3, uint64_t W1, uint64_t W1_precon, uint64_t W2, + uint64_t W2_precon, uint64_t W3, uint64_t W3_precon, uint64_t modulus, + uint64_t twice_modulus, uint64_t four_times_modulus) { + HEXL_VLOG(3, "FwdButterflyRadix4"); + HEXL_UNUSED(four_times_modulus); + + FwdButterflyRadix2(X_r0, X_r2, X_op0, X_op2, W1, W1_precon, modulus, + twice_modulus); + FwdButterflyRadix2(X_r1, X_r3, X_op1, X_op3, W1, W1_precon, modulus, + twice_modulus); + FwdButterflyRadix2(X_r0, X_r1, X_r0, X_r1, W2, W2_precon, modulus, + twice_modulus); + FwdButterflyRadix2(X_r2, X_r3, X_r2, X_r3, W3, W3_precon, modulus, + twice_modulus); + + // Alternate implementation + // // Returns Xs in [0, 6q) + // FwdButterflyRadix4Lazy(X0, X2, W1, W1_precon, modulus, twice_modulus); + // FwdButterflyRadix4Lazy(X1, X3, W1, W1_precon, modulus, twice_modulus); + + // // Returns Xs in [0, 8q) + // FwdButterflyRadix4Lazy(X0, X1, W2, W2_precon, modulus, twice_modulus); + // FwdButterflyRadix4Lazy(X2, X3, W3, W3_precon, modulus, twice_modulus); + + // // Reduce Xs to [0, 4q) + // *X0 = ReduceMod<2>(*X0, four_times_modulus); + // *X1 = ReduceMod<2>(*X1, four_times_modulus); + // *X2 = ReduceMod<2>(*X2, four_times_modulus); + // *X3 = ReduceMod<2>(*X3, four_times_modulus); +} + +/// @brief Out-of-place Harvey butterfly: assume X_op, Y_op in [0, 2q), and +/// return X_r, Y_r in [0, 2q) such that X_r = X_op + Y_op (mod q), +/// Y_r = W(X_op - Y_op) (mod q). +/// @param[out] X_r Butterfly data +/// @param[out] Y_r Butterfly data +/// @param[in] X_op Butterfly data +/// @param[in] Y_op Butterfly data +/// @param[in] W Root of unity +/// @param[in] W_precon Preconditioned root of unity for 64-bit Barrett +/// reduction +/// @param[in] modulus Modulus, i.e. (q) represented as 8 64-bit +/// signed integers in SIMD form +/// @param[in] twice_modulus Twice the modulus, i.e. 2*q represented as 8 64-bit +/// signed integers in SIMD form +/// @details See Algorithm 3 of https://arxiv.org/pdf/1205.2926.pdf +inline void InvButterflyRadix2(uint64_t* X_r, uint64_t* Y_r, + const uint64_t* X_op, const uint64_t* Y_op, + uint64_t W, uint64_t W_precon, uint64_t modulus, + uint64_t twice_modulus) { + HEXL_VLOG(4, "InvButterflyRadix2 X_op " + << *X_op << ", Y_op " << *Y_op << " W " << W << " W_precon " + << W_precon << " modulus " << modulus); + uint64_t tx = *X_op + *Y_op; + *Y_r = *X_op + twice_modulus - *Y_op; + *X_r = ReduceMod<2>(tx, twice_modulus); + *Y_r = MultiplyModLazy<64>(*Y_r, W, W_precon, modulus); + + HEXL_VLOG(4, "InvButterflyRadix2 returning X_r " << *X_r << ", Y_r " << *Y_r); +} + +// Assume X0, X1, X2, X3 in [0, 2q) and return X0, X1, X2, X3 in [0, 2q) +inline void InvButterflyRadix4(uint64_t* X_r0, uint64_t* X_r1, uint64_t* X_r2, + uint64_t* X_r3, const uint64_t* X_op0, + const uint64_t* X_op1, const uint64_t* X_op2, + const uint64_t* X_op3, uint64_t W1, + uint64_t W1_precon, uint64_t W2, + uint64_t W2_precon, uint64_t W3, + uint64_t W3_precon, uint64_t modulus, + uint64_t twice_modulus) { + HEXL_VLOG(4, "InvButterflyRadix4 " // + << "X_op0 " << *X_op0 << ", X_op1 " << *X_op1 // + << ", X_op2 " << *X_op2 << " X_op3 " << *X_op3 // + << " W1 " << W1 << " W1_precon " << W1_precon // + << " W2 " << W2 << " W2_precon " << W2_precon // + << " W3 " << W3 << " W3_precon " << W3_precon // + << " modulus " << modulus); + + InvButterflyRadix2(X_r0, X_r1, X_op0, X_op1, W1, W1_precon, modulus, + twice_modulus); + InvButterflyRadix2(X_r2, X_r3, X_op2, X_op3, W2, W2_precon, modulus, + twice_modulus); + InvButterflyRadix2(X_r0, X_r2, X_r0, X_r2, W3, W3_precon, modulus, + twice_modulus); + InvButterflyRadix2(X_r1, X_r3, X_r1, X_r3, W3, W3_precon, modulus, + twice_modulus); + + HEXL_VLOG(4, "InvButterflyRadix4 returning X0 " << *X_r0 << ", X_r1 " << *X_r1 + << ", X_r2 " << *X_r2 // + << " X_r3 " << *X_r3); +} + +} // namespace hexl +} // namespace intel diff --git a/hexl_ser/ntt/ntt-internal.cpp b/hexl_ser/ntt/ntt-internal.cpp new file mode 100644 index 00000000..9ace8741 --- /dev/null +++ b/hexl_ser/ntt/ntt-internal.cpp @@ -0,0 +1,313 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "ntt/ntt-internal.hpp" + +#include +#include + +#include "hexl/logging/logging.hpp" +#include "hexl/ntt/ntt.hpp" +#include "hexl/number-theory/number-theory.hpp" +#include "hexl/util/aligned-allocator.hpp" +#include "hexl/util/check.hpp" +#include "hexl/util/defines.hpp" +#include "ntt/fwd-ntt-avx512.hpp" +#include "ntt/inv-ntt-avx512.hpp" +#include "util/cpu-features.hpp" + +namespace intel { +namespace hexl { + +AllocatorStrategyPtr mallocStrategy = AllocatorStrategyPtr(new MallocStrategy); + +NTT::NTT(uint64_t degree, uint64_t q, uint64_t root_of_unity, + std::shared_ptr alloc_ptr) + : m_degree(degree), + m_q(q), + m_w(root_of_unity), + m_alloc(alloc_ptr), + m_aligned_alloc(AlignedAllocator(m_alloc)), + m_root_of_unity_powers(m_aligned_alloc), + m_precon32_root_of_unity_powers(m_aligned_alloc), + m_precon64_root_of_unity_powers(m_aligned_alloc), + m_avx512_root_of_unity_powers(m_aligned_alloc), + m_avx512_precon32_root_of_unity_powers(m_aligned_alloc), + m_avx512_precon52_root_of_unity_powers(m_aligned_alloc), + m_avx512_precon64_root_of_unity_powers(m_aligned_alloc), + m_precon32_inv_root_of_unity_powers(m_aligned_alloc), + m_precon52_inv_root_of_unity_powers(m_aligned_alloc), + m_precon64_inv_root_of_unity_powers(m_aligned_alloc), + m_inv_root_of_unity_powers(m_aligned_alloc) { + HEXL_CHECK(CheckArguments(degree, q), ""); + HEXL_CHECK(IsPrimitiveRoot(m_w, 2 * degree, q), + m_w << " is not a primitive 2*" << degree << "'th root of unity"); + + m_degree_bits = Log2(m_degree); + m_w_inv = InverseMod(m_w, m_q); + ComputeRootOfUnityPowers(); +} + +NTT::NTT(uint64_t degree, uint64_t q, std::shared_ptr alloc_ptr) + : NTT(degree, q, MinimalPrimitiveRoot(2 * degree, q), alloc_ptr) {} + +void NTT::ComputeRootOfUnityPowers() { + AlignedVector64 root_of_unity_powers(m_degree, 0, m_aligned_alloc); + AlignedVector64 inv_root_of_unity_powers(m_degree, 0, + m_aligned_alloc); + + // 64-bit preconditioned inverse and root of unity powers + root_of_unity_powers[0] = 1; + inv_root_of_unity_powers[0] = InverseMod(1, m_q); + uint64_t idx = 0; + uint64_t prev_idx = idx; + + for (size_t i = 1; i < m_degree; i++) { + idx = ReverseBits(i, m_degree_bits); + root_of_unity_powers[idx] = + MultiplyMod(root_of_unity_powers[prev_idx], m_w, m_q); + inv_root_of_unity_powers[idx] = InverseMod(root_of_unity_powers[idx], m_q); + + prev_idx = idx; + } + + m_root_of_unity_powers = root_of_unity_powers; + m_avx512_root_of_unity_powers = m_root_of_unity_powers; + + // Duplicate each root of unity at indices [N/4, N/2]. + // These are the roots of unity used in the FwdNTT FwdT2 function + // By creating these duplicates, we avoid extra permutations while loading the + // roots of unity + AlignedVector64 W2_roots; + W2_roots.reserve(m_degree / 2); + for (size_t i = m_degree / 4; i < m_degree / 2; ++i) { + W2_roots.push_back(m_root_of_unity_powers[i]); + W2_roots.push_back(m_root_of_unity_powers[i]); + } + m_avx512_root_of_unity_powers.erase( + m_avx512_root_of_unity_powers.begin() + m_degree / 4, + m_avx512_root_of_unity_powers.begin() + m_degree / 2); + m_avx512_root_of_unity_powers.insert( + m_avx512_root_of_unity_powers.begin() + m_degree / 4, W2_roots.begin(), + W2_roots.end()); + + // Duplicate each root of unity at indices [N/8, N/4]. + // These are the roots of unity used in the FwdNTT FwdT4 function + // By creating these duplicates, we avoid extra permutations while loading the + // roots of unity + AlignedVector64 W4_roots; + W4_roots.reserve(m_degree / 2); + for (size_t i = m_degree / 8; i < m_degree / 4; ++i) { + W4_roots.push_back(m_root_of_unity_powers[i]); + W4_roots.push_back(m_root_of_unity_powers[i]); + W4_roots.push_back(m_root_of_unity_powers[i]); + W4_roots.push_back(m_root_of_unity_powers[i]); + } + m_avx512_root_of_unity_powers.erase( + m_avx512_root_of_unity_powers.begin() + m_degree / 8, + m_avx512_root_of_unity_powers.begin() + m_degree / 4); + m_avx512_root_of_unity_powers.insert( + m_avx512_root_of_unity_powers.begin() + m_degree / 8, W4_roots.begin(), + W4_roots.end()); + + auto compute_barrett_vector = [&](const AlignedVector64& values, + uint64_t bit_shift) { + AlignedVector64 barrett_vector(m_aligned_alloc); + for (uint64_t value : values) { + MultiplyFactor mf(value, bit_shift, m_q); + barrett_vector.push_back(mf.BarrettFactor()); + } + return barrett_vector; + }; + + m_precon32_root_of_unity_powers = + compute_barrett_vector(root_of_unity_powers, 32); + m_precon64_root_of_unity_powers = + compute_barrett_vector(root_of_unity_powers, 64); + + // 52-bit preconditioned root of unity powers + if (has_avx512ifma) { + m_avx512_precon52_root_of_unity_powers = + compute_barrett_vector(m_avx512_root_of_unity_powers, 52); + } + + if (has_avx512dq) { + m_avx512_precon32_root_of_unity_powers = + compute_barrett_vector(m_avx512_root_of_unity_powers, 32); + m_avx512_precon64_root_of_unity_powers = + compute_barrett_vector(m_avx512_root_of_unity_powers, 64); + } + + // Inverse root of unity powers + + // Reordering inv_root_of_powers + AlignedVector64 temp(m_degree, 0, m_aligned_alloc); + temp[0] = inv_root_of_unity_powers[0]; + idx = 1; + + for (size_t m = (m_degree >> 1); m > 0; m >>= 1) { + for (size_t i = 0; i < m; i++) { + temp[idx] = inv_root_of_unity_powers[m + i]; + idx++; + } + } + m_inv_root_of_unity_powers = std::move(temp); + + // 32-bit preconditioned inverse root of unity powers + m_precon32_inv_root_of_unity_powers = + compute_barrett_vector(m_inv_root_of_unity_powers, 32); + + // 52-bit preconditioned inverse root of unity powers + if (has_avx512ifma) { + m_precon52_inv_root_of_unity_powers = + compute_barrett_vector(m_inv_root_of_unity_powers, 52); + } + + // 64-bit preconditioned inverse root of unity powers + m_precon64_inv_root_of_unity_powers = + compute_barrett_vector(m_inv_root_of_unity_powers, 64); +} + +bool NTT::CheckArguments(uint64_t degree, uint64_t modulus) { + HEXL_UNUSED(degree); + HEXL_UNUSED(modulus); + HEXL_CHECK(IsPowerOfTwo(degree), + "degree " << degree << " is not a power of 2"); + HEXL_CHECK(degree <= (1ULL << NTT::MaxDegreeBits()), + "degree should be less than 2^" << NTT::MaxDegreeBits() << " got " + << degree); + HEXL_CHECK(modulus <= (1ULL << NTT::MaxModulusBits()), + "modulus should be less than 2^" << NTT::MaxModulusBits() + << " got " << modulus); + HEXL_CHECK(modulus % (2 * degree) == 1, "modulus mod 2n != 1"); + HEXL_CHECK(IsPrime(modulus), "modulus is not prime"); + + return true; +} + +void NTT::ComputeForward(uint64_t* result, const uint64_t* operand, + uint64_t input_mod_factor, + uint64_t output_mod_factor) { + HEXL_CHECK(result != nullptr, "result == nullptr"); + HEXL_CHECK(operand != nullptr, "operand == nullptr"); + HEXL_CHECK( + input_mod_factor == 1 || input_mod_factor == 2 || input_mod_factor == 4, + "input_mod_factor must be 1, 2 or 4; got " << input_mod_factor); + HEXL_CHECK(output_mod_factor == 1 || output_mod_factor == 4, + "output_mod_factor must be 1 or 4; got " << output_mod_factor); + HEXL_CHECK_BOUNDS( + operand, m_degree, m_q * input_mod_factor, + "value in operand exceeds bound " << m_q * input_mod_factor); + +#ifdef HEXL_HAS_AVX512IFMA + if (has_avx512ifma && (m_q < s_max_fwd_ifma_modulus && (m_degree >= 16))) { + const uint64_t* root_of_unity_powers = GetAVX512RootOfUnityPowers().data(); + const uint64_t* precon_root_of_unity_powers = + GetAVX512Precon52RootOfUnityPowers().data(); + + HEXL_VLOG(3, "Calling 52-bit AVX512-IFMA FwdNTT"); + ForwardTransformToBitReverseAVX512( + result, operand, m_degree, m_q, root_of_unity_powers, + precon_root_of_unity_powers, input_mod_factor, output_mod_factor); + return; + } +#endif + +#ifdef HEXL_HAS_AVX512DQ + if (has_avx512dq && m_degree >= 16) { + if (m_q < s_max_fwd_32_modulus) { + HEXL_VLOG(3, "Calling 32-bit AVX512-DQ FwdNTT"); + const uint64_t* root_of_unity_powers = + GetAVX512RootOfUnityPowers().data(); + const uint64_t* precon_root_of_unity_powers = + GetAVX512Precon32RootOfUnityPowers().data(); + ForwardTransformToBitReverseAVX512<32>( + result, operand, m_degree, m_q, root_of_unity_powers, + precon_root_of_unity_powers, input_mod_factor, output_mod_factor); + } else { + HEXL_VLOG(3, "Calling 64-bit AVX512-DQ FwdNTT"); + const uint64_t* root_of_unity_powers = + GetAVX512RootOfUnityPowers().data(); + const uint64_t* precon_root_of_unity_powers = + GetAVX512Precon64RootOfUnityPowers().data(); + + ForwardTransformToBitReverseAVX512( + result, operand, m_degree, m_q, root_of_unity_powers, + precon_root_of_unity_powers, input_mod_factor, output_mod_factor); + } + return; + } +#endif + + HEXL_VLOG(3, "Calling ForwardTransformToBitReverseRadix2"); + const uint64_t* root_of_unity_powers = GetRootOfUnityPowers().data(); + const uint64_t* precon_root_of_unity_powers = + GetPrecon64RootOfUnityPowers().data(); + + ForwardTransformToBitReverseRadix2( + result, operand, m_degree, m_q, root_of_unity_powers, + precon_root_of_unity_powers, input_mod_factor, output_mod_factor); +} + +void NTT::ComputeInverse(uint64_t* result, const uint64_t* operand, + uint64_t input_mod_factor, + uint64_t output_mod_factor) { + HEXL_CHECK(result != nullptr, "result == nullptr"); + HEXL_CHECK(operand != nullptr, "operand == nullptr"); + HEXL_CHECK(input_mod_factor == 1 || input_mod_factor == 2, + "input_mod_factor must be 1 or 2; got " << input_mod_factor); + HEXL_CHECK(output_mod_factor == 1 || output_mod_factor == 2, + "output_mod_factor must be 1 or 2; got " << output_mod_factor); + HEXL_CHECK_BOUNDS(operand, m_degree, m_q * input_mod_factor, + "operand exceeds bound " << m_q * input_mod_factor); + +#ifdef HEXL_HAS_AVX512IFMA + if (has_avx512ifma && (m_q < s_max_inv_ifma_modulus) && (m_degree >= 16)) { + HEXL_VLOG(3, "Calling 52-bit AVX512-IFMA InvNTT"); + const uint64_t* inv_root_of_unity_powers = GetInvRootOfUnityPowers().data(); + const uint64_t* precon_inv_root_of_unity_powers = + GetPrecon52InvRootOfUnityPowers().data(); + InverseTransformFromBitReverseAVX512( + result, operand, m_degree, m_q, inv_root_of_unity_powers, + precon_inv_root_of_unity_powers, input_mod_factor, output_mod_factor); + return; + } +#endif + +#ifdef HEXL_HAS_AVX512DQ + if (has_avx512dq && m_degree >= 16) { + if (m_q < s_max_inv_32_modulus) { + HEXL_VLOG(3, "Calling 32-bit AVX512-DQ InvNTT"); + const uint64_t* inv_root_of_unity_powers = + GetInvRootOfUnityPowers().data(); + const uint64_t* precon_inv_root_of_unity_powers = + GetPrecon32InvRootOfUnityPowers().data(); + InverseTransformFromBitReverseAVX512<32>( + result, operand, m_degree, m_q, inv_root_of_unity_powers, + precon_inv_root_of_unity_powers, input_mod_factor, output_mod_factor); + } else { + HEXL_VLOG(3, "Calling 64-bit AVX512 InvNTT"); + const uint64_t* inv_root_of_unity_powers = + GetInvRootOfUnityPowers().data(); + const uint64_t* precon_inv_root_of_unity_powers = + GetPrecon64InvRootOfUnityPowers().data(); + + InverseTransformFromBitReverseAVX512( + result, operand, m_degree, m_q, inv_root_of_unity_powers, + precon_inv_root_of_unity_powers, input_mod_factor, output_mod_factor); + } + return; + } +#endif + + HEXL_VLOG(3, "Calling 64-bit default InvNTT"); + const uint64_t* inv_root_of_unity_powers = GetInvRootOfUnityPowers().data(); + const uint64_t* precon_inv_root_of_unity_powers = + GetPrecon64InvRootOfUnityPowers().data(); + InverseTransformFromBitReverseRadix2( + result, operand, m_degree, m_q, inv_root_of_unity_powers, + precon_inv_root_of_unity_powers, input_mod_factor, output_mod_factor); +} + +} // namespace hexl +} // namespace intel diff --git a/hexl_ser/ntt/ntt-internal.hpp b/hexl_ser/ntt/ntt-internal.hpp new file mode 100644 index 00000000..420dca10 --- /dev/null +++ b/hexl_ser/ntt/ntt-internal.hpp @@ -0,0 +1,125 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +#include + +#include "hexl/ntt/ntt.hpp" +#include "hexl/number-theory/number-theory.hpp" +#include "hexl/util/aligned-allocator.hpp" +#include "hexl/util/check.hpp" +#include "hexl/util/util.hpp" +#include "util/util-internal.hpp" + +namespace intel { +namespace hexl { + +/// @brief Radix-2 native C++ NTT implementation of the forward NTT +/// @param[out] result Output data. Overwritten with NTT output +/// @param[in] operand Input data. +/// @param[in] n Size of the transform, i.e. the polynomial degree. Must be a +/// power of two. +/// @param[in] modulus Prime modulus q. Must satisfy q == 1 mod 2n +/// @param[in] root_of_unity_powers Powers of 2n'th root of unity in F_q. In +/// bit-reversed order +/// @param[in] precon_root_of_unity_powers Pre-conditioned Powers of 2n'th root +/// of unity in F_q. In bit-reversed order. +/// @param[in] input_mod_factor Upper bound for inputs; inputs must be in [0, +/// input_mod_factor * q) +/// @param[in] output_mod_factor Upper bound for result; result must be in [0, +/// output_mod_factor * q) +void ForwardTransformToBitReverseRadix2( + uint64_t* result, const uint64_t* operand, uint64_t n, uint64_t modulus, + const uint64_t* root_of_unity_powers, + const uint64_t* precon_root_of_unity_powers, uint64_t input_mod_factor = 1, + uint64_t output_mod_factor = 1); + +/// @brief Radix-4 native C++ NTT implementation of the forward NTT +/// @param[out] result Output data. Overwritten with NTT output +/// @param[in] operand Input data. +/// @param[in] n Size of the transform, i.e. the polynomial degree. Must be a +/// power of two. +/// @param[in] modulus Prime modulus q. Must satisfy q == 1 mod 2n +/// @param[in] root_of_unity_powers Powers of 2n'th root of unity in F_q. In +/// bit-reversed order +/// @param[in] precon_root_of_unity_powers Pre-conditioned Powers of 2n'th root +/// of unity in F_q. In bit-reversed order. +/// @param[in] input_mod_factor Upper bound for inputs; inputs must be in [0, +/// input_mod_factor * q) +/// @param[in] output_mod_factor Upper bound for result; result must be in [0, +/// output_mod_factor * q) +void ForwardTransformToBitReverseRadix4( + uint64_t* result, const uint64_t* operand, uint64_t n, uint64_t modulus, + const uint64_t* root_of_unity_powers, + const uint64_t* precon_root_of_unity_powers, uint64_t input_mod_factor = 1, + uint64_t output_mod_factor = 1); + +/// @brief Reference forward NTT which is written for clarity rather than +/// performance +/// @param[in, out] operand Input data. Overwritten with NTT output +/// @param[in] n Size of the transform, i.e. the polynomial degree. Must be a +/// power of two. +/// @param[in] modulus Prime modulus. Must satisfy q == 1 mod 2n +/// @param[in] root_of_unity_powers Powers of 2n'th root of unity in F_q. In +/// bit-reversed order +void ReferenceForwardTransformToBitReverse( + uint64_t* operand, uint64_t n, uint64_t modulus, + const uint64_t* root_of_unity_powers); + +/// @brief Reference inverse NTT which is written for clarity rather than +/// performance +/// @param[in, out] operand Input data. Overwritten with NTT output +/// @param[in] n Size of the transform, i.e. the polynomial degree. Must be a +/// power of two. +/// @param[in] modulus Prime modulus. Must satisfy q == 1 mod 2n +/// @param[in] inv_root_of_unity_powers Powers of inverse 2n'th root of unity in +/// F_q. In bit-reversed order. +void ReferenceInverseTransformFromBitReverse( + uint64_t* operand, uint64_t n, uint64_t modulus, + const uint64_t* inv_root_of_unity_powers); + +/// @brief Radix-2 native C++ NTT implementation of the inverse NTT +/// @param[out] result Output data. Overwritten with NTT output +/// @param[in] operand Input data. +/// @param[in] n Size of the transform, i.e. the polynomial degree. Must be a +/// power of two. +/// @param[in] modulus Prime modulus q. Must satisfy q == 1 mod 2n +/// @param[in] inv_root_of_unity_powers Powers of inverse 2n'th root of unity in +/// F_q. In bit-reversed order. +/// @param[in] precon_root_of_unity_powers Pre-conditioned powers of inverse +/// 2n'th root of unity in F_q. In bit-reversed order. +/// @param[in] input_mod_factor Upper bound for inputs; inputs must be in [0, +/// input_mod_factor * q) +/// @param[in] output_mod_factor Upper bound for result; result must be in [0, +/// output_mod_factor * q) +void InverseTransformFromBitReverseRadix2( + uint64_t* result, const uint64_t* operand, uint64_t n, uint64_t modulus, + const uint64_t* inv_root_of_unity_powers, + const uint64_t* precon_inv_root_of_unity_powers, + uint64_t input_mod_factor = 1, uint64_t output_mod_factor = 1); + +/// @brief Radix-4 native C++ NTT implementation of the inverse NTT +/// @param[out] result Output data. Overwritten with NTT output +/// @param[in] operand Input data. +/// @param[in] n Size of the transform, i.e. the polynomial degree. Must be a +/// power of two. +/// @param[in] modulus Prime modulus q. Must satisfy q == 1 mod 2n +/// @param[in] root_of_unity_powers Powers of 2n'th root of unity in F_q. In +/// bit-reversed order +/// @param[in] precon_root_of_unity_powers Pre-conditioned Powers of 2n'th root +/// of unity in F_q. In bit-reversed order. +/// @param[in] input_mod_factor Upper bound for inputs; inputs must be in [0, +/// input_mod_factor * q) +/// @param[in] output_mod_factor Upper bound for result; result must be in [0, +/// output_mod_factor * q) +void InverseTransformFromBitReverseRadix4( + uint64_t* result, const uint64_t* operand, uint64_t n, uint64_t modulus, + const uint64_t* root_of_unity_powers, + const uint64_t* precon_root_of_unity_powers, uint64_t input_mod_factor = 1, + uint64_t output_mod_factor = 1); + +} // namespace hexl +} // namespace intel diff --git a/hexl_ser/ntt/ntt-radix-2.cpp b/hexl_ser/ntt/ntt-radix-2.cpp new file mode 100644 index 00000000..d0555d47 --- /dev/null +++ b/hexl_ser/ntt/ntt-radix-2.cpp @@ -0,0 +1,522 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include + +#include "hexl/logging/logging.hpp" +#include "hexl/ntt/ntt.hpp" +#include "hexl/number-theory/number-theory.hpp" +#include "hexl/util/aligned-allocator.hpp" +#include "hexl/util/check.hpp" +#include "ntt/ntt-default.hpp" +#include "util/cpu-features.hpp" + +namespace intel { +namespace hexl { + +void ForwardTransformToBitReverseRadix2( + uint64_t* result, const uint64_t* operand, uint64_t n, uint64_t modulus, + const uint64_t* root_of_unity_powers, + const uint64_t* precon_root_of_unity_powers, uint64_t input_mod_factor, + uint64_t output_mod_factor) { + HEXL_CHECK(NTT::CheckArguments(n, modulus), ""); + HEXL_CHECK_BOUNDS(operand, n, modulus * input_mod_factor, + "operand exceeds bound " << modulus * input_mod_factor); + HEXL_CHECK(root_of_unity_powers != nullptr, + "root_of_unity_powers == nullptr"); + HEXL_CHECK(precon_root_of_unity_powers != nullptr, + "precon_root_of_unity_powers == nullptr"); + HEXL_CHECK( + input_mod_factor == 1 || input_mod_factor == 2 || input_mod_factor == 4, + "input_mod_factor must be 1, 2, or 4; got " << input_mod_factor); + HEXL_UNUSED(input_mod_factor); + HEXL_CHECK(output_mod_factor == 1 || output_mod_factor == 4, + "output_mod_factor must be 1 or 4; got " << output_mod_factor); + + uint64_t twice_modulus = modulus << 1; + size_t t = (n >> 1); + + // In case of out-of-place operation do first pass and convert to in-place + { + const uint64_t W = root_of_unity_powers[1]; + const uint64_t W_precon = precon_root_of_unity_powers[1]; + + uint64_t* X_r = result; + uint64_t* Y_r = X_r + t; + + const uint64_t* X_op = operand; + const uint64_t* Y_op = X_op + t; + + // First pass for out-of-order case + switch (t) { + case 8: { + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + break; + } + case 4: { + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + break; + } + case 2: { + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + break; + } + case 1: { + FwdButterflyRadix2(X_r, Y_r, X_op, Y_op, W, W_precon, modulus, + twice_modulus); + break; + } + default: { + HEXL_LOOP_UNROLL_8 + for (size_t j = 0; j < t; j += 8) { + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + } + } + } + t >>= 1; + } + + // Continue with in-place operation + for (size_t m = 2; m < n; m <<= 1) { + size_t offset = 0; + switch (t) { + case 8: { + for (size_t i = 0; i < m; i++) { + if (i != 0) { + offset += (t << 1); + } + const uint64_t W = root_of_unity_powers[m + i]; + const uint64_t W_precon = precon_root_of_unity_powers[m + i]; + + uint64_t* X_r = result + offset; + uint64_t* Y_r = X_r + t; + const uint64_t* X_op = X_r; + const uint64_t* Y_op = Y_r; + + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + } + break; + } + case 4: { + for (size_t i = 0; i < m; i++) { + if (i != 0) { + offset += (t << 1); + } + const uint64_t W = root_of_unity_powers[m + i]; + const uint64_t W_precon = precon_root_of_unity_powers[m + i]; + + uint64_t* X_r = result + offset; + uint64_t* Y_r = X_r + t; + const uint64_t* X_op = X_r; + const uint64_t* Y_op = Y_r; + + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + } + break; + } + case 2: { + for (size_t i = 0; i < m; i++) { + if (i != 0) { + offset += (t << 1); + } + const uint64_t W = root_of_unity_powers[m + i]; + const uint64_t W_precon = precon_root_of_unity_powers[m + i]; + + uint64_t* X_r = result + offset; + uint64_t* Y_r = X_r + t; + const uint64_t* X_op = X_r; + const uint64_t* Y_op = Y_r; + + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + } + break; + } + case 1: { + for (size_t i = 0; i < m; i++) { + if (i != 0) { + offset += (t << 1); + } + const uint64_t W = root_of_unity_powers[m + i]; + const uint64_t W_precon = precon_root_of_unity_powers[m + i]; + + uint64_t* X_r = result + offset; + uint64_t* Y_r = X_r + t; + const uint64_t* X_op = X_r; + const uint64_t* Y_op = Y_r; + + FwdButterflyRadix2(X_r, Y_r, X_op, Y_op, W, W_precon, modulus, + twice_modulus); + } + break; + } + default: { + for (size_t i = 0; i < m; i++) { + if (i != 0) { + offset += (t << 1); + } + const uint64_t W = root_of_unity_powers[m + i]; + const uint64_t W_precon = precon_root_of_unity_powers[m + i]; + + uint64_t* X_r = result + offset; + uint64_t* Y_r = X_r + t; + const uint64_t* X_op = X_r; + const uint64_t* Y_op = Y_r; + + HEXL_LOOP_UNROLL_8 + for (size_t j = 0; j < t; j += 8) { + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, + modulus, twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, + modulus, twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, + modulus, twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, + modulus, twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, + modulus, twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, + modulus, twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, + modulus, twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, + modulus, twice_modulus); + } + } + } + } + t >>= 1; + } + if (output_mod_factor == 1) { + for (size_t i = 0; i < n; ++i) { + result[i] = ReduceMod<4>(result[i], modulus, &twice_modulus); + HEXL_CHECK(result[i] < modulus, "Incorrect modulus reduction in NTT " + << result[i] << " >= " << modulus); + } + } +} + +void ReferenceForwardTransformToBitReverse( + uint64_t* operand, uint64_t n, uint64_t modulus, + const uint64_t* root_of_unity_powers) { + HEXL_CHECK(NTT::CheckArguments(n, modulus), ""); + HEXL_CHECK(root_of_unity_powers != nullptr, + "root_of_unity_powers == nullptr"); + HEXL_CHECK(operand != nullptr, "operand == nullptr"); + + size_t t = (n >> 1); + for (size_t m = 1; m < n; m <<= 1) { + size_t offset = 0; + for (size_t i = 0; i < m; i++) { + size_t offset2 = offset + t; + const uint64_t W = root_of_unity_powers[m + i]; + + uint64_t* X = operand + offset; + uint64_t* Y = X + t; + for (size_t j = offset; j < offset2; j++) { + // X', Y' = X + WY, X - WY (mod q). + uint64_t tx = *X; + uint64_t W_x_Y = MultiplyMod(*Y, W, modulus); + *X++ = AddUIntMod(tx, W_x_Y, modulus); + *Y++ = SubUIntMod(tx, W_x_Y, modulus); + } + offset += (t << 1); + } + t >>= 1; + } +} + +void ReferenceInverseTransformFromBitReverse( + uint64_t* operand, uint64_t n, uint64_t modulus, + const uint64_t* inv_root_of_unity_powers) { + HEXL_CHECK(NTT::CheckArguments(n, modulus), ""); + HEXL_CHECK(inv_root_of_unity_powers != nullptr, + "inv_root_of_unity_powers == nullptr"); + HEXL_CHECK(operand != nullptr, "operand == nullptr"); + + size_t t = 1; + size_t root_index = 1; + for (size_t m = (n >> 1); m >= 1; m >>= 1) { + size_t offset = 0; + for (size_t i = 0; i < m; i++, root_index++) { + const uint64_t W = inv_root_of_unity_powers[root_index]; + uint64_t* X_r = operand + offset; + uint64_t* Y_r = X_r + t; + for (size_t j = 0; j < t; j++) { + uint64_t X_op = *X_r; + uint64_t Y_op = *Y_r; + // Butterfly X' = (X + Y) mod q, Y' = W(X-Y) mod q + *X_r = AddUIntMod(X_op, Y_op, modulus); + *Y_r = MultiplyMod(W, SubUIntMod(X_op, Y_op, modulus), modulus); + X_r++; + Y_r++; + } + offset += (t << 1); + } + t <<= 1; + } + + // Final multiplication by N^{-1} + const uint64_t inv_n = InverseMod(n, modulus); + for (size_t i = 0; i < n; ++i) { + operand[i] = MultiplyMod(operand[i], inv_n, modulus); + } +} + +void InverseTransformFromBitReverseRadix2( + uint64_t* result, const uint64_t* operand, uint64_t n, uint64_t modulus, + const uint64_t* inv_root_of_unity_powers, + const uint64_t* precon_inv_root_of_unity_powers, uint64_t input_mod_factor, + uint64_t output_mod_factor) { + HEXL_CHECK(NTT::CheckArguments(n, modulus), ""); + HEXL_CHECK(inv_root_of_unity_powers != nullptr, + "inv_root_of_unity_powers == nullptr"); + HEXL_CHECK(precon_inv_root_of_unity_powers != nullptr, + "precon_inv_root_of_unity_powers == nullptr"); + HEXL_CHECK(operand != nullptr, "operand == nullptr"); + HEXL_CHECK(input_mod_factor == 1 || input_mod_factor == 2, + "input_mod_factor must be 1 or 2; got " << input_mod_factor); + HEXL_UNUSED(input_mod_factor); + HEXL_CHECK(output_mod_factor == 1 || output_mod_factor == 2, + "output_mod_factor must be 1 or 2; got " << output_mod_factor); + + uint64_t twice_modulus = modulus << 1; + uint64_t n_div_2 = (n >> 1); + size_t t = 1; + size_t root_index = 1; + + for (size_t m = n_div_2; m > 1; m >>= 1) { + size_t offset = 0; + + switch (t) { + case 1: { + for (size_t i = 0; i < m; i++, root_index++) { + if (i != 0) { + offset += (t << 1); + } + const uint64_t W = inv_root_of_unity_powers[root_index]; + const uint64_t W_precon = precon_inv_root_of_unity_powers[root_index]; + + uint64_t* X_r = result + offset; + uint64_t* Y_r = X_r + t; + const uint64_t* X_op = operand + offset; + const uint64_t* Y_op = X_op + t; + InvButterflyRadix2(X_r, Y_r, X_op, Y_op, W, W_precon, modulus, + twice_modulus); + } + break; + } + case 2: { + for (size_t i = 0; i < m; i++, root_index++) { + if (i != 0) { + offset += (t << 1); + } + const uint64_t W = inv_root_of_unity_powers[root_index]; + const uint64_t W_precon = precon_inv_root_of_unity_powers[root_index]; + + uint64_t* X_r = result + offset; + uint64_t* Y_r = X_r + t; + const uint64_t* X_op = X_r; + const uint64_t* Y_op = Y_r; + InvButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + InvButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + } + break; + } + case 4: { + for (size_t i = 0; i < m; i++, root_index++) { + if (i != 0) { + offset += (t << 1); + } + const uint64_t W = inv_root_of_unity_powers[root_index]; + const uint64_t W_precon = precon_inv_root_of_unity_powers[root_index]; + + uint64_t* X_r = result + offset; + uint64_t* Y_r = X_r + t; + const uint64_t* X_op = X_r; + const uint64_t* Y_op = Y_r; + InvButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + InvButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + InvButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + InvButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + } + break; + } + case 8: { + for (size_t i = 0; i < m; i++, root_index++) { + if (i != 0) { + offset += (t << 1); + } + const uint64_t W = inv_root_of_unity_powers[root_index]; + const uint64_t W_precon = precon_inv_root_of_unity_powers[root_index]; + + uint64_t* X_r = result + offset; + uint64_t* Y_r = X_r + t; + const uint64_t* X_op = X_r; + const uint64_t* Y_op = Y_r; + InvButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + InvButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + InvButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + InvButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + InvButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + InvButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + InvButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + InvButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + } + break; + } + default: { + for (size_t i = 0; i < m; i++, root_index++) { + if (i != 0) { + offset += (t << 1); + } + const uint64_t W = inv_root_of_unity_powers[root_index]; + const uint64_t W_precon = precon_inv_root_of_unity_powers[root_index]; + + uint64_t* X_r = result + offset; + uint64_t* Y_r = X_r + t; + const uint64_t* X_op = X_r; + const uint64_t* Y_op = Y_r; + + HEXL_LOOP_UNROLL_8 + for (size_t j = 0; j < t; j += 8) { + InvButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, + modulus, twice_modulus); + InvButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, + modulus, twice_modulus); + InvButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, + modulus, twice_modulus); + InvButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, + modulus, twice_modulus); + InvButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, + modulus, twice_modulus); + InvButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, + modulus, twice_modulus); + InvButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, + modulus, twice_modulus); + InvButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, + modulus, twice_modulus); + } + } + } + } + t <<= 1; + } + + // When M is too short it only needs the final stage butterfly. Copying here + // in the case of out-of-place. + if (result != operand && n == 2) { + std::memcpy(result, operand, n * sizeof(uint64_t)); + } + + // Fold multiplication by N^{-1} to final stage butterfly + const uint64_t W = inv_root_of_unity_powers[n - 1]; + + const uint64_t inv_n = InverseMod(n, modulus); + uint64_t inv_n_precon = MultiplyFactor(inv_n, 64, modulus).BarrettFactor(); + const uint64_t inv_n_w = MultiplyMod(inv_n, W, modulus); + uint64_t inv_n_w_precon = + MultiplyFactor(inv_n_w, 64, modulus).BarrettFactor(); + + uint64_t* X = result; + uint64_t* Y = X + n_div_2; + for (size_t j = 0; j < n_div_2; ++j) { + // Assume X, Y in [0, 2q) and compute + // X' = N^{-1} (X + Y) (mod q) + // Y' = N^{-1} * W * (X - Y) (mod q) + uint64_t tx = AddUIntMod(X[j], Y[j], twice_modulus); + uint64_t ty = X[j] + twice_modulus - Y[j]; + X[j] = MultiplyModLazy<64>(tx, inv_n, inv_n_precon, modulus); + Y[j] = MultiplyModLazy<64>(ty, inv_n_w, inv_n_w_precon, modulus); + } + + if (output_mod_factor == 1) { + // Reduce from [0, 2q) to [0,q) + for (size_t i = 0; i < n; ++i) { + result[i] = ReduceMod<2>(result[i], modulus); + HEXL_CHECK(result[i] < modulus, "Incorrect modulus reduction in InvNTT" + << result[i] << " >= " << modulus); + } + } +} + +} // namespace hexl +} // namespace intel diff --git a/hexl_ser/ntt/ntt-radix-4.cpp b/hexl_ser/ntt/ntt-radix-4.cpp new file mode 100644 index 00000000..4f8e9b0d --- /dev/null +++ b/hexl_ser/ntt/ntt-radix-4.cpp @@ -0,0 +1,622 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include + +#include "hexl/logging/logging.hpp" +#include "hexl/ntt/ntt.hpp" +#include "hexl/number-theory/number-theory.hpp" +#include "hexl/util/aligned-allocator.hpp" +#include "hexl/util/check.hpp" +#include "ntt/ntt-default.hpp" +#include "util/cpu-features.hpp" + +namespace intel { +namespace hexl { + +void ForwardTransformToBitReverseRadix4( + uint64_t* result, const uint64_t* operand, uint64_t n, uint64_t modulus, + const uint64_t* root_of_unity_powers, + const uint64_t* precon_root_of_unity_powers, uint64_t input_mod_factor, + uint64_t output_mod_factor) { + HEXL_CHECK(NTT::CheckArguments(n, modulus), ""); + HEXL_CHECK_BOUNDS(operand, n, modulus * input_mod_factor, + "operand exceeds bound " << modulus * input_mod_factor); + HEXL_CHECK(root_of_unity_powers != nullptr, + "root_of_unity_powers == nullptr"); + HEXL_CHECK(precon_root_of_unity_powers != nullptr, + "precon_root_of_unity_powers == nullptr"); + HEXL_CHECK( + input_mod_factor == 1 || input_mod_factor == 2 || input_mod_factor == 4, + "input_mod_factor must be 1, 2, or 4; got " << input_mod_factor); + HEXL_UNUSED(input_mod_factor); + HEXL_CHECK(output_mod_factor == 1 || output_mod_factor == 4, + "output_mod_factor must be 1 or 4; got " << output_mod_factor); + + HEXL_VLOG(3, "modulus " << modulus); + HEXL_VLOG(3, "n " << n); + + HEXL_VLOG(3, "operand " << std::vector(operand, operand + n)); + + HEXL_VLOG(3, "root_of_unity_powers " << std::vector( + root_of_unity_powers, root_of_unity_powers + n)); + + bool is_power_of_4 = IsPowerOfFour(n); + + uint64_t twice_modulus = modulus << 1; + uint64_t four_times_modulus = modulus << 2; + + // Radix-2 step for non-powers of 4 + if (!is_power_of_4) { + HEXL_VLOG(3, "Radix 2 step"); + + size_t t = (n >> 1); + + const uint64_t W = root_of_unity_powers[1]; + const uint64_t W_precon = precon_root_of_unity_powers[1]; + + uint64_t* X_r = result; + uint64_t* Y_r = X_r + t; + const uint64_t* X_op = operand; + const uint64_t* Y_op = X_op + t; + + //HEXL_LOOP_UNROLL_8 + for (size_t j = 0; j < t; j++) { + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + } + // Data in [0, 4q) + HEXL_VLOG(3, "after radix 2 outputs " + << std::vector(result, result + n)); + } + + uint64_t m_start = 2; + size_t t = n >> 3; + if (is_power_of_4) { + t = n >> 2; + + uint64_t* X_r0 = result; + uint64_t* X_r1 = X_r0 + t; + uint64_t* X_r2 = X_r0 + 2 * t; + uint64_t* X_r3 = X_r0 + 3 * t; + const uint64_t* X_op0 = operand; + const uint64_t* X_op1 = operand + t; + const uint64_t* X_op2 = operand + 2 * t; + const uint64_t* X_op3 = operand + 3 * t; + + uint64_t W1_ind = 1; + uint64_t W2_ind = 2 * W1_ind; + uint64_t W3_ind = 2 * W1_ind + 1; + + const uint64_t W1 = root_of_unity_powers[W1_ind]; + const uint64_t W2 = root_of_unity_powers[W2_ind]; + const uint64_t W3 = root_of_unity_powers[W3_ind]; + + const uint64_t W1_precon = precon_root_of_unity_powers[W1_ind]; + const uint64_t W2_precon = precon_root_of_unity_powers[W2_ind]; + const uint64_t W3_precon = precon_root_of_unity_powers[W3_ind]; + + switch (t) { + case 4: { + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, W3, + W3_precon, modulus, twice_modulus, + four_times_modulus); + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, W3, + W3_precon, modulus, twice_modulus, + four_times_modulus); + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, W3, + W3_precon, modulus, twice_modulus, + four_times_modulus); + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, W3, + W3_precon, modulus, twice_modulus, + four_times_modulus); + break; + } + case 1: { + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, W3, + W3_precon, modulus, twice_modulus, + four_times_modulus); + break; + } + default: { + for (size_t j = 0; j < t; j += 16) { + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, W3, + W3_precon, modulus, twice_modulus, + four_times_modulus); + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, W3, + W3_precon, modulus, twice_modulus, + four_times_modulus); + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, W3, + W3_precon, modulus, twice_modulus, + four_times_modulus); + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, W3, + W3_precon, modulus, twice_modulus, + four_times_modulus); + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, W3, + W3_precon, modulus, twice_modulus, + four_times_modulus); + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, W3, + W3_precon, modulus, twice_modulus, + four_times_modulus); + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, W3, + W3_precon, modulus, twice_modulus, + four_times_modulus); + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, W3, + W3_precon, modulus, twice_modulus, + four_times_modulus); + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, W3, + W3_precon, modulus, twice_modulus, + four_times_modulus); + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, W3, + W3_precon, modulus, twice_modulus, + four_times_modulus); + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, W3, + W3_precon, modulus, twice_modulus, + four_times_modulus); + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, W3, + W3_precon, modulus, twice_modulus, + four_times_modulus); + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, W3, + W3_precon, modulus, twice_modulus, + four_times_modulus); + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, W3, + W3_precon, modulus, twice_modulus, + four_times_modulus); + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, W3, + W3_precon, modulus, twice_modulus, + four_times_modulus); + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, W3, + W3_precon, modulus, twice_modulus, + four_times_modulus); + } + } + } + t >>= 2; + m_start = 4; + } + + // uint64_t m_start = is_power_of_4 ? 1 : 2; + // size_t t = (n >> m_start) >> 1; + + for (size_t m = m_start; m < n; m <<= 2) { + HEXL_VLOG(3, "m " << m); + + size_t X0_offset = 0; + + switch (t) { + case 4: { + //HEXL_LOOP_UNROLL_8 + for (size_t i = 0; i < m; i++) { + if (i != 0) { + X0_offset += 4 * t; + } + uint64_t* X_r0 = result + X0_offset; + uint64_t* X_r1 = X_r0 + t; + uint64_t* X_r2 = X_r0 + 2 * t; + uint64_t* X_r3 = X_r0 + 3 * t; + const uint64_t* X_op0 = X_r0; + const uint64_t* X_op1 = X_r1; + const uint64_t* X_op2 = X_r2; + const uint64_t* X_op3 = X_r3; + + uint64_t W1_ind = m + i; + uint64_t W2_ind = 2 * W1_ind; + uint64_t W3_ind = 2 * W1_ind + 1; + + const uint64_t W1 = root_of_unity_powers[W1_ind]; + const uint64_t W2 = root_of_unity_powers[W2_ind]; + const uint64_t W3 = root_of_unity_powers[W3_ind]; + + const uint64_t W1_precon = precon_root_of_unity_powers[W1_ind]; + const uint64_t W2_precon = precon_root_of_unity_powers[W2_ind]; + const uint64_t W3_precon = precon_root_of_unity_powers[W3_ind]; + + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, W3, + W3_precon, modulus, twice_modulus, + four_times_modulus); + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, W3, + W3_precon, modulus, twice_modulus, + four_times_modulus); + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, W3, + W3_precon, modulus, twice_modulus, + four_times_modulus); + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, W3, + W3_precon, modulus, twice_modulus, + four_times_modulus); + } + break; + } + case 1: { + //HEXL_LOOP_UNROLL_8 + for (size_t i = 0; i < m; i++) { + if (i != 0) { + X0_offset += 4 * t; + } + uint64_t* X_r0 = result + X0_offset; + uint64_t* X_r1 = X_r0 + t; + uint64_t* X_r2 = X_r0 + 2 * t; + uint64_t* X_r3 = X_r0 + 3 * t; + const uint64_t* X_op0 = X_r0; + const uint64_t* X_op1 = X_r1; + const uint64_t* X_op2 = X_r2; + const uint64_t* X_op3 = X_r3; + + uint64_t W1_ind = m + i; + uint64_t W2_ind = 2 * W1_ind; + uint64_t W3_ind = 2 * W1_ind + 1; + + const uint64_t W1 = root_of_unity_powers[W1_ind]; + const uint64_t W2 = root_of_unity_powers[W2_ind]; + const uint64_t W3 = root_of_unity_powers[W3_ind]; + + const uint64_t W1_precon = precon_root_of_unity_powers[W1_ind]; + const uint64_t W2_precon = precon_root_of_unity_powers[W2_ind]; + const uint64_t W3_precon = precon_root_of_unity_powers[W3_ind]; + + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, W3, + W3_precon, modulus, twice_modulus, + four_times_modulus); + } + break; + } + default: { + for (size_t i = 0; i < m; i++) { + if (i != 0) { + X0_offset += 4 * t; + } + uint64_t* X_r0 = result + X0_offset; + uint64_t* X_r1 = X_r0 + t; + uint64_t* X_r2 = X_r0 + 2 * t; + uint64_t* X_r3 = X_r0 + 3 * t; + const uint64_t* X_op0 = X_r0; + const uint64_t* X_op1 = X_r1; + const uint64_t* X_op2 = X_r2; + const uint64_t* X_op3 = X_r3; + + uint64_t W1_ind = m + i; + uint64_t W2_ind = 2 * W1_ind; + uint64_t W3_ind = 2 * W1_ind + 1; + + const uint64_t W1 = root_of_unity_powers[W1_ind]; + const uint64_t W2 = root_of_unity_powers[W2_ind]; + const uint64_t W3 = root_of_unity_powers[W3_ind]; + + const uint64_t W1_precon = precon_root_of_unity_powers[W1_ind]; + const uint64_t W2_precon = precon_root_of_unity_powers[W2_ind]; + const uint64_t W3_precon = precon_root_of_unity_powers[W3_ind]; + + for (size_t j = 0; j < t; j += 16) { + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, + W3, W3_precon, modulus, twice_modulus, + four_times_modulus); + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, + W3, W3_precon, modulus, twice_modulus, + four_times_modulus); + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, + W3, W3_precon, modulus, twice_modulus, + four_times_modulus); + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, + W3, W3_precon, modulus, twice_modulus, + four_times_modulus); + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, + W3, W3_precon, modulus, twice_modulus, + four_times_modulus); + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, + W3, W3_precon, modulus, twice_modulus, + four_times_modulus); + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, + W3, W3_precon, modulus, twice_modulus, + four_times_modulus); + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, + W3, W3_precon, modulus, twice_modulus, + four_times_modulus); + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, + W3, W3_precon, modulus, twice_modulus, + four_times_modulus); + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, + W3, W3_precon, modulus, twice_modulus, + four_times_modulus); + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, + W3, W3_precon, modulus, twice_modulus, + four_times_modulus); + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, + W3, W3_precon, modulus, twice_modulus, + four_times_modulus); + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, + W3, W3_precon, modulus, twice_modulus, + four_times_modulus); + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, + W3, W3_precon, modulus, twice_modulus, + four_times_modulus); + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, + W3, W3_precon, modulus, twice_modulus, + four_times_modulus); + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, + W3, W3_precon, modulus, twice_modulus, + four_times_modulus); + } + } + } + } + t >>= 2; + } + + if (output_mod_factor == 1) { + for (size_t i = 0; i < n; ++i) { + if (result[i] >= twice_modulus) { + result[i] -= twice_modulus; + } + if (result[i] >= modulus) { + result[i] -= modulus; + } + HEXL_CHECK(result[i] < modulus, "Incorrect modulus reduction in NTT " + << result[i] << " >= " << modulus); + } + } + + HEXL_VLOG(3, "outputs " << std::vector(result, result + n)); +} + +void InverseTransformFromBitReverseRadix4( + uint64_t* result, const uint64_t* operand, uint64_t n, uint64_t modulus, + const uint64_t* inv_root_of_unity_powers, + const uint64_t* precon_inv_root_of_unity_powers, uint64_t input_mod_factor, + uint64_t output_mod_factor) { + HEXL_CHECK(NTT::CheckArguments(n, modulus), ""); + HEXL_CHECK(inv_root_of_unity_powers != nullptr, + "inv_root_of_unity_powers == nullptr"); + HEXL_CHECK(precon_inv_root_of_unity_powers != nullptr, + "precon_inv_root_of_unity_powers == nullptr"); + HEXL_CHECK(operand != nullptr, "operand == nullptr"); + HEXL_CHECK(input_mod_factor == 1 || input_mod_factor == 2, + "input_mod_factor must be 1 or 2; got " << input_mod_factor); + HEXL_UNUSED(input_mod_factor); + HEXL_CHECK(output_mod_factor == 1 || output_mod_factor == 2, + "output_mod_factor must be 1 or 2; got " << output_mod_factor); + + uint64_t twice_modulus = modulus << 1; + uint64_t n_div_2 = (n >> 1); + + bool is_power_of_4 = IsPowerOfFour(n); + // Radix-2 step for powers of 4 + if (is_power_of_4) { + uint64_t* X_r = result; + uint64_t* Y_r = X_r + 1; + const uint64_t* X_op = operand; + const uint64_t* Y_op = X_op + 1; + const uint64_t* W = inv_root_of_unity_powers + 1; + const uint64_t* W_precon = precon_inv_root_of_unity_powers + 1; + + //HEXL_LOOP_UNROLL_8 + for (size_t j = 0; j < n / 2; j++) { + InvButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, *W++, *W_precon++, + modulus, twice_modulus); + X_r++; + Y_r++; + X_op++; + Y_op++; + } + // Data in [0, 2q) + } + + uint64_t m_start = n >> (is_power_of_4 ? 3 : 2); + size_t t = is_power_of_4 ? 2 : 1; + + size_t w1_root_index = 1 + (is_power_of_4 ? n_div_2 : 0); + size_t w3_root_index = n_div_2 + 1 + (is_power_of_4 ? (n / 4) : 0); + + HEXL_VLOG(4, "m_start " << m_start); + + for (size_t m = m_start; m > 0; m >>= 2) { + HEXL_VLOG(4, "m " << m); + HEXL_VLOG(4, "t " << t); + + size_t X0_offset = 0; + + switch (t) { + case 1: { + for (size_t i = 0; i < m; i++) { + HEXL_VLOG(4, "i " << i); + if (i != 0) { + X0_offset += 4 * t; + } + + uint64_t* X_r0 = result + X0_offset; + uint64_t* X_r1 = X_r0 + t; + uint64_t* X_r2 = X_r0 + 2 * t; + uint64_t* X_r3 = X_r0 + 3 * t; + const uint64_t* X_op0 = operand + X0_offset; + const uint64_t* X_op1 = X_op0 + t; + const uint64_t* X_op2 = X_op0 + 2 * t; + const uint64_t* X_op3 = X_op0 + 3 * t; + + uint64_t W1_ind = w1_root_index++; + uint64_t W2_ind = w1_root_index++; + uint64_t W3_ind = w3_root_index++; + + const uint64_t W1 = inv_root_of_unity_powers[W1_ind]; + const uint64_t W2 = inv_root_of_unity_powers[W2_ind]; + const uint64_t W3 = inv_root_of_unity_powers[W3_ind]; + + const uint64_t W1_precon = precon_inv_root_of_unity_powers[W1_ind]; + const uint64_t W2_precon = precon_inv_root_of_unity_powers[W2_ind]; + const uint64_t W3_precon = precon_inv_root_of_unity_powers[W3_ind]; + + InvButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, W3, + W3_precon, modulus, twice_modulus); + } + break; + } + case 4: { + for (size_t i = 0; i < m; i++) { + HEXL_VLOG(4, "i " << i); + if (i != 0) { + X0_offset += 4 * t; + } + + uint64_t* X_r0 = result + X0_offset; + uint64_t* X_r1 = X_r0 + t; + uint64_t* X_r2 = X_r0 + 2 * t; + uint64_t* X_r3 = X_r0 + 3 * t; + const uint64_t* X_op0 = X_r0; + const uint64_t* X_op1 = X_r1; + const uint64_t* X_op2 = X_r2; + const uint64_t* X_op3 = X_r3; + + uint64_t W1_ind = w1_root_index++; + uint64_t W2_ind = w1_root_index++; + uint64_t W3_ind = w3_root_index++; + + const uint64_t W1 = inv_root_of_unity_powers[W1_ind]; + const uint64_t W2 = inv_root_of_unity_powers[W2_ind]; + const uint64_t W3 = inv_root_of_unity_powers[W3_ind]; + + const uint64_t W1_precon = precon_inv_root_of_unity_powers[W1_ind]; + const uint64_t W2_precon = precon_inv_root_of_unity_powers[W2_ind]; + const uint64_t W3_precon = precon_inv_root_of_unity_powers[W3_ind]; + + InvButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, W3, + W3_precon, modulus, twice_modulus); + InvButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, W3, + W3_precon, modulus, twice_modulus); + InvButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, W3, + W3_precon, modulus, twice_modulus); + InvButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, W3, + W3_precon, modulus, twice_modulus); + } + break; + } + default: { + //HEXL_LOOP_UNROLL_4 + for (size_t i = 0; i < m; i++) { + HEXL_VLOG(4, "i " << i); + if (i != 0) { + X0_offset += 4 * t; + } + + uint64_t* X_r0 = result + X0_offset; + uint64_t* X_r1 = X_r0 + t; + uint64_t* X_r2 = X_r0 + 2 * t; + uint64_t* X_r3 = X_r0 + 3 * t; + const uint64_t* X_op0 = X_r0; + const uint64_t* X_op1 = X_r1; + const uint64_t* X_op2 = X_r2; + const uint64_t* X_op3 = X_r3; + + uint64_t W1_ind = w1_root_index++; + uint64_t W2_ind = w1_root_index++; + uint64_t W3_ind = w3_root_index++; + + const uint64_t W1 = inv_root_of_unity_powers[W1_ind]; + const uint64_t W2 = inv_root_of_unity_powers[W2_ind]; + const uint64_t W3 = inv_root_of_unity_powers[W3_ind]; + + const uint64_t W1_precon = precon_inv_root_of_unity_powers[W1_ind]; + const uint64_t W2_precon = precon_inv_root_of_unity_powers[W2_ind]; + const uint64_t W3_precon = precon_inv_root_of_unity_powers[W3_ind]; + + for (size_t j = 0; j < t; j++) { + HEXL_VLOG(4, "j " << j); + InvButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, + W3, W3_precon, modulus, twice_modulus); + } + } + } + } + t <<= 2; + w1_root_index += m; + w3_root_index += m / 2; + } + + // When M is too short it only needs the final stage butterfly. Copying here + // in the case of out-of-place. + if (result != operand && n == 2) { + std::memcpy(result, operand, n * sizeof(uint64_t)); + } + + HEXL_VLOG(4, "Starting final invNTT stage"); + HEXL_VLOG(4, "operand " << std::vector(result, result + n)); + + // Fold multiplication by N^{-1} to final stage butterfly + const uint64_t W = inv_root_of_unity_powers[n - 1]; + HEXL_VLOG(4, "final W " << W); + + const uint64_t inv_n = InverseMod(n, modulus); + uint64_t inv_n_precon = MultiplyFactor(inv_n, 64, modulus).BarrettFactor(); + const uint64_t inv_n_w = MultiplyMod(inv_n, W, modulus); + uint64_t inv_n_w_precon = + MultiplyFactor(inv_n_w, 64, modulus).BarrettFactor(); + + uint64_t* X = result; + uint64_t* Y = X + n_div_2; + for (size_t j = 0; j < n_div_2; ++j) { + // Assume X, Y in [0, 2q) and compute + // X' = N^{-1} (X + Y) (mod q) + // Y' = N^{-1} * W * (X - Y) (mod q) + // with X', Y' in [0, 2q) + uint64_t tx = AddUIntMod(X[j], Y[j], twice_modulus); + uint64_t ty = X[j] + twice_modulus - Y[j]; + X[j] = MultiplyModLazy<64>(tx, inv_n, inv_n_precon, modulus); + Y[j] = MultiplyModLazy<64>(ty, inv_n_w, inv_n_w_precon, modulus); + } + + if (output_mod_factor == 1) { + // Reduce from [0, 2q) to [0,q) + for (size_t i = 0; i < n; ++i) { + result[i] = ReduceMod<2>(result[i], modulus); + HEXL_CHECK(result[i] < modulus, "Incorrect modulus reduction in InvNTT" + << result[i] << " >= " << modulus); + } + } +} + +} // namespace hexl +} // namespace intel diff --git a/hexl_ser/number-theory/number-theory.cpp b/hexl_ser/number-theory/number-theory.cpp new file mode 100644 index 00000000..d8d6f079 --- /dev/null +++ b/hexl_ser/number-theory/number-theory.cpp @@ -0,0 +1,264 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include "hexl/number-theory/number-theory.hpp" + +#include "hexl/logging/logging.hpp" +#include "hexl/util/check.hpp" +#include "util/util-internal.hpp" + +namespace intel { +namespace hexl { + +uint64_t InverseMod(uint64_t input, uint64_t modulus) { + uint64_t a = input % modulus; + HEXL_CHECK(a != 0, input << " does not have a InverseMod"); + + if (modulus == 1) { + return 0; + } + + int64_t m0 = static_cast(modulus); + int64_t y = 0; + int64_t x = 1; + while (a > 1) { + // q is quotient + int64_t q = static_cast(a / modulus); + + int64_t t = static_cast(modulus); + modulus = a % modulus; + a = static_cast(t); + + // Update y and x + t = y; + y = x - q * y; + x = t; + } + + // Make x positive + if (x < 0) x += m0; + + return uint64_t(x); +} + +uint64_t MultiplyMod(uint64_t x, uint64_t y, uint64_t modulus) { + HEXL_CHECK(modulus != 0, "modulus == 0"); + HEXL_CHECK(x < modulus, "x " << x << " >= modulus " << modulus); + HEXL_CHECK(y < modulus, "y " << y << " >= modulus " << modulus); + uint64_t prod_hi, prod_lo; + MultiplyUInt64(x, y, &prod_hi, &prod_lo); + + return BarrettReduce128(prod_hi, prod_lo, modulus); +} + +uint64_t MultiplyMod(uint64_t x, uint64_t y, uint64_t y_precon, + uint64_t modulus) { + uint64_t q = MultiplyUInt64Hi<64>(x, y_precon); + q = x * y - q * modulus; + return q >= modulus ? q - modulus : q; +} + +uint64_t AddUIntMod(uint64_t x, uint64_t y, uint64_t modulus) { + HEXL_CHECK(x < modulus, "x " << x << " >= modulus " << modulus); + HEXL_CHECK(y < modulus, "y " << y << " >= modulus " << modulus); + uint64_t sum = x + y; + return (sum >= modulus) ? (sum - modulus) : sum; +} + +uint64_t SubUIntMod(uint64_t x, uint64_t y, uint64_t modulus) { + HEXL_CHECK(x < modulus, "x " << x << " >= modulus " << modulus); + HEXL_CHECK(y < modulus, "y " << y << " >= modulus " << modulus); + uint64_t diff = (x + modulus) - y; + return (diff >= modulus) ? (diff - modulus) : diff; +} + +// Returns base^exp mod modulus +uint64_t PowMod(uint64_t base, uint64_t exp, uint64_t modulus) { + base %= modulus; + uint64_t result = 1; + while (exp > 0) { + if (exp & 1) { + result = MultiplyMod(result, base, modulus); + } + base = MultiplyMod(base, base, modulus); + exp >>= 1; + } + return result; +} + +// Returns true whether root is a degree-th root of unity +// degree must be a power of two. +bool IsPrimitiveRoot(uint64_t root, uint64_t degree, uint64_t modulus) { + if (root == 0) { + return false; + } + HEXL_CHECK(IsPowerOfTwo(degree), degree << " not a power of 2"); + + HEXL_VLOG(4, "IsPrimitiveRoot root " << root << ", degree " << degree + << ", modulus " << modulus); + + // Check if root^(degree/2) == -1 mod modulus + return PowMod(root, degree / 2, modulus) == (modulus - 1); +} + +// Tries to return a primitive degree-th root of unity +// throw error if no root is found +uint64_t GeneratePrimitiveRoot(uint64_t degree, uint64_t modulus) { + // We need to divide modulus-1 by degree to get the size of the quotient group + uint64_t size_entire_group = modulus - 1; + + // Compute size of quotient group + uint64_t size_quotient_group = size_entire_group / degree; + + for (int trial = 0; trial < 200; ++trial) { + uint64_t root = GenerateInsecureUniformIntRandomValue(0, modulus); + root = PowMod(root, size_quotient_group, modulus); + + if (IsPrimitiveRoot(root, degree, modulus)) { + return root; + } + } + HEXL_CHECK(false, "no primitive root found for degree " + << degree << " modulus " << modulus); + return 0; +} + +// Returns true whether root is a degree-th root of unity +// degree must be a power of two. +uint64_t MinimalPrimitiveRoot(uint64_t degree, uint64_t modulus) { + HEXL_CHECK(IsPowerOfTwo(degree), + "Degere " << degree << " is not a power of 2"); + + uint64_t root = GeneratePrimitiveRoot(degree, modulus); + + uint64_t generator_sq = MultiplyMod(root, root, modulus); + uint64_t current_generator = root; + + uint64_t min_root = root; + + // Check if root^(degree/2) == -1 mod modulus + for (size_t i = 0; i < degree; ++i) { + if (current_generator < min_root) { + min_root = current_generator; + } + current_generator = MultiplyMod(current_generator, generator_sq, modulus); + } + + return min_root; +} + +uint64_t ReverseBits(uint64_t x, uint64_t bit_width) { + HEXL_CHECK(x == 0 || MSB(x) <= bit_width, "MSB(" << x << ") = " << MSB(x) + << " must be >= bit_width " + << bit_width) + if (bit_width == 0) { + return 0; + } + uint64_t rev = 0; + for (uint64_t i = bit_width; i > 0; i--) { + rev |= ((x & 1) << (i - 1)); + x >>= 1; + } + return rev; +} + +// Miller-Rabin primality test +bool IsPrime(uint64_t n) { + // n < 2^64, so it is enough to test a=2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, + // and 37. See + // https://en.wikipedia.org/wiki/Miller%E2%80%93Rabin_primality_test#Testing_against_small_sets_of_bases + static const std::vector as{2, 3, 5, 7, 11, 13, + 17, 19, 23, 29, 31, 37}; + + for (const uint64_t a : as) { + if (n == a) return true; + if (n % a == 0) return false; + } + + // Write n == 2**r * d + 1 with d odd. + uint64_t r = 63; + while (r > 0) { + uint64_t two_pow_r = (1ULL << r); + if ((n - 1) % two_pow_r == 0) { + break; + } + --r; + } + HEXL_CHECK(r != 0, "Error factoring n " << n); + uint64_t d = (n - 1) / (1ULL << r); + + HEXL_CHECK(n == (1ULL << r) * d + 1, "Error factoring n " << n); + HEXL_CHECK(d % 2 == 1, "d is even"); + + for (const uint64_t a : as) { + uint64_t x = PowMod(a, d, n); + if ((x == 1) || (x == n - 1)) { + continue; + } + + bool prime = false; + for (uint64_t i = 1; i < r; ++i) { + x = PowMod(x, 2, n); + if (x == n - 1) { + prime = true; + break; + } + } + if (!prime) { + return false; + } + } + return true; +} + +std::vector GeneratePrimes(size_t num_primes, size_t bit_size, + bool prefer_small_primes, + size_t ntt_size) { + HEXL_CHECK(num_primes > 0, "num_primes == 0"); + HEXL_CHECK(IsPowerOfTwo(ntt_size), + "ntt_size " << ntt_size << " is not a power of two"); + HEXL_CHECK(Log2(ntt_size) < bit_size, + "log2(ntt_size) " << Log2(ntt_size) + << " should be less than bit_size " << bit_size); + + int64_t prime_lower_bound = (1LL << bit_size) + 1LL; + int64_t prime_upper_bound = (1LL << (bit_size + 1LL)) - 1LL; + + // Keep signed to enable negative step + int64_t prime_candidate = + prefer_small_primes + ? prime_lower_bound + : prime_upper_bound - (prime_upper_bound % (2 * ntt_size)) + 1; + HEXL_CHECK(prime_candidate % (2 * ntt_size) == 1, "bad prime candidate"); + + // Ensure prime % 2 * ntt_size == 1 + int64_t prime_candidate_step = + (prefer_small_primes ? 1 : -1) * 2 * static_cast(ntt_size); + + auto continue_condition = [&](int64_t local_candidate_prime) { + if (prefer_small_primes) { + return local_candidate_prime < prime_upper_bound; + } else { + return local_candidate_prime > prime_lower_bound; + } + }; + + std::vector ret; + + while (continue_condition(prime_candidate)) { + if (IsPrime(prime_candidate)) { + HEXL_CHECK(prime_candidate % (2 * ntt_size) == 1, "bad prime candidate"); + ret.emplace_back(static_cast(prime_candidate)); + if (ret.size() == num_primes) { + return ret; + } + } + prime_candidate += prime_candidate_step; + } + + HEXL_CHECK(false, "Failed to find enough primes"); + return ret; +} + +} // namespace hexl +} // namespace intel diff --git a/hexl_ser/util/avx512-util.hpp b/hexl_ser/util/avx512-util.hpp new file mode 100644 index 00000000..c5e206a9 --- /dev/null +++ b/hexl_ser/util/avx512-util.hpp @@ -0,0 +1,523 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +#include + +#include "hexl/logging/logging.hpp" +#include "hexl/number-theory/number-theory.hpp" +#include "hexl/util/check.hpp" +#include "hexl/util/defines.hpp" +#include "hexl/util/util.hpp" + +namespace intel { +namespace hexl { + +#ifdef HEXL_HAS_AVX512DQ + +/// @brief Returns the unsigned 64-bit integer values in x as a vector +inline std::vector ExtractValues(__m512i x) { + __m256i x0 = _mm512_extracti64x4_epi64(x, 0); + __m256i x1 = _mm512_extracti64x4_epi64(x, 1); + + std::vector xs{static_cast(_mm256_extract_epi64(x0, 0)), + static_cast(_mm256_extract_epi64(x0, 1)), + static_cast(_mm256_extract_epi64(x0, 2)), + static_cast(_mm256_extract_epi64(x0, 3)), + static_cast(_mm256_extract_epi64(x1, 0)), + static_cast(_mm256_extract_epi64(x1, 1)), + static_cast(_mm256_extract_epi64(x1, 2)), + static_cast(_mm256_extract_epi64(x1, 3))}; + + return xs; +} + +/// @brief Returns the signed 64-bit integer values in x as a vector +inline std::vector ExtractIntValues(__m512i x) { + __m256i x0 = _mm512_extracti64x4_epi64(x, 0); + __m256i x1 = _mm512_extracti64x4_epi64(x, 1); + + std::vector xs{static_cast(_mm256_extract_epi64(x0, 0)), + static_cast(_mm256_extract_epi64(x0, 1)), + static_cast(_mm256_extract_epi64(x0, 2)), + static_cast(_mm256_extract_epi64(x0, 3)), + static_cast(_mm256_extract_epi64(x1, 0)), + static_cast(_mm256_extract_epi64(x1, 1)), + static_cast(_mm256_extract_epi64(x1, 2)), + static_cast(_mm256_extract_epi64(x1, 3))}; + + return xs; +} + +// Returns the 64-bit floating-point values in x as a vector +inline std::vector ExtractValues(__m512d x) { + const double* x_ptr = reinterpret_cast(&x); + return std::vector{x_ptr, x_ptr + 8}; +} + +// Returns lower NumBits bits from a 64-bit value +template +inline __m512i ClearTopBits64(__m512i x) { + const __m512i low52b_mask = _mm512_set1_epi64((1ULL << NumBits) - 1); + return _mm512_and_epi64(x, low52b_mask); +} + +// Multiply packed unsigned BitShift-bit integers in each 64-bit element of x +// and y to form a 2*BitShift-bit intermediate result. +// Returns the high BitShift-bit unsigned integer from the intermediate result +template +inline __m512i _mm512_hexl_mulhi_epi(__m512i x, __m512i y); + +// Dummy implementation to avoid template substitution errors +template <> +inline __m512i _mm512_hexl_mulhi_epi<32>(__m512i x, __m512i y) { + HEXL_CHECK(false, "Unimplemented"); + HEXL_UNUSED(x); + HEXL_UNUSED(y); + return x; +} + +template <> +inline __m512i _mm512_hexl_mulhi_epi<64>(__m512i x, __m512i y) { + // https://stackoverflow.com/questions/28807341/simd-signed-with-unsigned-multiplication-for-64-bit-64-bit-to-128-bit + __m512i lo_mask = _mm512_set1_epi64(0x00000000ffffffff); + // Shuffle high bits with low bits in each 64-bit integer => + // x0_lo, x0_hi, x1_lo, x1_hi, x2_lo, x2_hi, ... + __m512i x_hi = _mm512_shuffle_epi32(x, (_MM_PERM_ENUM)0xB1); + // y0_lo, y0_hi, y1_lo, y1_hi, y2_lo, y2_hi, ... + __m512i y_hi = _mm512_shuffle_epi32(y, (_MM_PERM_ENUM)0xB1); + __m512i z_lo_lo = _mm512_mul_epu32(x, y); // x_lo * y_lo + __m512i z_lo_hi = _mm512_mul_epu32(x, y_hi); // x_lo * y_hi + __m512i z_hi_lo = _mm512_mul_epu32(x_hi, y); // x_hi * y_lo + __m512i z_hi_hi = _mm512_mul_epu32(x_hi, y_hi); // x_hi * y_hi + + // x_hi | x_lo + // x y_hi | y_lo + // ------------------------------ + // [x_lo * y_lo] // z_lo_lo + // + [z_lo * y_hi] // z_lo_hi + // + [x_hi * y_lo] // z_hi_lo + // + [x_hi * y_hi] // z_hi_hi + // ^-----------^ <-- only bits needed + // sum_| hi | mid | lo | + + // Low bits of z_lo_lo are not needed + __m512i z_lo_lo_shift = _mm512_srli_epi64(z_lo_lo, 32); + + // [x_lo * y_lo] // z_lo_lo + // + [z_lo * y_hi] // z_lo_hi + // ------------------------ + // | sum_tmp | + // |sum_mid|sum_lo| + __m512i sum_tmp = _mm512_add_epi64(z_lo_hi, z_lo_lo_shift); + __m512i sum_lo = _mm512_and_si512(sum_tmp, lo_mask); + __m512i sum_mid = _mm512_srli_epi64(sum_tmp, 32); + // | |sum_lo| + // + [x_hi * y_lo] // z_hi_lo + // ------------------ + // [ sum_mid2 ] + __m512i sum_mid2 = _mm512_add_epi64(z_hi_lo, sum_lo); + __m512i sum_mid2_hi = _mm512_srli_epi64(sum_mid2, 32); + __m512i sum_hi = _mm512_add_epi64(z_hi_hi, sum_mid); + return _mm512_add_epi64(sum_hi, sum_mid2_hi); +} + +#ifdef HEXL_HAS_AVX512IFMA +template <> +inline __m512i _mm512_hexl_mulhi_epi<52>(__m512i x, __m512i y) { + __m512i zero = _mm512_set1_epi64(0); + return _mm512_madd52hi_epu64(zero, x, y); +} +#endif + +// Multiply packed unsigned BitShift-bit integers in each 64-bit element of x +// and y to form a 2*BitShift-bit intermediate result. +// Returns the high BitShift-bit unsigned integer from the intermediate result, +// with approximation error at most 1 +template +inline __m512i _mm512_hexl_mulhi_approx_epi(__m512i x, __m512i y); + +// Dummy implementation to avoid template substitution errors +template <> +inline __m512i _mm512_hexl_mulhi_approx_epi<32>(__m512i x, __m512i y) { + HEXL_CHECK(false, "Unimplemented"); + HEXL_UNUSED(x); + HEXL_UNUSED(y); + return x; +} + +template <> +inline __m512i _mm512_hexl_mulhi_approx_epi<64>(__m512i x, __m512i y) { + // https://stackoverflow.com/questions/28807341/simd-signed-with-unsigned-multiplication-for-64-bit-64-bit-to-128-bit + __m512i lo_mask = _mm512_set1_epi64(0x00000000ffffffff); + // Shuffle high bits with low bits in each 64-bit integer => + // x0_lo, x0_hi, x1_lo, x1_hi, x2_lo, x2_hi, ... + __m512i x_hi = _mm512_shuffle_epi32(x, (_MM_PERM_ENUM)0xB1); + // y0_lo, y0_hi, y1_lo, y1_hi, y2_lo, y2_hi, ... + __m512i y_hi = _mm512_shuffle_epi32(y, (_MM_PERM_ENUM)0xB1); + __m512i z_lo_hi = _mm512_mul_epu32(x, y_hi); // x_lo * y_hi + __m512i z_hi_lo = _mm512_mul_epu32(x_hi, y); // x_hi * y_lo + __m512i z_hi_hi = _mm512_mul_epu32(x_hi, y_hi); // x_hi * y_hi + + // x_hi | x_lo + // x y_hi | y_lo + // ------------------------------ + // [x_lo * y_lo] // unused, resulting in approximation + // + [z_lo * y_hi] // z_lo_hi + // + [x_hi * y_lo] // z_hi_lo + // + [x_hi * y_hi] // z_hi_hi + // ^-----------^ <-- only bits needed + // sum_| hi | mid | lo | + + __m512i sum_lo = _mm512_and_si512(z_lo_hi, lo_mask); + __m512i sum_mid = _mm512_srli_epi64(z_lo_hi, 32); + // | |sum_lo| + // + [x_hi * y_lo] // z_hi_lo + // ------------------ + // [ sum_mid2 ] + __m512i sum_mid2 = _mm512_add_epi64(z_hi_lo, sum_lo); + __m512i sum_mid2_hi = _mm512_srli_epi64(sum_mid2, 32); + __m512i sum_hi = _mm512_add_epi64(z_hi_hi, sum_mid); + return _mm512_add_epi64(sum_hi, sum_mid2_hi); +} + +#ifdef HEXL_HAS_AVX512IFMA +template <> +inline __m512i _mm512_hexl_mulhi_approx_epi<52>(__m512i x, __m512i y) { + __m512i zero = _mm512_set1_epi64(0); + return _mm512_madd52hi_epu64(zero, x, y); +} +#endif + +// Multiply packed unsigned BitShift-bit integers in each 64-bit element of x +// and y to form a 2*BitShift-bit intermediate result. +// Returns the low BitShift-bit unsigned integer from the intermediate result +template +inline __m512i _mm512_hexl_mullo_epi(__m512i x, __m512i y); + +// Dummy implementation to avoid template substitution errors +template <> +inline __m512i _mm512_hexl_mullo_epi<32>(__m512i x, __m512i y) { + HEXL_CHECK(false, "Unimplemented"); + HEXL_UNUSED(x); + HEXL_UNUSED(y); + return x; +} + +template <> +inline __m512i _mm512_hexl_mullo_epi<64>(__m512i x, __m512i y) { + return _mm512_mullo_epi64(x, y); +} + +#ifdef HEXL_HAS_AVX512IFMA +template <> +inline __m512i _mm512_hexl_mullo_epi<52>(__m512i x, __m512i y) { + __m512i zero = _mm512_set1_epi64(0); + return _mm512_madd52lo_epu64(zero, x, y); +} +#endif + +// Multiply packed unsigned BitShift-bit integers in each 64-bit element of y +// and z to form a 2*BitShift-bit intermediate result. The low BitShift bits of +// the result are added to x, then the low BitShift bits of the result are +// returned. +template +inline __m512i _mm512_hexl_mullo_add_lo_epi(__m512i x, __m512i y, __m512i z); + +#ifdef HEXL_HAS_AVX512IFMA +template <> +inline __m512i _mm512_hexl_mullo_add_lo_epi<52>(__m512i x, __m512i y, + __m512i z) { + __m512i result = _mm512_madd52lo_epu64(x, y, z); + + // Clear high 12 bits from result + result = ClearTopBits64<52>(result); + return result; +} +#endif + +// Dummy implementation to avoid template substitution errors +template <> +inline __m512i _mm512_hexl_mullo_add_lo_epi<32>(__m512i x, __m512i y, + __m512i z) { + HEXL_CHECK(false, "Unimplemented"); + HEXL_UNUSED(x); + HEXL_UNUSED(y); + HEXL_UNUSED(z); + return x; +} + +template <> +inline __m512i _mm512_hexl_mullo_add_lo_epi<64>(__m512i x, __m512i y, + __m512i z) { + __m512i prod = _mm512_mullo_epi64(y, z); + return _mm512_add_epi64(x, prod); +} + +// Returns x mod q across each 64-bit integer SIMD lanes +// Assumes x < InputModFactor * q in all lanes +template +inline __m512i _mm512_hexl_small_mod_epu64(__m512i x, __m512i q, + __m512i* q_times_2 = nullptr, + __m512i* q_times_4 = nullptr) { + HEXL_CHECK(InputModFactor == 1 || InputModFactor == 2 || + InputModFactor == 4 || InputModFactor == 8, + "InputModFactor must be 1, 2, 4, or 8"); + if (InputModFactor == 1) { + return x; + } + if (InputModFactor == 2) { + return _mm512_min_epu64(x, _mm512_sub_epi64(x, q)); + } + if (InputModFactor == 4) { + HEXL_CHECK(q_times_2 != nullptr, "q_times_2 must not be nullptr"); + x = _mm512_min_epu64(x, _mm512_sub_epi64(x, *q_times_2)); + return _mm512_min_epu64(x, _mm512_sub_epi64(x, q)); + } + if (InputModFactor == 8) { + HEXL_CHECK(q_times_2 != nullptr, "q_times_2 must not be nullptr"); + HEXL_CHECK(q_times_4 != nullptr, "q_times_4 must not be nullptr"); + x = _mm512_min_epu64(x, _mm512_sub_epi64(x, *q_times_4)); + x = _mm512_min_epu64(x, _mm512_sub_epi64(x, *q_times_2)); + return _mm512_min_epu64(x, _mm512_sub_epi64(x, q)); + } + HEXL_CHECK(false, "Invalid InputModFactor"); + return x; // Return dummy value +} + +// Returns (x + y) mod q; assumes 0 < x, y < q +inline __m512i _mm512_hexl_small_add_mod_epi64(__m512i x, __m512i y, + __m512i q) { + HEXL_CHECK_BOUNDS(ExtractValues(x).data(), 8, ExtractValues(q)[0], + "x exceeds bound " << ExtractValues(q)[0]); + HEXL_CHECK_BOUNDS(ExtractValues(y).data(), 8, ExtractValues(q)[0], + "y exceeds bound " << ExtractValues(q)[0]); + return _mm512_hexl_small_mod_epu64(_mm512_add_epi64(x, y), q); + + // Alternate implementation: + // x += y - q; + // if (x < 0) x+= q + // return x + // __m512i v_diff = _mm512_sub_epi64(y, q); + // x = _mm512_add_epi64(x, v_diff); + // __mmask8 sign_bits = _mm512_movepi64_mask(x); + // return _mm512_mask_add_epi64(x, sign_bits, x, q); +} + +// Returns (x - y) mod q; assumes 0 < x, y < q + +inline __m512i _mm512_hexl_small_sub_mod_epi64(__m512i x, __m512i y, + __m512i q) { + HEXL_CHECK_BOUNDS(ExtractValues(x).data(), 8, ExtractValues(q)[0], + "x exceeds bound " << ExtractValues(q)[0]); + HEXL_CHECK_BOUNDS(ExtractValues(y).data(), 8, ExtractValues(q)[0], + "y exceeds bound " << ExtractValues(q)[0]); + + // diff = x - y; + // return (diff < 0) ? (diff + q) : diff + __m512i v_diff = _mm512_sub_epi64(x, y); + __mmask8 sign_bits = _mm512_movepi64_mask(v_diff); + return _mm512_mask_add_epi64(v_diff, sign_bits, v_diff, q); +} + +inline __mmask8 _mm512_hexl_cmp_epu64_mask(__m512i a, __m512i b, CMPINT cmp) { + switch (cmp) { + case CMPINT::EQ: + return _mm512_cmp_epu64_mask(a, b, static_cast(CMPINT::EQ)); + case CMPINT::LT: + return _mm512_cmp_epu64_mask(a, b, static_cast(CMPINT::LT)); + case CMPINT::LE: + return _mm512_cmp_epu64_mask(a, b, static_cast(CMPINT::LE)); + case CMPINT::FALSE: + return _mm512_cmp_epu64_mask(a, b, static_cast(CMPINT::FALSE)); + case CMPINT::NE: + return _mm512_cmp_epu64_mask(a, b, static_cast(CMPINT::NE)); + case CMPINT::NLT: + return _mm512_cmp_epu64_mask(a, b, static_cast(CMPINT::NLT)); + case CMPINT::NLE: + return _mm512_cmp_epu64_mask(a, b, static_cast(CMPINT::NLE)); + case CMPINT::TRUE: + return _mm512_cmp_epu64_mask(a, b, static_cast(CMPINT::TRUE)); + } + __mmask8 dummy = 0; // Avoid end of non-void function warning + return dummy; +} + +// Returns c[i] = a[i] CMP b[i] ? match_value : 0 +inline __m512i _mm512_hexl_cmp_epi64(__m512i a, __m512i b, CMPINT cmp, + uint64_t match_value) { + __mmask8 mask = _mm512_hexl_cmp_epu64_mask(a, b, cmp); + return _mm512_maskz_broadcastq_epi64( + mask, _mm_set1_epi64x(static_cast(match_value))); +} + +// Returns c[i] = a[i] >= b[i] ? match_value : 0 +inline __m512i _mm512_hexl_cmpge_epu64(__m512i a, __m512i b, + uint64_t match_value) { + return _mm512_hexl_cmp_epi64(a, b, CMPINT::NLT, match_value); +} + +// Returns c[i] = a[i] < b[i] ? match_value : 0 +inline __m512i _mm512_hexl_cmplt_epu64(__m512i a, __m512i b, + uint64_t match_value) { + return _mm512_hexl_cmp_epi64(a, b, CMPINT::LT, match_value); +} + +// Returns c[i] = a[i] <= b[i] ? match_value : 0 +inline __m512i _mm512_hexl_cmple_epu64(__m512i a, __m512i b, + uint64_t match_value) { + return _mm512_hexl_cmp_epi64(a, b, CMPINT::LE, match_value); +} + +// Returns Montgomery form of ab mod q, computed via the REDC algorithm, +// also known as Montgomery reduction. +// Template: r with R = 2^r +// Inputs: q such that gcd(R, q) = 1. R > q. +// v_inv_mod in [0, R − 1] such that q*v_inv_mod ≡ −1 mod R, +// T = ab in the range [0, Rq − 1]. +// T_hi and T_lo for BitShift = 64 should be given in 63 bits. +// Output: Integer S in the range [0, q − 1] such that S ≡ TR^−1 mod q +template +inline __m512i _mm512_hexl_montgomery_reduce(__m512i T_hi, __m512i T_lo, + __m512i q, __m512i v_inv_mod, + __m512i v_rs_or_msk) { + HEXL_CHECK(BitShift == 52 || BitShift == 64, + "Invalid bitshift " << BitShift << "; need 52 or 64"); + +#ifdef HEXL_HAS_AVX512IFMA + if (BitShift == 52) { + // Operation: + // m ← ((T mod R)N′) mod R | m ← ((T & mod_R_mask)*v_inv_mod) & mod_R_mask + __m512i m = _mm512_hexl_mullo_epi(T_lo, v_inv_mod); + m = ClearTopBits64(m); + + // Operation: t ← (T + mN) / R = (T + m*q) >> r + // Hi part + __m512i t_hi = _mm512_madd52hi_epu64(T_hi, m, q); + // Low part + __m512i t = _mm512_madd52lo_epu64(T_lo, m, q); + t = _mm512_srli_epi64(t, r); + // Join parts + t = _mm512_madd52lo_epu64(t, t_hi, v_rs_or_msk); + + // If this function exists for 52 bits we could save 1 cycle + // t = _mm512_shrdi_epi64 (t_hi, t, r) + + // Operation: t ≥ q? return (t - q) : return t + return _mm512_hexl_small_mod_epu64<2>(t, q); + } +#endif + + HEXL_CHECK(BitShift == 64, "Invalid bitshift " << BitShift << "; need 64"); + + // Operation: + // m ← ((T mod R)N′) mod R | m ← ((T & mod_R_mask)*v_inv_mod) & mod_R_mask + __m512i m = ClearTopBits64(T_lo); + m = _mm512_hexl_mullo_epi(m, v_inv_mod); + m = ClearTopBits64(m); + + __m512i mq_hi = _mm512_hexl_mulhi_epi(m, q); + __m512i mq_lo = _mm512_hexl_mullo_epi(m, q); + + // to 63 bits + mq_hi = _mm512_slli_epi64(mq_hi, 1); + __m512i tmp = _mm512_srli_epi64(mq_lo, 63); + mq_hi = _mm512_add_epi64(mq_hi, tmp); + mq_lo = _mm512_and_epi64(mq_lo, v_rs_or_msk); + + __m512i t_hi = _mm512_add_epi64(T_hi, mq_hi); + t_hi = _mm512_slli_epi64(t_hi, 63 - r); + __m512i t = _mm512_add_epi64(T_lo, mq_lo); + t = _mm512_srli_epi64(t, r); + + // Join parts + t = _mm512_add_epi64(t_hi, t); + + return _mm512_hexl_small_mod_epu64<2>(t, q); +} + +// Returns x mod q, computed via Barrett reduction +// @param q_barr floor(2^BitShift / q) +template +inline __m512i _mm512_hexl_barrett_reduce64(__m512i x, __m512i q, + __m512i q_barr_64, + __m512i q_barr_52, + uint64_t prod_right_shift, + __m512i v_neg_mod) { + HEXL_UNUSED(q_barr_52); + HEXL_UNUSED(prod_right_shift); + HEXL_UNUSED(v_neg_mod); + HEXL_CHECK(BitShift == 52 || BitShift == 64, + "Invalid bitshift " << BitShift << "; need 52 or 64"); + +#ifdef HEXL_HAS_AVX512IFMA + if (BitShift == 52) { + __m512i two_pow_fiftytwo = _mm512_set1_epi64(2251799813685248); + __mmask8 mask = + _mm512_hexl_cmp_epu64_mask(x, two_pow_fiftytwo, CMPINT::NLT); + if (mask != 0) { + // values above 2^52 + __m512i x_hi = _mm512_srli_epi64(x, static_cast(52ULL)); + __m512i x_lo = ClearTopBits64<52>(x); + + // c1 = floor(U / 2^{n + beta}) + __m512i c1_lo = + _mm512_srli_epi64(x_lo, static_cast(prod_right_shift)); + __m512i c1_hi = _mm512_slli_epi64( + x_hi, static_cast(52ULL - (prod_right_shift))); + __m512i c1 = _mm512_or_epi64(c1_lo, c1_hi); + + // alpha - beta == 52, so we only need high 52 bits + __m512i q_hat = _mm512_hexl_mulhi_epi<52>(c1, q_barr_64); + // Z = prod_lo - (p * q_hat)_lo + x = _mm512_hexl_mullo_add_lo_epi<52>(x_lo, q_hat, v_neg_mod); + } else { + __m512i rnd1_hi = _mm512_hexl_mulhi_epi<52>(x, q_barr_52); + __m512i tmp1_times_mod = _mm512_hexl_mullo_epi<52>(rnd1_hi, q); + x = _mm512_sub_epi64(x, tmp1_times_mod); + } + } +#endif + if (BitShift == 64) { + __m512i rnd1_hi = _mm512_hexl_mulhi_epi<64>(x, q_barr_64); + __m512i tmp1_times_mod = _mm512_hexl_mullo_epi<64>(rnd1_hi, q); + x = _mm512_sub_epi64(x, tmp1_times_mod); + } + + // Correction + if (OutputModFactor == 1) { + x = _mm512_hexl_small_mod_epu64<2>(x, q); + } + return x; +} + +// Concatenate packed 64-bit integers in x and y, producing an intermediate +// 128-bit result. Shift the result right by bit_shift bits, and return the +// lower 64 bits. The bit_shift is a run-time argument, rather than a +// compile-time template parameter, so we can't use _mm512_shrdi_epi64 +inline __m512i _mm512_hexl_shrdi_epi64(__m512i x, __m512i y, + unsigned int bit_shift) { + __m512i c_lo = _mm512_srli_epi64(x, bit_shift); + __m512i c_hi = _mm512_slli_epi64(y, 64 - bit_shift); + return _mm512_add_epi64(c_lo, c_hi); +} + +// Concatenate packed 64-bit integers in x and y, producing an intermediate +// 128-bit result. Shift the result right by BitShift bits, and return the lower +// 64 bits. +template +inline __m512i _mm512_hexl_shrdi_epi64(__m512i x, __m512i y) { +#ifdef HEXL_HAS_AVX512VBMI2 + return _mm512_shrdi_epi64(x, y, BitShift); +#else + return _mm512_hexl_shrdi_epi64(x, y, BitShift); +#endif +} + +#endif // HEXL_HAS_AVX512DQ + +} // namespace hexl +} // namespace intel diff --git a/hexl_ser/util/cpu-features.hpp b/hexl_ser/util/cpu-features.hpp new file mode 100644 index 00000000..ba408394 --- /dev/null +++ b/hexl_ser/util/cpu-features.hpp @@ -0,0 +1,35 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +#include + +#include "cpuinfo_x86.h" // NOLINT(build/include_subdir) + +namespace intel { +namespace hexl { + +// Use to disable avx512 dispatching at runtime +static const bool disable_avx512dq = + (std::getenv("HEXL_DISABLE_AVX512DQ") != nullptr); +static const bool disable_avx512ifma = + disable_avx512dq || (std::getenv("HEXL_DISABLE_AVX512IFMA") != nullptr); +static const bool disable_avx512vbmi2 = + disable_avx512dq || (std::getenv("HEXL_DISABLE_AVX512VBMI2") != nullptr); + +static const cpu_features::X86Features features = + cpu_features::GetX86Info().features; + +static const bool has_avx512dq = features.avx512f && features.avx512dq && + features.avx512vl && !disable_avx512dq; + +static const bool has_avx512ifma = features.avx512ifma && !disable_avx512ifma; + +static const bool has_avx512vbmi2 = + features.avx512vbmi2 && !disable_avx512vbmi2; + +} // namespace hexl +} // namespace intel diff --git a/hexl_ser/util/util-internal.hpp b/hexl_ser/util/util-internal.hpp new file mode 100644 index 00000000..20e94da9 --- /dev/null +++ b/hexl_ser/util/util-internal.hpp @@ -0,0 +1,102 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include +#include + +#include "hexl/util/aligned-allocator.hpp" +#include "hexl/util/util.hpp" + +namespace intel { +namespace hexl { + +inline bool Compare(CMPINT cmp, uint64_t lhs, uint64_t rhs) { + switch (cmp) { + case CMPINT::EQ: + return lhs == rhs; + case CMPINT::LT: + return lhs < rhs; + break; + case CMPINT::LE: + return lhs <= rhs; + break; + case CMPINT::FALSE: + return false; + break; + case CMPINT::NE: + return lhs != rhs; + break; + case CMPINT::NLT: + return lhs >= rhs; + break; + case CMPINT::NLE: + return lhs > rhs; + case CMPINT::TRUE: + return true; + default: + return true; + } +} + +/// Generates a vector of size random values drawn uniformly from [min_value, +/// max_value) +/// NOTE: this function is not a cryptographically secure random number +/// generator and should be used for testing/benchmarking only +inline double GenerateInsecureUniformRealRandomValue(double min_value, + double max_value) { + HEXL_CHECK(min_value < max_value, "min_value must be > max_value"); + + static std::random_device rd; + static std::mt19937 mersenne_engine(rd()); + std::uniform_real_distribution distrib(min_value, max_value); + double res = distrib(mersenne_engine); + return (res == max_value) ? min_value : res; +} + +/// Generates a vector of size random values drawn uniformly from [min_value, +/// max_value) +/// NOTE: this function is not a cryptographically secure random number +/// generator and should be used for testing/benchmarking only +inline uint64_t GenerateInsecureUniformIntRandomValue(uint64_t min_value, + uint64_t max_value) { + HEXL_CHECK(min_value < max_value, "min_value must be > max_value"); + + static std::random_device rd; + static std::mt19937 mersenne_engine(rd()); + std::uniform_int_distribution distrib(min_value, max_value - 1); + return distrib(mersenne_engine); +} + +/// Generates a vector of size random values drawn uniformly from [min_value, +/// max_value) +/// NOTE: this function is not a cryptographically secure random +/// number generator and should be used for testing/benchmarking only +inline AlignedVector64 GenerateInsecureUniformRealRandomValues( + uint64_t size, double min_value, double max_value) { + AlignedVector64 values(size); + auto generator = [&]() { + return GenerateInsecureUniformRealRandomValue(min_value, max_value); + }; + std::generate(values.begin(), values.end(), generator); + return values; +} + +/// Generates a vector of size random values drawn uniformly from [min_value, +/// max_value) +/// NOTE: this function is not a cryptographically secure random +/// number generator and should be used for testing/benchmarking only +inline AlignedVector64 GenerateInsecureUniformIntRandomValues( + uint64_t size, uint64_t min_value, uint64_t max_value) { + AlignedVector64 values(size); + auto generator = [&]() { + return GenerateInsecureUniformIntRandomValue(min_value, max_value); + }; + std::generate(values.begin(), values.end(), generator); + return values; +} + +} // namespace hexl +} // namespace intel diff --git a/hexl_ser_out_0824_1431.csv b/hexl_ser_out_0824_1431.csv new file mode 100644 index 00000000..3c5ec97e --- /dev/null +++ b/hexl_ser_out_0824_1431.csv @@ -0,0 +1,10 @@ +Method Input_size=4096 Input_size=65536 Input_size=1048576 Input_size=16777216 Input_size=268435456 +BM_EltwiseCmpAdd 0.0015148 0.0183375 0.24938 10.5679 109.842 +BM_EltwiseCmpSubMod 0.0311302 0.342652 3.65048 44.5498 783.263 +BM_EltwiseFMAModAdd 0.0254424 0.381004 3.13449 52.2436 900.251 +BM_EltwiseMultMod 0.0198413 0.278771 2.21727 56.6321 672.338 +BM_EltwiseReduceModInPlace 0.0202927 0.305697 2.73887 41.9916 664.917 +BM_EltwiseVectorScalarAddMod 0.0016041 0.0354446 0.495221 16.3091 174.137 +BM_EltwiseVectorVectorAddMod 0.0022032 0.0516952 0.801214 19.9865 227.267 +BM_EltwiseVectorVectorSubMod 0.0021632 0.0509133 0.851897 20.7146 215.027 +BM_NTTInPlace 0.000456 0.0013347 0.0093004 0.097884 1.98 \ No newline at end of file diff --git a/hexl_serial.sh b/hexl_serial.sh new file mode 100644 index 00000000..e95dbdf1 --- /dev/null +++ b/hexl_serial.sh @@ -0,0 +1,15 @@ +#!/bin/bash + +# Define constants +iteration_number="10" +thread_numbers="4096,65536,1048576,16777216,268435456" +file_name="serial_result.csv" + +# Check if the CSV file exists. If not, create it and add the header. +if [ ! -f $file_name ]; then + echo "Method,Threads=4096,Threads=65536,Threads=1048576,Threads=16777216,Threads=268435456" > $file_name +fi + +# Loop over iterations + # Run the command and append its output to the CSV file +./build/example $iteration $thread_numbers >> $file_name diff --git a/hexl_time.sh b/hexl_time.sh new file mode 100755 index 00000000..5b0445a3 --- /dev/null +++ b/hexl_time.sh @@ -0,0 +1,68 @@ +#!/bin/bash + +# Initialize the parameters +iterations=100 +thread_numbers="1 2 4 6 8" +# thread_numbers="1 2 4" + +method_number=4 +# INPUT_SIZES=( $((2**12)) $((2**16)) $((2**20)) $((2**24)) $((2**28)) ) +INPUT_SIZES=( $((2**12)) $((2**16)) $((2**20)) $((2**24))) + +# Initialize CSV file and write header +csv_file="omp_time.csv" +tmp_file="temp_file.csv" + +echo "Method_Number,Thread_Count,Input_Size,Average_Elapsed_Time" > $csv_file + +# Loop through the input sizes +for input_size in "${INPUT_SIZES[@]}"; do + echo "Running test with input_size=${input_size}" + + # Loop through each thread number + for thread in $thread_numbers; do + echo "Running with thread_count=${thread}" + + # # Run the command along with the binary and parameters + # # Capture the output to a temporary file + ./time_example/build/time_example $iterations $thread $input_size $method_number > $tmp_file + + # Initialize variables to calculate the average + total_time=0 + count=0 + error_count=0 + + + # Read each line from the temporary file + while read -r line; do + # Extract thread count and elapsed time from the line + read -r out_thread out_time <<< "$line" + + # Check if thread counts match + if [ "$out_thread" -ne "$thread" ]; then + echo "Error: Mismatch in thread count (expected $thread, got $out_thread)" + error_count=$((error_count + 1)) + continue + fi + + # Update total time and count + # echo "$total_time + $out_time" | bc -l + total_time=$(echo "$total_time + $out_time" | bc -l) + count=$((count + 1)) + + done < $tmp_file + + # Calculate and append the average to the CSV file + if [ "$count" -ne 0 ]; then + average_time=$(echo "$total_time / $count" | bc -l) + echo "$method_number,$thread,$input_size,$average_time" >> $csv_file + fi + + # Remove the temporary file + rm -f $tmp_file + + # ./time_example/build/time_example $iterations $thread $input_size $method_number + done +done + +echo "All tests completed." diff --git a/hexl_v0/include/hexl/eltwise/eltwise-add-mod.hpp b/hexl_v0/include/hexl/eltwise/eltwise-add-mod.hpp new file mode 100644 index 00000000..cb2df110 --- /dev/null +++ b/hexl_v0/include/hexl/eltwise/eltwise-add-mod.hpp @@ -0,0 +1,40 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +namespace intel { +namespace hexl { + +/// @brief Adds two vectors elementwise with modular reduction +/// @param[out] result Stores result +/// @param[in] operand1 Vector of elements to add. Each element must be less +/// than the modulus +/// @param[in] operand2 Vector of elements to add. Each element must be less +/// than the modulus +/// @param[in] n Number of elements in each vector +/// @param[in] modulus Modulus with which to perform modular reduction. Must be +/// in the range \f$[2, 2^{63} - 1]\f$ +/// @details Computes \f$ operand1[i] = (operand1[i] + operand2[i]) \mod modulus +/// \f$ for \f$ i=0, ..., n-1\f$. +void EltwiseAddMod(uint64_t* result, const uint64_t* operand1, + const uint64_t* operand2, uint64_t n, uint64_t modulus); + +/// @brief Adds a vector and scalar elementwise with modular reduction +/// @param[out] result Stores result +/// @param[in] operand1 Vector of elements to add. Each element must be less +/// than the modulus +/// @param[in] operand2 Scalar to add. Must be less +/// than the modulus +/// @param[in] n Number of elements in each vector +/// @param[in] modulus Modulus with which to perform modular reduction. Must be +/// in the range \f$[2, 2^{63} - 1]\f$ +/// @details Computes \f$ operand1[i] = (operand1[i] + operand2) \mod modulus +/// \f$ for \f$ i=0, ..., n-1\f$. +void EltwiseAddMod(uint64_t* result, const uint64_t* operand1, + uint64_t operand2, uint64_t n, uint64_t modulus); + +} // namespace hexl +} // namespace intel diff --git a/hexl_v0/include/hexl/eltwise/eltwise-cmp-add.hpp b/hexl_v0/include/hexl/eltwise/eltwise-cmp-add.hpp new file mode 100644 index 00000000..27e514ff --- /dev/null +++ b/hexl_v0/include/hexl/eltwise/eltwise-cmp-add.hpp @@ -0,0 +1,26 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +#include "hexl/util/util.hpp" + +namespace intel { +namespace hexl { + +/// @brief Computes element-wise conditional addition. +/// @param[out] result Stores the result +/// @param[in] operand1 Vector of elements to compare; stores result +/// @param[in] n Number of elements in \p operand1 +/// @param[in] cmp Comparison operation +/// @param[in] bound Scalar to compare against +/// @param[in] diff Scalar to conditionally add +/// @details Computes result[i] = cmp(operand1[i], bound) ? operand1[i] + +/// diff : operand1[i] for all \f$i=0, ..., n-1\f$. +void EltwiseCmpAdd(uint64_t* result, const uint64_t* operand1, uint64_t n, + CMPINT cmp, uint64_t bound, uint64_t diff); + +} // namespace hexl +} // namespace intel diff --git a/hexl_v0/include/hexl/eltwise/eltwise-cmp-sub-mod.hpp b/hexl_v0/include/hexl/eltwise/eltwise-cmp-sub-mod.hpp new file mode 100644 index 00000000..07ba3d23 --- /dev/null +++ b/hexl_v0/include/hexl/eltwise/eltwise-cmp-sub-mod.hpp @@ -0,0 +1,29 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +#include "hexl/util/util.hpp" + +namespace intel { +namespace hexl { + +/// @brief Computes element-wise conditional modular subtraction. +/// @param[out] result Stores the result +/// @param[in] operand1 Vector of elements to compare +/// @param[in] n Number of elements in \p operand1 +/// @param[in] modulus Modulus to reduce by +/// @param[in] cmp Comparison function +/// @param[in] bound Scalar to compare against +/// @param[in] diff Scalar to subtract by +/// @details Computes \p operand1[i] = (\p cmp(\p operand1, \p bound)) ? (\p +/// operand1 - \p diff) mod \p modulus : \p operand1 mod \p modulus for all i=0, +/// ..., n-1 +void EltwiseCmpSubMod(uint64_t* result, const uint64_t* operand1, uint64_t n, + uint64_t modulus, CMPINT cmp, uint64_t bound, + uint64_t diff); + +} // namespace hexl +} // namespace intel diff --git a/hexl_v0/include/hexl/eltwise/eltwise-fma-mod.hpp b/hexl_v0/include/hexl/eltwise/eltwise-fma-mod.hpp new file mode 100644 index 00000000..03651a42 --- /dev/null +++ b/hexl_v0/include/hexl/eltwise/eltwise-fma-mod.hpp @@ -0,0 +1,27 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +namespace intel { +namespace hexl { + +/// @brief Computes fused multiply-add (\p arg1 * \p arg2 + \p arg3) mod \p +/// modulus element-wise, broadcasting scalars to vectors. +/// @param[out] result Stores the result +/// @param[in] arg1 Vector to multiply +/// @param[in] arg2 Scalar to multiply +/// @param[in] arg3 Vector to add. Will not add if \p arg3 == nullptr +/// @param[in] n Number of elements in each vector +/// @param[in] modulus Modulus with which to perform modular reduction. Must be +/// in the range \f$ [2, 2^{61} - 1]\f$ +/// @param[in] input_mod_factor Assumes input elements are in [0, +/// input_mod_factor * modulus). Must be 1, 2, 4, or 8. +void EltwiseFMAMod(uint64_t* result, const uint64_t* arg1, uint64_t arg2, + const uint64_t* arg3, uint64_t n, uint64_t modulus, + uint64_t input_mod_factor); + +} // namespace hexl +} // namespace intel diff --git a/hexl_v0/include/hexl/eltwise/eltwise-mult-mod.hpp b/hexl_v0/include/hexl/eltwise/eltwise-mult-mod.hpp new file mode 100644 index 00000000..e4d2dbd7 --- /dev/null +++ b/hexl_v0/include/hexl/eltwise/eltwise-mult-mod.hpp @@ -0,0 +1,28 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +namespace intel { +namespace hexl { + +/// @brief Multiplies two vectors elementwise with modular reduction +/// @param[in] result Result of element-wise multiplication +/// @param[in] operand1 Vector of elements to multiply. Each element must be +/// less than the modulus. +/// @param[in] operand2 Vector of elements to multiply. Each element must be +/// less than the modulus. +/// @param[in] n Number of elements in each vector +/// @param[in] modulus Modulus with which to perform modular reduction +/// @param[in] input_mod_factor Assumes input elements are in [0, +/// input_mod_factor * p) Must be 1, 2 or 4. +/// @details Computes \p result[i] = (\p operand1[i] * \p operand2[i]) mod \p +/// modulus for i=0, ..., \p n - 1 +void EltwiseMultMod(uint64_t* result, const uint64_t* operand1, + const uint64_t* operand2, uint64_t n, uint64_t modulus, + uint64_t input_mod_factor); + +} // namespace hexl +} // namespace intel diff --git a/hexl_v0/include/hexl/eltwise/eltwise-reduce-mod.hpp b/hexl_v0/include/hexl/eltwise/eltwise-reduce-mod.hpp new file mode 100644 index 00000000..c23abde2 --- /dev/null +++ b/hexl_v0/include/hexl/eltwise/eltwise-reduce-mod.hpp @@ -0,0 +1,29 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +namespace intel { +namespace hexl { + +/// @brief Performs elementwise modular reduction +/// @param[out] result Stores the result +/// @param[in] operand Data on which to compute the elementwise modular +/// reduction +/// @param[in] n Number of elements in operand +/// @param[in] modulus Modulus with which to perform modular reduction +/// @param[in] input_mod_factor Assumes input elements are in [0, +/// input_mod_factor * p) Must be modulus, 1, 2 or 4. input_mod_factor=modulus +/// means, input range is [0, p * p]. Barrett reduction will be used in this +/// case. input_mod_factor > output_mod_factor +/// @param[in] output_mod_factor output elements will be in [0, +/// output_mod_factor * modulus) Must be 1 or 2. For input_mod_factor=0, +/// output_mod_factor will be set to 1. +void EltwiseReduceMod(uint64_t* result, const uint64_t* operand, uint64_t n, + uint64_t modulus, uint64_t input_mod_factor, + uint64_t output_mod_factor); + +} // namespace hexl +} // namespace intel diff --git a/hexl_v0/include/hexl/eltwise/eltwise-sub-mod.hpp b/hexl_v0/include/hexl/eltwise/eltwise-sub-mod.hpp new file mode 100644 index 00000000..bd286e47 --- /dev/null +++ b/hexl_v0/include/hexl/eltwise/eltwise-sub-mod.hpp @@ -0,0 +1,40 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +namespace intel { +namespace hexl { + +/// @brief Subtracts two vectors elementwise with modular reduction +/// @param[out] result Stores result +/// @param[in] operand1 Vector of elements to subtract from. Each element must +/// be less than the modulus +/// @param[in] operand2 Vector of elements to subtract. Each element must be +/// less than the modulus +/// @param[in] n Number of elements in each vector +/// @param[in] modulus Modulus with which to perform modular reduction. Must be +/// in the range \f$[2, 2^{63} - 1]\f$ +/// @details Computes \f$ operand1[i] = (operand1[i] - operand2[i]) \mod modulus +/// \f$ for \f$ i=0, ..., n-1\f$. +void EltwiseSubMod(uint64_t* result, const uint64_t* operand1, + const uint64_t* operand2, uint64_t n, uint64_t modulus); + +/// @brief Subtracts a scalar from a vector elementwise with modular reduction +/// @param[out] result Stores result +/// @param[in] operand1 Vector of elements to subtract from. Each element must +/// be less than the modulus +/// @param[in] operand2 Elements to subtract. Each element must be +/// less than the modulus +/// @param[in] n Number of elements in each vector +/// @param[in] modulus Modulus with which to perform modular reduction. Must be +/// in the range \f$[2, 2^{63} - 1]\f$ +/// @details Computes \f$ operand1[i] = (operand1[i] - operand2) \mod modulus +/// \f$ for \f$ i=0, ..., n-1\f$. +void EltwiseSubMod(uint64_t* result, const uint64_t* operand1, + uint64_t operand2, uint64_t n, uint64_t modulus); + +} // namespace hexl +} // namespace intel diff --git a/hexl_v0/include/hexl/experimental/fft-like/fft-like-avx512-util.hpp b/hexl_v0/include/hexl/experimental/fft-like/fft-like-avx512-util.hpp new file mode 100644 index 00000000..28a2dddf --- /dev/null +++ b/hexl_v0/include/hexl/experimental/fft-like/fft-like-avx512-util.hpp @@ -0,0 +1,402 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "util/avx512-util.hpp" + +namespace intel { +namespace hexl { + +#ifdef HEXL_HAS_AVX512DQ + +// ************************************ T1 ************************************ + +// ComplexLoadFwdInterleavedT1: +// Assumes ComplexLoadFwdInterleavedT2 was used before. +// Given input: 15, 14, 11, 10, 7, 6, 3, 2, 13, 12, 9, 8, 5, 4, 1, 0 +// Returns +// *out1 = (14, 12, 10, 8, 6, 4, 2, 0); +// *out2 = (15, 13, 11, 9, 7, 5, 3, 1); +// +// Given output: 15, 13, 11, 9, 7, 5, 3, 1, 14, 12, 10, 8, 6, 4, 2, 0 +inline void ComplexLoadFwdInterleavedT1(const double* arg, __m512d* out1, + __m512d* out2) { + const __m512i vperm_idx = _mm512_set_epi64(6, 7, 4, 5, 2, 3, 0, 1); + + const __m512d* arg_512 = reinterpret_cast(arg); + + // 13 12 9 8 5 4 1 0 + __m512d v_7to0 = _mm512_loadu_pd(arg_512); + arg_512 += 2; + // 15 14 11 10 7 6 3 2 + __m512d v_15to8 = _mm512_loadu_pd(arg_512); + + // 12, 13, 8, 9, 4, 5, 0, 1 + __m512d perm_1 = _mm512_permutexvar_pd(vperm_idx, v_7to0); + // 14, 15, 10, 11, 6, 7, 2, 3 + __m512d perm_2 = _mm512_permutexvar_pd(vperm_idx, v_15to8); + + // 14, 12, 10, 8, 6, 4, 2, 0 + *out1 = _mm512_mask_blend_pd(0xaa, v_7to0, perm_2); + // 15, 13, 11, 9, 7, 5, 3, 1 + *out2 = _mm512_mask_blend_pd(0x55, v_15to8, perm_1); +} + +// ComplexWriteFwdInterleavedT1: +// Assumes ComplexLoadFwdInterleavedT1 was used before. +// Given inputs: +// 15i, 13i, 11i, 9i, 7i, 5i, 3i, 1i, 15r, 13r, 11r, 9r, 7r, 5r, 3r, 1r, +// 14i, 12i, 10i, 8i, 6i, 4i, 2i, 0i, 14r, 12r, 10r, 8r, 6r, 4r, 2r, 0r +// As seen with internal indexes: +// @param arg_yr = (15r, 14r, 13r, 12r, 11r, 10r, 9r, 8r); +// @param arg_xr = ( 7r, 6r, 5r, 4r, 3r, 2r, 1r, 0r); +// @param arg_yi = (15i, 14i, 13i, 12i, 11i, 10i, 9i, 8i); +// @param arg_xi = ( 7i, 6i, 5i, 4i, 3i, 2i, 1i, 0i); +// Writes out = +// {15i, 15r, 7i, 7r, 14i, 14r, 6i, 6r, 13i, 13r, 5i, 5r, 12i, 12r, 4i, 4r, +// 11i, 11r, 3i, 3r, 10i, 10r, 2i, 2r, 9i, 9r, 1i, 1r, 8i, 8r, 0i, 0r} +// +// Given output: +// 15i, 15r, 14i, 14r, 13i, 13r, 12i, 12r, 11i, 11r, 10i, 10r, 9i, 9r, 8i, 8r, +// 7i, 7r, 6i, 6r, 5i, 5r, 4i, 4r, 3i, 3r, 2i, 2r, 1i, 1r, 0i, 0r +inline void ComplexWriteFwdInterleavedT1(__m512d arg_xr, __m512d arg_yr, + __m512d arg_xi, __m512d arg_yi, + __m512d* out) { + const __m512i vperm_4hi_4lo_idx = _mm512_set_epi64(3, 2, 1, 0, 7, 6, 5, 4); + const __m512i v_X_out_idx = _mm512_set_epi64(3, 1, 7, 5, 2, 0, 6, 4); + const __m512i v_Y_out_idx = _mm512_set_epi64(7, 5, 3, 1, 6, 4, 2, 0); + + // Real part + // in: 14r, 12r, 10r, 8r, 6r, 4r, 2r, 0r + // -> 6r, 4r, 2r, 0r, 14r, 12r, 10r, 8r + arg_xr = _mm512_permutexvar_pd(vperm_4hi_4lo_idx, arg_xr); + + // arg_yr: 15r, 13r, 11r, 9r, 7r, 5r, 3r, 1r + // -> 6r, 4r, 2r, 0r, 7r, 5r, 3r, 1r + __m512d perm_1 = _mm512_mask_blend_pd(0x0f, arg_xr, arg_yr); + // -> 15r, 13r, 11r, 9r, 14r, 12r, 10r, 8r + __m512d perm_2 = _mm512_mask_blend_pd(0xf0, arg_xr, arg_yr); + + // 7r, 3r, 6r, 2r, 5r, 1r, 4r, 0r + arg_xr = _mm512_permutexvar_pd(v_X_out_idx, perm_1); + // 15r, 11r, 14r, 10r, 13r, 9r, 12r, 8r + arg_yr = _mm512_permutexvar_pd(v_Y_out_idx, perm_2); + + // Imaginary part + // in: 14i, 12i, 10i, 8i, 6i, 4i, 2i, 0i + // -> 6i, 4i, 2i, 0i, 14i, 12i, 10i, 8i + arg_xi = _mm512_permutexvar_pd(vperm_4hi_4lo_idx, arg_xi); + + // arg_yr: 15i, 13i, 11i, 9i, 7i, 5i, 3i, 1i + // -> 6i, 4i, 2i, 0i, 7i, 5i, 3i, 1i + perm_1 = _mm512_mask_blend_pd(0x0f, arg_xi, arg_yi); + // -> 15i, 13i, 11i, 9i, 14i, 12i, 10i, 8i + perm_2 = _mm512_mask_blend_pd(0xf0, arg_xi, arg_yi); + + // 7i, 3i, 6i, 2i, 5i, 1i, 4i, 0i + arg_xi = _mm512_permutexvar_pd(v_X_out_idx, perm_1); + // 15i, 11i, 14i, 10i, 13i, 9i, 12i, 8i + arg_yi = _mm512_permutexvar_pd(v_Y_out_idx, perm_2); + + // Merge + // 00000000 > 3i 3r 2i 2r 1i 1r 0i 0r + __m512d out1 = _mm512_shuffle_pd(arg_xr, arg_xi, 0x00); + // 11111111 > 7i 7r 6i 6r 5i 5r 4i 4r + __m512d out2 = _mm512_shuffle_pd(arg_xr, arg_xi, 0xff); + + // 00000000 > 11i 11r 10i 10r 9i 9r 8i 8r + __m512d out3 = _mm512_shuffle_pd(arg_yr, arg_yi, 0x00); + // 11111111 > 15i 15r 14i 14r 13i 13r 12i 12r + __m512d out4 = _mm512_shuffle_pd(arg_yr, arg_yi, 0xff); + + _mm512_storeu_pd(out++, out1); + _mm512_storeu_pd(out++, out2); + _mm512_storeu_pd(out++, out3); + _mm512_storeu_pd(out++, out4); +} + +// ComplexLoadInvInterleavedT1: +// Given input: 15i 15r 14i 14r 13i 13r 12i 12r 11i 11r 10i 10r 9i 9r 8i 8r +// 7i 7r 6i 6r 5i 5r 4i 4r 3i 3r 2i 2r 1i 1r 0i 0r +// Returns +// *out1_r = (14r, 10r, 6r, 2r, 12r, 8r, 4r, 0r); +// *out1_i = (14i, 10i, 6i, 2i, 12i, 8i, 4i, 0i); +// *out2_r = (15r, 11r, 7r, 3r, 13r, 9r, 5r, 1r); +// *out2_i = (15i, 11i, 7i, 3i, 13i, 9i, 5i, 1i); +// +// Given output: +// 15i, 11i, 7i, 3i, 13i, 9i, 5i, 1i, 15r, 11r, 7r, 3r, 13r, 9r, 5r, 1r, +// 14i, 10i, 6i, 2i, 12i, 8i, 4i, 0i, 14r, 10r, 6r, 2r, 12r, 8r, 4r, 0r +inline void ComplexLoadInvInterleavedT1(const double* arg, __m512d* out1_r, + __m512d* out1_i, __m512d* out2_r, + __m512d* out2_i) { + const __m512d* arg_512 = reinterpret_cast(arg); + + // 3i 3r 2i 2r 1i 1r 0i 0r + __m512d v_3to0 = _mm512_loadu_pd(arg_512++); + // 7i 7r 6i 6r 5i 5r 4i 4r + __m512d v_7to4 = _mm512_loadu_pd(arg_512++); + // 11i 11r 10i 10r 9i 9r 8i 8r + __m512d v_11to8 = _mm512_loadu_pd(arg_512++); + // 15i 15r 14i 14r 13i 13r 12i 12r + __m512d v_15to12 = _mm512_loadu_pd(arg_512++); + + // 00000000 > 7r 3r 6r 2r 5r 1r 4r 0r + __m512d v_7to0_r = _mm512_shuffle_pd(v_3to0, v_7to4, 0x00); + // 11111111 > 7i 3i 6i 2i 5i 1i 4i 0i + __m512d v_7to0_i = _mm512_shuffle_pd(v_3to0, v_7to4, 0xff); + // 00000000 > 15r 11r 14r 10r 13r 9r 12r 8r + __m512d v_15to8_r = _mm512_shuffle_pd(v_11to8, v_15to12, 0x00); + // 11111111 > 15i 11i 14i 10i 13i 9i 12i 8i + __m512d v_15to8_i = _mm512_shuffle_pd(v_11to8, v_15to12, 0xff); + + // real + const __m512i v1_perm_idx = _mm512_set_epi64(5, 4, 7, 6, 1, 0, 3, 2); + // 6 2 7 3 4 0 5 1 + __m512d v1r = _mm512_permutexvar_pd(v1_perm_idx, v_7to0_r); + // 14 10 15 11 12 8 13 9 + __m512d v2r = _mm512_permutexvar_pd(v1_perm_idx, v_15to8_r); + // 11001100 > 14 10 6 2 12 8 4 0 + *out1_r = _mm512_mask_blend_pd(0xcc, v_7to0_r, v2r); + // 11001100 > 15 11 7 3 13 9 5 1 + *out2_r = _mm512_mask_blend_pd(0xcc, v1r, v_15to8_r); + + // imag + // 6 2 7 3 4 0 5 1 + __m512d v1i = _mm512_permutexvar_pd(v1_perm_idx, v_7to0_i); + // 14 10 15 11 12 8 13 9 + __m512d v2i = _mm512_permutexvar_pd(v1_perm_idx, v_15to8_i); + // 11001100 > 14 10 6 2 12 8 4 0 + *out1_i = _mm512_mask_blend_pd(0xcc, v_7to0_i, v2i); + // 11001100 > 15 11 7 3 13 9 5 1 + *out2_i = _mm512_mask_blend_pd(0xcc, v1i, v_15to8_i); +} + +// ************************************ T2 ************************************ + +// ComplexLoadFwdInterleavedT2: +// Assumes ComplexLoadFwdInterleavedT4 was used before. +// Given input: 15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0 +// Returns +// *out1 = (13, 12, 9, 8, 5, 4, 1, 0) +// *out2 = (15, 14, 11, 10, 7, 6, 3, 2) +// +// Given output: 15, 14, 11, 10, 7, 6, 3, 2, 13, 12, 9, 8, 5, 4, 1, 0 +inline void ComplexLoadFwdInterleavedT2(const double* arg, __m512d* out1, + __m512d* out2) { + const __m512d* arg_512 = reinterpret_cast(arg); + + // Values were swapped in T4 + // 11, 10, 9, 8, 3, 2, 1, 0 + __m512d v1 = _mm512_loadu_pd(arg_512); + arg_512 += 2; + // 15, 14, 13, 12, 7, 6, 5, 4 + __m512d v2 = _mm512_loadu_pd(arg_512); + + const __m512i v1_perm_idx = _mm512_set_epi64(5, 4, 7, 6, 1, 0, 3, 2); + + __m512d v1_perm = _mm512_permutexvar_pd(v1_perm_idx, v1); + __m512d v2_perm = _mm512_permutexvar_pd(v1_perm_idx, v2); + + *out1 = _mm512_mask_blend_pd(0xcc, v1, v2_perm); + *out2 = _mm512_mask_blend_pd(0xcc, v1_perm, v2); +} + +// ComplexLoadInvInterleavedT2: +// Assumes ComplexLoadInvInterleavedT1 was used before. +// Given input: 15, 11, 7, 3, 13, 9, 5, 1, 14, 10, 6, 2, 12, 8, 4, 0 +// Returns +// *out1 = (13, 9, 5, 1, 12, 8, 4, 0) +// *out2 = (15, 11, 7, 3, 14, 10, 6, 2) +// +// Given output: 15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0 +inline void ComplexLoadInvInterleavedT2(const double* arg, __m512d* out1, + __m512d* out2) { + const __m512d* arg_512 = reinterpret_cast(arg); + + // 14 10 6 2 12 8 4 0 + __m512d v1 = _mm512_loadu_pd(arg_512); + arg_512 += 2; + // 15 11 7 3 13 9 5 1 + __m512d v2 = _mm512_loadu_pd(arg_512); + + const __m512i v1_perm_idx = _mm512_set_epi64(3, 2, 1, 0, 7, 6, 5, 4); + + // 12 8 4 0 14 10 6 2 + __m512d v1_perm = _mm512_permutexvar_pd(v1_perm_idx, v1); + // 13 9 5 1 15 11 7 3 + __m512d v2_perm = _mm512_permutexvar_pd(v1_perm_idx, v2); + + // 11110000 > 13 9 5 1 12 8 4 0 + *out1 = _mm512_mask_blend_pd(0xf0, v1, v2_perm); + // 11110000 > 15 11 7 3 14 10 6 2 + *out2 = _mm512_mask_blend_pd(0xf0, v1_perm, v2); +} + +// ************************************ T4 ************************************ + +// Complex LoadFwdInterleavedT4: +// Given input: 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0 +// Returns +// *out1 = (11, 10, 9, 8, 3, 2, 1, 0) +// *out2 = (15, 14, 13, 12, 7, 6, 5, 4) +// +// Given output: 15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0 +inline void ComplexLoadFwdInterleavedT4(const double* arg, __m512d* out1, + __m512d* out2) { + const __m512d* arg_512 = reinterpret_cast(arg); + + const __m512i vperm2_idx = _mm512_set_epi64(3, 2, 1, 0, 7, 6, 5, 4); + __m512d v_7to0 = _mm512_loadu_pd(arg_512); + arg_512 += 2; + __m512d v_15to8 = _mm512_loadu_pd(arg_512); + __m512d perm_hi = _mm512_permutexvar_pd(vperm2_idx, v_15to8); + *out1 = _mm512_mask_blend_pd(0x0f, perm_hi, v_7to0); + *out2 = _mm512_mask_blend_pd(0xf0, perm_hi, v_7to0); + *out2 = _mm512_permutexvar_pd(vperm2_idx, *out2); +} + +// ComplexLoadInvInterleavedT4: +// Assumes ComplexLoadInvInterleavedT2 was used before. +// Given input: 15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0 +// Returns +// *out1 = (11, 9, 3, 1, 10, 8, 2, 0) +// *out2 = (15, 13, 7, 5, 14, 12, 6, 4) +// +// Given output: 15, 13, 7, 5, 14, 12, 6, 4, 11, 9, 3, 1, 10, 8, 2, 0 + +inline void ComplexLoadInvInterleavedT4(const double* arg, __m512d* out1, + __m512d* out2) { + const __m512d* arg_512 = reinterpret_cast(arg); + + // 13, 9, 5, 1, 12, 8, 4, 0 + __m512d v1 = _mm512_loadu_pd(arg_512); + arg_512 += 2; + // 15, 11, 7, 3, 14, 10, 6, 2 + __m512d v2 = _mm512_loadu_pd(arg_512); + + // 00000000 > 11 9 3 1 10 8 2 0 + *out1 = _mm512_shuffle_pd(v1, v2, 0x00); + // 11111111 > 15 13 7 5 14 12 6 4 + *out2 = _mm512_shuffle_pd(v1, v2, 0xff); +} + +// ComplexWriteInvInterleavedT4: +// Assuming ComplexLoadInvInterleavedT4 was used before. +// Given inputs: 15, 13, 7, 5, 14, 12, 6, 4, 11, 9, 3, 1, 10, 8, 2, 0 +// Seen Internally: +// @param arg1 = ( 7, 6, 5, 4, 3, 2, 1, 0); +// @param arg2 = (15, 14, 13, 12, 11, 10, 9, 8); +// Writes out = {15, 11, 14, 10, 7, 3, 6, 2, +// 13, 9, 12, 8, 5, 1, 4, 0} +// +// Given output: 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0 +inline void ComplexWriteInvInterleavedT4(__m512d arg1, __m512d arg2, + __m512d* out) { + const __m512i vperm_4hi_4lo_idx = _mm512_set_epi64(3, 2, 1, 0, 7, 6, 5, 4); + const __m512i vperm1 = _mm512_set_epi64(7, 3, 6, 2, 5, 1, 4, 0); + const __m512i vperm2 = _mm512_set_epi64(5, 1, 4, 0, 7, 3, 6, 2); + + // in: 11 9 3 1 10 8 2 0 + // -> 11 10 9 8 3 2 1 0 + arg1 = _mm512_permutexvar_pd(vperm1, arg1); + // in: 15 13 7 5 14 12 6 4 + // -> 7 6 5 4 15 14 13 12 + arg2 = _mm512_permutexvar_pd(vperm2, arg2); + + // 7 6 5 4 3 2 1 0 + __m512d out1 = _mm512_mask_blend_pd(0xf0, arg1, arg2); + // 11 10 9 8 15 14 13 12 + __m512d out2 = _mm512_mask_blend_pd(0x0f, arg1, arg2); + // 15 14 13 12 11 10 9 8 + out2 = _mm512_permutexvar_pd(vperm_4hi_4lo_idx, out2); + + _mm512_storeu_pd(out, out1); + out += 2; + _mm512_storeu_pd(out, out2); +} + +// ************************************ T8 ************************************ + +// ComplexLoadFwdInterleavedT8: +// Given inputs: 7i, 7r, 6i, 6r, 5i, 5r, 4i, 4r, 3i, 3r, 2i, 2r, 1i, 1r, 0i, 0r +// Seen Internally: +// v_X1 = ( 7, 6, 5, 4, 3, 2, 1, 0); +// v_X2 = (15, 14, 13, 12, 11, 10, 9, 8); +// Writes out = {15, 13, 11, 9, 7, 5, 3, 1, +// 14, 12, 10, 8, 6, 4, 2, 0} +// +// Given output: 7i, 6i, 5i, 4i, 3i, 2i, 1i, 0i, 7r, 6r, 5r, 4r, 3r, 2r, 1r, 0r +inline void ComplexLoadFwdInterleavedT8(const __m512d* arg_x, + const __m512d* arg_y, __m512d* out1_r, + __m512d* out1_i, __m512d* out2_r, + __m512d* out2_i) { + const __m512i v_perm_idx = _mm512_set_epi64(7, 5, 3, 1, 6, 4, 2, 0); + + // 3i, 3r, 2i, 2r, 1i, 1r, 0i, 0r + __m512d v_X1 = _mm512_loadu_pd(arg_x++); + // 7i, 7r, 6i, 6r, 5i, 5r, 4i, 4r + __m512d v_X2 = _mm512_loadu_pd(arg_x); + // 7r, 3r, 6r, 2r, 5r, 1r, 4r, 0r + *out1_r = _mm512_shuffle_pd(v_X1, v_X2, 0x00); + // 7i, 3i, 6i, 2i, 5i, 1i, 4i, 0i + *out1_i = _mm512_shuffle_pd(v_X1, v_X2, 0xff); + // 7r, 6r, 5r, 4r, 3r, 2r, 1r, 0r + *out1_r = _mm512_permutexvar_pd(v_perm_idx, *out1_r); + // 7i, 6i, 5i, 4i, 3i, 2i, 1i, 0i + *out1_i = _mm512_permutexvar_pd(v_perm_idx, *out1_i); + + __m512d v_Y1 = _mm512_loadu_pd(arg_y++); + __m512d v_Y2 = _mm512_loadu_pd(arg_y); + *out2_r = _mm512_shuffle_pd(v_Y1, v_Y2, 0x00); + *out2_i = _mm512_shuffle_pd(v_Y1, v_Y2, 0xff); + *out2_r = _mm512_permutexvar_pd(v_perm_idx, *out2_r); + *out2_i = _mm512_permutexvar_pd(v_perm_idx, *out2_i); +} + +// ComplexWriteInvInterleavedT8: +// Assuming ComplexLoadInvInterleavedT4 was used before. +// Given inputs: 7i, 6i, 5i, 4i, 3i, 2i, 1i, 0i, 7r, 6r, 5r, 4r, 3r, 2r, 1r, 0r +// Seen Internally: +// @param arg1 = ( 7, 6, 5, 4, 3, 2, 1, 0); +// @param arg2 = (15, 14, 13, 12, 11, 10, 9, 8); +// Writes out = {15, 7, 14, 6, 13, 5, 12, 4, +// 11, 3, 10, 2, 9, 1, 8, 0} +// +// Given output: 7i, 7r, 6i, 6r, 5i, 5r, 4i, 4r, 3i, 3r, 2i, 2r, 1i, 1r, 0i, 0r +inline void ComplexWriteInvInterleavedT8(__m512d* v_X_real, __m512d* v_X_imag, + __m512d* v_Y_real, __m512d* v_Y_imag, + __m512d* v_X_pt, __m512d* v_Y_pt) { + const __m512i vperm = _mm512_set_epi64(7, 3, 6, 2, 5, 1, 4, 0); + // in: 7r 6r 5r 4r 3r 2r 1r 0r + // -> 7r 3r 6r 2r 5r 1r 4r 0r + *v_X_real = _mm512_permutexvar_pd(vperm, *v_X_real); + // in: 7i 6i 5i 4i 3i 2i 1i 0i + // -> 7i 3i 6i 2i 5i 1i 4i 0i + *v_X_imag = _mm512_permutexvar_pd(vperm, *v_X_imag); + // in: 15r 14r 13r 12r 11r 10r 9r 8r + // -> 15r 11r 14r 10r 13r 9r 12r 8r + *v_Y_real = _mm512_permutexvar_pd(vperm, *v_Y_real); + // in: 15i 14i 13i 12i 11i 10i 9i 8i + // -> 15i 11i 14i 10i 13i 9i 12i 8i + *v_Y_imag = _mm512_permutexvar_pd(vperm, *v_Y_imag); + + // 00000000 > 3i 3r 2i 2r 1i 1r 0i 0r + __m512d v_X1 = _mm512_shuffle_pd(*v_X_real, *v_X_imag, 0x00); + // 11111111 > 7i 7r 6i 6r 5i 5r 4i 4r + __m512d v_X2 = _mm512_shuffle_pd(*v_X_real, *v_X_imag, 0xff); + // 00000000 > 11i 11r 10i 10r 9i 9r 8i 8r + __m512d v_Y1 = _mm512_shuffle_pd(*v_Y_real, *v_Y_imag, 0x00); + // 11111111 > 15i 15r 14i 14r 13i 13r 12i 12r + __m512d v_Y2 = _mm512_shuffle_pd(*v_Y_real, *v_Y_imag, 0xff); + + _mm512_storeu_pd(v_X_pt++, v_X1); + _mm512_storeu_pd(v_X_pt, v_X2); + _mm512_storeu_pd(v_Y_pt++, v_Y1); + _mm512_storeu_pd(v_Y_pt, v_Y2); +} +#endif // HEXL_HAS_AVX512DQ + +} // namespace hexl +} // namespace intel diff --git a/hexl_v0/include/hexl/experimental/fft-like/fft-like-native.hpp b/hexl_v0/include/hexl/experimental/fft-like/fft-like-native.hpp new file mode 100644 index 00000000..7e02492d --- /dev/null +++ b/hexl_v0/include/hexl/experimental/fft-like/fft-like-native.hpp @@ -0,0 +1,38 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +namespace intel { +namespace hexl { + +/// @brief Radix-2 native C++ FFT like implementation of the forward FFT like +/// @param[out] result Output data. Overwritten with FFT like output +/// @param[in] operand Input data. +/// @param[in] n Size of the transform, i.e. the polynomial degree. Must be a +/// power of two. +/// @param[in] root_of_unity_powers Powers of 2n'th root of unity. In +/// bit-reversed order +/// @param[in] scale Scale applied to output data +void Forward_FFTLike_ToBitReverseRadix2( + std::complex* result, const std::complex* operand, + const std::complex* root_of_unity_powers, const uint64_t n, + const double* scale = nullptr); + +/// @brief Radix-2 native C++ FFT like implementation of the inverse FFT like +/// @param[out] result Output data. Overwritten with FFT like output +/// @param[in] operand Input data. +/// @param[in] n Size of the transform, i.e. the polynomial degree. Must be a +/// power of two. +/// @param[in] inv_root_of_unity_powers Powers of inverse 2n'th root of unity. +/// In bit-reversed order. +/// @param[in] scale Scale applied to output data +void Inverse_FFTLike_FromBitReverseRadix2( + std::complex* result, const std::complex* operand, + const std::complex* inv_root_of_unity_powers, const uint64_t n, + const double* scale = nullptr); + +} // namespace hexl +} // namespace intel diff --git a/hexl_v0/include/hexl/experimental/fft-like/fft-like.hpp b/hexl_v0/include/hexl/experimental/fft-like/fft-like.hpp new file mode 100644 index 00000000..334de246 --- /dev/null +++ b/hexl_v0/include/hexl/experimental/fft-like/fft-like.hpp @@ -0,0 +1,147 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +#include "hexl/experimental/fft-like/fwd-fft-like-avx512.hpp" +#include "hexl/experimental/fft-like/inv-fft-like-avx512.hpp" +#include "hexl/util/aligned-allocator.hpp" +#include "hexl/util/allocator.hpp" + +namespace intel { +namespace hexl { + +/// @brief Performs linear forward and inverse FFT like transform +/// for CKKS encoding and decoding. +class FFTLike { + public: + /// @brief Helper class for custom memory allocation + template + struct AllocatorAdapter + : public AllocatorInterface> { + explicit AllocatorAdapter(Adaptee&& _a, Args&&... args); + AllocatorAdapter(const Adaptee& _a, Args&... args); + + // interface implementation + void* allocate_impl(size_t bytes_count); + void deallocate_impl(void* p, size_t n); + + private: + Adaptee alloc; + }; + + /// @brief Initializes an empty CKKS_FTT object + FFTLike() = default; + + /// @brief Destructs the CKKS_FTT object + ~FFTLike() = default; + + /// @brief Initializes an FFTLike object with degree \p degree and scalar + /// \p in_scalar. + /// @param[in] degree also known as N. Size of the FFT like transform. Must be + /// a power of 2 + /// @param[in] in_scalar Scalar value to calculate scale and inv scale + /// @param[in] alloc_ptr Custom memory allocator used for intermediate + /// calculations + /// @details Performs pre-computation necessary for forward and inverse + /// transforms + FFTLike(uint64_t degree, double* in_scalar, + std::shared_ptr alloc_ptr = {}); + + template + FFTLike(uint64_t degree, double* in_scalar, Allocator&& a, + AllocatorArgs&&... args) + : FFTLike( + degree, in_scalar, + std::static_pointer_cast( + std::make_shared>( + std::move(a), std::forward(args)...))) {} + + /// @brief Compute forward FFT like. Results are bit-reversed. + /// @param[out] result Stores the result + /// @param[in] operand Data on which to compute the FFT like + /// @param[in] in_scale Scale applied to output values + void ComputeForwardFFTLike(std::complex* result, + const std::complex* operand, + const double* in_scale = nullptr); + + /// @brief Compute inverse FFT like. Results are bit-reversed. + /// @param[out] result Stores the result + /// @param[in] operand Data on which to compute the FFT like + /// @param[in] in_scale Scale applied to output values + void ComputeInverseFFTLike(std::complex* result, + const std::complex* operand, + const double* in_scale = nullptr); + + /// @brief Construct floating-point values from CRT-composed polynomial with + /// integer coefficients. + /// @param[out] res Stores the result + /// @param[in] plain Plaintext + /// @param[in] threshold Upper half threshold with respect to the total + /// coefficient modulus + /// @param[in] decryption_modulus Product of all primes in the coefficient + /// modulus + /// @param[in] inv_scale Scale applied to output values + /// @param[in] mod_size Size of coefficient modulus parameter + /// @param[in] coeff_count Degree of the polynomial modulus parameter + void BuildFloatingPoints(std::complex* res, const uint64_t* plain, + const uint64_t* threshold, + const uint64_t* decryption_modulus, + const double inv_scale, size_t mod_size, + size_t coeff_count); + + /// @brief Returns the root of unity power at bit-reversed index i. + /// @param[in] i Index + std::complex GetComplexRootOfUnity(size_t i) { + return GetComplexRootsOfUnity()[i]; + } + + /// @brief Returns the root of unity in bit-reversed order + const AlignedVector64>& GetComplexRootsOfUnity() const { + return m_complex_roots_of_unity; + } + + /// @brief Returns the root of unity power at bit-reversed index i. + /// @param[in] i Index + std::complex GetInvComplexRootOfUnity(size_t i) { + return GetInvComplexRootsOfUnity()[i]; + } + + /// @brief Returns the inverse root of unity in bit-reversed order + const AlignedVector64>& GetInvComplexRootsOfUnity() + const { + return m_inv_complex_roots_of_unity; + } + + /// @brief Returns the degree N + uint64_t GetDegree() const { return m_degree; } + + private: + // Computes 1~(n-1)-th powers and inv powers of the primitive 2n-th root + void ComputeComplexRootsOfUnity(); + + uint64_t m_degree; // N: size of FFT like transform, should be power of 2 + + double* scalar; // Pointer to scalar used for scale/inv_scale calculation + + double scale; // Scale value use for encoding (inv fft-like) + + double inv_scale; // Scale value use in decoding (fwd fft-like) + + std::shared_ptr m_alloc; + + AlignedAllocator m_aligned_alloc; + + uint64_t m_degree_bits; // log_2(m_degree) + + // Contains 0~(n-1)-th powers of the 2n-th primitive root. + AlignedVector64> m_complex_roots_of_unity; + + // Contains 0~(n-1)-th inv powers of the 2n-th primitive inv root. + AlignedVector64> m_inv_complex_roots_of_unity; +}; + +} // namespace hexl +} // namespace intel diff --git a/hexl_v0/include/hexl/experimental/fft-like/fwd-fft-like-avx512.hpp b/hexl_v0/include/hexl/experimental/fft-like/fwd-fft-like-avx512.hpp new file mode 100644 index 00000000..aba4ca4d --- /dev/null +++ b/hexl_v0/include/hexl/experimental/fft-like/fwd-fft-like-avx512.hpp @@ -0,0 +1,59 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "hexl/experimental/fft-like/fft-like.hpp" + +namespace intel { +namespace hexl { + +#ifdef HEXL_HAS_AVX512DQ + +/// @brief AVX512 implementation of the forward FFT like +/// @param[out] result_cmplx_intrlvd Output data. Overwritten with FFT like +/// output. Result is a vector of double with interleaved real and imaginary +/// numbers. +/// @param[in] operand_cmplx_intrlvd Input data. A vector of double with +/// interleaved real and imaginary numbers. +/// @param[in] roots_of_unity_cmplx_intrlvd Powers of 2n'th root of unity. In +/// bit-reversed order. +/// @param[in] n Size of the transform, i.e. the polynomial degree. Must be a +/// power of two. +/// @param[in] scale Scale applied to output values +/// @param[in] recursion_depth Depth of recursive call +/// @param[in] recursion_half Helper for indexing roots of unity +/// @details The implementation is recursive. The base case is a breadth-first +/// FFT like, where all the butterflies in a given stage are processed before +/// any butterflies in the next stage. The base case is small enough to fit in +/// the smallest cache. Larger FFTs are processed recursively in a depth-first +/// manner, such that an entire subtransform is completed before moving to the +/// next subtransform. This reduces the number of cache misses, improving +/// performance on larger transform sizes. +void Forward_FFTLike_ToBitReverseAVX512( + double* result_cmplx_intrlvd, const double* operand_cmplx_intrlvd, + const double* roots_of_unity_cmplx_intrlvd, const uint64_t n, + const double* scale = nullptr, uint64_t recursion_depth = 0, + uint64_t recursion_half = 0); + +/// @brief Construct floating-point values from CRT-composed polynomial with +/// integer coefficients in AVX512. +/// @param[out] res_cmplx_intrlvd Stores the result +/// @param[in] plain Plaintext +/// @param[in] threshold Upper half threshold with respect to the total +/// coefficient modulus +/// @param[in] decryption_modulus Product of all primes in the coefficient +/// modulus +/// @param[in] inv_scale Scale applied to output values +/// @param[in] mod_size Size of coefficient modulus parameter +/// @param[in] coeff_count Degree of the polynomial modulus parameter +void BuildFloatingPointsAVX512(double* res_cmplx_intrlvd, const uint64_t* plain, + const uint64_t* threshold, + const uint64_t* decryption_modulus, + const double inv_scale, const size_t mod_size, + const size_t coeff_count); + +#endif // HEXL_HAS_AVX512DQ + +} // namespace hexl +} // namespace intel diff --git a/hexl_v0/include/hexl/experimental/fft-like/inv-fft-like-avx512.hpp b/hexl_v0/include/hexl/experimental/fft-like/inv-fft-like-avx512.hpp new file mode 100644 index 00000000..487e2828 --- /dev/null +++ b/hexl_v0/include/hexl/experimental/fft-like/inv-fft-like-avx512.hpp @@ -0,0 +1,42 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "hexl/experimental/fft-like/fft-like.hpp" + +namespace intel { +namespace hexl { + +#ifdef HEXL_HAS_AVX512DQ + +/// @brief AVX512 implementation of the inverse FFT like +/// @param[out] result_cmplx_intrlvd Output data. Overwritten with FFT like +/// output. Result is a vector of double with interleaved real and imaginary +/// numbers. +/// @param[in] operand_cmplx_intrlvd Input data. A vector of double with +/// interleaved real and imaginary numbers. +/// @param[in] inv_roots_of_unity_cmplx_intrlvd Powers of 2n'th root of unity. +/// In bit-reversed order. +/// @param[in] n Size of the transform, i.e. the polynomial degree. Must be a +/// power of two. +/// @param[in] scale Scale applied to output values +/// @param[in] recursion_depth Depth of recursive call +/// @param[in] recursion_half Helper for indexing roots of unity +/// @details The implementation is recursive. The base case is a breadth-first +/// FFT like, where all the butterflies in a given stage are processed before +/// any butterflies in the next stage. The base case is small enough to fit in +/// the smallest cache. Larger FFTs are processed recursively in a depth-first +/// manner, such that an entire subtransform is completed before moving to the +/// next subtransform. This reduces the number of cache misses, improving +/// performance on larger transform sizes. +void Inverse_FFTLike_FromBitReverseAVX512( + double* result_cmplx_intrlvd, const double* operand_cmplx_intrlvd, + const double* inv_root_of_unity_cmplxintrlvd, const uint64_t n, + const double* scale = nullptr, uint64_t recursion_depth = 0, + uint64_t recursion_half = 0); + +#endif // HEXL_HAS_AVX512DQ + +} // namespace hexl +} // namespace intel diff --git a/hexl_v0/include/hexl/experimental/misc/lr-mat-vec-mult.hpp b/hexl_v0/include/hexl/experimental/misc/lr-mat-vec-mult.hpp new file mode 100644 index 00000000..df03df92 --- /dev/null +++ b/hexl_v0/include/hexl/experimental/misc/lr-mat-vec-mult.hpp @@ -0,0 +1,31 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +namespace intel { +namespace hexl { + +/// @brief Computes transposed linear regression +/// @param[in,out] result Ciphertext data. Will be over-written with result. Has +/// (3 * n * num_moduli) elements +/// @param[in] operand1 Vector of ciphertext representing a matrix that encodes +/// a transposed logistic regression model. Has (num_weights * 2 * n * +/// num_moduli) elements. +/// @param[in] operand2 Vector of ciphertext representing a matrix that encodes +/// at most n/2 input samples with feature size num_weights. Has (num_weights * +/// 2 * n * num_moduli) elements. +/// @param[in] n Number of coefficients in each polynomial +/// @param[in] moduli Pointer to contiguous array of num_moduli word-sized +/// coefficient moduli +/// @param[in] num_moduli Number of word-sized coefficient moduli +/// @param[in] num_weights Feature size of the linear/logistic regression model +void LinRegMatrixVectorMultiply(uint64_t* result, const uint64_t* operand1, + const uint64_t* operand2, uint64_t n, + const uint64_t* moduli, uint64_t num_moduli, + uint64_t num_weights); + +} // namespace hexl +} // namespace intel diff --git a/hexl_v0/include/hexl/experimental/seal/dyadic-multiply-internal.hpp b/hexl_v0/include/hexl/experimental/seal/dyadic-multiply-internal.hpp new file mode 100644 index 00000000..310a46b0 --- /dev/null +++ b/hexl_v0/include/hexl/experimental/seal/dyadic-multiply-internal.hpp @@ -0,0 +1,29 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +namespace intel { +namespace hexl { +namespace internal { + +/// @brief Computes dyadic multiplication +/// @param[in,out] result Ciphertext data. Will be over-written with result. Has +/// (2 * n * num_moduli) elements +/// @param[in] operand1 First ciphertext argument. Has (2 * n * num_moduli) +/// elements. +/// @param[in] operand2 Second ciphertext argument. Has (2 * n * num_moduli) +/// elements. +/// @param[in] n Number of coefficients in each polynomial +/// @param[in] moduli Pointer to contiguous array of num_moduli word-sized +/// coefficient moduli +/// @param[in] num_moduli Number of word-sized coefficient moduli +void DyadicMultiply(uint64_t* result, const uint64_t* operand1, + const uint64_t* operand2, uint64_t n, + const uint64_t* moduli, uint64_t num_moduli); + +} // namespace internal +} // namespace hexl +} // namespace intel diff --git a/hexl_v0/include/hexl/experimental/seal/dyadic-multiply.hpp b/hexl_v0/include/hexl/experimental/seal/dyadic-multiply.hpp new file mode 100644 index 00000000..f7eacfdf --- /dev/null +++ b/hexl_v0/include/hexl/experimental/seal/dyadic-multiply.hpp @@ -0,0 +1,27 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +namespace intel { +namespace hexl { + +/// @brief Computes dyadic multiplication +/// @param[in,out] result Ciphertext data. Will be over-written with result. Has +/// (2 * n * num_moduli) elements +/// @param[in] operand1 First ciphertext argument. Has (2 * n * num_moduli) +/// elements. +/// @param[in] operand2 Second ciphertext argument. Has (2 * n * num_moduli) +/// elements. +/// @param[in] n Number of coefficients in each polynomial +/// @param[in] moduli Pointer to contiguous array of num_moduli word-sized +/// coefficient moduli +/// @param[in] num_moduli Number of word-sized coefficient moduli +void DyadicMultiply(uint64_t* result, const uint64_t* operand1, + const uint64_t* operand2, uint64_t n, + const uint64_t* moduli, uint64_t num_moduli); + +} // namespace hexl +} // namespace intel diff --git a/hexl_v0/include/hexl/experimental/seal/key-switch-internal.hpp b/hexl_v0/include/hexl/experimental/seal/key-switch-internal.hpp new file mode 100644 index 00000000..8fc9d53e --- /dev/null +++ b/hexl_v0/include/hexl/experimental/seal/key-switch-internal.hpp @@ -0,0 +1,44 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +namespace intel { +namespace hexl { +namespace internal { + +/// @brief Computes key switching in-place +/// @param[in,out] result Ciphertext data. Will be over-written with result. Has +/// (n * decomp_modulus_size * key_component_count) elements +/// @param[in] t_target_iter_ptr Pointer to the last component of the input +/// ciphertext +/// @param[in] n Number of coefficients in each polynomial +/// @param[in] decomp_modulus_size Number of moduli in the ciphertext at its +/// current level, excluding one auxiliary prime. +/// @param[in] key_modulus_size Number of moduli in the ciphertext at its top +/// level, including one auxiliary prime. +/// @param[in] rns_modulus_size Number of moduli in the ciphertext at its +/// current level, including one auxiliary prime. rns_modulus_size == +/// decomp_modulus_size + 1 +/// @param[in] key_component_count Number of components in the resulting +/// ciphertext, e.g. key_component_count == 2. +/// @param[in] moduli Array of word-sized coefficient moduli. There must be +/// key_modulus_size moduli in the array +/// @param[in] k_switch_keys Array of evaluation key data. Has +/// decomp_modulus_size entries, each with +/// coeff_count * ((key_modulus_size - 1)+ (key_component_count - 1) * +/// (key_modulus_size) + 1) entries +/// @param[in] modswitch_factors Array of modulus switch factors +/// @param[in] root_of_unity_powers_ptr Array of root of unity powers +void KeySwitch(uint64_t* result, const uint64_t* t_target_iter_ptr, uint64_t n, + uint64_t decomp_modulus_size, uint64_t key_modulus_size, + uint64_t rns_modulus_size, uint64_t key_component_count, + const uint64_t* moduli, const uint64_t** k_switch_keys, + const uint64_t* modswitch_factors, + const uint64_t* root_of_unity_powers_ptr = nullptr); + +} // namespace internal +} // namespace hexl +} // namespace intel diff --git a/hexl_v0/include/hexl/experimental/seal/key-switch.hpp b/hexl_v0/include/hexl/experimental/seal/key-switch.hpp new file mode 100644 index 00000000..9eda159c --- /dev/null +++ b/hexl_v0/include/hexl/experimental/seal/key-switch.hpp @@ -0,0 +1,42 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +namespace intel { +namespace hexl { + +/// @brief Computes key switching in-place +/// @param[in,out] result Ciphertext data. Will be over-written with result. Has +/// (n * decomp_modulus_size * key_component_count) elements +/// @param[in] t_target_iter_ptr Pointer to the last component of the input +/// ciphertext +/// @param[in] n Number of coefficients in each polynomial +/// @param[in] decomp_modulus_size Number of moduli in the ciphertext at its +/// current level, excluding one auxiliary prime. +/// @param[in] key_modulus_size Number of moduli in the ciphertext at its top +/// level, including one auxiliary prime. +/// @param[in] rns_modulus_size Number of moduli in the ciphertext at its +/// current level, including one auxiliary prime. rns_modulus_size == +/// decomp_modulus_size + 1 +/// @param[in] key_component_count Number of components in the resulting +/// ciphertext, e.g. key_component_count == 2. +/// @param[in] moduli Array of word-sized coefficient moduli. There must be +/// key_modulus_size moduli in the array +/// @param[in] k_switch_keys Array of evaluation key data. Has +/// decomp_modulus_size entries, each with +/// coeff_count * ((key_modulus_size - 1)+ (key_component_count - 1) * +/// (key_modulus_size) + 1) entries +/// @param[in] modswitch_factors Array of modulus switch factors +/// @param[in] root_of_unity_powers_ptr Array of root of unity powers +void KeySwitch(uint64_t* result, const uint64_t* t_target_iter_ptr, uint64_t n, + uint64_t decomp_modulus_size, uint64_t key_modulus_size, + uint64_t rns_modulus_size, uint64_t key_component_count, + const uint64_t* moduli, const uint64_t** k_switch_keys, + const uint64_t* modswitch_factors, + const uint64_t* root_of_unity_powers_ptr = nullptr); + +} // namespace hexl +} // namespace intel diff --git a/hexl_v0/include/hexl/experimental/seal/locks.hpp b/hexl_v0/include/hexl/experimental/seal/locks.hpp new file mode 100644 index 00000000..4595f4e5 --- /dev/null +++ b/hexl_v0/include/hexl/experimental/seal/locks.hpp @@ -0,0 +1,35 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include + +namespace intel { +namespace hexl { + +using Lock = std::shared_mutex; +using WriteLock = std::unique_lock; +using ReadLock = std::shared_lock; + +class RWLock { + public: + RWLock() = default; + inline ReadLock AcquireRead() { return ReadLock(rw_mutex); } + inline WriteLock AcquireWrite() { return WriteLock(rw_mutex); } + inline ReadLock TryAcquireRead() noexcept { + return ReadLock(rw_mutex, std::try_to_lock); + } + inline WriteLock TryAcquireWrite() noexcept { + return WriteLock(rw_mutex, std::try_to_lock); + } + + private: + RWLock(const RWLock& copy) = delete; + RWLock& operator=(const RWLock& assign) = delete; + Lock rw_mutex{}; +}; + +} // namespace hexl +} // namespace intel diff --git a/hexl_v0/include/hexl/experimental/seal/ntt-cache.hpp b/hexl_v0/include/hexl/experimental/seal/ntt-cache.hpp new file mode 100644 index 00000000..8f6c1046 --- /dev/null +++ b/hexl_v0/include/hexl/experimental/seal/ntt-cache.hpp @@ -0,0 +1,56 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "hexl/experimental/seal/locks.hpp" +#include "ntt/ntt-internal.hpp" + +namespace intel { +namespace hexl { + +struct HashPair { + template + std::size_t operator()(const std::pair& p) const { + auto hash1 = std::hash{}(p.first); + auto hash2 = std::hash{}(p.second); + return hash_combine(hash1, hash2); + } + + // Golden Ratio Hashing with seeds + static std::size_t hash_combine(std::size_t lhs, std::size_t rhs) { + lhs ^= rhs + 0x9e3779b9 + (lhs << 6) + (lhs >> 2); + return lhs; + } +}; + +NTT& GetNTT(size_t N, uint64_t modulus) { + static std::unordered_map, NTT, HashPair> + ntt_cache; + static RWLock ntt_cache_locker; + + std::pair key{N, modulus}; + + // Enable shared access to NTT already present + { + ReadLock reader_lock(ntt_cache_locker.AcquireRead()); + auto ntt_it = ntt_cache.find(key); + if (ntt_it != ntt_cache.end()) { + return ntt_it->second; + } + } + + // Deal with NTT not yet present + WriteLock write_lock(ntt_cache_locker.AcquireWrite()); + + // Check ntt_cache for value (may be added by another thread) + auto ntt_it = ntt_cache.find(key); + if (ntt_it == ntt_cache.end()) { + NTT ntt(N, modulus); + ntt_it = ntt_cache.emplace(std::move(key), std::move(ntt)).first; + } + return ntt_it->second; +} + +} // namespace hexl +} // namespace intel diff --git a/hexl_v0/include/hexl/hexl.hpp b/hexl_v0/include/hexl/hexl.hpp new file mode 100644 index 00000000..6f07ae57 --- /dev/null +++ b/hexl_v0/include/hexl/hexl.hpp @@ -0,0 +1,26 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "hexl/eltwise/eltwise-add-mod.hpp" +#include "hexl/eltwise/eltwise-cmp-add.hpp" +#include "hexl/eltwise/eltwise-cmp-sub-mod.hpp" +#include "hexl/eltwise/eltwise-fma-mod.hpp" +#include "hexl/eltwise/eltwise-mult-mod.hpp" +#include "hexl/eltwise/eltwise-reduce-mod.hpp" +#include "hexl/eltwise/eltwise-sub-mod.hpp" +#include "hexl/experimental/fft-like/fft-like.hpp" +#include "hexl/experimental/misc/lr-mat-vec-mult.hpp" +#include "hexl/experimental/seal/dyadic-multiply-internal.hpp" +#include "hexl/experimental/seal/dyadic-multiply.hpp" +#include "hexl/experimental/seal/key-switch-internal.hpp" +#include "hexl/experimental/seal/key-switch.hpp" +#include "hexl/logging/logging.hpp" +#include "hexl/ntt/ntt.hpp" +#include "hexl/number-theory/number-theory.hpp" +#include "hexl/util/check.hpp" +#include "hexl/util/compiler.hpp" +#include "hexl/util/defines.hpp" +#include "hexl/util/types.hpp" +#include "hexl/util/util.hpp" diff --git a/hexl_v0/include/hexl/logging/logging.hpp b/hexl_v0/include/hexl/logging/logging.hpp new file mode 100644 index 00000000..af5bfcd8 --- /dev/null +++ b/hexl_v0/include/hexl/logging/logging.hpp @@ -0,0 +1,43 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include + +#include "hexl/util/defines.hpp" + +// Wrap HEXL_VLOG with HEXL_DEBUG; this ensures no logging overhead in +// release mode +#ifdef HEXL_DEBUG + +// TODO(fboemer) Enable if needed +// #define ELPP_THREAD_SAFE +#define ELPP_CUSTOM_COUT std::cerr +#define ELPP_STL_LOGGING +#define ELPP_LOG_STD_ARRAY +#define ELPP_LOG_UNORDERED_MAP +#define ELPP_LOG_UNORDERED_SET +#define ELPP_NO_LOG_TO_FILE +#define ELPP_DISABLE_DEFAULT_CRASH_HANDLING +#define ELPP_WINSOCK2 + +#include + +#define HEXL_VLOG(N, rest) \ + do { \ + if (VLOG_IS_ON(N)) { \ + VLOG(N) << rest; \ + } \ + } while (0); + +#else + +#define HEXL_VLOG(N, rest) \ + {} + +#define START_EASYLOGGINGPP(X, Y) \ + {} + +#endif diff --git a/hexl_v0/include/hexl/ntt/ntt.hpp b/hexl_v0/include/hexl/ntt/ntt.hpp new file mode 100644 index 00000000..93ccba72 --- /dev/null +++ b/hexl_v0/include/hexl/ntt/ntt.hpp @@ -0,0 +1,296 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +#include +#include + +#include "hexl/util/aligned-allocator.hpp" +#include "hexl/util/allocator.hpp" + +namespace intel { +namespace hexl { + +/// @brief Performs negacyclic forward and inverse number-theoretic transform +/// (NTT), commonly used in RLWE cryptography. +/// @details The number-theoretic transform (NTT) specializes the discrete +/// Fourier transform (DFT) to the finite field \f$ \mathbb{Z}_q[X] / (X^N + 1) +/// \f$. +class NTT { + public: + /// @brief Helper class for custom memory allocation + template + struct AllocatorAdapter + : public AllocatorInterface> { + explicit AllocatorAdapter(Adaptee&& _a, Args&&... args); + AllocatorAdapter(const Adaptee& _a, Args&... args); + + // interface implementation + void* allocate_impl(size_t bytes_count); + void deallocate_impl(void* p, size_t n); + + private: + Adaptee alloc; + }; + + /// @brief Initializes an empty NTT object + NTT() = default; + + /// @brief Destructs the NTT object + ~NTT() = default; + + /// @brief Initializes an NTT object with degree \p degree and modulus \p q. + /// @param[in] degree also known as N. Size of the NTT transform. Must be a + /// power of + /// 2 + /// @param[in] q Prime modulus. Must satisfy \f$ q == 1 \mod 2N \f$ + /// @param[in] alloc_ptr Custom memory allocator used for intermediate + /// calculations + /// @brief Performs pre-computation necessary for forward and inverse + /// transforms + NTT(uint64_t degree, uint64_t q, + std::shared_ptr alloc_ptr = {}); + + template + NTT(uint64_t degree, uint64_t q, Allocator&& a, AllocatorArgs&&... args) + : NTT(degree, q, + std::static_pointer_cast( + std::make_shared>( + std::move(a), std::forward(args)...))) {} + + /// @brief Initializes an NTT object with degree \p degree and modulus + /// \p q. + /// @param[in] degree also known as N. Size of the NTT transform. Must be a + /// power of + /// 2 + /// @param[in] q Prime modulus. Must satisfy \f$ q == 1 \mod 2N \f$ + /// @param[in] root_of_unity 2N'th root of unity in \f$ \mathbb{Z_q} \f$. + /// @param[in] alloc_ptr Custom memory allocator used for intermediate + /// calculations + /// @details Performs pre-computation necessary for forward and inverse + /// transforms + NTT(uint64_t degree, uint64_t q, uint64_t root_of_unity, + std::shared_ptr alloc_ptr = {}); + + template + NTT(uint64_t degree, uint64_t q, uint64_t root_of_unity, Allocator&& a, + AllocatorArgs&&... args) + : NTT(degree, q, root_of_unity, + std::static_pointer_cast( + std::make_shared>( + std::move(a), std::forward(args)...))) {} + + /// @brief Returns true if arguments satisfy constraints for negacyclic NTT + /// @param[in] degree N. Size of the transform, i.e. the polynomial degree. + /// Must be a power of two. + /// @param[in] modulus Prime modulus q. Must satisfy q mod 2N = 1 + static bool CheckArguments(uint64_t degree, uint64_t modulus); + + /// @brief Compute forward NTT. Results are bit-reversed. + /// @param[out] result Stores the result + /// @param[in] operand Data on which to compute the NTT + /// @param[in] input_mod_factor Assume input \p operand are in [0, + /// input_mod_factor * q). Must be 1, 2 or 4. + /// @param[in] output_mod_factor Returns output \p result in [0, + /// output_mod_factor * q). Must be 1 or 4. + void ComputeForward(uint64_t* result, const uint64_t* operand, + uint64_t input_mod_factor, uint64_t output_mod_factor); + + /// Compute inverse NTT. Results are bit-reversed. + /// @param[out] result Stores the result + /// @param[in] operand Data on which to compute the NTT + /// @param[in] input_mod_factor Assume input \p operand are in [0, + /// input_mod_factor * q). Must be 1 or 2. + /// @param[in] output_mod_factor Returns output \p result in [0, + /// output_mod_factor * q). Must be 1 or 2. + void ComputeInverse(uint64_t* result, const uint64_t* operand, + uint64_t input_mod_factor, uint64_t output_mod_factor); + + /// @brief Returns the minimal 2N'th root of unity + uint64_t GetMinimalRootOfUnity() const { return m_w; } + + /// @brief Returns the degree N + uint64_t GetDegree() const { return m_degree; } + + /// @brief Returns the word-sized prime modulus + uint64_t GetModulus() const { return m_q; } + + /// @brief Returns the root of unity powers in bit-reversed order + const AlignedVector64& GetRootOfUnityPowers() const { + return m_root_of_unity_powers; + } + + /// @brief Returns the root of unity power at bit-reversed index i. + uint64_t GetRootOfUnityPower(size_t i) { return GetRootOfUnityPowers()[i]; } + + /// @brief Returns 32-bit pre-conditioned root of unity powers in + /// bit-reversed order + const AlignedVector64& GetPrecon32RootOfUnityPowers() const { + return m_precon32_root_of_unity_powers; + } + + /// @brief Returns 64-bit pre-conditioned root of unity powers in + /// bit-reversed order + const AlignedVector64& GetPrecon64RootOfUnityPowers() const { + return m_precon64_root_of_unity_powers; + } + + /// @brief Returns the root of unity powers in bit-reversed order with + /// modifications for use by AVX512 implementation + const AlignedVector64& GetAVX512RootOfUnityPowers() const { + return m_avx512_root_of_unity_powers; + } + + /// @brief Returns 32-bit pre-conditioned AVX512 root of unity powers in + /// bit-reversed order + const AlignedVector64& GetAVX512Precon32RootOfUnityPowers() const { + return m_avx512_precon32_root_of_unity_powers; + } + + /// @brief Returns 52-bit pre-conditioned AVX512 root of unity powers in + /// bit-reversed order + const AlignedVector64& GetAVX512Precon52RootOfUnityPowers() const { + return m_avx512_precon52_root_of_unity_powers; + } + + /// @brief Returns 64-bit pre-conditioned AVX512 root of unity powers in + /// bit-reversed order + const AlignedVector64& GetAVX512Precon64RootOfUnityPowers() const { + return m_avx512_precon64_root_of_unity_powers; + } + + /// @brief Returns the inverse root of unity powers in bit-reversed order + const AlignedVector64& GetInvRootOfUnityPowers() const { + return m_inv_root_of_unity_powers; + } + + /// @brief Returns the inverse root of unity power at bit-reversed index i. + uint64_t GetInvRootOfUnityPower(size_t i) { + return GetInvRootOfUnityPowers()[i]; + } + + /// @brief Returns the vector of 32-bit pre-conditioned pre-computed root of + /// unity + // powers for the modulus and root of unity. + const AlignedVector64& GetPrecon32InvRootOfUnityPowers() const { + return m_precon32_inv_root_of_unity_powers; + } + + /// @brief Returns the vector of 52-bit pre-conditioned pre-computed root of + /// unity + // powers for the modulus and root of unity. + const AlignedVector64& GetPrecon52InvRootOfUnityPowers() const { + return m_precon52_inv_root_of_unity_powers; + } + + /// @brief Returns the vector of 64-bit pre-conditioned pre-computed root of + /// unity + // powers for the modulus and root of unity. + const AlignedVector64& GetPrecon64InvRootOfUnityPowers() const { + return m_precon64_inv_root_of_unity_powers; + } + + /// @brief Maximum power of 2 in degree + static size_t MaxDegreeBits() { return 20; } + + /// @brief Maximum number of bits in modulus; + static size_t MaxModulusBits() { return 62; } + + /// @brief Default bit shift used in Barrett precomputation + static const size_t s_default_shift_bits{64}; + + /// @brief Bit shift used in Barrett precomputation when AVX512-IFMA + /// acceleration is enabled + static const size_t s_ifma_shift_bits{52}; + + /// @brief Maximum modulus to use 32-bit AVX512-DQ acceleration for the + /// forward transform + static const size_t s_max_fwd_32_modulus{1ULL << (32 - 2)}; + + /// @brief Maximum modulus to use 32-bit AVX512-DQ acceleration for the + /// inverse transform + static const size_t s_max_inv_32_modulus{1ULL << (32 - 2)}; + + /// @brief Maximum modulus to use AVX512-IFMA acceleration for the forward + /// transform + static const size_t s_max_fwd_ifma_modulus{1ULL << (s_ifma_shift_bits - 2)}; + + /// @brief Maximum modulus to use AVX512-IFMA acceleration for the inverse + /// transform + static const size_t s_max_inv_ifma_modulus{1ULL << (s_ifma_shift_bits - 2)}; + + /// @brief Maximum modulus to use AVX512-DQ acceleration for the inverse + /// transform + static const size_t s_max_inv_dq_modulus{1ULL << (s_default_shift_bits - 2)}; + + static size_t s_max_fwd_modulus(int bit_shift) { + if (bit_shift == 32) { + return s_max_fwd_32_modulus; + } else if (bit_shift == 52) { + return s_max_fwd_ifma_modulus; + } else if (bit_shift == 64) { + return 1ULL << MaxModulusBits(); + } + HEXL_CHECK(false, "Invalid bit_shift " << bit_shift); + return 0; + } + + static size_t s_max_inv_modulus(int bit_shift) { + if (bit_shift == 32) { + return s_max_inv_32_modulus; + } else if (bit_shift == 52) { + return s_max_inv_ifma_modulus; + } else if (bit_shift == 64) { + return 1ULL << MaxModulusBits(); + } + HEXL_CHECK(false, "Invalid bit_shift " << bit_shift); + return 0; + } + + private: + void ComputeRootOfUnityPowers(); + + uint64_t m_degree; // N: size of NTT transform, should be power of 2 + uint64_t m_q; // prime modulus. Must satisfy q == 1 mod 2n + + uint64_t m_degree_bits; // log_2(m_degree) + + uint64_t m_w_inv; // Inverse of minimal root of unity + uint64_t m_w; // A 2N'th root of unity + + std::shared_ptr m_alloc; + + AlignedAllocator m_aligned_alloc; + + // powers of the minimal root of unity + AlignedVector64 m_root_of_unity_powers; + // vector of floor(W * 2**32 / m_q), with W the root of unity powers + AlignedVector64 m_precon32_root_of_unity_powers; + // vector of floor(W * 2**64 / m_q), with W the root of unity powers + AlignedVector64 m_precon64_root_of_unity_powers; + + // powers of the minimal root of unity adjusted for use in AVX512 + // implementations + AlignedVector64 m_avx512_root_of_unity_powers; + // vector of floor(W * 2**32 / m_q), with W the AVX512 root of unity powers + AlignedVector64 m_avx512_precon32_root_of_unity_powers; + // vector of floor(W * 2**52 / m_q), with W the AVX512 root of unity powers + AlignedVector64 m_avx512_precon52_root_of_unity_powers; + // vector of floor(W * 2**64 / m_q), with W the AVX512 root of unity powers + AlignedVector64 m_avx512_precon64_root_of_unity_powers; + + // vector of floor(W * 2**32 / m_q), with W the inverse root of unity powers + AlignedVector64 m_precon32_inv_root_of_unity_powers; + // vector of floor(W * 2**52 / m_q), with W the inverse root of unity powers + AlignedVector64 m_precon52_inv_root_of_unity_powers; + // vector of floor(W * 2**64 / m_q), with W the inverse root of unity powers + AlignedVector64 m_precon64_inv_root_of_unity_powers; + + AlignedVector64 m_inv_root_of_unity_powers; +}; + +} // namespace hexl +} // namespace intel diff --git a/hexl_v0/include/hexl/number-theory/number-theory.hpp b/hexl_v0/include/hexl/number-theory/number-theory.hpp new file mode 100644 index 00000000..da8d1d2a --- /dev/null +++ b/hexl_v0/include/hexl/number-theory/number-theory.hpp @@ -0,0 +1,342 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +#include +#include + +#include "hexl/util/check.hpp" +#include "hexl/util/compiler.hpp" + +namespace intel { +namespace hexl { + +/// @brief Pre-computes a Barrett factor with which modular multiplication can +/// be performed more efficiently +class MultiplyFactor { + public: + MultiplyFactor() = default; + + /// @brief Computes and stores the Barrett factor floor((operand << bit_shift) + /// / modulus). This is useful when modular multiplication of the form + /// (x * operand) mod modulus is performed with same modulus and operand + /// several times. Note, passing operand=1 can be used to pre-compute a + /// Barrett factor for multiplications of the form (x * y) mod modulus, where + /// only the modulus is re-used across calls to modular multiplication. + MultiplyFactor(uint64_t operand, uint64_t bit_shift, uint64_t modulus) + : m_operand(operand) { + HEXL_CHECK(operand <= modulus, "operand " << operand + << " must be less than modulus " + << modulus); + HEXL_CHECK(bit_shift == 32 || bit_shift == 52 || bit_shift == 64, + "Unsupported BitShift " << bit_shift); + uint64_t op_hi = operand >> (64 - bit_shift); + uint64_t op_lo = (bit_shift == 64) ? 0 : (operand << bit_shift); + + m_barrett_factor = DivideUInt128UInt64Lo(op_hi, op_lo, modulus); + } + + /// @brief Returns the pre-computed Barrett factor + inline uint64_t BarrettFactor() const { return m_barrett_factor; } + + /// @brief Returns the operand corresponding to the Barrett factor + inline uint64_t Operand() const { return m_operand; } + + private: + uint64_t m_operand; + uint64_t m_barrett_factor; +}; + +/// @brief Returns whether or not num is a power of two +inline bool IsPowerOfTwo(uint64_t num) { return num && !(num & (num - 1)); } + +/// @brief Returns floor(log2(x)) +inline uint64_t Log2(uint64_t x) { return MSB(x); } + +inline bool IsPowerOfFour(uint64_t num) { + return IsPowerOfTwo(num) && (Log2(num) % 2 == 0); +} + +/// @brief Returns the maximum value that can be represented using \p bits bits +inline uint64_t MaximumValue(uint64_t bits) { + HEXL_CHECK(bits <= 64, "MaximumValue requires bits <= 64; got " << bits); + if (bits == 64) { + return (std::numeric_limits::max)(); + } + return (1ULL << bits) - 1; +} + +/// @brief Reverses the bits +/// @param[in] x Input to reverse +/// @param[in] bit_width Number of bits in the input; must be >= MSB(x) +/// @return The bit-reversed representation of \p x using \p bit_width bits +uint64_t ReverseBits(uint64_t x, uint64_t bit_width); + +/// @brief Returns x^{-1} mod modulus +/// @details Requires x % modulus != 0 +uint64_t InverseMod(uint64_t x, uint64_t modulus); + +/// @brief Returns (x * y) mod modulus +/// @details Assumes x, y < modulus +uint64_t MultiplyMod(uint64_t x, uint64_t y, uint64_t modulus); + +/// @brief Returns (x * y) mod modulus +/// @param[in] x +/// @param[in] y +/// @param[in] y_precon 64-bit precondition factor floor(2**64 / modulus) +/// @param[in] modulus +uint64_t MultiplyMod(uint64_t x, uint64_t y, uint64_t y_precon, + uint64_t modulus); + +/// @brief Returns (x + y) mod modulus +/// @details Assumes x, y < modulus +uint64_t AddUIntMod(uint64_t x, uint64_t y, uint64_t modulus); + +/// @brief Returns (x - y) mod modulus +/// @details Assumes x, y < modulus +uint64_t SubUIntMod(uint64_t x, uint64_t y, uint64_t modulus); + +/// @brief Returns base^exp mod modulus +uint64_t PowMod(uint64_t base, uint64_t exp, uint64_t modulus); + +/// @brief Returns whether or not root is a degree-th root of unity mod modulus +/// @param[in] root Root of unity to check +/// @param[in] degree Degree of root of unity; must be a power of two +/// @param[in] modulus Modulus of finite field +bool IsPrimitiveRoot(uint64_t root, uint64_t degree, uint64_t modulus); + +/// @brief Tries to return a primitive degree-th root of unity +/// @details Returns 0 or throws an error if no root is found +uint64_t GeneratePrimitiveRoot(uint64_t degree, uint64_t modulus); + +/// @brief Returns whether or not root is a degree-th root of unity +/// @param[in] degree Must be a power of two +/// @param[in] modulus Modulus of finite field +uint64_t MinimalPrimitiveRoot(uint64_t degree, uint64_t modulus); + +/// @brief Computes (x * y) mod modulus, except that the output is in [0, 2 * +/// modulus] +/// @param[in] x +/// @param[in] y_operand also denoted y +/// @param[in] modulus +/// @param[in] y_barrett_factor Pre-computed Barrett reduction factor floor((y +/// << BitShift) / modulus) +template +inline uint64_t MultiplyModLazy(uint64_t x, uint64_t y_operand, + uint64_t y_barrett_factor, uint64_t modulus) { + HEXL_CHECK(y_operand < modulus, "y_operand " << y_operand + << " must be less than modulus " + << modulus); + HEXL_CHECK( + modulus <= MaximumValue(BitShift), + "Modulus " << modulus << " exceeds bound " << MaximumValue(BitShift)); + HEXL_CHECK(x <= MaximumValue(BitShift), + "Operand " << x << " exceeds bound " << MaximumValue(BitShift)); + + uint64_t Q = MultiplyUInt64Hi(x, y_barrett_factor); + return y_operand * x - Q * modulus; +} + +/// @brief Computes (x * y) mod modulus, except that the output is in [0, 2 * +/// modulus] +/// @param[in] x +/// @param[in] y +/// @param[in] modulus +template +inline uint64_t MultiplyModLazy(uint64_t x, uint64_t y, uint64_t modulus) { + HEXL_CHECK(BitShift == 64 || BitShift == 52, + "Unsupported BitShift " << BitShift); + HEXL_CHECK(x <= MaximumValue(BitShift), + "Operand " << x << " exceeds bound " << MaximumValue(BitShift)); + HEXL_CHECK(y < modulus, + "y " << y << " must be less than modulus " << modulus); + HEXL_CHECK( + modulus <= MaximumValue(BitShift), + "Modulus " << modulus << " exceeds bound " << MaximumValue(BitShift)); + + uint64_t y_barrett = MultiplyFactor(y, BitShift, modulus).BarrettFactor(); + return MultiplyModLazy(x, y, y_barrett, modulus); +} + +/// @brief Adds two unsigned 64-bit integers +/// @param operand1 Number to add +/// @param operand2 Number to add +/// @param result Stores the sum +/// @return The carry bit +inline unsigned char AddUInt64(uint64_t operand1, uint64_t operand2, + uint64_t* result) { + *result = operand1 + operand2; + return static_cast(*result < operand1); +} + +/// @brief Returns whether or not the input is prime +bool IsPrime(uint64_t n); + +/// @brief Generates a list of num_primes primes in the range [2^(bit_size), +// 2^(bit_size+1)]. Ensures each prime q satisfies +// q % (2*ntt_size+1)) == 1 +/// @param[in] num_primes Number of primes to generate +/// @param[in] bit_size Bit size of each prime +/// @param[in] prefer_small_primes When true, returns primes starting from +/// 2^(bit_size); when false, returns primes starting from 2^(bit_size+1) +/// @param[in] ntt_size N such that each prime q satisfies q % (2N) == 1. N must +/// be a power of two less than 2^bit_size. +std::vector GeneratePrimes(size_t num_primes, size_t bit_size, + bool prefer_small_primes, + size_t ntt_size = 1); + +/// @brief Returns input mod modulus, computed via 64-bit Barrett reduction +/// @param[in] input +/// @param[in] modulus +/// @param[in] q_barr floor(2^64 / modulus) +template +uint64_t BarrettReduce64(uint64_t input, uint64_t modulus, uint64_t q_barr) { + HEXL_CHECK(modulus != 0, "modulus == 0"); + uint64_t q = MultiplyUInt64Hi<64>(input, q_barr); + uint64_t q_times_input = input - q * modulus; + if (OutputModFactor == 2) { + return q_times_input; + } else { + return (q_times_input >= modulus) ? q_times_input - modulus : q_times_input; + } +} + +/// @brief Returns x mod modulus, assuming x < InputModFactor * modulus +/// @param[in] x +/// @param[in] modulus also denoted q +/// @param[in] twice_modulus 2 * q; must not be nullptr if InputModFactor == 4 +/// or 8 +/// @param[in] four_times_modulus 4 * q; must not be nullptr if InputModFactor +/// == 8 +template +uint64_t ReduceMod(uint64_t x, uint64_t modulus, + const uint64_t* twice_modulus = nullptr, + const uint64_t* four_times_modulus = nullptr) { + HEXL_CHECK(InputModFactor == 1 || InputModFactor == 2 || + InputModFactor == 4 || InputModFactor == 8, + "InputModFactor should be 1, 2, 4, or 8"); + if (InputModFactor == 1) { + return x; + } + if (InputModFactor == 2) { + if (x >= modulus) { + x -= modulus; + } + return x; + } + if (InputModFactor == 4) { + HEXL_CHECK(twice_modulus != nullptr, "twice_modulus should not be nullptr"); + if (x >= *twice_modulus) { + x -= *twice_modulus; + } + if (x >= modulus) { + x -= modulus; + } + return x; + } + if (InputModFactor == 8) { + HEXL_CHECK(twice_modulus != nullptr, "twice_modulus should not be nullptr"); + HEXL_CHECK(four_times_modulus != nullptr, + "four_times_modulus should not be nullptr"); + + if (x >= *four_times_modulus) { + x -= *four_times_modulus; + } + if (x >= *twice_modulus) { + x -= *twice_modulus; + } + if (x >= modulus) { + x -= modulus; + } + return x; + } + HEXL_CHECK(false, "Should be unreachable"); + return x; +} + +/// @brief Returns Montgomery form of ab mod q, computed via the REDC algorithm, +/// also known as Montgomery reduction. +/// @param[in] r +/// @param[in] q with R = 2^r such that gcd(R, q) = 1. R > q. +/// @param[in] inv_mod in [0, R − 1] such that q*v_inv_mod ≡ −1 mod R. +/// @param[in] mod_R_msk take r last bits to apply mod R. +/// @param[in] T_hi of T = ab in the range [0, Rq − 1]. +/// @param[in] T_lo of T. +/// @return Unsigned long int in the range [0, q − 1] such that S ≡ TR^−1 mod q +template +inline uint64_t MontgomeryReduce(uint64_t T_hi, uint64_t T_lo, uint64_t q, + int r, uint64_t mod_R_msk, uint64_t inv_mod) { + HEXL_CHECK(BitShift == 64 || BitShift == 52, + "Unsupported BitShift " << BitShift); + HEXL_CHECK((1ULL << r) > static_cast(q), + "R value should be greater than q = " << static_cast(q)); + + uint64_t mq_hi; + uint64_t mq_lo; + + uint64_t m = ((T_lo & mod_R_msk) * inv_mod) & mod_R_msk; + MultiplyUInt64(m, q, &mq_hi, &mq_lo); + + if (BitShift == 52) { + mq_hi = (mq_hi << 12) | (mq_lo >> 52); + mq_lo &= (1ULL << 52) - 1; + } + + uint64_t t_hi; + uint64_t t_lo; + + // first 64bit block + t_lo = T_lo + mq_lo; + unsigned int carry = static_cast(t_lo < T_lo); + t_hi = T_hi + mq_hi + carry; + + t_hi = t_hi << (BitShift - r); + t_lo = t_lo >> r; + t_lo = t_hi + t_lo; + + return (t_lo >= q) ? (t_lo - q) : t_lo; +} + +/// @brief Hensel's Lemma for 2-adic numbers +/// Find solution for qX + 1 = 0 mod 2^r +/// @param[in] r +/// @param[in] q such that gcd(2, q) = 1 +/// @return Unsigned long int in [0, 2^r − 1] such that q*x ≡ −1 mod 2^r +inline uint64_t HenselLemma2adicRoot(uint32_t r, uint64_t q) { + uint64_t a_prev = 1; + uint64_t c = 2; + uint64_t mod_mask = 3; + + // Root: + // f(x) = qX + 1 and a_(0) = 1 then f(1) ≡ 0 mod 2 + // General Case: + // - a_(n) ≡ a_(n-1) mod 2^(n) + // => a_(n) = a_(n-1) + 2^(n)*t + // - Find 't' such that f(a_(n)) = 0 mod 2^(n+1) + // First case in for: + // - a_(1) ≡ 1 mod 2 or a_(1) = 1 + 2t + // - Find 't' so f(a_(1)) ≡ 0 mod 4 => q(1 + 2t) + 1 ≡ 0 mod 4 + for (uint64_t k = 2; k <= r; k++) { + uint64_t f = 0; + uint64_t t = 0; + uint64_t a = 0; + + do { + a = a_prev + c * t++; + f = q * a + 1ULL; + } while (f & mod_mask); // f(a) ≡ 0 mod 2^(k) + + // Update vars + mod_mask = mod_mask * 2 + 1ULL; + c *= 2; + a_prev = a; + } + + return a_prev; +} + +} // namespace hexl +} // namespace intel diff --git a/hexl_v0/include/hexl/util/aligned-allocator.hpp b/hexl_v0/include/hexl/util/aligned-allocator.hpp new file mode 100644 index 00000000..d175c734 --- /dev/null +++ b/hexl_v0/include/hexl/util/aligned-allocator.hpp @@ -0,0 +1,110 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include +#include + +#include "hexl/number-theory/number-theory.hpp" +#include "hexl/util/allocator.hpp" +#include "hexl/util/defines.hpp" + +namespace intel { +namespace hexl { + +/// @brief Allocater implementation using malloc and free +struct MallocStrategy : AllocatorBase { + void* allocate(size_t bytes_count) final { return std::malloc(bytes_count); } + + void deallocate(void* p, size_t n) final { + HEXL_UNUSED(n); + std::free(p); + } +}; + +using AllocatorStrategyPtr = std::shared_ptr; +extern AllocatorStrategyPtr mallocStrategy; + +/// @brief Allocates memory aligned to Alignment-byte sized boundaries +/// @details Alignment must be a power of two +template +class AlignedAllocator { + public: + template + friend class AlignedAllocator; + + using value_type = T; + + explicit AlignedAllocator(AllocatorStrategyPtr strategy = nullptr) noexcept + : m_alloc_impl((strategy != nullptr) ? strategy : mallocStrategy) {} + + AlignedAllocator(const AlignedAllocator& src) = default; + AlignedAllocator& operator=(const AlignedAllocator& src) = default; + + template + AlignedAllocator(const AlignedAllocator& src) + : m_alloc_impl(src.m_alloc_impl) {} + + ~AlignedAllocator() {} + + template + struct rebind { + using other = AlignedAllocator; + }; + + bool operator==(const AlignedAllocator&) { return true; } + + bool operator!=(const AlignedAllocator&) { return false; } + + /// @brief Allocates \p n elements aligned to Alignment-byte boundaries + /// @return Pointer to the aligned allocated memory + T* allocate(size_t n) { + if (!IsPowerOfTwo(Alignment)) { + return nullptr; + } + // Allocate enough space to ensure the alignment can be satisfied + size_t buffer_size = sizeof(T) * n + Alignment; + // Additionally, allocate a prefix to store the memory location of the + // unaligned buffer + size_t alloc_size = buffer_size + sizeof(void*); + void* buffer = m_alloc_impl->allocate(alloc_size); + if (!buffer) { + return nullptr; + } + + // Reserve first location for pointer to originally-allocated space + void* aligned_buffer = static_cast(buffer) + sizeof(void*); + std::align(Alignment, sizeof(T) * n, aligned_buffer, buffer_size); + if (!aligned_buffer) { + return nullptr; + } + + // Store allocated buffer address at aligned_buffer - sizeof(void*). + void* store_buffer_addr = + static_cast(aligned_buffer) - sizeof(void*); + *(static_cast(store_buffer_addr)) = buffer; + + return static_cast(aligned_buffer); + } + + void deallocate(T* p, size_t n) { + if (!p) { + return; + } + void* store_buffer_addr = (reinterpret_cast(p) - sizeof(void*)); + void* free_address = *(static_cast(store_buffer_addr)); + m_alloc_impl->deallocate(free_address, n); + } + + private: + AllocatorStrategyPtr m_alloc_impl; +}; + +/// @brief 64-byte aligned memory allocator +template +using AlignedVector64 = std::vector >; + +} // namespace hexl +} // namespace intel diff --git a/hexl_v0/include/hexl/util/allocator.hpp b/hexl_v0/include/hexl/util/allocator.hpp new file mode 100644 index 00000000..5f4a7a31 --- /dev/null +++ b/hexl_v0/include/hexl/util/allocator.hpp @@ -0,0 +1,53 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +namespace intel { +namespace hexl { + +/// @brief Base class for custom memory allocator +struct AllocatorBase { + virtual ~AllocatorBase() noexcept {} + + /// @brief Allocates byte_count bytes of memory + /// @param[in] bytes_count Number of bytes to allocate + /// @return A pointer to the allocated memory + virtual void* allocate(size_t bytes_count) = 0; + + /// @brief Deallocate memory + /// @param[in] p Pointer to memory to deallocate + /// @param[in] n Number of bytes to deallocate + virtual void deallocate(void* p, size_t n) = 0; +}; + +/// @brief Helper memory allocation struct which delegates implementation to +/// AllocatorImpl +template +struct AllocatorInterface : public AllocatorBase { + /// @brief Override interface and delegate implementation to AllocatorImpl + void* allocate(size_t bytes_count) override { + return static_cast(this)->allocate_impl(bytes_count); + } + + /// @brief Override interface and delegate implementation to AllocatorImpl + void deallocate(void* p, size_t n) override { + static_cast(this)->deallocate_impl(p, n); + } + + private: + // in case AllocatorImpl doesn't provide implementations, use default null + // behavior + void* allocate_impl(size_t bytes_count) { + HEXL_UNUSED(bytes_count); + return nullptr; + } + void deallocate_impl(void* p, size_t n) { + HEXL_UNUSED(p); + HEXL_UNUSED(n); + } +}; +} // namespace hexl +} // namespace intel diff --git a/hexl_v0/include/hexl/util/check.hpp b/hexl_v0/include/hexl/util/check.hpp new file mode 100644 index 00000000..386eba89 --- /dev/null +++ b/hexl_v0/include/hexl/util/check.hpp @@ -0,0 +1,44 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +#include "hexl/util/types.hpp" + +// Create logging/debug macros with no run-time overhead unless HEXL_DEBUG is +// enabled +#ifdef HEXL_DEBUG +#include "hexl/logging/logging.hpp" + +/// @brief If input condition is not true, logs the expression and throws an +/// error +/// @param[in] cond A boolean indication the condition +/// @param[in] expr The expression to be logged +#define HEXL_CHECK(cond, expr) \ + if (!(cond)) { \ + LOG(ERROR) << expr << " in function: " << __FUNCTION__ \ + << " in file: " __FILE__ << ":" << __LINE__; \ + throw std::runtime_error("Error. Check log output"); \ + } + +/// @brief If input has an element >= bound, logs the expression and throws an +/// error +/// @param[in] arg Input container which supports the [] operator. +/// @param[in] n Size of input +/// @param[in] bound Upper bound on the input +/// @param[in] expr The expression to be logged +#define HEXL_CHECK_BOUNDS(arg, n, bound, expr) \ + for (size_t hexl_check_idx = 0; hexl_check_idx < n; ++hexl_check_idx) { \ + HEXL_CHECK((arg)[hexl_check_idx] < bound, expr); \ + } + +#else // HEXL_DEBUG=OFF + +#define HEXL_CHECK(cond, expr) \ + {} +#define HEXL_CHECK_BOUNDS(...) \ + {} + +#endif // HEXL_DEBUG diff --git a/hexl_v0/include/hexl/util/clang.hpp b/hexl_v0/include/hexl/util/clang.hpp new file mode 100644 index 00000000..958bea7b --- /dev/null +++ b/hexl_v0/include/hexl/util/clang.hpp @@ -0,0 +1,67 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once +#include + +#include "hexl/util/check.hpp" +#include "hexl/util/types.hpp" + +namespace intel { +namespace hexl { + +#ifdef HEXL_USE_CLANG +// Return x * y as 128-bit integer +// Correctness if x * y < 128 bits +inline uint128_t MultiplyUInt64(uint64_t x, uint64_t y) { + return uint128_t(x) * uint128_t(y); +} + +inline uint64_t BarrettReduce128(uint64_t input_hi, uint64_t input_lo, + uint64_t modulus) { + HEXL_CHECK(modulus != 0, "modulus == 0") + uint128_t n = (static_cast(input_hi) << 64) | + (static_cast(input_lo)); + + return n % modulus; + // TODO(fboemer): actually use barrett reduction if performance-critical +} + +// Returns low 64bit of 128b/64b where x1=high 64b, x0=low 64b +inline uint64_t DivideUInt128UInt64Lo(uint64_t x1, uint64_t x0, uint64_t y) { + uint128_t n = + (static_cast(x1) << 64) | (static_cast(x0)); + uint128_t q = n / y; + + return static_cast(q); +} + +// Multiplies x * y as 128-bit integer. +// @param prod_hi Stores high 64 bits of product +// @param prod_lo Stores low 64 bits of product +inline void MultiplyUInt64(uint64_t x, uint64_t y, uint64_t* prod_hi, + uint64_t* prod_lo) { + uint128_t prod = MultiplyUInt64(x, y); + *prod_hi = static_cast(prod >> 64); + *prod_lo = static_cast(prod); +} + +// Return the high 128 minus BitShift bits of the 128-bit product x * y +template +inline uint64_t MultiplyUInt64Hi(uint64_t x, uint64_t y) { + uint128_t product = static_cast(x) * y; + return static_cast(product >> BitShift); +} + +// Returns most-significant bit of the input +inline uint64_t MSB(uint64_t input) { + return static_cast(std::log2l(input)); +} + +#define HEXL_LOOP_UNROLL_4 _Pragma("clang loop unroll_count(4)") +#define HEXL_LOOP_UNROLL_8 _Pragma("clang loop unroll_count(8)") + +#endif + +} // namespace hexl +} // namespace intel diff --git a/hexl_v0/include/hexl/util/compiler.hpp b/hexl_v0/include/hexl/util/compiler.hpp new file mode 100644 index 00000000..7dd077df --- /dev/null +++ b/hexl_v0/include/hexl/util/compiler.hpp @@ -0,0 +1,14 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "hexl/util/defines.hpp" + +#ifdef HEXL_USE_MSVC +#include "hexl/util/msvc.hpp" +#elif defined HEXL_USE_GNU +#include "hexl/util/gcc.hpp" +#elif defined HEXL_USE_CLANG +#include "hexl/util/clang.hpp" +#endif diff --git a/hexl_v0/include/hexl/util/defines.hpp b/hexl_v0/include/hexl/util/defines.hpp new file mode 100644 index 00000000..93db376e --- /dev/null +++ b/hexl_v0/include/hexl/util/defines.hpp @@ -0,0 +1,13 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +/* #undef HEXL_USE_MSVC */ +#define HEXL_USE_GNU +/* #undef HEXL_USE_CLANG */ + +/* #undef HEXL_DEBUG */ + +// Avoid unused variable warnings +#define HEXL_UNUSED(x) (void)(x) diff --git a/hexl_v0/include/hexl/util/gcc.hpp b/hexl_v0/include/hexl/util/gcc.hpp new file mode 100644 index 00000000..828e3836 --- /dev/null +++ b/hexl_v0/include/hexl/util/gcc.hpp @@ -0,0 +1,67 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once +#include + +#include "hexl/util/check.hpp" +#include "hexl/util/types.hpp" + +namespace intel { +namespace hexl { + +#ifdef HEXL_USE_GNU +// Return x * y as 128-bit integer +// Correctness if x * y < 128 bits +inline uint128_t MultiplyUInt64(uint64_t x, uint64_t y) { + return uint128_t(x) * uint128_t(y); +} + +inline uint64_t BarrettReduce128(uint64_t input_hi, uint64_t input_lo, + uint64_t modulus) { + HEXL_CHECK(modulus != 0, "modulus == 0") + uint128_t n = (static_cast(input_hi) << 64) | + (static_cast(input_lo)); + + return static_cast(n % modulus); + // TODO(fboemer): actually use barrett reduction if performance-critical +} + +// Returns low 64bit of 128b/64b where x1=high 64b, x0=low 64b +inline uint64_t DivideUInt128UInt64Lo(uint64_t x1, uint64_t x0, uint64_t y) { + uint128_t n = + (static_cast(x1) << 64) | (static_cast(x0)); + uint128_t q = n / y; + + return static_cast(q); +} + +// Multiplies x * y as 128-bit integer. +// @param prod_hi Stores high 64 bits of product +// @param prod_lo Stores low 64 bits of product +inline void MultiplyUInt64(uint64_t x, uint64_t y, uint64_t* prod_hi, + uint64_t* prod_lo) { + uint128_t prod = MultiplyUInt64(x, y); + *prod_hi = static_cast(prod >> 64); + *prod_lo = static_cast(prod); +} + +// Return the high 128 minus BitShift bits of the 128-bit product x * y +template +inline uint64_t MultiplyUInt64Hi(uint64_t x, uint64_t y) { + uint128_t product = MultiplyUInt64(x, y); + return static_cast(product >> BitShift); +} + +// Returns most-significant bit of the input +inline uint64_t MSB(uint64_t input) { + return static_cast(std::log2l(input)); +} + +#define HEXL_LOOP_UNROLL_4 _Pragma("GCC unroll 4") +#define HEXL_LOOP_UNROLL_8 _Pragma("GCC unroll 8") + +#endif + +} // namespace hexl +} // namespace intel diff --git a/hexl_v0/include/hexl/util/msvc.hpp b/hexl_v0/include/hexl/util/msvc.hpp new file mode 100644 index 00000000..0ada2d45 --- /dev/null +++ b/hexl_v0/include/hexl/util/msvc.hpp @@ -0,0 +1,289 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#ifdef HEXL_USE_MSVC + +#define NOMINMAX // Avoid errors with std::min/std::max +#undef min +#undef max + +#include +#include +#include + +#include + +#include "hexl/util/check.hpp" + +#pragma intrinsic(_addcarry_u64, _BitScanReverse64, _subborrow_u64, _udiv128, \ + _umul128) + +#undef TRUE +#undef FALSE + +namespace intel { +namespace hexl { + +inline uint64_t BarrettReduce128(uint64_t input_hi, uint64_t input_lo, + uint64_t modulus) { + HEXL_CHECK(modulus != 0, "modulus == 0") + uint64_t remainder; + _udiv128(input_hi, input_lo, modulus, &remainder); + + return remainder; +} + +// Multiplies x * y as 128-bit integer. +// @param prod_hi Stores high 64 bits of product +// @param prod_lo Stores low 64 bits of product +inline void MultiplyUInt64(uint64_t x, uint64_t y, uint64_t* prod_hi, + uint64_t* prod_lo) { + *prod_lo = _umul128(x, y, prod_hi); +} + +// Return the high 128 minus BitShift bits of the 128-bit product x * y +template +inline uint64_t MultiplyUInt64Hi(uint64_t x, uint64_t y) { + HEXL_CHECK(BitShift == 52 || BitShift == 64, + "Invalid BitShift " << BitShift << "; expected 52 or 64"); + uint64_t prod_hi; + uint64_t prod_lo = _umul128(x, y, &prod_hi); + uint64_t result_hi; + uint64_t result_lo; + RightShift128(&result_hi, &result_lo, prod_hi, prod_lo, BitShift); + return result_lo; +} + +/// @brief Computes Left Shift op as 128-bit unsigned integer +/// @param[out] result_hi Stores high 64 bits of result +/// @param[out] result_lo Stores low 64 bits of result +/// @param[in] op_hi Stores high 64 bits of input +/// @param[in] op_lo Stores low 64 bits of input +inline void LeftShift128(uint64_t* result_hi, uint64_t* result_lo, + const uint64_t op_hi, const uint64_t op_lo, + const uint64_t shift_value) { + HEXL_CHECK(result_hi != nullptr, "Require result_hi != nullptr"); + HEXL_CHECK(result_lo != nullptr, "Require result_lo != nullptr"); + HEXL_CHECK(shift_value <= 128, + "shift_value cannot be greater than 128 " << shift_value); + + if (shift_value == 0) { + *result_hi = op_hi; + *result_lo = op_lo; + } else if (shift_value == 64) { + *result_hi = op_lo; + *result_lo = 0ULL; + } else if (shift_value == 128) { + *result_hi = 0ULL; + *result_lo = 0ULL; + } else if (shift_value >= 1 && shift_value <= 63) { + *result_hi = (op_hi << shift_value) | (op_lo >> (64 - shift_value)); + *result_lo = op_lo << shift_value; + } else if (shift_value >= 65 && shift_value < 128) { + *result_hi = op_lo << (shift_value - 64); + *result_lo = 0ULL; + } +} + +/// @brief Computes Right Shift op as 128-bit unsigned integer +/// @param[out] result_hi Stores high 64 bits of result +/// @param[out] result_lo Stores low 64 bits of result +/// @param[in] op_hi Stores high 64 bits of input +/// @param[in] op_lo Stores low 64 bits of input +inline void RightShift128(uint64_t* result_hi, uint64_t* result_lo, + const uint64_t op_hi, const uint64_t op_lo, + const uint64_t shift_value) { + HEXL_CHECK(result_hi != nullptr, "Require result_hi != nullptr"); + HEXL_CHECK(result_lo != nullptr, "Require result_lo != nullptr"); + HEXL_CHECK(shift_value <= 128, + "shift_value cannot be greater than 128 " << shift_value); + + if (shift_value == 0) { + *result_hi = op_hi; + *result_lo = op_lo; + } else if (shift_value == 64) { + *result_hi = 0ULL; + *result_lo = op_hi; + } else if (shift_value == 128) { + *result_hi = 0ULL; + *result_lo = 0ULL; + } else if (shift_value >= 1 && shift_value <= 63) { + *result_hi = op_hi >> shift_value; + *result_lo = (op_hi << (64 - shift_value)) | (op_lo >> shift_value); + } else if (shift_value >= 65 && shift_value < 128) { + *result_hi = 0ULL; + *result_lo = op_hi >> (shift_value - 64); + } +} + +/// @brief Adds op1 + op2 as 128-bit integer +/// @param[out] result_hi Stores high 64 bits of result +/// @param[out] result_lo Stores low 64 bits of result +inline void AddWithCarry128(uint64_t* result_hi, uint64_t* result_lo, + const uint64_t op1_hi, const uint64_t op1_lo, + const uint64_t op2_hi, const uint64_t op2_lo) { + HEXL_CHECK(result_hi != nullptr, "Require result_hi != nullptr"); + HEXL_CHECK(result_lo != nullptr, "Require result_lo != nullptr"); + + // first 64bit block + *result_lo = op1_lo + op2_lo; + unsigned char carry = static_cast(*result_lo < op1_lo); + + // second 64bit block + _addcarry_u64(carry, op1_hi, op2_hi, result_hi); +} + +/// @brief Subtracts op1 - op2 as 128-bit integer +/// @param[out] result_hi Stores high 64 bits of result +/// @param[out] result_lo Stores low 64 bits of result +inline void SubWithCarry128(uint64_t* result_hi, uint64_t* result_lo, + const uint64_t op1_hi, const uint64_t op1_lo, + const uint64_t op2_hi, const uint64_t op2_lo) { + HEXL_CHECK(result_hi != nullptr, "Require result_hi != nullptr"); + HEXL_CHECK(result_lo != nullptr, "Require result_lo != nullptr"); + + unsigned char borrow; + + // first 64bit block + *result_lo = op1_lo - op2_lo; + borrow = static_cast(op2_lo > op1_lo); + + // second 64bit block + _subborrow_u64(borrow, op1_hi, op2_hi, result_hi); +} + +/// @brief Computes and returns significant bit count +/// @param[in] value Input element at most 128 bits long +inline uint64_t SignificantBitLength(const uint64_t* value) { + HEXL_CHECK(value != nullptr, "Require value != nullptr"); + + unsigned long count = 0; // NOLINT(runtime/int) + + // second 64bit block + _BitScanReverse64(&count, *(value + 1)); + if (count >= 0 && *(value + 1) > 0) { + return static_cast(count) + 1 + 64; + } + + // first 64bit block + _BitScanReverse64(&count, *value); + if (count >= 0 && *(value) > 0) { + return static_cast(count) + 1; + } + return 0; +} + +/// @brief Checks if input is negative number +/// @param[in] input Input element to check for sign +inline bool CheckSign(const uint64_t* input) { + HEXL_CHECK(input != nullptr, "Require input != nullptr"); + + uint64_t input_temp[2]{0, 0}; + RightShift128(&input_temp[1], &input_temp[0], input[1], input[0], 127); + return (input_temp[0] == 1); +} + +/// @brief Divides numerator by denominator +/// @param[out] quotient Stores quotient as two 64-bit blocks after division +/// @param[in] numerator +/// @param[in] denominator +inline void DivideUInt128UInt64(uint64_t* quotient, const uint64_t* numerator, + const uint64_t denominator) { + HEXL_CHECK(quotient != nullptr, "Require quotient != nullptr"); + HEXL_CHECK(numerator != nullptr, "Require numerator != nullptr"); + HEXL_CHECK(denominator != 0, "denominator cannot be 0 " << denominator); + + // get bit count of divisor + uint64_t numerator_bits = SignificantBitLength(numerator); + const uint64_t numerator_bits_const = numerator_bits; + const uint64_t uint_128_bit = 128ULL; + + uint64_t MASK[2]{0x0000000000000001, 0x0000000000000000}; + uint64_t remainder[2]{0, 0}; + uint64_t quotient_temp[2]{0, 0}; + uint64_t denominator_temp[2]{denominator, 0}; + + quotient[0] = numerator[0]; + quotient[1] = numerator[1]; + + // align numerator + LeftShift128("ient[1], "ient[0], quotient[1], quotient[0], + (uint_128_bit - numerator_bits_const)); + + while (numerator_bits) { + // if remainder is negative + if (CheckSign(remainder)) { + LeftShift128(&remainder[1], &remainder[0], remainder[1], remainder[0], 1); + RightShift128("ient_temp[1], "ient_temp[0], quotient[1], + quotient[0], (uint_128_bit - 1)); + remainder[0] = remainder[0] | quotient_temp[0]; + LeftShift128("ient[1], "ient[0], quotient[1], quotient[0], 1); + // remainder=remainder+denominator_temp + AddWithCarry128(&remainder[1], &remainder[0], remainder[1], remainder[0], + denominator_temp[1], denominator_temp[0]); + } else { // if remainder is positive + LeftShift128(&remainder[1], &remainder[0], remainder[1], remainder[0], 1); + RightShift128("ient_temp[1], "ient_temp[0], quotient[1], + quotient[0], (uint_128_bit - 1)); + remainder[0] = remainder[0] | quotient_temp[0]; + LeftShift128("ient[1], "ient[0], quotient[1], quotient[0], 1); + // remainder=remainder-denominator_temp + SubWithCarry128(&remainder[1], &remainder[0], remainder[1], remainder[0], + denominator_temp[1], denominator_temp[0]); + } + + // if remainder is positive set MSB of quotient[0]=1 + if (!CheckSign(remainder)) { + MASK[0] = 0x0000000000000001; + MASK[1] = 0x0000000000000000; + LeftShift128(&MASK[1], &MASK[0], MASK[1], MASK[0], + (uint_128_bit - numerator_bits_const)); + quotient[0] = quotient[0] | MASK[0]; + quotient[1] = quotient[1] | MASK[1]; + } + quotient_temp[0] = 0; + quotient_temp[1] = 0; + numerator_bits--; + } + + if (CheckSign(remainder)) { + // remainder=remainder+denominator_temp + AddWithCarry128(&remainder[1], &remainder[0], remainder[1], remainder[0], + denominator_temp[1], denominator_temp[0]); + } + RightShift128("ient[1], "ient[0], quotient[1], quotient[0], + (uint_128_bit - numerator_bits_const)); +} + +/// @brief Returns low of dividing numerator by denominator +/// @param[in] numerator_hi Stores high 64 bit of numerator +/// @param[in] numerator_lo Stores low 64 bit of numerator +/// @param[in] denominator Stores denominator +inline uint64_t DivideUInt128UInt64Lo(const uint64_t numerator_hi, + const uint64_t numerator_lo, + const uint64_t denominator) { + uint64_t numerator[2]{numerator_lo, numerator_hi}; + uint64_t quotient[2]{0, 0}; + + DivideUInt128UInt64(quotient, numerator, denominator); + return quotient[0]; +} + +// Returns most-significant bit of the input +inline uint64_t MSB(uint64_t input) { + unsigned long index{0}; // NOLINT(runtime/int) + _BitScanReverse64(&index, input); + return index; +} + +#define HEXL_LOOP_UNROLL_4 \ + {} +#define HEXL_LOOP_UNROLL_8 \ + {} + +#endif + +} // namespace hexl +} // namespace intel diff --git a/hexl_v0/include/hexl/util/types.hpp b/hexl_v0/include/hexl/util/types.hpp new file mode 100644 index 00000000..2d2d8551 --- /dev/null +++ b/hexl_v0/include/hexl/util/types.hpp @@ -0,0 +1,13 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +#include "hexl/util/defines.hpp" + +#if defined(HEXL_USE_GNU) || defined(HEXL_USE_CLANG) +__extension__ typedef __int128 int128_t; +__extension__ typedef unsigned __int128 uint128_t; +#endif diff --git a/hexl_v0/include/hexl/util/util.hpp b/hexl_v0/include/hexl/util/util.hpp new file mode 100644 index 00000000..bf878a98 --- /dev/null +++ b/hexl_v0/include/hexl/util/util.hpp @@ -0,0 +1,53 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +namespace intel { +namespace hexl { + +#undef TRUE // MSVC defines TRUE +#undef FALSE // MSVC defines FALSE + +/// @enum CMPINT +/// @brief Represents binary operations between two boolean values +enum class CMPINT { + EQ = 0, ///< Equal + LT = 1, ///< Less than + LE = 2, ///< Less than or equal + FALSE = 3, ///< False + NE = 4, ///< Not equal + NLT = 5, ///< Not less than + NLE = 6, ///< Not less than or equal + TRUE = 7 ///< True +}; + +/// @brief Returns the logical negation of a binary operation +/// @param[in] cmp The binary operation to negate +inline CMPINT Not(CMPINT cmp) { + switch (cmp) { + case CMPINT::EQ: + return CMPINT::NE; + case CMPINT::LT: + return CMPINT::NLT; + case CMPINT::LE: + return CMPINT::NLE; + case CMPINT::FALSE: + return CMPINT::TRUE; + case CMPINT::NE: + return CMPINT::EQ; + case CMPINT::NLT: + return CMPINT::LT; + case CMPINT::NLE: + return CMPINT::LE; + case CMPINT::TRUE: + return CMPINT::FALSE; + default: + return CMPINT::FALSE; + } +} + +} // namespace hexl +} // namespace intel diff --git a/hexl_v0/lib/cmake/hexl-1.2.5/HEXLConfig.cmake b/hexl_v0/lib/cmake/hexl-1.2.5/HEXLConfig.cmake new file mode 100644 index 00000000..d3c012b5 --- /dev/null +++ b/hexl_v0/lib/cmake/hexl-1.2.5/HEXLConfig.cmake @@ -0,0 +1,59 @@ +# Copyright (C) 2020 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +# This will define the following variables: +# +# HEXL_FOUND - True if the system has the Intel HEXL library +# HEXL_VERSION - The full major.minor.patch version number +# HEXL_VERSION_MAJOR - The major version number +# HEXL_VERSION_MINOR - The minor version number +# HEXL_VERSION_PATCH - The patch version number + + +####### Expanded from @PACKAGE_INIT@ by configure_package_config_file() ####### +####### Any changes to this file will be overwritten by the next CMake run #### +####### The input file was HEXLConfig.cmake.in ######## + +get_filename_component(PACKAGE_PREFIX_DIR "${CMAKE_CURRENT_LIST_DIR}/../../../" ABSOLUTE) + +macro(set_and_check _var _file) + set(${_var} "${_file}") + if(NOT EXISTS "${_file}") + message(FATAL_ERROR "File or directory ${_file} referenced by variable ${_var} does not exist !") + endif() +endmacro() + +macro(check_required_components _NAME) + foreach(comp ${${_NAME}_FIND_COMPONENTS}) + if(NOT ${_NAME}_${comp}_FOUND) + if(${_NAME}_FIND_REQUIRED_${comp}) + set(${_NAME}_FOUND FALSE) + endif() + endif() + endforeach() +endmacro() + +#################################################################################### + +include(CMakeFindDependencyMacro) +find_package(CpuFeatures CONFIG) +if(NOT CpuFeatures_FOUND) + message(WARNING "Could not find pre-installed CpuFeatures; using CpuFeatures packaged with HEXL") +endif() + +include(${CMAKE_CURRENT_LIST_DIR}/HEXLTargets.cmake) + +# Defines HEXL_FOUND: If Intel HEXL library was found +if(TARGET HEXL::hexl) + set(HEXL_FOUND TRUE) + message(STATUS "Intel HEXL found") +else() + message(STATUS "Intel HEXL not found") +endif() + +set(HEXL_VERSION "1.2.5") +set(HEXL_VERSION_MAJOR "1") +set(HEXL_VERSION_MINOR "2") +set(HEXL_VERSION_PATCH "5") + +set(HEXL_DEBUG "OFF") diff --git a/hexl_v0/lib/cmake/hexl-1.2.5/HEXLConfigVersion.cmake b/hexl_v0/lib/cmake/hexl-1.2.5/HEXLConfigVersion.cmake new file mode 100644 index 00000000..98b46110 --- /dev/null +++ b/hexl_v0/lib/cmake/hexl-1.2.5/HEXLConfigVersion.cmake @@ -0,0 +1,88 @@ +# This is a basic version file for the Config-mode of find_package(). +# It is used by write_basic_package_version_file() as input file for configure_file() +# to create a version-file which can be installed along a config.cmake file. +# +# The created file sets PACKAGE_VERSION_EXACT if the current version string and +# the requested version string are exactly the same and it sets +# PACKAGE_VERSION_COMPATIBLE if the current version is equal to the requested version. +# The tweak version component is ignored. +# The variable CVF_VERSION must be set before calling configure_file(). + + +if (PACKAGE_FIND_VERSION_RANGE) + message(AUTHOR_WARNING + "`find_package()` specify a version range but the version strategy " + "(ExactVersion) of the module `${PACKAGE_FIND_NAME}` is incompatible " + "with this request. Only the lower endpoint of the range will be used.") +endif() + +set(PACKAGE_VERSION "1.2.5") + +if("1.2.5" MATCHES "^([0-9]+)\\.([0-9]+)\\.([0-9]+)") # strip the tweak version + set(CVF_VERSION_MAJOR "${CMAKE_MATCH_1}") + set(CVF_VERSION_MINOR "${CMAKE_MATCH_2}") + set(CVF_VERSION_PATCH "${CMAKE_MATCH_3}") + + if(NOT CVF_VERSION_MAJOR VERSION_EQUAL 0) + string(REGEX REPLACE "^0+" "" CVF_VERSION_MAJOR "${CVF_VERSION_MAJOR}") + endif() + if(NOT CVF_VERSION_MINOR VERSION_EQUAL 0) + string(REGEX REPLACE "^0+" "" CVF_VERSION_MINOR "${CVF_VERSION_MINOR}") + endif() + if(NOT CVF_VERSION_PATCH VERSION_EQUAL 0) + string(REGEX REPLACE "^0+" "" CVF_VERSION_PATCH "${CVF_VERSION_PATCH}") + endif() + + set(CVF_VERSION_NO_TWEAK "${CVF_VERSION_MAJOR}.${CVF_VERSION_MINOR}.${CVF_VERSION_PATCH}") +else() + set(CVF_VERSION_NO_TWEAK "1.2.5") +endif() + +if(PACKAGE_FIND_VERSION MATCHES "^([0-9]+)\\.([0-9]+)\\.([0-9]+)") # strip the tweak version + set(REQUESTED_VERSION_MAJOR "${CMAKE_MATCH_1}") + set(REQUESTED_VERSION_MINOR "${CMAKE_MATCH_2}") + set(REQUESTED_VERSION_PATCH "${CMAKE_MATCH_3}") + + if(NOT REQUESTED_VERSION_MAJOR VERSION_EQUAL 0) + string(REGEX REPLACE "^0+" "" REQUESTED_VERSION_MAJOR "${REQUESTED_VERSION_MAJOR}") + endif() + if(NOT REQUESTED_VERSION_MINOR VERSION_EQUAL 0) + string(REGEX REPLACE "^0+" "" REQUESTED_VERSION_MINOR "${REQUESTED_VERSION_MINOR}") + endif() + if(NOT REQUESTED_VERSION_PATCH VERSION_EQUAL 0) + string(REGEX REPLACE "^0+" "" REQUESTED_VERSION_PATCH "${REQUESTED_VERSION_PATCH}") + endif() + + set(REQUESTED_VERSION_NO_TWEAK + "${REQUESTED_VERSION_MAJOR}.${REQUESTED_VERSION_MINOR}.${REQUESTED_VERSION_PATCH}") +else() + set(REQUESTED_VERSION_NO_TWEAK "${PACKAGE_FIND_VERSION}") +endif() + +if(REQUESTED_VERSION_NO_TWEAK STREQUAL CVF_VERSION_NO_TWEAK) + set(PACKAGE_VERSION_COMPATIBLE TRUE) +else() + set(PACKAGE_VERSION_COMPATIBLE FALSE) +endif() + +if(PACKAGE_FIND_VERSION STREQUAL PACKAGE_VERSION) + set(PACKAGE_VERSION_EXACT TRUE) +endif() + + +# if the installed project requested no architecture check, don't perform the check +if("FALSE") + return() +endif() + +# if the installed or the using project don't have CMAKE_SIZEOF_VOID_P set, ignore it: +if("${CMAKE_SIZEOF_VOID_P}" STREQUAL "" OR "8" STREQUAL "") + return() +endif() + +# check that the installed version has the same 32/64bit-ness as the one which is currently searching: +if(NOT CMAKE_SIZEOF_VOID_P STREQUAL "8") + math(EXPR installedBits "8 * 8") + set(PACKAGE_VERSION "${PACKAGE_VERSION} (${installedBits}bit)") + set(PACKAGE_VERSION_UNSUITABLE TRUE) +endif() diff --git a/hexl_v0/lib/cmake/hexl-1.2.5/HEXLTargets-release.cmake b/hexl_v0/lib/cmake/hexl-1.2.5/HEXLTargets-release.cmake new file mode 100644 index 00000000..c8aefe49 --- /dev/null +++ b/hexl_v0/lib/cmake/hexl-1.2.5/HEXLTargets-release.cmake @@ -0,0 +1,19 @@ +#---------------------------------------------------------------- +# Generated CMake target import file for configuration "Release". +#---------------------------------------------------------------- + +# Commands may need to know the format version. +set(CMAKE_IMPORT_FILE_VERSION 1) + +# Import target "HEXL::hexl" for configuration "Release" +set_property(TARGET HEXL::hexl APPEND PROPERTY IMPORTED_CONFIGURATIONS RELEASE) +set_target_properties(HEXL::hexl PROPERTIES + IMPORTED_LINK_INTERFACE_LANGUAGES_RELEASE "CXX" + IMPORTED_LOCATION_RELEASE "${_IMPORT_PREFIX}/lib/libhexl.a" + ) + +list(APPEND _IMPORT_CHECK_TARGETS HEXL::hexl ) +list(APPEND _IMPORT_CHECK_FILES_FOR_HEXL::hexl "${_IMPORT_PREFIX}/lib/libhexl.a" ) + +# Commands beyond this point should not need to know the version. +set(CMAKE_IMPORT_FILE_VERSION) diff --git a/hexl_v0/lib/cmake/hexl-1.2.5/HEXLTargets.cmake b/hexl_v0/lib/cmake/hexl-1.2.5/HEXLTargets.cmake new file mode 100644 index 00000000..de9a8bd8 --- /dev/null +++ b/hexl_v0/lib/cmake/hexl-1.2.5/HEXLTargets.cmake @@ -0,0 +1,95 @@ +# Generated by CMake + +if("${CMAKE_MAJOR_VERSION}.${CMAKE_MINOR_VERSION}" LESS 2.6) + message(FATAL_ERROR "CMake >= 2.6.0 required") +endif() +cmake_policy(PUSH) +cmake_policy(VERSION 2.6...3.20) +#---------------------------------------------------------------- +# Generated CMake target import file. +#---------------------------------------------------------------- + +# Commands may need to know the format version. +set(CMAKE_IMPORT_FILE_VERSION 1) + +# Protect against multiple inclusion, which would fail when already imported targets are added once more. +set(_targetsDefined) +set(_targetsNotDefined) +set(_expectedTargets) +foreach(_expectedTarget HEXL::hexl) + list(APPEND _expectedTargets ${_expectedTarget}) + if(NOT TARGET ${_expectedTarget}) + list(APPEND _targetsNotDefined ${_expectedTarget}) + endif() + if(TARGET ${_expectedTarget}) + list(APPEND _targetsDefined ${_expectedTarget}) + endif() +endforeach() +if("${_targetsDefined}" STREQUAL "${_expectedTargets}") + unset(_targetsDefined) + unset(_targetsNotDefined) + unset(_expectedTargets) + set(CMAKE_IMPORT_FILE_VERSION) + cmake_policy(POP) + return() +endif() +if(NOT "${_targetsDefined}" STREQUAL "") + message(FATAL_ERROR "Some (but not all) targets in this export set were already defined.\nTargets Defined: ${_targetsDefined}\nTargets not yet defined: ${_targetsNotDefined}\n") +endif() +unset(_targetsDefined) +unset(_targetsNotDefined) +unset(_expectedTargets) + + +# Compute the installation prefix relative to this file. +get_filename_component(_IMPORT_PREFIX "${CMAKE_CURRENT_LIST_FILE}" PATH) +get_filename_component(_IMPORT_PREFIX "${_IMPORT_PREFIX}" PATH) +get_filename_component(_IMPORT_PREFIX "${_IMPORT_PREFIX}" PATH) +get_filename_component(_IMPORT_PREFIX "${_IMPORT_PREFIX}" PATH) +if(_IMPORT_PREFIX STREQUAL "/") + set(_IMPORT_PREFIX "") +endif() + +# Create imported target HEXL::hexl +add_library(HEXL::hexl STATIC IMPORTED) + +set_target_properties(HEXL::hexl PROPERTIES + INTERFACE_COMPILE_OPTIONS "-Wno-unknown-warning;-Wno-unknown-warning-option" + INTERFACE_INCLUDE_DIRECTORIES "${_IMPORT_PREFIX}/include" +) + +# Load information for each installed configuration. +get_filename_component(_DIR "${CMAKE_CURRENT_LIST_FILE}" PATH) +file(GLOB CONFIG_FILES "${_DIR}/HEXLTargets-*.cmake") +foreach(f ${CONFIG_FILES}) + include(${f}) +endforeach() + +# Cleanup temporary variables. +set(_IMPORT_PREFIX) + +# Loop over all imported files and verify that they actually exist +foreach(target ${_IMPORT_CHECK_TARGETS} ) + foreach(file ${_IMPORT_CHECK_FILES_FOR_${target}} ) + if(NOT EXISTS "${file}" ) + message(FATAL_ERROR "The imported target \"${target}\" references the file + \"${file}\" +but this file does not exist. Possible reasons include: +* The file was deleted, renamed, or moved to another location. +* An install or uninstall procedure did not complete successfully. +* The installation package was faulty and contained + \"${CMAKE_CURRENT_LIST_FILE}\" +but not all the files it references. +") + endif() + endforeach() + unset(_IMPORT_CHECK_FILES_FOR_${target}) +endforeach() +unset(_IMPORT_CHECK_TARGETS) + +# This file does not depend on other imported targets which have +# been exported from the same project but in a separate export set. + +# Commands beyond this point should not need to know the version. +set(CMAKE_IMPORT_FILE_VERSION) +cmake_policy(POP) diff --git a/hexl_v0/lib/libhexl.a b/hexl_v0/lib/libhexl.a new file mode 100644 index 00000000..a33f2d37 Binary files /dev/null and b/hexl_v0/lib/libhexl.a differ diff --git a/hexl_v0/lib/pkgconfig/hexl.pc b/hexl_v0/lib/pkgconfig/hexl.pc new file mode 100644 index 00000000..5b481213 --- /dev/null +++ b/hexl_v0/lib/pkgconfig/hexl.pc @@ -0,0 +1,13 @@ +# Copyright (C) 2020 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +prefix=/home/eidf018/eidf018/s1820742psd/hexl/hexl_v0 +libdir=${prefix}/lib +includedir=${prefix}/include + +Name: Intel HEXL +Version: 1.2.5 +Description: Intel® HEXL is an open-source library which provides efficient implementations of integer arithmetic on Galois fields. + +Libs: -L${libdir} -lhexl +Cflags: -I${includedir} diff --git a/hexl_v1/include/hexl/eltwise/eltwise-add-mod.hpp b/hexl_v1/include/hexl/eltwise/eltwise-add-mod.hpp new file mode 100644 index 00000000..cb2df110 --- /dev/null +++ b/hexl_v1/include/hexl/eltwise/eltwise-add-mod.hpp @@ -0,0 +1,40 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +namespace intel { +namespace hexl { + +/// @brief Adds two vectors elementwise with modular reduction +/// @param[out] result Stores result +/// @param[in] operand1 Vector of elements to add. Each element must be less +/// than the modulus +/// @param[in] operand2 Vector of elements to add. Each element must be less +/// than the modulus +/// @param[in] n Number of elements in each vector +/// @param[in] modulus Modulus with which to perform modular reduction. Must be +/// in the range \f$[2, 2^{63} - 1]\f$ +/// @details Computes \f$ operand1[i] = (operand1[i] + operand2[i]) \mod modulus +/// \f$ for \f$ i=0, ..., n-1\f$. +void EltwiseAddMod(uint64_t* result, const uint64_t* operand1, + const uint64_t* operand2, uint64_t n, uint64_t modulus); + +/// @brief Adds a vector and scalar elementwise with modular reduction +/// @param[out] result Stores result +/// @param[in] operand1 Vector of elements to add. Each element must be less +/// than the modulus +/// @param[in] operand2 Scalar to add. Must be less +/// than the modulus +/// @param[in] n Number of elements in each vector +/// @param[in] modulus Modulus with which to perform modular reduction. Must be +/// in the range \f$[2, 2^{63} - 1]\f$ +/// @details Computes \f$ operand1[i] = (operand1[i] + operand2) \mod modulus +/// \f$ for \f$ i=0, ..., n-1\f$. +void EltwiseAddMod(uint64_t* result, const uint64_t* operand1, + uint64_t operand2, uint64_t n, uint64_t modulus); + +} // namespace hexl +} // namespace intel diff --git a/hexl_v1/include/hexl/eltwise/eltwise-cmp-add.hpp b/hexl_v1/include/hexl/eltwise/eltwise-cmp-add.hpp new file mode 100644 index 00000000..27e514ff --- /dev/null +++ b/hexl_v1/include/hexl/eltwise/eltwise-cmp-add.hpp @@ -0,0 +1,26 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +#include "hexl/util/util.hpp" + +namespace intel { +namespace hexl { + +/// @brief Computes element-wise conditional addition. +/// @param[out] result Stores the result +/// @param[in] operand1 Vector of elements to compare; stores result +/// @param[in] n Number of elements in \p operand1 +/// @param[in] cmp Comparison operation +/// @param[in] bound Scalar to compare against +/// @param[in] diff Scalar to conditionally add +/// @details Computes result[i] = cmp(operand1[i], bound) ? operand1[i] + +/// diff : operand1[i] for all \f$i=0, ..., n-1\f$. +void EltwiseCmpAdd(uint64_t* result, const uint64_t* operand1, uint64_t n, + CMPINT cmp, uint64_t bound, uint64_t diff); + +} // namespace hexl +} // namespace intel diff --git a/hexl_v1/include/hexl/eltwise/eltwise-cmp-sub-mod.hpp b/hexl_v1/include/hexl/eltwise/eltwise-cmp-sub-mod.hpp new file mode 100644 index 00000000..07ba3d23 --- /dev/null +++ b/hexl_v1/include/hexl/eltwise/eltwise-cmp-sub-mod.hpp @@ -0,0 +1,29 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +#include "hexl/util/util.hpp" + +namespace intel { +namespace hexl { + +/// @brief Computes element-wise conditional modular subtraction. +/// @param[out] result Stores the result +/// @param[in] operand1 Vector of elements to compare +/// @param[in] n Number of elements in \p operand1 +/// @param[in] modulus Modulus to reduce by +/// @param[in] cmp Comparison function +/// @param[in] bound Scalar to compare against +/// @param[in] diff Scalar to subtract by +/// @details Computes \p operand1[i] = (\p cmp(\p operand1, \p bound)) ? (\p +/// operand1 - \p diff) mod \p modulus : \p operand1 mod \p modulus for all i=0, +/// ..., n-1 +void EltwiseCmpSubMod(uint64_t* result, const uint64_t* operand1, uint64_t n, + uint64_t modulus, CMPINT cmp, uint64_t bound, + uint64_t diff); + +} // namespace hexl +} // namespace intel diff --git a/hexl_v1/include/hexl/eltwise/eltwise-fma-mod.hpp b/hexl_v1/include/hexl/eltwise/eltwise-fma-mod.hpp new file mode 100644 index 00000000..03651a42 --- /dev/null +++ b/hexl_v1/include/hexl/eltwise/eltwise-fma-mod.hpp @@ -0,0 +1,27 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +namespace intel { +namespace hexl { + +/// @brief Computes fused multiply-add (\p arg1 * \p arg2 + \p arg3) mod \p +/// modulus element-wise, broadcasting scalars to vectors. +/// @param[out] result Stores the result +/// @param[in] arg1 Vector to multiply +/// @param[in] arg2 Scalar to multiply +/// @param[in] arg3 Vector to add. Will not add if \p arg3 == nullptr +/// @param[in] n Number of elements in each vector +/// @param[in] modulus Modulus with which to perform modular reduction. Must be +/// in the range \f$ [2, 2^{61} - 1]\f$ +/// @param[in] input_mod_factor Assumes input elements are in [0, +/// input_mod_factor * modulus). Must be 1, 2, 4, or 8. +void EltwiseFMAMod(uint64_t* result, const uint64_t* arg1, uint64_t arg2, + const uint64_t* arg3, uint64_t n, uint64_t modulus, + uint64_t input_mod_factor); + +} // namespace hexl +} // namespace intel diff --git a/hexl_v1/include/hexl/eltwise/eltwise-mult-mod.hpp b/hexl_v1/include/hexl/eltwise/eltwise-mult-mod.hpp new file mode 100644 index 00000000..e4d2dbd7 --- /dev/null +++ b/hexl_v1/include/hexl/eltwise/eltwise-mult-mod.hpp @@ -0,0 +1,28 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +namespace intel { +namespace hexl { + +/// @brief Multiplies two vectors elementwise with modular reduction +/// @param[in] result Result of element-wise multiplication +/// @param[in] operand1 Vector of elements to multiply. Each element must be +/// less than the modulus. +/// @param[in] operand2 Vector of elements to multiply. Each element must be +/// less than the modulus. +/// @param[in] n Number of elements in each vector +/// @param[in] modulus Modulus with which to perform modular reduction +/// @param[in] input_mod_factor Assumes input elements are in [0, +/// input_mod_factor * p) Must be 1, 2 or 4. +/// @details Computes \p result[i] = (\p operand1[i] * \p operand2[i]) mod \p +/// modulus for i=0, ..., \p n - 1 +void EltwiseMultMod(uint64_t* result, const uint64_t* operand1, + const uint64_t* operand2, uint64_t n, uint64_t modulus, + uint64_t input_mod_factor); + +} // namespace hexl +} // namespace intel diff --git a/hexl_v1/include/hexl/eltwise/eltwise-reduce-mod.hpp b/hexl_v1/include/hexl/eltwise/eltwise-reduce-mod.hpp new file mode 100644 index 00000000..c23abde2 --- /dev/null +++ b/hexl_v1/include/hexl/eltwise/eltwise-reduce-mod.hpp @@ -0,0 +1,29 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +namespace intel { +namespace hexl { + +/// @brief Performs elementwise modular reduction +/// @param[out] result Stores the result +/// @param[in] operand Data on which to compute the elementwise modular +/// reduction +/// @param[in] n Number of elements in operand +/// @param[in] modulus Modulus with which to perform modular reduction +/// @param[in] input_mod_factor Assumes input elements are in [0, +/// input_mod_factor * p) Must be modulus, 1, 2 or 4. input_mod_factor=modulus +/// means, input range is [0, p * p]. Barrett reduction will be used in this +/// case. input_mod_factor > output_mod_factor +/// @param[in] output_mod_factor output elements will be in [0, +/// output_mod_factor * modulus) Must be 1 or 2. For input_mod_factor=0, +/// output_mod_factor will be set to 1. +void EltwiseReduceMod(uint64_t* result, const uint64_t* operand, uint64_t n, + uint64_t modulus, uint64_t input_mod_factor, + uint64_t output_mod_factor); + +} // namespace hexl +} // namespace intel diff --git a/hexl_v1/include/hexl/eltwise/eltwise-sub-mod.hpp b/hexl_v1/include/hexl/eltwise/eltwise-sub-mod.hpp new file mode 100644 index 00000000..bd286e47 --- /dev/null +++ b/hexl_v1/include/hexl/eltwise/eltwise-sub-mod.hpp @@ -0,0 +1,40 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +namespace intel { +namespace hexl { + +/// @brief Subtracts two vectors elementwise with modular reduction +/// @param[out] result Stores result +/// @param[in] operand1 Vector of elements to subtract from. Each element must +/// be less than the modulus +/// @param[in] operand2 Vector of elements to subtract. Each element must be +/// less than the modulus +/// @param[in] n Number of elements in each vector +/// @param[in] modulus Modulus with which to perform modular reduction. Must be +/// in the range \f$[2, 2^{63} - 1]\f$ +/// @details Computes \f$ operand1[i] = (operand1[i] - operand2[i]) \mod modulus +/// \f$ for \f$ i=0, ..., n-1\f$. +void EltwiseSubMod(uint64_t* result, const uint64_t* operand1, + const uint64_t* operand2, uint64_t n, uint64_t modulus); + +/// @brief Subtracts a scalar from a vector elementwise with modular reduction +/// @param[out] result Stores result +/// @param[in] operand1 Vector of elements to subtract from. Each element must +/// be less than the modulus +/// @param[in] operand2 Elements to subtract. Each element must be +/// less than the modulus +/// @param[in] n Number of elements in each vector +/// @param[in] modulus Modulus with which to perform modular reduction. Must be +/// in the range \f$[2, 2^{63} - 1]\f$ +/// @details Computes \f$ operand1[i] = (operand1[i] - operand2) \mod modulus +/// \f$ for \f$ i=0, ..., n-1\f$. +void EltwiseSubMod(uint64_t* result, const uint64_t* operand1, + uint64_t operand2, uint64_t n, uint64_t modulus); + +} // namespace hexl +} // namespace intel diff --git a/hexl_v1/include/hexl/experimental/fft-like/fft-like-avx512-util.hpp b/hexl_v1/include/hexl/experimental/fft-like/fft-like-avx512-util.hpp new file mode 100644 index 00000000..28a2dddf --- /dev/null +++ b/hexl_v1/include/hexl/experimental/fft-like/fft-like-avx512-util.hpp @@ -0,0 +1,402 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "util/avx512-util.hpp" + +namespace intel { +namespace hexl { + +#ifdef HEXL_HAS_AVX512DQ + +// ************************************ T1 ************************************ + +// ComplexLoadFwdInterleavedT1: +// Assumes ComplexLoadFwdInterleavedT2 was used before. +// Given input: 15, 14, 11, 10, 7, 6, 3, 2, 13, 12, 9, 8, 5, 4, 1, 0 +// Returns +// *out1 = (14, 12, 10, 8, 6, 4, 2, 0); +// *out2 = (15, 13, 11, 9, 7, 5, 3, 1); +// +// Given output: 15, 13, 11, 9, 7, 5, 3, 1, 14, 12, 10, 8, 6, 4, 2, 0 +inline void ComplexLoadFwdInterleavedT1(const double* arg, __m512d* out1, + __m512d* out2) { + const __m512i vperm_idx = _mm512_set_epi64(6, 7, 4, 5, 2, 3, 0, 1); + + const __m512d* arg_512 = reinterpret_cast(arg); + + // 13 12 9 8 5 4 1 0 + __m512d v_7to0 = _mm512_loadu_pd(arg_512); + arg_512 += 2; + // 15 14 11 10 7 6 3 2 + __m512d v_15to8 = _mm512_loadu_pd(arg_512); + + // 12, 13, 8, 9, 4, 5, 0, 1 + __m512d perm_1 = _mm512_permutexvar_pd(vperm_idx, v_7to0); + // 14, 15, 10, 11, 6, 7, 2, 3 + __m512d perm_2 = _mm512_permutexvar_pd(vperm_idx, v_15to8); + + // 14, 12, 10, 8, 6, 4, 2, 0 + *out1 = _mm512_mask_blend_pd(0xaa, v_7to0, perm_2); + // 15, 13, 11, 9, 7, 5, 3, 1 + *out2 = _mm512_mask_blend_pd(0x55, v_15to8, perm_1); +} + +// ComplexWriteFwdInterleavedT1: +// Assumes ComplexLoadFwdInterleavedT1 was used before. +// Given inputs: +// 15i, 13i, 11i, 9i, 7i, 5i, 3i, 1i, 15r, 13r, 11r, 9r, 7r, 5r, 3r, 1r, +// 14i, 12i, 10i, 8i, 6i, 4i, 2i, 0i, 14r, 12r, 10r, 8r, 6r, 4r, 2r, 0r +// As seen with internal indexes: +// @param arg_yr = (15r, 14r, 13r, 12r, 11r, 10r, 9r, 8r); +// @param arg_xr = ( 7r, 6r, 5r, 4r, 3r, 2r, 1r, 0r); +// @param arg_yi = (15i, 14i, 13i, 12i, 11i, 10i, 9i, 8i); +// @param arg_xi = ( 7i, 6i, 5i, 4i, 3i, 2i, 1i, 0i); +// Writes out = +// {15i, 15r, 7i, 7r, 14i, 14r, 6i, 6r, 13i, 13r, 5i, 5r, 12i, 12r, 4i, 4r, +// 11i, 11r, 3i, 3r, 10i, 10r, 2i, 2r, 9i, 9r, 1i, 1r, 8i, 8r, 0i, 0r} +// +// Given output: +// 15i, 15r, 14i, 14r, 13i, 13r, 12i, 12r, 11i, 11r, 10i, 10r, 9i, 9r, 8i, 8r, +// 7i, 7r, 6i, 6r, 5i, 5r, 4i, 4r, 3i, 3r, 2i, 2r, 1i, 1r, 0i, 0r +inline void ComplexWriteFwdInterleavedT1(__m512d arg_xr, __m512d arg_yr, + __m512d arg_xi, __m512d arg_yi, + __m512d* out) { + const __m512i vperm_4hi_4lo_idx = _mm512_set_epi64(3, 2, 1, 0, 7, 6, 5, 4); + const __m512i v_X_out_idx = _mm512_set_epi64(3, 1, 7, 5, 2, 0, 6, 4); + const __m512i v_Y_out_idx = _mm512_set_epi64(7, 5, 3, 1, 6, 4, 2, 0); + + // Real part + // in: 14r, 12r, 10r, 8r, 6r, 4r, 2r, 0r + // -> 6r, 4r, 2r, 0r, 14r, 12r, 10r, 8r + arg_xr = _mm512_permutexvar_pd(vperm_4hi_4lo_idx, arg_xr); + + // arg_yr: 15r, 13r, 11r, 9r, 7r, 5r, 3r, 1r + // -> 6r, 4r, 2r, 0r, 7r, 5r, 3r, 1r + __m512d perm_1 = _mm512_mask_blend_pd(0x0f, arg_xr, arg_yr); + // -> 15r, 13r, 11r, 9r, 14r, 12r, 10r, 8r + __m512d perm_2 = _mm512_mask_blend_pd(0xf0, arg_xr, arg_yr); + + // 7r, 3r, 6r, 2r, 5r, 1r, 4r, 0r + arg_xr = _mm512_permutexvar_pd(v_X_out_idx, perm_1); + // 15r, 11r, 14r, 10r, 13r, 9r, 12r, 8r + arg_yr = _mm512_permutexvar_pd(v_Y_out_idx, perm_2); + + // Imaginary part + // in: 14i, 12i, 10i, 8i, 6i, 4i, 2i, 0i + // -> 6i, 4i, 2i, 0i, 14i, 12i, 10i, 8i + arg_xi = _mm512_permutexvar_pd(vperm_4hi_4lo_idx, arg_xi); + + // arg_yr: 15i, 13i, 11i, 9i, 7i, 5i, 3i, 1i + // -> 6i, 4i, 2i, 0i, 7i, 5i, 3i, 1i + perm_1 = _mm512_mask_blend_pd(0x0f, arg_xi, arg_yi); + // -> 15i, 13i, 11i, 9i, 14i, 12i, 10i, 8i + perm_2 = _mm512_mask_blend_pd(0xf0, arg_xi, arg_yi); + + // 7i, 3i, 6i, 2i, 5i, 1i, 4i, 0i + arg_xi = _mm512_permutexvar_pd(v_X_out_idx, perm_1); + // 15i, 11i, 14i, 10i, 13i, 9i, 12i, 8i + arg_yi = _mm512_permutexvar_pd(v_Y_out_idx, perm_2); + + // Merge + // 00000000 > 3i 3r 2i 2r 1i 1r 0i 0r + __m512d out1 = _mm512_shuffle_pd(arg_xr, arg_xi, 0x00); + // 11111111 > 7i 7r 6i 6r 5i 5r 4i 4r + __m512d out2 = _mm512_shuffle_pd(arg_xr, arg_xi, 0xff); + + // 00000000 > 11i 11r 10i 10r 9i 9r 8i 8r + __m512d out3 = _mm512_shuffle_pd(arg_yr, arg_yi, 0x00); + // 11111111 > 15i 15r 14i 14r 13i 13r 12i 12r + __m512d out4 = _mm512_shuffle_pd(arg_yr, arg_yi, 0xff); + + _mm512_storeu_pd(out++, out1); + _mm512_storeu_pd(out++, out2); + _mm512_storeu_pd(out++, out3); + _mm512_storeu_pd(out++, out4); +} + +// ComplexLoadInvInterleavedT1: +// Given input: 15i 15r 14i 14r 13i 13r 12i 12r 11i 11r 10i 10r 9i 9r 8i 8r +// 7i 7r 6i 6r 5i 5r 4i 4r 3i 3r 2i 2r 1i 1r 0i 0r +// Returns +// *out1_r = (14r, 10r, 6r, 2r, 12r, 8r, 4r, 0r); +// *out1_i = (14i, 10i, 6i, 2i, 12i, 8i, 4i, 0i); +// *out2_r = (15r, 11r, 7r, 3r, 13r, 9r, 5r, 1r); +// *out2_i = (15i, 11i, 7i, 3i, 13i, 9i, 5i, 1i); +// +// Given output: +// 15i, 11i, 7i, 3i, 13i, 9i, 5i, 1i, 15r, 11r, 7r, 3r, 13r, 9r, 5r, 1r, +// 14i, 10i, 6i, 2i, 12i, 8i, 4i, 0i, 14r, 10r, 6r, 2r, 12r, 8r, 4r, 0r +inline void ComplexLoadInvInterleavedT1(const double* arg, __m512d* out1_r, + __m512d* out1_i, __m512d* out2_r, + __m512d* out2_i) { + const __m512d* arg_512 = reinterpret_cast(arg); + + // 3i 3r 2i 2r 1i 1r 0i 0r + __m512d v_3to0 = _mm512_loadu_pd(arg_512++); + // 7i 7r 6i 6r 5i 5r 4i 4r + __m512d v_7to4 = _mm512_loadu_pd(arg_512++); + // 11i 11r 10i 10r 9i 9r 8i 8r + __m512d v_11to8 = _mm512_loadu_pd(arg_512++); + // 15i 15r 14i 14r 13i 13r 12i 12r + __m512d v_15to12 = _mm512_loadu_pd(arg_512++); + + // 00000000 > 7r 3r 6r 2r 5r 1r 4r 0r + __m512d v_7to0_r = _mm512_shuffle_pd(v_3to0, v_7to4, 0x00); + // 11111111 > 7i 3i 6i 2i 5i 1i 4i 0i + __m512d v_7to0_i = _mm512_shuffle_pd(v_3to0, v_7to4, 0xff); + // 00000000 > 15r 11r 14r 10r 13r 9r 12r 8r + __m512d v_15to8_r = _mm512_shuffle_pd(v_11to8, v_15to12, 0x00); + // 11111111 > 15i 11i 14i 10i 13i 9i 12i 8i + __m512d v_15to8_i = _mm512_shuffle_pd(v_11to8, v_15to12, 0xff); + + // real + const __m512i v1_perm_idx = _mm512_set_epi64(5, 4, 7, 6, 1, 0, 3, 2); + // 6 2 7 3 4 0 5 1 + __m512d v1r = _mm512_permutexvar_pd(v1_perm_idx, v_7to0_r); + // 14 10 15 11 12 8 13 9 + __m512d v2r = _mm512_permutexvar_pd(v1_perm_idx, v_15to8_r); + // 11001100 > 14 10 6 2 12 8 4 0 + *out1_r = _mm512_mask_blend_pd(0xcc, v_7to0_r, v2r); + // 11001100 > 15 11 7 3 13 9 5 1 + *out2_r = _mm512_mask_blend_pd(0xcc, v1r, v_15to8_r); + + // imag + // 6 2 7 3 4 0 5 1 + __m512d v1i = _mm512_permutexvar_pd(v1_perm_idx, v_7to0_i); + // 14 10 15 11 12 8 13 9 + __m512d v2i = _mm512_permutexvar_pd(v1_perm_idx, v_15to8_i); + // 11001100 > 14 10 6 2 12 8 4 0 + *out1_i = _mm512_mask_blend_pd(0xcc, v_7to0_i, v2i); + // 11001100 > 15 11 7 3 13 9 5 1 + *out2_i = _mm512_mask_blend_pd(0xcc, v1i, v_15to8_i); +} + +// ************************************ T2 ************************************ + +// ComplexLoadFwdInterleavedT2: +// Assumes ComplexLoadFwdInterleavedT4 was used before. +// Given input: 15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0 +// Returns +// *out1 = (13, 12, 9, 8, 5, 4, 1, 0) +// *out2 = (15, 14, 11, 10, 7, 6, 3, 2) +// +// Given output: 15, 14, 11, 10, 7, 6, 3, 2, 13, 12, 9, 8, 5, 4, 1, 0 +inline void ComplexLoadFwdInterleavedT2(const double* arg, __m512d* out1, + __m512d* out2) { + const __m512d* arg_512 = reinterpret_cast(arg); + + // Values were swapped in T4 + // 11, 10, 9, 8, 3, 2, 1, 0 + __m512d v1 = _mm512_loadu_pd(arg_512); + arg_512 += 2; + // 15, 14, 13, 12, 7, 6, 5, 4 + __m512d v2 = _mm512_loadu_pd(arg_512); + + const __m512i v1_perm_idx = _mm512_set_epi64(5, 4, 7, 6, 1, 0, 3, 2); + + __m512d v1_perm = _mm512_permutexvar_pd(v1_perm_idx, v1); + __m512d v2_perm = _mm512_permutexvar_pd(v1_perm_idx, v2); + + *out1 = _mm512_mask_blend_pd(0xcc, v1, v2_perm); + *out2 = _mm512_mask_blend_pd(0xcc, v1_perm, v2); +} + +// ComplexLoadInvInterleavedT2: +// Assumes ComplexLoadInvInterleavedT1 was used before. +// Given input: 15, 11, 7, 3, 13, 9, 5, 1, 14, 10, 6, 2, 12, 8, 4, 0 +// Returns +// *out1 = (13, 9, 5, 1, 12, 8, 4, 0) +// *out2 = (15, 11, 7, 3, 14, 10, 6, 2) +// +// Given output: 15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0 +inline void ComplexLoadInvInterleavedT2(const double* arg, __m512d* out1, + __m512d* out2) { + const __m512d* arg_512 = reinterpret_cast(arg); + + // 14 10 6 2 12 8 4 0 + __m512d v1 = _mm512_loadu_pd(arg_512); + arg_512 += 2; + // 15 11 7 3 13 9 5 1 + __m512d v2 = _mm512_loadu_pd(arg_512); + + const __m512i v1_perm_idx = _mm512_set_epi64(3, 2, 1, 0, 7, 6, 5, 4); + + // 12 8 4 0 14 10 6 2 + __m512d v1_perm = _mm512_permutexvar_pd(v1_perm_idx, v1); + // 13 9 5 1 15 11 7 3 + __m512d v2_perm = _mm512_permutexvar_pd(v1_perm_idx, v2); + + // 11110000 > 13 9 5 1 12 8 4 0 + *out1 = _mm512_mask_blend_pd(0xf0, v1, v2_perm); + // 11110000 > 15 11 7 3 14 10 6 2 + *out2 = _mm512_mask_blend_pd(0xf0, v1_perm, v2); +} + +// ************************************ T4 ************************************ + +// Complex LoadFwdInterleavedT4: +// Given input: 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0 +// Returns +// *out1 = (11, 10, 9, 8, 3, 2, 1, 0) +// *out2 = (15, 14, 13, 12, 7, 6, 5, 4) +// +// Given output: 15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0 +inline void ComplexLoadFwdInterleavedT4(const double* arg, __m512d* out1, + __m512d* out2) { + const __m512d* arg_512 = reinterpret_cast(arg); + + const __m512i vperm2_idx = _mm512_set_epi64(3, 2, 1, 0, 7, 6, 5, 4); + __m512d v_7to0 = _mm512_loadu_pd(arg_512); + arg_512 += 2; + __m512d v_15to8 = _mm512_loadu_pd(arg_512); + __m512d perm_hi = _mm512_permutexvar_pd(vperm2_idx, v_15to8); + *out1 = _mm512_mask_blend_pd(0x0f, perm_hi, v_7to0); + *out2 = _mm512_mask_blend_pd(0xf0, perm_hi, v_7to0); + *out2 = _mm512_permutexvar_pd(vperm2_idx, *out2); +} + +// ComplexLoadInvInterleavedT4: +// Assumes ComplexLoadInvInterleavedT2 was used before. +// Given input: 15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0 +// Returns +// *out1 = (11, 9, 3, 1, 10, 8, 2, 0) +// *out2 = (15, 13, 7, 5, 14, 12, 6, 4) +// +// Given output: 15, 13, 7, 5, 14, 12, 6, 4, 11, 9, 3, 1, 10, 8, 2, 0 + +inline void ComplexLoadInvInterleavedT4(const double* arg, __m512d* out1, + __m512d* out2) { + const __m512d* arg_512 = reinterpret_cast(arg); + + // 13, 9, 5, 1, 12, 8, 4, 0 + __m512d v1 = _mm512_loadu_pd(arg_512); + arg_512 += 2; + // 15, 11, 7, 3, 14, 10, 6, 2 + __m512d v2 = _mm512_loadu_pd(arg_512); + + // 00000000 > 11 9 3 1 10 8 2 0 + *out1 = _mm512_shuffle_pd(v1, v2, 0x00); + // 11111111 > 15 13 7 5 14 12 6 4 + *out2 = _mm512_shuffle_pd(v1, v2, 0xff); +} + +// ComplexWriteInvInterleavedT4: +// Assuming ComplexLoadInvInterleavedT4 was used before. +// Given inputs: 15, 13, 7, 5, 14, 12, 6, 4, 11, 9, 3, 1, 10, 8, 2, 0 +// Seen Internally: +// @param arg1 = ( 7, 6, 5, 4, 3, 2, 1, 0); +// @param arg2 = (15, 14, 13, 12, 11, 10, 9, 8); +// Writes out = {15, 11, 14, 10, 7, 3, 6, 2, +// 13, 9, 12, 8, 5, 1, 4, 0} +// +// Given output: 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0 +inline void ComplexWriteInvInterleavedT4(__m512d arg1, __m512d arg2, + __m512d* out) { + const __m512i vperm_4hi_4lo_idx = _mm512_set_epi64(3, 2, 1, 0, 7, 6, 5, 4); + const __m512i vperm1 = _mm512_set_epi64(7, 3, 6, 2, 5, 1, 4, 0); + const __m512i vperm2 = _mm512_set_epi64(5, 1, 4, 0, 7, 3, 6, 2); + + // in: 11 9 3 1 10 8 2 0 + // -> 11 10 9 8 3 2 1 0 + arg1 = _mm512_permutexvar_pd(vperm1, arg1); + // in: 15 13 7 5 14 12 6 4 + // -> 7 6 5 4 15 14 13 12 + arg2 = _mm512_permutexvar_pd(vperm2, arg2); + + // 7 6 5 4 3 2 1 0 + __m512d out1 = _mm512_mask_blend_pd(0xf0, arg1, arg2); + // 11 10 9 8 15 14 13 12 + __m512d out2 = _mm512_mask_blend_pd(0x0f, arg1, arg2); + // 15 14 13 12 11 10 9 8 + out2 = _mm512_permutexvar_pd(vperm_4hi_4lo_idx, out2); + + _mm512_storeu_pd(out, out1); + out += 2; + _mm512_storeu_pd(out, out2); +} + +// ************************************ T8 ************************************ + +// ComplexLoadFwdInterleavedT8: +// Given inputs: 7i, 7r, 6i, 6r, 5i, 5r, 4i, 4r, 3i, 3r, 2i, 2r, 1i, 1r, 0i, 0r +// Seen Internally: +// v_X1 = ( 7, 6, 5, 4, 3, 2, 1, 0); +// v_X2 = (15, 14, 13, 12, 11, 10, 9, 8); +// Writes out = {15, 13, 11, 9, 7, 5, 3, 1, +// 14, 12, 10, 8, 6, 4, 2, 0} +// +// Given output: 7i, 6i, 5i, 4i, 3i, 2i, 1i, 0i, 7r, 6r, 5r, 4r, 3r, 2r, 1r, 0r +inline void ComplexLoadFwdInterleavedT8(const __m512d* arg_x, + const __m512d* arg_y, __m512d* out1_r, + __m512d* out1_i, __m512d* out2_r, + __m512d* out2_i) { + const __m512i v_perm_idx = _mm512_set_epi64(7, 5, 3, 1, 6, 4, 2, 0); + + // 3i, 3r, 2i, 2r, 1i, 1r, 0i, 0r + __m512d v_X1 = _mm512_loadu_pd(arg_x++); + // 7i, 7r, 6i, 6r, 5i, 5r, 4i, 4r + __m512d v_X2 = _mm512_loadu_pd(arg_x); + // 7r, 3r, 6r, 2r, 5r, 1r, 4r, 0r + *out1_r = _mm512_shuffle_pd(v_X1, v_X2, 0x00); + // 7i, 3i, 6i, 2i, 5i, 1i, 4i, 0i + *out1_i = _mm512_shuffle_pd(v_X1, v_X2, 0xff); + // 7r, 6r, 5r, 4r, 3r, 2r, 1r, 0r + *out1_r = _mm512_permutexvar_pd(v_perm_idx, *out1_r); + // 7i, 6i, 5i, 4i, 3i, 2i, 1i, 0i + *out1_i = _mm512_permutexvar_pd(v_perm_idx, *out1_i); + + __m512d v_Y1 = _mm512_loadu_pd(arg_y++); + __m512d v_Y2 = _mm512_loadu_pd(arg_y); + *out2_r = _mm512_shuffle_pd(v_Y1, v_Y2, 0x00); + *out2_i = _mm512_shuffle_pd(v_Y1, v_Y2, 0xff); + *out2_r = _mm512_permutexvar_pd(v_perm_idx, *out2_r); + *out2_i = _mm512_permutexvar_pd(v_perm_idx, *out2_i); +} + +// ComplexWriteInvInterleavedT8: +// Assuming ComplexLoadInvInterleavedT4 was used before. +// Given inputs: 7i, 6i, 5i, 4i, 3i, 2i, 1i, 0i, 7r, 6r, 5r, 4r, 3r, 2r, 1r, 0r +// Seen Internally: +// @param arg1 = ( 7, 6, 5, 4, 3, 2, 1, 0); +// @param arg2 = (15, 14, 13, 12, 11, 10, 9, 8); +// Writes out = {15, 7, 14, 6, 13, 5, 12, 4, +// 11, 3, 10, 2, 9, 1, 8, 0} +// +// Given output: 7i, 7r, 6i, 6r, 5i, 5r, 4i, 4r, 3i, 3r, 2i, 2r, 1i, 1r, 0i, 0r +inline void ComplexWriteInvInterleavedT8(__m512d* v_X_real, __m512d* v_X_imag, + __m512d* v_Y_real, __m512d* v_Y_imag, + __m512d* v_X_pt, __m512d* v_Y_pt) { + const __m512i vperm = _mm512_set_epi64(7, 3, 6, 2, 5, 1, 4, 0); + // in: 7r 6r 5r 4r 3r 2r 1r 0r + // -> 7r 3r 6r 2r 5r 1r 4r 0r + *v_X_real = _mm512_permutexvar_pd(vperm, *v_X_real); + // in: 7i 6i 5i 4i 3i 2i 1i 0i + // -> 7i 3i 6i 2i 5i 1i 4i 0i + *v_X_imag = _mm512_permutexvar_pd(vperm, *v_X_imag); + // in: 15r 14r 13r 12r 11r 10r 9r 8r + // -> 15r 11r 14r 10r 13r 9r 12r 8r + *v_Y_real = _mm512_permutexvar_pd(vperm, *v_Y_real); + // in: 15i 14i 13i 12i 11i 10i 9i 8i + // -> 15i 11i 14i 10i 13i 9i 12i 8i + *v_Y_imag = _mm512_permutexvar_pd(vperm, *v_Y_imag); + + // 00000000 > 3i 3r 2i 2r 1i 1r 0i 0r + __m512d v_X1 = _mm512_shuffle_pd(*v_X_real, *v_X_imag, 0x00); + // 11111111 > 7i 7r 6i 6r 5i 5r 4i 4r + __m512d v_X2 = _mm512_shuffle_pd(*v_X_real, *v_X_imag, 0xff); + // 00000000 > 11i 11r 10i 10r 9i 9r 8i 8r + __m512d v_Y1 = _mm512_shuffle_pd(*v_Y_real, *v_Y_imag, 0x00); + // 11111111 > 15i 15r 14i 14r 13i 13r 12i 12r + __m512d v_Y2 = _mm512_shuffle_pd(*v_Y_real, *v_Y_imag, 0xff); + + _mm512_storeu_pd(v_X_pt++, v_X1); + _mm512_storeu_pd(v_X_pt, v_X2); + _mm512_storeu_pd(v_Y_pt++, v_Y1); + _mm512_storeu_pd(v_Y_pt, v_Y2); +} +#endif // HEXL_HAS_AVX512DQ + +} // namespace hexl +} // namespace intel diff --git a/hexl_v1/include/hexl/experimental/fft-like/fft-like-native.hpp b/hexl_v1/include/hexl/experimental/fft-like/fft-like-native.hpp new file mode 100644 index 00000000..7e02492d --- /dev/null +++ b/hexl_v1/include/hexl/experimental/fft-like/fft-like-native.hpp @@ -0,0 +1,38 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +namespace intel { +namespace hexl { + +/// @brief Radix-2 native C++ FFT like implementation of the forward FFT like +/// @param[out] result Output data. Overwritten with FFT like output +/// @param[in] operand Input data. +/// @param[in] n Size of the transform, i.e. the polynomial degree. Must be a +/// power of two. +/// @param[in] root_of_unity_powers Powers of 2n'th root of unity. In +/// bit-reversed order +/// @param[in] scale Scale applied to output data +void Forward_FFTLike_ToBitReverseRadix2( + std::complex* result, const std::complex* operand, + const std::complex* root_of_unity_powers, const uint64_t n, + const double* scale = nullptr); + +/// @brief Radix-2 native C++ FFT like implementation of the inverse FFT like +/// @param[out] result Output data. Overwritten with FFT like output +/// @param[in] operand Input data. +/// @param[in] n Size of the transform, i.e. the polynomial degree. Must be a +/// power of two. +/// @param[in] inv_root_of_unity_powers Powers of inverse 2n'th root of unity. +/// In bit-reversed order. +/// @param[in] scale Scale applied to output data +void Inverse_FFTLike_FromBitReverseRadix2( + std::complex* result, const std::complex* operand, + const std::complex* inv_root_of_unity_powers, const uint64_t n, + const double* scale = nullptr); + +} // namespace hexl +} // namespace intel diff --git a/hexl_v1/include/hexl/experimental/fft-like/fft-like.hpp b/hexl_v1/include/hexl/experimental/fft-like/fft-like.hpp new file mode 100644 index 00000000..334de246 --- /dev/null +++ b/hexl_v1/include/hexl/experimental/fft-like/fft-like.hpp @@ -0,0 +1,147 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +#include "hexl/experimental/fft-like/fwd-fft-like-avx512.hpp" +#include "hexl/experimental/fft-like/inv-fft-like-avx512.hpp" +#include "hexl/util/aligned-allocator.hpp" +#include "hexl/util/allocator.hpp" + +namespace intel { +namespace hexl { + +/// @brief Performs linear forward and inverse FFT like transform +/// for CKKS encoding and decoding. +class FFTLike { + public: + /// @brief Helper class for custom memory allocation + template + struct AllocatorAdapter + : public AllocatorInterface> { + explicit AllocatorAdapter(Adaptee&& _a, Args&&... args); + AllocatorAdapter(const Adaptee& _a, Args&... args); + + // interface implementation + void* allocate_impl(size_t bytes_count); + void deallocate_impl(void* p, size_t n); + + private: + Adaptee alloc; + }; + + /// @brief Initializes an empty CKKS_FTT object + FFTLike() = default; + + /// @brief Destructs the CKKS_FTT object + ~FFTLike() = default; + + /// @brief Initializes an FFTLike object with degree \p degree and scalar + /// \p in_scalar. + /// @param[in] degree also known as N. Size of the FFT like transform. Must be + /// a power of 2 + /// @param[in] in_scalar Scalar value to calculate scale and inv scale + /// @param[in] alloc_ptr Custom memory allocator used for intermediate + /// calculations + /// @details Performs pre-computation necessary for forward and inverse + /// transforms + FFTLike(uint64_t degree, double* in_scalar, + std::shared_ptr alloc_ptr = {}); + + template + FFTLike(uint64_t degree, double* in_scalar, Allocator&& a, + AllocatorArgs&&... args) + : FFTLike( + degree, in_scalar, + std::static_pointer_cast( + std::make_shared>( + std::move(a), std::forward(args)...))) {} + + /// @brief Compute forward FFT like. Results are bit-reversed. + /// @param[out] result Stores the result + /// @param[in] operand Data on which to compute the FFT like + /// @param[in] in_scale Scale applied to output values + void ComputeForwardFFTLike(std::complex* result, + const std::complex* operand, + const double* in_scale = nullptr); + + /// @brief Compute inverse FFT like. Results are bit-reversed. + /// @param[out] result Stores the result + /// @param[in] operand Data on which to compute the FFT like + /// @param[in] in_scale Scale applied to output values + void ComputeInverseFFTLike(std::complex* result, + const std::complex* operand, + const double* in_scale = nullptr); + + /// @brief Construct floating-point values from CRT-composed polynomial with + /// integer coefficients. + /// @param[out] res Stores the result + /// @param[in] plain Plaintext + /// @param[in] threshold Upper half threshold with respect to the total + /// coefficient modulus + /// @param[in] decryption_modulus Product of all primes in the coefficient + /// modulus + /// @param[in] inv_scale Scale applied to output values + /// @param[in] mod_size Size of coefficient modulus parameter + /// @param[in] coeff_count Degree of the polynomial modulus parameter + void BuildFloatingPoints(std::complex* res, const uint64_t* plain, + const uint64_t* threshold, + const uint64_t* decryption_modulus, + const double inv_scale, size_t mod_size, + size_t coeff_count); + + /// @brief Returns the root of unity power at bit-reversed index i. + /// @param[in] i Index + std::complex GetComplexRootOfUnity(size_t i) { + return GetComplexRootsOfUnity()[i]; + } + + /// @brief Returns the root of unity in bit-reversed order + const AlignedVector64>& GetComplexRootsOfUnity() const { + return m_complex_roots_of_unity; + } + + /// @brief Returns the root of unity power at bit-reversed index i. + /// @param[in] i Index + std::complex GetInvComplexRootOfUnity(size_t i) { + return GetInvComplexRootsOfUnity()[i]; + } + + /// @brief Returns the inverse root of unity in bit-reversed order + const AlignedVector64>& GetInvComplexRootsOfUnity() + const { + return m_inv_complex_roots_of_unity; + } + + /// @brief Returns the degree N + uint64_t GetDegree() const { return m_degree; } + + private: + // Computes 1~(n-1)-th powers and inv powers of the primitive 2n-th root + void ComputeComplexRootsOfUnity(); + + uint64_t m_degree; // N: size of FFT like transform, should be power of 2 + + double* scalar; // Pointer to scalar used for scale/inv_scale calculation + + double scale; // Scale value use for encoding (inv fft-like) + + double inv_scale; // Scale value use in decoding (fwd fft-like) + + std::shared_ptr m_alloc; + + AlignedAllocator m_aligned_alloc; + + uint64_t m_degree_bits; // log_2(m_degree) + + // Contains 0~(n-1)-th powers of the 2n-th primitive root. + AlignedVector64> m_complex_roots_of_unity; + + // Contains 0~(n-1)-th inv powers of the 2n-th primitive inv root. + AlignedVector64> m_inv_complex_roots_of_unity; +}; + +} // namespace hexl +} // namespace intel diff --git a/hexl_v1/include/hexl/experimental/fft-like/fwd-fft-like-avx512.hpp b/hexl_v1/include/hexl/experimental/fft-like/fwd-fft-like-avx512.hpp new file mode 100644 index 00000000..aba4ca4d --- /dev/null +++ b/hexl_v1/include/hexl/experimental/fft-like/fwd-fft-like-avx512.hpp @@ -0,0 +1,59 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "hexl/experimental/fft-like/fft-like.hpp" + +namespace intel { +namespace hexl { + +#ifdef HEXL_HAS_AVX512DQ + +/// @brief AVX512 implementation of the forward FFT like +/// @param[out] result_cmplx_intrlvd Output data. Overwritten with FFT like +/// output. Result is a vector of double with interleaved real and imaginary +/// numbers. +/// @param[in] operand_cmplx_intrlvd Input data. A vector of double with +/// interleaved real and imaginary numbers. +/// @param[in] roots_of_unity_cmplx_intrlvd Powers of 2n'th root of unity. In +/// bit-reversed order. +/// @param[in] n Size of the transform, i.e. the polynomial degree. Must be a +/// power of two. +/// @param[in] scale Scale applied to output values +/// @param[in] recursion_depth Depth of recursive call +/// @param[in] recursion_half Helper for indexing roots of unity +/// @details The implementation is recursive. The base case is a breadth-first +/// FFT like, where all the butterflies in a given stage are processed before +/// any butterflies in the next stage. The base case is small enough to fit in +/// the smallest cache. Larger FFTs are processed recursively in a depth-first +/// manner, such that an entire subtransform is completed before moving to the +/// next subtransform. This reduces the number of cache misses, improving +/// performance on larger transform sizes. +void Forward_FFTLike_ToBitReverseAVX512( + double* result_cmplx_intrlvd, const double* operand_cmplx_intrlvd, + const double* roots_of_unity_cmplx_intrlvd, const uint64_t n, + const double* scale = nullptr, uint64_t recursion_depth = 0, + uint64_t recursion_half = 0); + +/// @brief Construct floating-point values from CRT-composed polynomial with +/// integer coefficients in AVX512. +/// @param[out] res_cmplx_intrlvd Stores the result +/// @param[in] plain Plaintext +/// @param[in] threshold Upper half threshold with respect to the total +/// coefficient modulus +/// @param[in] decryption_modulus Product of all primes in the coefficient +/// modulus +/// @param[in] inv_scale Scale applied to output values +/// @param[in] mod_size Size of coefficient modulus parameter +/// @param[in] coeff_count Degree of the polynomial modulus parameter +void BuildFloatingPointsAVX512(double* res_cmplx_intrlvd, const uint64_t* plain, + const uint64_t* threshold, + const uint64_t* decryption_modulus, + const double inv_scale, const size_t mod_size, + const size_t coeff_count); + +#endif // HEXL_HAS_AVX512DQ + +} // namespace hexl +} // namespace intel diff --git a/hexl_v1/include/hexl/experimental/fft-like/inv-fft-like-avx512.hpp b/hexl_v1/include/hexl/experimental/fft-like/inv-fft-like-avx512.hpp new file mode 100644 index 00000000..487e2828 --- /dev/null +++ b/hexl_v1/include/hexl/experimental/fft-like/inv-fft-like-avx512.hpp @@ -0,0 +1,42 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "hexl/experimental/fft-like/fft-like.hpp" + +namespace intel { +namespace hexl { + +#ifdef HEXL_HAS_AVX512DQ + +/// @brief AVX512 implementation of the inverse FFT like +/// @param[out] result_cmplx_intrlvd Output data. Overwritten with FFT like +/// output. Result is a vector of double with interleaved real and imaginary +/// numbers. +/// @param[in] operand_cmplx_intrlvd Input data. A vector of double with +/// interleaved real and imaginary numbers. +/// @param[in] inv_roots_of_unity_cmplx_intrlvd Powers of 2n'th root of unity. +/// In bit-reversed order. +/// @param[in] n Size of the transform, i.e. the polynomial degree. Must be a +/// power of two. +/// @param[in] scale Scale applied to output values +/// @param[in] recursion_depth Depth of recursive call +/// @param[in] recursion_half Helper for indexing roots of unity +/// @details The implementation is recursive. The base case is a breadth-first +/// FFT like, where all the butterflies in a given stage are processed before +/// any butterflies in the next stage. The base case is small enough to fit in +/// the smallest cache. Larger FFTs are processed recursively in a depth-first +/// manner, such that an entire subtransform is completed before moving to the +/// next subtransform. This reduces the number of cache misses, improving +/// performance on larger transform sizes. +void Inverse_FFTLike_FromBitReverseAVX512( + double* result_cmplx_intrlvd, const double* operand_cmplx_intrlvd, + const double* inv_root_of_unity_cmplxintrlvd, const uint64_t n, + const double* scale = nullptr, uint64_t recursion_depth = 0, + uint64_t recursion_half = 0); + +#endif // HEXL_HAS_AVX512DQ + +} // namespace hexl +} // namespace intel diff --git a/hexl_v1/include/hexl/experimental/misc/lr-mat-vec-mult.hpp b/hexl_v1/include/hexl/experimental/misc/lr-mat-vec-mult.hpp new file mode 100644 index 00000000..df03df92 --- /dev/null +++ b/hexl_v1/include/hexl/experimental/misc/lr-mat-vec-mult.hpp @@ -0,0 +1,31 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +namespace intel { +namespace hexl { + +/// @brief Computes transposed linear regression +/// @param[in,out] result Ciphertext data. Will be over-written with result. Has +/// (3 * n * num_moduli) elements +/// @param[in] operand1 Vector of ciphertext representing a matrix that encodes +/// a transposed logistic regression model. Has (num_weights * 2 * n * +/// num_moduli) elements. +/// @param[in] operand2 Vector of ciphertext representing a matrix that encodes +/// at most n/2 input samples with feature size num_weights. Has (num_weights * +/// 2 * n * num_moduli) elements. +/// @param[in] n Number of coefficients in each polynomial +/// @param[in] moduli Pointer to contiguous array of num_moduli word-sized +/// coefficient moduli +/// @param[in] num_moduli Number of word-sized coefficient moduli +/// @param[in] num_weights Feature size of the linear/logistic regression model +void LinRegMatrixVectorMultiply(uint64_t* result, const uint64_t* operand1, + const uint64_t* operand2, uint64_t n, + const uint64_t* moduli, uint64_t num_moduli, + uint64_t num_weights); + +} // namespace hexl +} // namespace intel diff --git a/hexl_v1/include/hexl/experimental/seal/dyadic-multiply-internal.hpp b/hexl_v1/include/hexl/experimental/seal/dyadic-multiply-internal.hpp new file mode 100644 index 00000000..310a46b0 --- /dev/null +++ b/hexl_v1/include/hexl/experimental/seal/dyadic-multiply-internal.hpp @@ -0,0 +1,29 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +namespace intel { +namespace hexl { +namespace internal { + +/// @brief Computes dyadic multiplication +/// @param[in,out] result Ciphertext data. Will be over-written with result. Has +/// (2 * n * num_moduli) elements +/// @param[in] operand1 First ciphertext argument. Has (2 * n * num_moduli) +/// elements. +/// @param[in] operand2 Second ciphertext argument. Has (2 * n * num_moduli) +/// elements. +/// @param[in] n Number of coefficients in each polynomial +/// @param[in] moduli Pointer to contiguous array of num_moduli word-sized +/// coefficient moduli +/// @param[in] num_moduli Number of word-sized coefficient moduli +void DyadicMultiply(uint64_t* result, const uint64_t* operand1, + const uint64_t* operand2, uint64_t n, + const uint64_t* moduli, uint64_t num_moduli); + +} // namespace internal +} // namespace hexl +} // namespace intel diff --git a/hexl_v1/include/hexl/experimental/seal/dyadic-multiply.hpp b/hexl_v1/include/hexl/experimental/seal/dyadic-multiply.hpp new file mode 100644 index 00000000..f7eacfdf --- /dev/null +++ b/hexl_v1/include/hexl/experimental/seal/dyadic-multiply.hpp @@ -0,0 +1,27 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +namespace intel { +namespace hexl { + +/// @brief Computes dyadic multiplication +/// @param[in,out] result Ciphertext data. Will be over-written with result. Has +/// (2 * n * num_moduli) elements +/// @param[in] operand1 First ciphertext argument. Has (2 * n * num_moduli) +/// elements. +/// @param[in] operand2 Second ciphertext argument. Has (2 * n * num_moduli) +/// elements. +/// @param[in] n Number of coefficients in each polynomial +/// @param[in] moduli Pointer to contiguous array of num_moduli word-sized +/// coefficient moduli +/// @param[in] num_moduli Number of word-sized coefficient moduli +void DyadicMultiply(uint64_t* result, const uint64_t* operand1, + const uint64_t* operand2, uint64_t n, + const uint64_t* moduli, uint64_t num_moduli); + +} // namespace hexl +} // namespace intel diff --git a/hexl_v1/include/hexl/experimental/seal/key-switch-internal.hpp b/hexl_v1/include/hexl/experimental/seal/key-switch-internal.hpp new file mode 100644 index 00000000..8fc9d53e --- /dev/null +++ b/hexl_v1/include/hexl/experimental/seal/key-switch-internal.hpp @@ -0,0 +1,44 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +namespace intel { +namespace hexl { +namespace internal { + +/// @brief Computes key switching in-place +/// @param[in,out] result Ciphertext data. Will be over-written with result. Has +/// (n * decomp_modulus_size * key_component_count) elements +/// @param[in] t_target_iter_ptr Pointer to the last component of the input +/// ciphertext +/// @param[in] n Number of coefficients in each polynomial +/// @param[in] decomp_modulus_size Number of moduli in the ciphertext at its +/// current level, excluding one auxiliary prime. +/// @param[in] key_modulus_size Number of moduli in the ciphertext at its top +/// level, including one auxiliary prime. +/// @param[in] rns_modulus_size Number of moduli in the ciphertext at its +/// current level, including one auxiliary prime. rns_modulus_size == +/// decomp_modulus_size + 1 +/// @param[in] key_component_count Number of components in the resulting +/// ciphertext, e.g. key_component_count == 2. +/// @param[in] moduli Array of word-sized coefficient moduli. There must be +/// key_modulus_size moduli in the array +/// @param[in] k_switch_keys Array of evaluation key data. Has +/// decomp_modulus_size entries, each with +/// coeff_count * ((key_modulus_size - 1)+ (key_component_count - 1) * +/// (key_modulus_size) + 1) entries +/// @param[in] modswitch_factors Array of modulus switch factors +/// @param[in] root_of_unity_powers_ptr Array of root of unity powers +void KeySwitch(uint64_t* result, const uint64_t* t_target_iter_ptr, uint64_t n, + uint64_t decomp_modulus_size, uint64_t key_modulus_size, + uint64_t rns_modulus_size, uint64_t key_component_count, + const uint64_t* moduli, const uint64_t** k_switch_keys, + const uint64_t* modswitch_factors, + const uint64_t* root_of_unity_powers_ptr = nullptr); + +} // namespace internal +} // namespace hexl +} // namespace intel diff --git a/hexl_v1/include/hexl/experimental/seal/key-switch.hpp b/hexl_v1/include/hexl/experimental/seal/key-switch.hpp new file mode 100644 index 00000000..9eda159c --- /dev/null +++ b/hexl_v1/include/hexl/experimental/seal/key-switch.hpp @@ -0,0 +1,42 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +namespace intel { +namespace hexl { + +/// @brief Computes key switching in-place +/// @param[in,out] result Ciphertext data. Will be over-written with result. Has +/// (n * decomp_modulus_size * key_component_count) elements +/// @param[in] t_target_iter_ptr Pointer to the last component of the input +/// ciphertext +/// @param[in] n Number of coefficients in each polynomial +/// @param[in] decomp_modulus_size Number of moduli in the ciphertext at its +/// current level, excluding one auxiliary prime. +/// @param[in] key_modulus_size Number of moduli in the ciphertext at its top +/// level, including one auxiliary prime. +/// @param[in] rns_modulus_size Number of moduli in the ciphertext at its +/// current level, including one auxiliary prime. rns_modulus_size == +/// decomp_modulus_size + 1 +/// @param[in] key_component_count Number of components in the resulting +/// ciphertext, e.g. key_component_count == 2. +/// @param[in] moduli Array of word-sized coefficient moduli. There must be +/// key_modulus_size moduli in the array +/// @param[in] k_switch_keys Array of evaluation key data. Has +/// decomp_modulus_size entries, each with +/// coeff_count * ((key_modulus_size - 1)+ (key_component_count - 1) * +/// (key_modulus_size) + 1) entries +/// @param[in] modswitch_factors Array of modulus switch factors +/// @param[in] root_of_unity_powers_ptr Array of root of unity powers +void KeySwitch(uint64_t* result, const uint64_t* t_target_iter_ptr, uint64_t n, + uint64_t decomp_modulus_size, uint64_t key_modulus_size, + uint64_t rns_modulus_size, uint64_t key_component_count, + const uint64_t* moduli, const uint64_t** k_switch_keys, + const uint64_t* modswitch_factors, + const uint64_t* root_of_unity_powers_ptr = nullptr); + +} // namespace hexl +} // namespace intel diff --git a/hexl_v1/include/hexl/experimental/seal/locks.hpp b/hexl_v1/include/hexl/experimental/seal/locks.hpp new file mode 100644 index 00000000..4595f4e5 --- /dev/null +++ b/hexl_v1/include/hexl/experimental/seal/locks.hpp @@ -0,0 +1,35 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include + +namespace intel { +namespace hexl { + +using Lock = std::shared_mutex; +using WriteLock = std::unique_lock; +using ReadLock = std::shared_lock; + +class RWLock { + public: + RWLock() = default; + inline ReadLock AcquireRead() { return ReadLock(rw_mutex); } + inline WriteLock AcquireWrite() { return WriteLock(rw_mutex); } + inline ReadLock TryAcquireRead() noexcept { + return ReadLock(rw_mutex, std::try_to_lock); + } + inline WriteLock TryAcquireWrite() noexcept { + return WriteLock(rw_mutex, std::try_to_lock); + } + + private: + RWLock(const RWLock& copy) = delete; + RWLock& operator=(const RWLock& assign) = delete; + Lock rw_mutex{}; +}; + +} // namespace hexl +} // namespace intel diff --git a/hexl_v1/include/hexl/experimental/seal/ntt-cache.hpp b/hexl_v1/include/hexl/experimental/seal/ntt-cache.hpp new file mode 100644 index 00000000..8f6c1046 --- /dev/null +++ b/hexl_v1/include/hexl/experimental/seal/ntt-cache.hpp @@ -0,0 +1,56 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "hexl/experimental/seal/locks.hpp" +#include "ntt/ntt-internal.hpp" + +namespace intel { +namespace hexl { + +struct HashPair { + template + std::size_t operator()(const std::pair& p) const { + auto hash1 = std::hash{}(p.first); + auto hash2 = std::hash{}(p.second); + return hash_combine(hash1, hash2); + } + + // Golden Ratio Hashing with seeds + static std::size_t hash_combine(std::size_t lhs, std::size_t rhs) { + lhs ^= rhs + 0x9e3779b9 + (lhs << 6) + (lhs >> 2); + return lhs; + } +}; + +NTT& GetNTT(size_t N, uint64_t modulus) { + static std::unordered_map, NTT, HashPair> + ntt_cache; + static RWLock ntt_cache_locker; + + std::pair key{N, modulus}; + + // Enable shared access to NTT already present + { + ReadLock reader_lock(ntt_cache_locker.AcquireRead()); + auto ntt_it = ntt_cache.find(key); + if (ntt_it != ntt_cache.end()) { + return ntt_it->second; + } + } + + // Deal with NTT not yet present + WriteLock write_lock(ntt_cache_locker.AcquireWrite()); + + // Check ntt_cache for value (may be added by another thread) + auto ntt_it = ntt_cache.find(key); + if (ntt_it == ntt_cache.end()) { + NTT ntt(N, modulus); + ntt_it = ntt_cache.emplace(std::move(key), std::move(ntt)).first; + } + return ntt_it->second; +} + +} // namespace hexl +} // namespace intel diff --git a/hexl_v1/include/hexl/hexl.hpp b/hexl_v1/include/hexl/hexl.hpp new file mode 100644 index 00000000..6f07ae57 --- /dev/null +++ b/hexl_v1/include/hexl/hexl.hpp @@ -0,0 +1,26 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "hexl/eltwise/eltwise-add-mod.hpp" +#include "hexl/eltwise/eltwise-cmp-add.hpp" +#include "hexl/eltwise/eltwise-cmp-sub-mod.hpp" +#include "hexl/eltwise/eltwise-fma-mod.hpp" +#include "hexl/eltwise/eltwise-mult-mod.hpp" +#include "hexl/eltwise/eltwise-reduce-mod.hpp" +#include "hexl/eltwise/eltwise-sub-mod.hpp" +#include "hexl/experimental/fft-like/fft-like.hpp" +#include "hexl/experimental/misc/lr-mat-vec-mult.hpp" +#include "hexl/experimental/seal/dyadic-multiply-internal.hpp" +#include "hexl/experimental/seal/dyadic-multiply.hpp" +#include "hexl/experimental/seal/key-switch-internal.hpp" +#include "hexl/experimental/seal/key-switch.hpp" +#include "hexl/logging/logging.hpp" +#include "hexl/ntt/ntt.hpp" +#include "hexl/number-theory/number-theory.hpp" +#include "hexl/util/check.hpp" +#include "hexl/util/compiler.hpp" +#include "hexl/util/defines.hpp" +#include "hexl/util/types.hpp" +#include "hexl/util/util.hpp" diff --git a/hexl_v1/include/hexl/logging/logging.hpp b/hexl_v1/include/hexl/logging/logging.hpp new file mode 100644 index 00000000..af5bfcd8 --- /dev/null +++ b/hexl_v1/include/hexl/logging/logging.hpp @@ -0,0 +1,43 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include + +#include "hexl/util/defines.hpp" + +// Wrap HEXL_VLOG with HEXL_DEBUG; this ensures no logging overhead in +// release mode +#ifdef HEXL_DEBUG + +// TODO(fboemer) Enable if needed +// #define ELPP_THREAD_SAFE +#define ELPP_CUSTOM_COUT std::cerr +#define ELPP_STL_LOGGING +#define ELPP_LOG_STD_ARRAY +#define ELPP_LOG_UNORDERED_MAP +#define ELPP_LOG_UNORDERED_SET +#define ELPP_NO_LOG_TO_FILE +#define ELPP_DISABLE_DEFAULT_CRASH_HANDLING +#define ELPP_WINSOCK2 + +#include + +#define HEXL_VLOG(N, rest) \ + do { \ + if (VLOG_IS_ON(N)) { \ + VLOG(N) << rest; \ + } \ + } while (0); + +#else + +#define HEXL_VLOG(N, rest) \ + {} + +#define START_EASYLOGGINGPP(X, Y) \ + {} + +#endif diff --git a/hexl_v1/include/hexl/ntt/ntt.hpp b/hexl_v1/include/hexl/ntt/ntt.hpp new file mode 100644 index 00000000..93ccba72 --- /dev/null +++ b/hexl_v1/include/hexl/ntt/ntt.hpp @@ -0,0 +1,296 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +#include +#include + +#include "hexl/util/aligned-allocator.hpp" +#include "hexl/util/allocator.hpp" + +namespace intel { +namespace hexl { + +/// @brief Performs negacyclic forward and inverse number-theoretic transform +/// (NTT), commonly used in RLWE cryptography. +/// @details The number-theoretic transform (NTT) specializes the discrete +/// Fourier transform (DFT) to the finite field \f$ \mathbb{Z}_q[X] / (X^N + 1) +/// \f$. +class NTT { + public: + /// @brief Helper class for custom memory allocation + template + struct AllocatorAdapter + : public AllocatorInterface> { + explicit AllocatorAdapter(Adaptee&& _a, Args&&... args); + AllocatorAdapter(const Adaptee& _a, Args&... args); + + // interface implementation + void* allocate_impl(size_t bytes_count); + void deallocate_impl(void* p, size_t n); + + private: + Adaptee alloc; + }; + + /// @brief Initializes an empty NTT object + NTT() = default; + + /// @brief Destructs the NTT object + ~NTT() = default; + + /// @brief Initializes an NTT object with degree \p degree and modulus \p q. + /// @param[in] degree also known as N. Size of the NTT transform. Must be a + /// power of + /// 2 + /// @param[in] q Prime modulus. Must satisfy \f$ q == 1 \mod 2N \f$ + /// @param[in] alloc_ptr Custom memory allocator used for intermediate + /// calculations + /// @brief Performs pre-computation necessary for forward and inverse + /// transforms + NTT(uint64_t degree, uint64_t q, + std::shared_ptr alloc_ptr = {}); + + template + NTT(uint64_t degree, uint64_t q, Allocator&& a, AllocatorArgs&&... args) + : NTT(degree, q, + std::static_pointer_cast( + std::make_shared>( + std::move(a), std::forward(args)...))) {} + + /// @brief Initializes an NTT object with degree \p degree and modulus + /// \p q. + /// @param[in] degree also known as N. Size of the NTT transform. Must be a + /// power of + /// 2 + /// @param[in] q Prime modulus. Must satisfy \f$ q == 1 \mod 2N \f$ + /// @param[in] root_of_unity 2N'th root of unity in \f$ \mathbb{Z_q} \f$. + /// @param[in] alloc_ptr Custom memory allocator used for intermediate + /// calculations + /// @details Performs pre-computation necessary for forward and inverse + /// transforms + NTT(uint64_t degree, uint64_t q, uint64_t root_of_unity, + std::shared_ptr alloc_ptr = {}); + + template + NTT(uint64_t degree, uint64_t q, uint64_t root_of_unity, Allocator&& a, + AllocatorArgs&&... args) + : NTT(degree, q, root_of_unity, + std::static_pointer_cast( + std::make_shared>( + std::move(a), std::forward(args)...))) {} + + /// @brief Returns true if arguments satisfy constraints for negacyclic NTT + /// @param[in] degree N. Size of the transform, i.e. the polynomial degree. + /// Must be a power of two. + /// @param[in] modulus Prime modulus q. Must satisfy q mod 2N = 1 + static bool CheckArguments(uint64_t degree, uint64_t modulus); + + /// @brief Compute forward NTT. Results are bit-reversed. + /// @param[out] result Stores the result + /// @param[in] operand Data on which to compute the NTT + /// @param[in] input_mod_factor Assume input \p operand are in [0, + /// input_mod_factor * q). Must be 1, 2 or 4. + /// @param[in] output_mod_factor Returns output \p result in [0, + /// output_mod_factor * q). Must be 1 or 4. + void ComputeForward(uint64_t* result, const uint64_t* operand, + uint64_t input_mod_factor, uint64_t output_mod_factor); + + /// Compute inverse NTT. Results are bit-reversed. + /// @param[out] result Stores the result + /// @param[in] operand Data on which to compute the NTT + /// @param[in] input_mod_factor Assume input \p operand are in [0, + /// input_mod_factor * q). Must be 1 or 2. + /// @param[in] output_mod_factor Returns output \p result in [0, + /// output_mod_factor * q). Must be 1 or 2. + void ComputeInverse(uint64_t* result, const uint64_t* operand, + uint64_t input_mod_factor, uint64_t output_mod_factor); + + /// @brief Returns the minimal 2N'th root of unity + uint64_t GetMinimalRootOfUnity() const { return m_w; } + + /// @brief Returns the degree N + uint64_t GetDegree() const { return m_degree; } + + /// @brief Returns the word-sized prime modulus + uint64_t GetModulus() const { return m_q; } + + /// @brief Returns the root of unity powers in bit-reversed order + const AlignedVector64& GetRootOfUnityPowers() const { + return m_root_of_unity_powers; + } + + /// @brief Returns the root of unity power at bit-reversed index i. + uint64_t GetRootOfUnityPower(size_t i) { return GetRootOfUnityPowers()[i]; } + + /// @brief Returns 32-bit pre-conditioned root of unity powers in + /// bit-reversed order + const AlignedVector64& GetPrecon32RootOfUnityPowers() const { + return m_precon32_root_of_unity_powers; + } + + /// @brief Returns 64-bit pre-conditioned root of unity powers in + /// bit-reversed order + const AlignedVector64& GetPrecon64RootOfUnityPowers() const { + return m_precon64_root_of_unity_powers; + } + + /// @brief Returns the root of unity powers in bit-reversed order with + /// modifications for use by AVX512 implementation + const AlignedVector64& GetAVX512RootOfUnityPowers() const { + return m_avx512_root_of_unity_powers; + } + + /// @brief Returns 32-bit pre-conditioned AVX512 root of unity powers in + /// bit-reversed order + const AlignedVector64& GetAVX512Precon32RootOfUnityPowers() const { + return m_avx512_precon32_root_of_unity_powers; + } + + /// @brief Returns 52-bit pre-conditioned AVX512 root of unity powers in + /// bit-reversed order + const AlignedVector64& GetAVX512Precon52RootOfUnityPowers() const { + return m_avx512_precon52_root_of_unity_powers; + } + + /// @brief Returns 64-bit pre-conditioned AVX512 root of unity powers in + /// bit-reversed order + const AlignedVector64& GetAVX512Precon64RootOfUnityPowers() const { + return m_avx512_precon64_root_of_unity_powers; + } + + /// @brief Returns the inverse root of unity powers in bit-reversed order + const AlignedVector64& GetInvRootOfUnityPowers() const { + return m_inv_root_of_unity_powers; + } + + /// @brief Returns the inverse root of unity power at bit-reversed index i. + uint64_t GetInvRootOfUnityPower(size_t i) { + return GetInvRootOfUnityPowers()[i]; + } + + /// @brief Returns the vector of 32-bit pre-conditioned pre-computed root of + /// unity + // powers for the modulus and root of unity. + const AlignedVector64& GetPrecon32InvRootOfUnityPowers() const { + return m_precon32_inv_root_of_unity_powers; + } + + /// @brief Returns the vector of 52-bit pre-conditioned pre-computed root of + /// unity + // powers for the modulus and root of unity. + const AlignedVector64& GetPrecon52InvRootOfUnityPowers() const { + return m_precon52_inv_root_of_unity_powers; + } + + /// @brief Returns the vector of 64-bit pre-conditioned pre-computed root of + /// unity + // powers for the modulus and root of unity. + const AlignedVector64& GetPrecon64InvRootOfUnityPowers() const { + return m_precon64_inv_root_of_unity_powers; + } + + /// @brief Maximum power of 2 in degree + static size_t MaxDegreeBits() { return 20; } + + /// @brief Maximum number of bits in modulus; + static size_t MaxModulusBits() { return 62; } + + /// @brief Default bit shift used in Barrett precomputation + static const size_t s_default_shift_bits{64}; + + /// @brief Bit shift used in Barrett precomputation when AVX512-IFMA + /// acceleration is enabled + static const size_t s_ifma_shift_bits{52}; + + /// @brief Maximum modulus to use 32-bit AVX512-DQ acceleration for the + /// forward transform + static const size_t s_max_fwd_32_modulus{1ULL << (32 - 2)}; + + /// @brief Maximum modulus to use 32-bit AVX512-DQ acceleration for the + /// inverse transform + static const size_t s_max_inv_32_modulus{1ULL << (32 - 2)}; + + /// @brief Maximum modulus to use AVX512-IFMA acceleration for the forward + /// transform + static const size_t s_max_fwd_ifma_modulus{1ULL << (s_ifma_shift_bits - 2)}; + + /// @brief Maximum modulus to use AVX512-IFMA acceleration for the inverse + /// transform + static const size_t s_max_inv_ifma_modulus{1ULL << (s_ifma_shift_bits - 2)}; + + /// @brief Maximum modulus to use AVX512-DQ acceleration for the inverse + /// transform + static const size_t s_max_inv_dq_modulus{1ULL << (s_default_shift_bits - 2)}; + + static size_t s_max_fwd_modulus(int bit_shift) { + if (bit_shift == 32) { + return s_max_fwd_32_modulus; + } else if (bit_shift == 52) { + return s_max_fwd_ifma_modulus; + } else if (bit_shift == 64) { + return 1ULL << MaxModulusBits(); + } + HEXL_CHECK(false, "Invalid bit_shift " << bit_shift); + return 0; + } + + static size_t s_max_inv_modulus(int bit_shift) { + if (bit_shift == 32) { + return s_max_inv_32_modulus; + } else if (bit_shift == 52) { + return s_max_inv_ifma_modulus; + } else if (bit_shift == 64) { + return 1ULL << MaxModulusBits(); + } + HEXL_CHECK(false, "Invalid bit_shift " << bit_shift); + return 0; + } + + private: + void ComputeRootOfUnityPowers(); + + uint64_t m_degree; // N: size of NTT transform, should be power of 2 + uint64_t m_q; // prime modulus. Must satisfy q == 1 mod 2n + + uint64_t m_degree_bits; // log_2(m_degree) + + uint64_t m_w_inv; // Inverse of minimal root of unity + uint64_t m_w; // A 2N'th root of unity + + std::shared_ptr m_alloc; + + AlignedAllocator m_aligned_alloc; + + // powers of the minimal root of unity + AlignedVector64 m_root_of_unity_powers; + // vector of floor(W * 2**32 / m_q), with W the root of unity powers + AlignedVector64 m_precon32_root_of_unity_powers; + // vector of floor(W * 2**64 / m_q), with W the root of unity powers + AlignedVector64 m_precon64_root_of_unity_powers; + + // powers of the minimal root of unity adjusted for use in AVX512 + // implementations + AlignedVector64 m_avx512_root_of_unity_powers; + // vector of floor(W * 2**32 / m_q), with W the AVX512 root of unity powers + AlignedVector64 m_avx512_precon32_root_of_unity_powers; + // vector of floor(W * 2**52 / m_q), with W the AVX512 root of unity powers + AlignedVector64 m_avx512_precon52_root_of_unity_powers; + // vector of floor(W * 2**64 / m_q), with W the AVX512 root of unity powers + AlignedVector64 m_avx512_precon64_root_of_unity_powers; + + // vector of floor(W * 2**32 / m_q), with W the inverse root of unity powers + AlignedVector64 m_precon32_inv_root_of_unity_powers; + // vector of floor(W * 2**52 / m_q), with W the inverse root of unity powers + AlignedVector64 m_precon52_inv_root_of_unity_powers; + // vector of floor(W * 2**64 / m_q), with W the inverse root of unity powers + AlignedVector64 m_precon64_inv_root_of_unity_powers; + + AlignedVector64 m_inv_root_of_unity_powers; +}; + +} // namespace hexl +} // namespace intel diff --git a/hexl_v1/include/hexl/number-theory/number-theory.hpp b/hexl_v1/include/hexl/number-theory/number-theory.hpp new file mode 100644 index 00000000..da8d1d2a --- /dev/null +++ b/hexl_v1/include/hexl/number-theory/number-theory.hpp @@ -0,0 +1,342 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +#include +#include + +#include "hexl/util/check.hpp" +#include "hexl/util/compiler.hpp" + +namespace intel { +namespace hexl { + +/// @brief Pre-computes a Barrett factor with which modular multiplication can +/// be performed more efficiently +class MultiplyFactor { + public: + MultiplyFactor() = default; + + /// @brief Computes and stores the Barrett factor floor((operand << bit_shift) + /// / modulus). This is useful when modular multiplication of the form + /// (x * operand) mod modulus is performed with same modulus and operand + /// several times. Note, passing operand=1 can be used to pre-compute a + /// Barrett factor for multiplications of the form (x * y) mod modulus, where + /// only the modulus is re-used across calls to modular multiplication. + MultiplyFactor(uint64_t operand, uint64_t bit_shift, uint64_t modulus) + : m_operand(operand) { + HEXL_CHECK(operand <= modulus, "operand " << operand + << " must be less than modulus " + << modulus); + HEXL_CHECK(bit_shift == 32 || bit_shift == 52 || bit_shift == 64, + "Unsupported BitShift " << bit_shift); + uint64_t op_hi = operand >> (64 - bit_shift); + uint64_t op_lo = (bit_shift == 64) ? 0 : (operand << bit_shift); + + m_barrett_factor = DivideUInt128UInt64Lo(op_hi, op_lo, modulus); + } + + /// @brief Returns the pre-computed Barrett factor + inline uint64_t BarrettFactor() const { return m_barrett_factor; } + + /// @brief Returns the operand corresponding to the Barrett factor + inline uint64_t Operand() const { return m_operand; } + + private: + uint64_t m_operand; + uint64_t m_barrett_factor; +}; + +/// @brief Returns whether or not num is a power of two +inline bool IsPowerOfTwo(uint64_t num) { return num && !(num & (num - 1)); } + +/// @brief Returns floor(log2(x)) +inline uint64_t Log2(uint64_t x) { return MSB(x); } + +inline bool IsPowerOfFour(uint64_t num) { + return IsPowerOfTwo(num) && (Log2(num) % 2 == 0); +} + +/// @brief Returns the maximum value that can be represented using \p bits bits +inline uint64_t MaximumValue(uint64_t bits) { + HEXL_CHECK(bits <= 64, "MaximumValue requires bits <= 64; got " << bits); + if (bits == 64) { + return (std::numeric_limits::max)(); + } + return (1ULL << bits) - 1; +} + +/// @brief Reverses the bits +/// @param[in] x Input to reverse +/// @param[in] bit_width Number of bits in the input; must be >= MSB(x) +/// @return The bit-reversed representation of \p x using \p bit_width bits +uint64_t ReverseBits(uint64_t x, uint64_t bit_width); + +/// @brief Returns x^{-1} mod modulus +/// @details Requires x % modulus != 0 +uint64_t InverseMod(uint64_t x, uint64_t modulus); + +/// @brief Returns (x * y) mod modulus +/// @details Assumes x, y < modulus +uint64_t MultiplyMod(uint64_t x, uint64_t y, uint64_t modulus); + +/// @brief Returns (x * y) mod modulus +/// @param[in] x +/// @param[in] y +/// @param[in] y_precon 64-bit precondition factor floor(2**64 / modulus) +/// @param[in] modulus +uint64_t MultiplyMod(uint64_t x, uint64_t y, uint64_t y_precon, + uint64_t modulus); + +/// @brief Returns (x + y) mod modulus +/// @details Assumes x, y < modulus +uint64_t AddUIntMod(uint64_t x, uint64_t y, uint64_t modulus); + +/// @brief Returns (x - y) mod modulus +/// @details Assumes x, y < modulus +uint64_t SubUIntMod(uint64_t x, uint64_t y, uint64_t modulus); + +/// @brief Returns base^exp mod modulus +uint64_t PowMod(uint64_t base, uint64_t exp, uint64_t modulus); + +/// @brief Returns whether or not root is a degree-th root of unity mod modulus +/// @param[in] root Root of unity to check +/// @param[in] degree Degree of root of unity; must be a power of two +/// @param[in] modulus Modulus of finite field +bool IsPrimitiveRoot(uint64_t root, uint64_t degree, uint64_t modulus); + +/// @brief Tries to return a primitive degree-th root of unity +/// @details Returns 0 or throws an error if no root is found +uint64_t GeneratePrimitiveRoot(uint64_t degree, uint64_t modulus); + +/// @brief Returns whether or not root is a degree-th root of unity +/// @param[in] degree Must be a power of two +/// @param[in] modulus Modulus of finite field +uint64_t MinimalPrimitiveRoot(uint64_t degree, uint64_t modulus); + +/// @brief Computes (x * y) mod modulus, except that the output is in [0, 2 * +/// modulus] +/// @param[in] x +/// @param[in] y_operand also denoted y +/// @param[in] modulus +/// @param[in] y_barrett_factor Pre-computed Barrett reduction factor floor((y +/// << BitShift) / modulus) +template +inline uint64_t MultiplyModLazy(uint64_t x, uint64_t y_operand, + uint64_t y_barrett_factor, uint64_t modulus) { + HEXL_CHECK(y_operand < modulus, "y_operand " << y_operand + << " must be less than modulus " + << modulus); + HEXL_CHECK( + modulus <= MaximumValue(BitShift), + "Modulus " << modulus << " exceeds bound " << MaximumValue(BitShift)); + HEXL_CHECK(x <= MaximumValue(BitShift), + "Operand " << x << " exceeds bound " << MaximumValue(BitShift)); + + uint64_t Q = MultiplyUInt64Hi(x, y_barrett_factor); + return y_operand * x - Q * modulus; +} + +/// @brief Computes (x * y) mod modulus, except that the output is in [0, 2 * +/// modulus] +/// @param[in] x +/// @param[in] y +/// @param[in] modulus +template +inline uint64_t MultiplyModLazy(uint64_t x, uint64_t y, uint64_t modulus) { + HEXL_CHECK(BitShift == 64 || BitShift == 52, + "Unsupported BitShift " << BitShift); + HEXL_CHECK(x <= MaximumValue(BitShift), + "Operand " << x << " exceeds bound " << MaximumValue(BitShift)); + HEXL_CHECK(y < modulus, + "y " << y << " must be less than modulus " << modulus); + HEXL_CHECK( + modulus <= MaximumValue(BitShift), + "Modulus " << modulus << " exceeds bound " << MaximumValue(BitShift)); + + uint64_t y_barrett = MultiplyFactor(y, BitShift, modulus).BarrettFactor(); + return MultiplyModLazy(x, y, y_barrett, modulus); +} + +/// @brief Adds two unsigned 64-bit integers +/// @param operand1 Number to add +/// @param operand2 Number to add +/// @param result Stores the sum +/// @return The carry bit +inline unsigned char AddUInt64(uint64_t operand1, uint64_t operand2, + uint64_t* result) { + *result = operand1 + operand2; + return static_cast(*result < operand1); +} + +/// @brief Returns whether or not the input is prime +bool IsPrime(uint64_t n); + +/// @brief Generates a list of num_primes primes in the range [2^(bit_size), +// 2^(bit_size+1)]. Ensures each prime q satisfies +// q % (2*ntt_size+1)) == 1 +/// @param[in] num_primes Number of primes to generate +/// @param[in] bit_size Bit size of each prime +/// @param[in] prefer_small_primes When true, returns primes starting from +/// 2^(bit_size); when false, returns primes starting from 2^(bit_size+1) +/// @param[in] ntt_size N such that each prime q satisfies q % (2N) == 1. N must +/// be a power of two less than 2^bit_size. +std::vector GeneratePrimes(size_t num_primes, size_t bit_size, + bool prefer_small_primes, + size_t ntt_size = 1); + +/// @brief Returns input mod modulus, computed via 64-bit Barrett reduction +/// @param[in] input +/// @param[in] modulus +/// @param[in] q_barr floor(2^64 / modulus) +template +uint64_t BarrettReduce64(uint64_t input, uint64_t modulus, uint64_t q_barr) { + HEXL_CHECK(modulus != 0, "modulus == 0"); + uint64_t q = MultiplyUInt64Hi<64>(input, q_barr); + uint64_t q_times_input = input - q * modulus; + if (OutputModFactor == 2) { + return q_times_input; + } else { + return (q_times_input >= modulus) ? q_times_input - modulus : q_times_input; + } +} + +/// @brief Returns x mod modulus, assuming x < InputModFactor * modulus +/// @param[in] x +/// @param[in] modulus also denoted q +/// @param[in] twice_modulus 2 * q; must not be nullptr if InputModFactor == 4 +/// or 8 +/// @param[in] four_times_modulus 4 * q; must not be nullptr if InputModFactor +/// == 8 +template +uint64_t ReduceMod(uint64_t x, uint64_t modulus, + const uint64_t* twice_modulus = nullptr, + const uint64_t* four_times_modulus = nullptr) { + HEXL_CHECK(InputModFactor == 1 || InputModFactor == 2 || + InputModFactor == 4 || InputModFactor == 8, + "InputModFactor should be 1, 2, 4, or 8"); + if (InputModFactor == 1) { + return x; + } + if (InputModFactor == 2) { + if (x >= modulus) { + x -= modulus; + } + return x; + } + if (InputModFactor == 4) { + HEXL_CHECK(twice_modulus != nullptr, "twice_modulus should not be nullptr"); + if (x >= *twice_modulus) { + x -= *twice_modulus; + } + if (x >= modulus) { + x -= modulus; + } + return x; + } + if (InputModFactor == 8) { + HEXL_CHECK(twice_modulus != nullptr, "twice_modulus should not be nullptr"); + HEXL_CHECK(four_times_modulus != nullptr, + "four_times_modulus should not be nullptr"); + + if (x >= *four_times_modulus) { + x -= *four_times_modulus; + } + if (x >= *twice_modulus) { + x -= *twice_modulus; + } + if (x >= modulus) { + x -= modulus; + } + return x; + } + HEXL_CHECK(false, "Should be unreachable"); + return x; +} + +/// @brief Returns Montgomery form of ab mod q, computed via the REDC algorithm, +/// also known as Montgomery reduction. +/// @param[in] r +/// @param[in] q with R = 2^r such that gcd(R, q) = 1. R > q. +/// @param[in] inv_mod in [0, R − 1] such that q*v_inv_mod ≡ −1 mod R. +/// @param[in] mod_R_msk take r last bits to apply mod R. +/// @param[in] T_hi of T = ab in the range [0, Rq − 1]. +/// @param[in] T_lo of T. +/// @return Unsigned long int in the range [0, q − 1] such that S ≡ TR^−1 mod q +template +inline uint64_t MontgomeryReduce(uint64_t T_hi, uint64_t T_lo, uint64_t q, + int r, uint64_t mod_R_msk, uint64_t inv_mod) { + HEXL_CHECK(BitShift == 64 || BitShift == 52, + "Unsupported BitShift " << BitShift); + HEXL_CHECK((1ULL << r) > static_cast(q), + "R value should be greater than q = " << static_cast(q)); + + uint64_t mq_hi; + uint64_t mq_lo; + + uint64_t m = ((T_lo & mod_R_msk) * inv_mod) & mod_R_msk; + MultiplyUInt64(m, q, &mq_hi, &mq_lo); + + if (BitShift == 52) { + mq_hi = (mq_hi << 12) | (mq_lo >> 52); + mq_lo &= (1ULL << 52) - 1; + } + + uint64_t t_hi; + uint64_t t_lo; + + // first 64bit block + t_lo = T_lo + mq_lo; + unsigned int carry = static_cast(t_lo < T_lo); + t_hi = T_hi + mq_hi + carry; + + t_hi = t_hi << (BitShift - r); + t_lo = t_lo >> r; + t_lo = t_hi + t_lo; + + return (t_lo >= q) ? (t_lo - q) : t_lo; +} + +/// @brief Hensel's Lemma for 2-adic numbers +/// Find solution for qX + 1 = 0 mod 2^r +/// @param[in] r +/// @param[in] q such that gcd(2, q) = 1 +/// @return Unsigned long int in [0, 2^r − 1] such that q*x ≡ −1 mod 2^r +inline uint64_t HenselLemma2adicRoot(uint32_t r, uint64_t q) { + uint64_t a_prev = 1; + uint64_t c = 2; + uint64_t mod_mask = 3; + + // Root: + // f(x) = qX + 1 and a_(0) = 1 then f(1) ≡ 0 mod 2 + // General Case: + // - a_(n) ≡ a_(n-1) mod 2^(n) + // => a_(n) = a_(n-1) + 2^(n)*t + // - Find 't' such that f(a_(n)) = 0 mod 2^(n+1) + // First case in for: + // - a_(1) ≡ 1 mod 2 or a_(1) = 1 + 2t + // - Find 't' so f(a_(1)) ≡ 0 mod 4 => q(1 + 2t) + 1 ≡ 0 mod 4 + for (uint64_t k = 2; k <= r; k++) { + uint64_t f = 0; + uint64_t t = 0; + uint64_t a = 0; + + do { + a = a_prev + c * t++; + f = q * a + 1ULL; + } while (f & mod_mask); // f(a) ≡ 0 mod 2^(k) + + // Update vars + mod_mask = mod_mask * 2 + 1ULL; + c *= 2; + a_prev = a; + } + + return a_prev; +} + +} // namespace hexl +} // namespace intel diff --git a/hexl_v1/include/hexl/util/aligned-allocator.hpp b/hexl_v1/include/hexl/util/aligned-allocator.hpp new file mode 100644 index 00000000..d175c734 --- /dev/null +++ b/hexl_v1/include/hexl/util/aligned-allocator.hpp @@ -0,0 +1,110 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include +#include + +#include "hexl/number-theory/number-theory.hpp" +#include "hexl/util/allocator.hpp" +#include "hexl/util/defines.hpp" + +namespace intel { +namespace hexl { + +/// @brief Allocater implementation using malloc and free +struct MallocStrategy : AllocatorBase { + void* allocate(size_t bytes_count) final { return std::malloc(bytes_count); } + + void deallocate(void* p, size_t n) final { + HEXL_UNUSED(n); + std::free(p); + } +}; + +using AllocatorStrategyPtr = std::shared_ptr; +extern AllocatorStrategyPtr mallocStrategy; + +/// @brief Allocates memory aligned to Alignment-byte sized boundaries +/// @details Alignment must be a power of two +template +class AlignedAllocator { + public: + template + friend class AlignedAllocator; + + using value_type = T; + + explicit AlignedAllocator(AllocatorStrategyPtr strategy = nullptr) noexcept + : m_alloc_impl((strategy != nullptr) ? strategy : mallocStrategy) {} + + AlignedAllocator(const AlignedAllocator& src) = default; + AlignedAllocator& operator=(const AlignedAllocator& src) = default; + + template + AlignedAllocator(const AlignedAllocator& src) + : m_alloc_impl(src.m_alloc_impl) {} + + ~AlignedAllocator() {} + + template + struct rebind { + using other = AlignedAllocator; + }; + + bool operator==(const AlignedAllocator&) { return true; } + + bool operator!=(const AlignedAllocator&) { return false; } + + /// @brief Allocates \p n elements aligned to Alignment-byte boundaries + /// @return Pointer to the aligned allocated memory + T* allocate(size_t n) { + if (!IsPowerOfTwo(Alignment)) { + return nullptr; + } + // Allocate enough space to ensure the alignment can be satisfied + size_t buffer_size = sizeof(T) * n + Alignment; + // Additionally, allocate a prefix to store the memory location of the + // unaligned buffer + size_t alloc_size = buffer_size + sizeof(void*); + void* buffer = m_alloc_impl->allocate(alloc_size); + if (!buffer) { + return nullptr; + } + + // Reserve first location for pointer to originally-allocated space + void* aligned_buffer = static_cast(buffer) + sizeof(void*); + std::align(Alignment, sizeof(T) * n, aligned_buffer, buffer_size); + if (!aligned_buffer) { + return nullptr; + } + + // Store allocated buffer address at aligned_buffer - sizeof(void*). + void* store_buffer_addr = + static_cast(aligned_buffer) - sizeof(void*); + *(static_cast(store_buffer_addr)) = buffer; + + return static_cast(aligned_buffer); + } + + void deallocate(T* p, size_t n) { + if (!p) { + return; + } + void* store_buffer_addr = (reinterpret_cast(p) - sizeof(void*)); + void* free_address = *(static_cast(store_buffer_addr)); + m_alloc_impl->deallocate(free_address, n); + } + + private: + AllocatorStrategyPtr m_alloc_impl; +}; + +/// @brief 64-byte aligned memory allocator +template +using AlignedVector64 = std::vector >; + +} // namespace hexl +} // namespace intel diff --git a/hexl_v1/include/hexl/util/allocator.hpp b/hexl_v1/include/hexl/util/allocator.hpp new file mode 100644 index 00000000..5f4a7a31 --- /dev/null +++ b/hexl_v1/include/hexl/util/allocator.hpp @@ -0,0 +1,53 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +namespace intel { +namespace hexl { + +/// @brief Base class for custom memory allocator +struct AllocatorBase { + virtual ~AllocatorBase() noexcept {} + + /// @brief Allocates byte_count bytes of memory + /// @param[in] bytes_count Number of bytes to allocate + /// @return A pointer to the allocated memory + virtual void* allocate(size_t bytes_count) = 0; + + /// @brief Deallocate memory + /// @param[in] p Pointer to memory to deallocate + /// @param[in] n Number of bytes to deallocate + virtual void deallocate(void* p, size_t n) = 0; +}; + +/// @brief Helper memory allocation struct which delegates implementation to +/// AllocatorImpl +template +struct AllocatorInterface : public AllocatorBase { + /// @brief Override interface and delegate implementation to AllocatorImpl + void* allocate(size_t bytes_count) override { + return static_cast(this)->allocate_impl(bytes_count); + } + + /// @brief Override interface and delegate implementation to AllocatorImpl + void deallocate(void* p, size_t n) override { + static_cast(this)->deallocate_impl(p, n); + } + + private: + // in case AllocatorImpl doesn't provide implementations, use default null + // behavior + void* allocate_impl(size_t bytes_count) { + HEXL_UNUSED(bytes_count); + return nullptr; + } + void deallocate_impl(void* p, size_t n) { + HEXL_UNUSED(p); + HEXL_UNUSED(n); + } +}; +} // namespace hexl +} // namespace intel diff --git a/hexl_v1/include/hexl/util/check.hpp b/hexl_v1/include/hexl/util/check.hpp new file mode 100644 index 00000000..386eba89 --- /dev/null +++ b/hexl_v1/include/hexl/util/check.hpp @@ -0,0 +1,44 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +#include "hexl/util/types.hpp" + +// Create logging/debug macros with no run-time overhead unless HEXL_DEBUG is +// enabled +#ifdef HEXL_DEBUG +#include "hexl/logging/logging.hpp" + +/// @brief If input condition is not true, logs the expression and throws an +/// error +/// @param[in] cond A boolean indication the condition +/// @param[in] expr The expression to be logged +#define HEXL_CHECK(cond, expr) \ + if (!(cond)) { \ + LOG(ERROR) << expr << " in function: " << __FUNCTION__ \ + << " in file: " __FILE__ << ":" << __LINE__; \ + throw std::runtime_error("Error. Check log output"); \ + } + +/// @brief If input has an element >= bound, logs the expression and throws an +/// error +/// @param[in] arg Input container which supports the [] operator. +/// @param[in] n Size of input +/// @param[in] bound Upper bound on the input +/// @param[in] expr The expression to be logged +#define HEXL_CHECK_BOUNDS(arg, n, bound, expr) \ + for (size_t hexl_check_idx = 0; hexl_check_idx < n; ++hexl_check_idx) { \ + HEXL_CHECK((arg)[hexl_check_idx] < bound, expr); \ + } + +#else // HEXL_DEBUG=OFF + +#define HEXL_CHECK(cond, expr) \ + {} +#define HEXL_CHECK_BOUNDS(...) \ + {} + +#endif // HEXL_DEBUG diff --git a/hexl_v1/include/hexl/util/clang.hpp b/hexl_v1/include/hexl/util/clang.hpp new file mode 100644 index 00000000..958bea7b --- /dev/null +++ b/hexl_v1/include/hexl/util/clang.hpp @@ -0,0 +1,67 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once +#include + +#include "hexl/util/check.hpp" +#include "hexl/util/types.hpp" + +namespace intel { +namespace hexl { + +#ifdef HEXL_USE_CLANG +// Return x * y as 128-bit integer +// Correctness if x * y < 128 bits +inline uint128_t MultiplyUInt64(uint64_t x, uint64_t y) { + return uint128_t(x) * uint128_t(y); +} + +inline uint64_t BarrettReduce128(uint64_t input_hi, uint64_t input_lo, + uint64_t modulus) { + HEXL_CHECK(modulus != 0, "modulus == 0") + uint128_t n = (static_cast(input_hi) << 64) | + (static_cast(input_lo)); + + return n % modulus; + // TODO(fboemer): actually use barrett reduction if performance-critical +} + +// Returns low 64bit of 128b/64b where x1=high 64b, x0=low 64b +inline uint64_t DivideUInt128UInt64Lo(uint64_t x1, uint64_t x0, uint64_t y) { + uint128_t n = + (static_cast(x1) << 64) | (static_cast(x0)); + uint128_t q = n / y; + + return static_cast(q); +} + +// Multiplies x * y as 128-bit integer. +// @param prod_hi Stores high 64 bits of product +// @param prod_lo Stores low 64 bits of product +inline void MultiplyUInt64(uint64_t x, uint64_t y, uint64_t* prod_hi, + uint64_t* prod_lo) { + uint128_t prod = MultiplyUInt64(x, y); + *prod_hi = static_cast(prod >> 64); + *prod_lo = static_cast(prod); +} + +// Return the high 128 minus BitShift bits of the 128-bit product x * y +template +inline uint64_t MultiplyUInt64Hi(uint64_t x, uint64_t y) { + uint128_t product = static_cast(x) * y; + return static_cast(product >> BitShift); +} + +// Returns most-significant bit of the input +inline uint64_t MSB(uint64_t input) { + return static_cast(std::log2l(input)); +} + +#define HEXL_LOOP_UNROLL_4 _Pragma("clang loop unroll_count(4)") +#define HEXL_LOOP_UNROLL_8 _Pragma("clang loop unroll_count(8)") + +#endif + +} // namespace hexl +} // namespace intel diff --git a/hexl_v1/include/hexl/util/compiler.hpp b/hexl_v1/include/hexl/util/compiler.hpp new file mode 100644 index 00000000..7dd077df --- /dev/null +++ b/hexl_v1/include/hexl/util/compiler.hpp @@ -0,0 +1,14 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "hexl/util/defines.hpp" + +#ifdef HEXL_USE_MSVC +#include "hexl/util/msvc.hpp" +#elif defined HEXL_USE_GNU +#include "hexl/util/gcc.hpp" +#elif defined HEXL_USE_CLANG +#include "hexl/util/clang.hpp" +#endif diff --git a/hexl_v1/include/hexl/util/defines.hpp b/hexl_v1/include/hexl/util/defines.hpp new file mode 100644 index 00000000..b92dd24e --- /dev/null +++ b/hexl_v1/include/hexl/util/defines.hpp @@ -0,0 +1,13 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +/* #undef HEXL_USE_MSVC */ +#define HEXL_USE_GNU +/* #undef HEXL_USE_CLANG */ + +#define HEXL_DEBUG + +// Avoid unused variable warnings +#define HEXL_UNUSED(x) (void)(x) diff --git a/hexl_v1/include/hexl/util/gcc.hpp b/hexl_v1/include/hexl/util/gcc.hpp new file mode 100644 index 00000000..828e3836 --- /dev/null +++ b/hexl_v1/include/hexl/util/gcc.hpp @@ -0,0 +1,67 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once +#include + +#include "hexl/util/check.hpp" +#include "hexl/util/types.hpp" + +namespace intel { +namespace hexl { + +#ifdef HEXL_USE_GNU +// Return x * y as 128-bit integer +// Correctness if x * y < 128 bits +inline uint128_t MultiplyUInt64(uint64_t x, uint64_t y) { + return uint128_t(x) * uint128_t(y); +} + +inline uint64_t BarrettReduce128(uint64_t input_hi, uint64_t input_lo, + uint64_t modulus) { + HEXL_CHECK(modulus != 0, "modulus == 0") + uint128_t n = (static_cast(input_hi) << 64) | + (static_cast(input_lo)); + + return static_cast(n % modulus); + // TODO(fboemer): actually use barrett reduction if performance-critical +} + +// Returns low 64bit of 128b/64b where x1=high 64b, x0=low 64b +inline uint64_t DivideUInt128UInt64Lo(uint64_t x1, uint64_t x0, uint64_t y) { + uint128_t n = + (static_cast(x1) << 64) | (static_cast(x0)); + uint128_t q = n / y; + + return static_cast(q); +} + +// Multiplies x * y as 128-bit integer. +// @param prod_hi Stores high 64 bits of product +// @param prod_lo Stores low 64 bits of product +inline void MultiplyUInt64(uint64_t x, uint64_t y, uint64_t* prod_hi, + uint64_t* prod_lo) { + uint128_t prod = MultiplyUInt64(x, y); + *prod_hi = static_cast(prod >> 64); + *prod_lo = static_cast(prod); +} + +// Return the high 128 minus BitShift bits of the 128-bit product x * y +template +inline uint64_t MultiplyUInt64Hi(uint64_t x, uint64_t y) { + uint128_t product = MultiplyUInt64(x, y); + return static_cast(product >> BitShift); +} + +// Returns most-significant bit of the input +inline uint64_t MSB(uint64_t input) { + return static_cast(std::log2l(input)); +} + +#define HEXL_LOOP_UNROLL_4 _Pragma("GCC unroll 4") +#define HEXL_LOOP_UNROLL_8 _Pragma("GCC unroll 8") + +#endif + +} // namespace hexl +} // namespace intel diff --git a/hexl_v1/include/hexl/util/msvc.hpp b/hexl_v1/include/hexl/util/msvc.hpp new file mode 100644 index 00000000..0ada2d45 --- /dev/null +++ b/hexl_v1/include/hexl/util/msvc.hpp @@ -0,0 +1,289 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#ifdef HEXL_USE_MSVC + +#define NOMINMAX // Avoid errors with std::min/std::max +#undef min +#undef max + +#include +#include +#include + +#include + +#include "hexl/util/check.hpp" + +#pragma intrinsic(_addcarry_u64, _BitScanReverse64, _subborrow_u64, _udiv128, \ + _umul128) + +#undef TRUE +#undef FALSE + +namespace intel { +namespace hexl { + +inline uint64_t BarrettReduce128(uint64_t input_hi, uint64_t input_lo, + uint64_t modulus) { + HEXL_CHECK(modulus != 0, "modulus == 0") + uint64_t remainder; + _udiv128(input_hi, input_lo, modulus, &remainder); + + return remainder; +} + +// Multiplies x * y as 128-bit integer. +// @param prod_hi Stores high 64 bits of product +// @param prod_lo Stores low 64 bits of product +inline void MultiplyUInt64(uint64_t x, uint64_t y, uint64_t* prod_hi, + uint64_t* prod_lo) { + *prod_lo = _umul128(x, y, prod_hi); +} + +// Return the high 128 minus BitShift bits of the 128-bit product x * y +template +inline uint64_t MultiplyUInt64Hi(uint64_t x, uint64_t y) { + HEXL_CHECK(BitShift == 52 || BitShift == 64, + "Invalid BitShift " << BitShift << "; expected 52 or 64"); + uint64_t prod_hi; + uint64_t prod_lo = _umul128(x, y, &prod_hi); + uint64_t result_hi; + uint64_t result_lo; + RightShift128(&result_hi, &result_lo, prod_hi, prod_lo, BitShift); + return result_lo; +} + +/// @brief Computes Left Shift op as 128-bit unsigned integer +/// @param[out] result_hi Stores high 64 bits of result +/// @param[out] result_lo Stores low 64 bits of result +/// @param[in] op_hi Stores high 64 bits of input +/// @param[in] op_lo Stores low 64 bits of input +inline void LeftShift128(uint64_t* result_hi, uint64_t* result_lo, + const uint64_t op_hi, const uint64_t op_lo, + const uint64_t shift_value) { + HEXL_CHECK(result_hi != nullptr, "Require result_hi != nullptr"); + HEXL_CHECK(result_lo != nullptr, "Require result_lo != nullptr"); + HEXL_CHECK(shift_value <= 128, + "shift_value cannot be greater than 128 " << shift_value); + + if (shift_value == 0) { + *result_hi = op_hi; + *result_lo = op_lo; + } else if (shift_value == 64) { + *result_hi = op_lo; + *result_lo = 0ULL; + } else if (shift_value == 128) { + *result_hi = 0ULL; + *result_lo = 0ULL; + } else if (shift_value >= 1 && shift_value <= 63) { + *result_hi = (op_hi << shift_value) | (op_lo >> (64 - shift_value)); + *result_lo = op_lo << shift_value; + } else if (shift_value >= 65 && shift_value < 128) { + *result_hi = op_lo << (shift_value - 64); + *result_lo = 0ULL; + } +} + +/// @brief Computes Right Shift op as 128-bit unsigned integer +/// @param[out] result_hi Stores high 64 bits of result +/// @param[out] result_lo Stores low 64 bits of result +/// @param[in] op_hi Stores high 64 bits of input +/// @param[in] op_lo Stores low 64 bits of input +inline void RightShift128(uint64_t* result_hi, uint64_t* result_lo, + const uint64_t op_hi, const uint64_t op_lo, + const uint64_t shift_value) { + HEXL_CHECK(result_hi != nullptr, "Require result_hi != nullptr"); + HEXL_CHECK(result_lo != nullptr, "Require result_lo != nullptr"); + HEXL_CHECK(shift_value <= 128, + "shift_value cannot be greater than 128 " << shift_value); + + if (shift_value == 0) { + *result_hi = op_hi; + *result_lo = op_lo; + } else if (shift_value == 64) { + *result_hi = 0ULL; + *result_lo = op_hi; + } else if (shift_value == 128) { + *result_hi = 0ULL; + *result_lo = 0ULL; + } else if (shift_value >= 1 && shift_value <= 63) { + *result_hi = op_hi >> shift_value; + *result_lo = (op_hi << (64 - shift_value)) | (op_lo >> shift_value); + } else if (shift_value >= 65 && shift_value < 128) { + *result_hi = 0ULL; + *result_lo = op_hi >> (shift_value - 64); + } +} + +/// @brief Adds op1 + op2 as 128-bit integer +/// @param[out] result_hi Stores high 64 bits of result +/// @param[out] result_lo Stores low 64 bits of result +inline void AddWithCarry128(uint64_t* result_hi, uint64_t* result_lo, + const uint64_t op1_hi, const uint64_t op1_lo, + const uint64_t op2_hi, const uint64_t op2_lo) { + HEXL_CHECK(result_hi != nullptr, "Require result_hi != nullptr"); + HEXL_CHECK(result_lo != nullptr, "Require result_lo != nullptr"); + + // first 64bit block + *result_lo = op1_lo + op2_lo; + unsigned char carry = static_cast(*result_lo < op1_lo); + + // second 64bit block + _addcarry_u64(carry, op1_hi, op2_hi, result_hi); +} + +/// @brief Subtracts op1 - op2 as 128-bit integer +/// @param[out] result_hi Stores high 64 bits of result +/// @param[out] result_lo Stores low 64 bits of result +inline void SubWithCarry128(uint64_t* result_hi, uint64_t* result_lo, + const uint64_t op1_hi, const uint64_t op1_lo, + const uint64_t op2_hi, const uint64_t op2_lo) { + HEXL_CHECK(result_hi != nullptr, "Require result_hi != nullptr"); + HEXL_CHECK(result_lo != nullptr, "Require result_lo != nullptr"); + + unsigned char borrow; + + // first 64bit block + *result_lo = op1_lo - op2_lo; + borrow = static_cast(op2_lo > op1_lo); + + // second 64bit block + _subborrow_u64(borrow, op1_hi, op2_hi, result_hi); +} + +/// @brief Computes and returns significant bit count +/// @param[in] value Input element at most 128 bits long +inline uint64_t SignificantBitLength(const uint64_t* value) { + HEXL_CHECK(value != nullptr, "Require value != nullptr"); + + unsigned long count = 0; // NOLINT(runtime/int) + + // second 64bit block + _BitScanReverse64(&count, *(value + 1)); + if (count >= 0 && *(value + 1) > 0) { + return static_cast(count) + 1 + 64; + } + + // first 64bit block + _BitScanReverse64(&count, *value); + if (count >= 0 && *(value) > 0) { + return static_cast(count) + 1; + } + return 0; +} + +/// @brief Checks if input is negative number +/// @param[in] input Input element to check for sign +inline bool CheckSign(const uint64_t* input) { + HEXL_CHECK(input != nullptr, "Require input != nullptr"); + + uint64_t input_temp[2]{0, 0}; + RightShift128(&input_temp[1], &input_temp[0], input[1], input[0], 127); + return (input_temp[0] == 1); +} + +/// @brief Divides numerator by denominator +/// @param[out] quotient Stores quotient as two 64-bit blocks after division +/// @param[in] numerator +/// @param[in] denominator +inline void DivideUInt128UInt64(uint64_t* quotient, const uint64_t* numerator, + const uint64_t denominator) { + HEXL_CHECK(quotient != nullptr, "Require quotient != nullptr"); + HEXL_CHECK(numerator != nullptr, "Require numerator != nullptr"); + HEXL_CHECK(denominator != 0, "denominator cannot be 0 " << denominator); + + // get bit count of divisor + uint64_t numerator_bits = SignificantBitLength(numerator); + const uint64_t numerator_bits_const = numerator_bits; + const uint64_t uint_128_bit = 128ULL; + + uint64_t MASK[2]{0x0000000000000001, 0x0000000000000000}; + uint64_t remainder[2]{0, 0}; + uint64_t quotient_temp[2]{0, 0}; + uint64_t denominator_temp[2]{denominator, 0}; + + quotient[0] = numerator[0]; + quotient[1] = numerator[1]; + + // align numerator + LeftShift128("ient[1], "ient[0], quotient[1], quotient[0], + (uint_128_bit - numerator_bits_const)); + + while (numerator_bits) { + // if remainder is negative + if (CheckSign(remainder)) { + LeftShift128(&remainder[1], &remainder[0], remainder[1], remainder[0], 1); + RightShift128("ient_temp[1], "ient_temp[0], quotient[1], + quotient[0], (uint_128_bit - 1)); + remainder[0] = remainder[0] | quotient_temp[0]; + LeftShift128("ient[1], "ient[0], quotient[1], quotient[0], 1); + // remainder=remainder+denominator_temp + AddWithCarry128(&remainder[1], &remainder[0], remainder[1], remainder[0], + denominator_temp[1], denominator_temp[0]); + } else { // if remainder is positive + LeftShift128(&remainder[1], &remainder[0], remainder[1], remainder[0], 1); + RightShift128("ient_temp[1], "ient_temp[0], quotient[1], + quotient[0], (uint_128_bit - 1)); + remainder[0] = remainder[0] | quotient_temp[0]; + LeftShift128("ient[1], "ient[0], quotient[1], quotient[0], 1); + // remainder=remainder-denominator_temp + SubWithCarry128(&remainder[1], &remainder[0], remainder[1], remainder[0], + denominator_temp[1], denominator_temp[0]); + } + + // if remainder is positive set MSB of quotient[0]=1 + if (!CheckSign(remainder)) { + MASK[0] = 0x0000000000000001; + MASK[1] = 0x0000000000000000; + LeftShift128(&MASK[1], &MASK[0], MASK[1], MASK[0], + (uint_128_bit - numerator_bits_const)); + quotient[0] = quotient[0] | MASK[0]; + quotient[1] = quotient[1] | MASK[1]; + } + quotient_temp[0] = 0; + quotient_temp[1] = 0; + numerator_bits--; + } + + if (CheckSign(remainder)) { + // remainder=remainder+denominator_temp + AddWithCarry128(&remainder[1], &remainder[0], remainder[1], remainder[0], + denominator_temp[1], denominator_temp[0]); + } + RightShift128("ient[1], "ient[0], quotient[1], quotient[0], + (uint_128_bit - numerator_bits_const)); +} + +/// @brief Returns low of dividing numerator by denominator +/// @param[in] numerator_hi Stores high 64 bit of numerator +/// @param[in] numerator_lo Stores low 64 bit of numerator +/// @param[in] denominator Stores denominator +inline uint64_t DivideUInt128UInt64Lo(const uint64_t numerator_hi, + const uint64_t numerator_lo, + const uint64_t denominator) { + uint64_t numerator[2]{numerator_lo, numerator_hi}; + uint64_t quotient[2]{0, 0}; + + DivideUInt128UInt64(quotient, numerator, denominator); + return quotient[0]; +} + +// Returns most-significant bit of the input +inline uint64_t MSB(uint64_t input) { + unsigned long index{0}; // NOLINT(runtime/int) + _BitScanReverse64(&index, input); + return index; +} + +#define HEXL_LOOP_UNROLL_4 \ + {} +#define HEXL_LOOP_UNROLL_8 \ + {} + +#endif + +} // namespace hexl +} // namespace intel diff --git a/hexl_v1/include/hexl/util/types.hpp b/hexl_v1/include/hexl/util/types.hpp new file mode 100644 index 00000000..2d2d8551 --- /dev/null +++ b/hexl_v1/include/hexl/util/types.hpp @@ -0,0 +1,13 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +#include "hexl/util/defines.hpp" + +#if defined(HEXL_USE_GNU) || defined(HEXL_USE_CLANG) +__extension__ typedef __int128 int128_t; +__extension__ typedef unsigned __int128 uint128_t; +#endif diff --git a/hexl_v1/include/hexl/util/util.hpp b/hexl_v1/include/hexl/util/util.hpp new file mode 100644 index 00000000..bf878a98 --- /dev/null +++ b/hexl_v1/include/hexl/util/util.hpp @@ -0,0 +1,53 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +namespace intel { +namespace hexl { + +#undef TRUE // MSVC defines TRUE +#undef FALSE // MSVC defines FALSE + +/// @enum CMPINT +/// @brief Represents binary operations between two boolean values +enum class CMPINT { + EQ = 0, ///< Equal + LT = 1, ///< Less than + LE = 2, ///< Less than or equal + FALSE = 3, ///< False + NE = 4, ///< Not equal + NLT = 5, ///< Not less than + NLE = 6, ///< Not less than or equal + TRUE = 7 ///< True +}; + +/// @brief Returns the logical negation of a binary operation +/// @param[in] cmp The binary operation to negate +inline CMPINT Not(CMPINT cmp) { + switch (cmp) { + case CMPINT::EQ: + return CMPINT::NE; + case CMPINT::LT: + return CMPINT::NLT; + case CMPINT::LE: + return CMPINT::NLE; + case CMPINT::FALSE: + return CMPINT::TRUE; + case CMPINT::NE: + return CMPINT::EQ; + case CMPINT::NLT: + return CMPINT::LT; + case CMPINT::NLE: + return CMPINT::LE; + case CMPINT::TRUE: + return CMPINT::FALSE; + default: + return CMPINT::FALSE; + } +} + +} // namespace hexl +} // namespace intel diff --git a/hexl_v1/lib/cmake/hexl-1.2.5/HEXLConfig.cmake b/hexl_v1/lib/cmake/hexl-1.2.5/HEXLConfig.cmake new file mode 100644 index 00000000..87f67e31 --- /dev/null +++ b/hexl_v1/lib/cmake/hexl-1.2.5/HEXLConfig.cmake @@ -0,0 +1,59 @@ +# Copyright (C) 2020 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +# This will define the following variables: +# +# HEXL_FOUND - True if the system has the Intel HEXL library +# HEXL_VERSION - The full major.minor.patch version number +# HEXL_VERSION_MAJOR - The major version number +# HEXL_VERSION_MINOR - The minor version number +# HEXL_VERSION_PATCH - The patch version number + + +####### Expanded from @PACKAGE_INIT@ by configure_package_config_file() ####### +####### Any changes to this file will be overwritten by the next CMake run #### +####### The input file was HEXLConfig.cmake.in ######## + +get_filename_component(PACKAGE_PREFIX_DIR "${CMAKE_CURRENT_LIST_DIR}/../../../" ABSOLUTE) + +macro(set_and_check _var _file) + set(${_var} "${_file}") + if(NOT EXISTS "${_file}") + message(FATAL_ERROR "File or directory ${_file} referenced by variable ${_var} does not exist !") + endif() +endmacro() + +macro(check_required_components _NAME) + foreach(comp ${${_NAME}_FIND_COMPONENTS}) + if(NOT ${_NAME}_${comp}_FOUND) + if(${_NAME}_FIND_REQUIRED_${comp}) + set(${_NAME}_FOUND FALSE) + endif() + endif() + endforeach() +endmacro() + +#################################################################################### + +include(CMakeFindDependencyMacro) +find_package(CpuFeatures CONFIG) +if(NOT CpuFeatures_FOUND) + message(WARNING "Could not find pre-installed CpuFeatures; using CpuFeatures packaged with HEXL") +endif() + +include(${CMAKE_CURRENT_LIST_DIR}/HEXLTargets.cmake) + +# Defines HEXL_FOUND: If Intel HEXL library was found +if(TARGET HEXL::hexl) + set(HEXL_FOUND TRUE) + message(STATUS "Intel HEXL found") +else() + message(STATUS "Intel HEXL not found") +endif() + +set(HEXL_VERSION "1.2.5") +set(HEXL_VERSION_MAJOR "1") +set(HEXL_VERSION_MINOR "2") +set(HEXL_VERSION_PATCH "5") + +set(HEXL_DEBUG "ON") diff --git a/hexl_v1/lib/cmake/hexl-1.2.5/HEXLConfigVersion.cmake b/hexl_v1/lib/cmake/hexl-1.2.5/HEXLConfigVersion.cmake new file mode 100644 index 00000000..98b46110 --- /dev/null +++ b/hexl_v1/lib/cmake/hexl-1.2.5/HEXLConfigVersion.cmake @@ -0,0 +1,88 @@ +# This is a basic version file for the Config-mode of find_package(). +# It is used by write_basic_package_version_file() as input file for configure_file() +# to create a version-file which can be installed along a config.cmake file. +# +# The created file sets PACKAGE_VERSION_EXACT if the current version string and +# the requested version string are exactly the same and it sets +# PACKAGE_VERSION_COMPATIBLE if the current version is equal to the requested version. +# The tweak version component is ignored. +# The variable CVF_VERSION must be set before calling configure_file(). + + +if (PACKAGE_FIND_VERSION_RANGE) + message(AUTHOR_WARNING + "`find_package()` specify a version range but the version strategy " + "(ExactVersion) of the module `${PACKAGE_FIND_NAME}` is incompatible " + "with this request. Only the lower endpoint of the range will be used.") +endif() + +set(PACKAGE_VERSION "1.2.5") + +if("1.2.5" MATCHES "^([0-9]+)\\.([0-9]+)\\.([0-9]+)") # strip the tweak version + set(CVF_VERSION_MAJOR "${CMAKE_MATCH_1}") + set(CVF_VERSION_MINOR "${CMAKE_MATCH_2}") + set(CVF_VERSION_PATCH "${CMAKE_MATCH_3}") + + if(NOT CVF_VERSION_MAJOR VERSION_EQUAL 0) + string(REGEX REPLACE "^0+" "" CVF_VERSION_MAJOR "${CVF_VERSION_MAJOR}") + endif() + if(NOT CVF_VERSION_MINOR VERSION_EQUAL 0) + string(REGEX REPLACE "^0+" "" CVF_VERSION_MINOR "${CVF_VERSION_MINOR}") + endif() + if(NOT CVF_VERSION_PATCH VERSION_EQUAL 0) + string(REGEX REPLACE "^0+" "" CVF_VERSION_PATCH "${CVF_VERSION_PATCH}") + endif() + + set(CVF_VERSION_NO_TWEAK "${CVF_VERSION_MAJOR}.${CVF_VERSION_MINOR}.${CVF_VERSION_PATCH}") +else() + set(CVF_VERSION_NO_TWEAK "1.2.5") +endif() + +if(PACKAGE_FIND_VERSION MATCHES "^([0-9]+)\\.([0-9]+)\\.([0-9]+)") # strip the tweak version + set(REQUESTED_VERSION_MAJOR "${CMAKE_MATCH_1}") + set(REQUESTED_VERSION_MINOR "${CMAKE_MATCH_2}") + set(REQUESTED_VERSION_PATCH "${CMAKE_MATCH_3}") + + if(NOT REQUESTED_VERSION_MAJOR VERSION_EQUAL 0) + string(REGEX REPLACE "^0+" "" REQUESTED_VERSION_MAJOR "${REQUESTED_VERSION_MAJOR}") + endif() + if(NOT REQUESTED_VERSION_MINOR VERSION_EQUAL 0) + string(REGEX REPLACE "^0+" "" REQUESTED_VERSION_MINOR "${REQUESTED_VERSION_MINOR}") + endif() + if(NOT REQUESTED_VERSION_PATCH VERSION_EQUAL 0) + string(REGEX REPLACE "^0+" "" REQUESTED_VERSION_PATCH "${REQUESTED_VERSION_PATCH}") + endif() + + set(REQUESTED_VERSION_NO_TWEAK + "${REQUESTED_VERSION_MAJOR}.${REQUESTED_VERSION_MINOR}.${REQUESTED_VERSION_PATCH}") +else() + set(REQUESTED_VERSION_NO_TWEAK "${PACKAGE_FIND_VERSION}") +endif() + +if(REQUESTED_VERSION_NO_TWEAK STREQUAL CVF_VERSION_NO_TWEAK) + set(PACKAGE_VERSION_COMPATIBLE TRUE) +else() + set(PACKAGE_VERSION_COMPATIBLE FALSE) +endif() + +if(PACKAGE_FIND_VERSION STREQUAL PACKAGE_VERSION) + set(PACKAGE_VERSION_EXACT TRUE) +endif() + + +# if the installed project requested no architecture check, don't perform the check +if("FALSE") + return() +endif() + +# if the installed or the using project don't have CMAKE_SIZEOF_VOID_P set, ignore it: +if("${CMAKE_SIZEOF_VOID_P}" STREQUAL "" OR "8" STREQUAL "") + return() +endif() + +# check that the installed version has the same 32/64bit-ness as the one which is currently searching: +if(NOT CMAKE_SIZEOF_VOID_P STREQUAL "8") + math(EXPR installedBits "8 * 8") + set(PACKAGE_VERSION "${PACKAGE_VERSION} (${installedBits}bit)") + set(PACKAGE_VERSION_UNSUITABLE TRUE) +endif() diff --git a/hexl_v1/lib/cmake/hexl-1.2.5/HEXLTargets-debug.cmake b/hexl_v1/lib/cmake/hexl-1.2.5/HEXLTargets-debug.cmake new file mode 100644 index 00000000..1b485a76 --- /dev/null +++ b/hexl_v1/lib/cmake/hexl-1.2.5/HEXLTargets-debug.cmake @@ -0,0 +1,19 @@ +#---------------------------------------------------------------- +# Generated CMake target import file for configuration "Debug". +#---------------------------------------------------------------- + +# Commands may need to know the format version. +set(CMAKE_IMPORT_FILE_VERSION 1) + +# Import target "HEXL::hexl" for configuration "Debug" +set_property(TARGET HEXL::hexl APPEND PROPERTY IMPORTED_CONFIGURATIONS DEBUG) +set_target_properties(HEXL::hexl PROPERTIES + IMPORTED_LOCATION_DEBUG "${_IMPORT_PREFIX}/lib/libhexl_debug.so.1.2.5" + IMPORTED_SONAME_DEBUG "libhexl_debug.so.1.2.5" + ) + +list(APPEND _IMPORT_CHECK_TARGETS HEXL::hexl ) +list(APPEND _IMPORT_CHECK_FILES_FOR_HEXL::hexl "${_IMPORT_PREFIX}/lib/libhexl_debug.so.1.2.5" ) + +# Commands beyond this point should not need to know the version. +set(CMAKE_IMPORT_FILE_VERSION) diff --git a/hexl_v1/lib/cmake/hexl-1.2.5/HEXLTargets.cmake b/hexl_v1/lib/cmake/hexl-1.2.5/HEXLTargets.cmake new file mode 100644 index 00000000..5b29b31e --- /dev/null +++ b/hexl_v1/lib/cmake/hexl-1.2.5/HEXLTargets.cmake @@ -0,0 +1,101 @@ +# Generated by CMake + +if("${CMAKE_MAJOR_VERSION}.${CMAKE_MINOR_VERSION}" LESS 2.6) + message(FATAL_ERROR "CMake >= 2.6.0 required") +endif() +cmake_policy(PUSH) +cmake_policy(VERSION 2.6...3.20) +#---------------------------------------------------------------- +# Generated CMake target import file. +#---------------------------------------------------------------- + +# Commands may need to know the format version. +set(CMAKE_IMPORT_FILE_VERSION 1) + +# Protect against multiple inclusion, which would fail when already imported targets are added once more. +set(_targetsDefined) +set(_targetsNotDefined) +set(_expectedTargets) +foreach(_expectedTarget HEXL::hexl) + list(APPEND _expectedTargets ${_expectedTarget}) + if(NOT TARGET ${_expectedTarget}) + list(APPEND _targetsNotDefined ${_expectedTarget}) + endif() + if(TARGET ${_expectedTarget}) + list(APPEND _targetsDefined ${_expectedTarget}) + endif() +endforeach() +if("${_targetsDefined}" STREQUAL "${_expectedTargets}") + unset(_targetsDefined) + unset(_targetsNotDefined) + unset(_expectedTargets) + set(CMAKE_IMPORT_FILE_VERSION) + cmake_policy(POP) + return() +endif() +if(NOT "${_targetsDefined}" STREQUAL "") + message(FATAL_ERROR "Some (but not all) targets in this export set were already defined.\nTargets Defined: ${_targetsDefined}\nTargets not yet defined: ${_targetsNotDefined}\n") +endif() +unset(_targetsDefined) +unset(_targetsNotDefined) +unset(_expectedTargets) + + +# Compute the installation prefix relative to this file. +get_filename_component(_IMPORT_PREFIX "${CMAKE_CURRENT_LIST_FILE}" PATH) +get_filename_component(_IMPORT_PREFIX "${_IMPORT_PREFIX}" PATH) +get_filename_component(_IMPORT_PREFIX "${_IMPORT_PREFIX}" PATH) +get_filename_component(_IMPORT_PREFIX "${_IMPORT_PREFIX}" PATH) +if(_IMPORT_PREFIX STREQUAL "/") + set(_IMPORT_PREFIX "") +endif() + +# Create imported target HEXL::hexl +add_library(HEXL::hexl SHARED IMPORTED) + +set_target_properties(HEXL::hexl PROPERTIES + INTERFACE_COMPILE_OPTIONS "-fsanitize=address;-Wno-unknown-warning;-Wno-unknown-warning-option" + INTERFACE_INCLUDE_DIRECTORIES "${_IMPORT_PREFIX}/include" + INTERFACE_LINK_LIBRARIES "OpenMP::OpenMP_CXX;easyloggingpp" + INTERFACE_LINK_OPTIONS "-fsanitize=address" +) + +if(CMAKE_VERSION VERSION_LESS 2.8.12) + message(FATAL_ERROR "This file relies on consumers using CMake 2.8.12 or greater.") +endif() + +# Load information for each installed configuration. +get_filename_component(_DIR "${CMAKE_CURRENT_LIST_FILE}" PATH) +file(GLOB CONFIG_FILES "${_DIR}/HEXLTargets-*.cmake") +foreach(f ${CONFIG_FILES}) + include(${f}) +endforeach() + +# Cleanup temporary variables. +set(_IMPORT_PREFIX) + +# Loop over all imported files and verify that they actually exist +foreach(target ${_IMPORT_CHECK_TARGETS} ) + foreach(file ${_IMPORT_CHECK_FILES_FOR_${target}} ) + if(NOT EXISTS "${file}" ) + message(FATAL_ERROR "The imported target \"${target}\" references the file + \"${file}\" +but this file does not exist. Possible reasons include: +* The file was deleted, renamed, or moved to another location. +* An install or uninstall procedure did not complete successfully. +* The installation package was faulty and contained + \"${CMAKE_CURRENT_LIST_FILE}\" +but not all the files it references. +") + endif() + endforeach() + unset(_IMPORT_CHECK_FILES_FOR_${target}) +endforeach() +unset(_IMPORT_CHECK_TARGETS) + +# This file does not depend on other imported targets which have +# been exported from the same project but in a separate export set. + +# Commands beyond this point should not need to know the version. +set(CMAKE_IMPORT_FILE_VERSION) +cmake_policy(POP) diff --git a/hexl_v1/lib/libhexl.a b/hexl_v1/lib/libhexl.a new file mode 100644 index 00000000..cd0bb4e8 Binary files /dev/null and b/hexl_v1/lib/libhexl.a differ diff --git a/hexl_v1/lib/libhexl_debug.so b/hexl_v1/lib/libhexl_debug.so new file mode 120000 index 00000000..c54310f0 --- /dev/null +++ b/hexl_v1/lib/libhexl_debug.so @@ -0,0 +1 @@ +libhexl_debug.so.1.2.5 \ No newline at end of file diff --git a/hexl_v1/lib/libhexl_debug.so.1.2.5 b/hexl_v1/lib/libhexl_debug.so.1.2.5 new file mode 100644 index 00000000..f01e250f Binary files /dev/null and b/hexl_v1/lib/libhexl_debug.so.1.2.5 differ diff --git a/hexl_v1/lib/pkgconfig/hexl.pc b/hexl_v1/lib/pkgconfig/hexl.pc new file mode 100644 index 00000000..f9aa0c65 --- /dev/null +++ b/hexl_v1/lib/pkgconfig/hexl.pc @@ -0,0 +1,13 @@ +# Copyright (C) 2020 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +prefix=/home/eidf018/eidf018/s1820742psd/hexl/hexl_v1 +libdir=${prefix}/lib +includedir=${prefix}/include + +Name: Intel HEXL +Version: 1.2.5 +Description: Intel® HEXL is an open-source library which provides efficient implementations of integer arithmetic on Galois fields. + +Libs: -L${libdir} -fsanitize=address -lhexl_debug +Cflags: -I${includedir} -fsanitize=address diff --git a/hexl_v2/include/easylogging++.h b/hexl_v2/include/easylogging++.h new file mode 100644 index 00000000..6e81edfb --- /dev/null +++ b/hexl_v2/include/easylogging++.h @@ -0,0 +1,4576 @@ +// +// Bismillah ar-Rahmaan ar-Raheem +// +// Easylogging++ v9.96.7 +// Single-header only, cross-platform logging library for C++ applications +// +// Copyright (c) 2012-2018 Amrayn Web Services +// Copyright (c) 2012-2018 @abumusamq +// +// This library is released under the MIT Licence. +// https://github.com/amrayn/easyloggingpp/blob/master/LICENSE +// +// https://amrayn.com +// http://muflihun.com +// + +#ifndef EASYLOGGINGPP_H +#define EASYLOGGINGPP_H +// Compilers and C++0x/C++11 Evaluation +#if __cplusplus >= 201103L +# define ELPP_CXX11 1 +#endif // __cplusplus >= 201103L +#if (defined(__GNUC__)) +# define ELPP_COMPILER_GCC 1 +#else +# define ELPP_COMPILER_GCC 0 +#endif +#if ELPP_COMPILER_GCC +# define ELPP_GCC_VERSION (__GNUC__ * 10000 \ ++ __GNUC_MINOR__ * 100 \ ++ __GNUC_PATCHLEVEL__) +# if defined(__GXX_EXPERIMENTAL_CXX0X__) +# define ELPP_CXX0X 1 +# endif +#endif +// Visual C++ +#if defined(_MSC_VER) +# define ELPP_COMPILER_MSVC 1 +#else +# define ELPP_COMPILER_MSVC 0 +#endif +#define ELPP_CRT_DBG_WARNINGS ELPP_COMPILER_MSVC +#if ELPP_COMPILER_MSVC +# if (_MSC_VER == 1600) +# define ELPP_CXX0X 1 +# elif(_MSC_VER >= 1700) +# define ELPP_CXX11 1 +# endif +#endif +// Clang++ +#if (defined(__clang__) && (__clang__ == 1)) +# define ELPP_COMPILER_CLANG 1 +#else +# define ELPP_COMPILER_CLANG 0 +#endif +#if ELPP_COMPILER_CLANG +# if __has_include() +# include // Make __GLIBCXX__ defined when using libstdc++ +# if !defined(__GLIBCXX__) || __GLIBCXX__ >= 20150426 +# define ELPP_CLANG_SUPPORTS_THREAD +# endif // !defined(__GLIBCXX__) || __GLIBCXX__ >= 20150426 +# endif // __has_include() +#endif +#if (defined(__MINGW32__) || defined(__MINGW64__)) +# define ELPP_MINGW 1 +#else +# define ELPP_MINGW 0 +#endif +#if (defined(__CYGWIN__) && (__CYGWIN__ == 1)) +# define ELPP_CYGWIN 1 +#else +# define ELPP_CYGWIN 0 +#endif +#if (defined(__INTEL_COMPILER)) +# define ELPP_COMPILER_INTEL 1 +#else +# define ELPP_COMPILER_INTEL 0 +#endif +// Operating System Evaluation +// Windows +#if (defined(_WIN32) || defined(_WIN64)) +# define ELPP_OS_WINDOWS 1 +#else +# define ELPP_OS_WINDOWS 0 +#endif +// Linux +#if (defined(__linux) || defined(__linux__)) +# define ELPP_OS_LINUX 1 +#else +# define ELPP_OS_LINUX 0 +#endif +#if (defined(__APPLE__)) +# define ELPP_OS_MAC 1 +#else +# define ELPP_OS_MAC 0 +#endif +#if (defined(__FreeBSD__) || defined(__FreeBSD_kernel__)) +# define ELPP_OS_FREEBSD 1 +#else +# define ELPP_OS_FREEBSD 0 +#endif +#if (defined(__sun)) +# define ELPP_OS_SOLARIS 1 +#else +# define ELPP_OS_SOLARIS 0 +#endif +#if (defined(_AIX)) +# define ELPP_OS_AIX 1 +#else +# define ELPP_OS_AIX 0 +#endif +#if (defined(__NetBSD__)) +# define ELPP_OS_NETBSD 1 +#else +# define ELPP_OS_NETBSD 0 +#endif +#if defined(__EMSCRIPTEN__) +# define ELPP_OS_EMSCRIPTEN 1 +#else +# define ELPP_OS_EMSCRIPTEN 0 +#endif +#if (defined(__QNX__) || defined(__QNXNTO__)) +# define ELPP_OS_QNX 1 +#else +# define ELPP_OS_QNX 0 +#endif +// Unix +#if ((ELPP_OS_LINUX || ELPP_OS_MAC || ELPP_OS_FREEBSD || ELPP_OS_NETBSD || ELPP_OS_SOLARIS || ELPP_OS_AIX || ELPP_OS_EMSCRIPTEN || ELPP_OS_QNX) && (!ELPP_OS_WINDOWS)) +# define ELPP_OS_UNIX 1 +#else +# define ELPP_OS_UNIX 0 +#endif +#if (defined(__ANDROID__)) +# define ELPP_OS_ANDROID 1 +#else +# define ELPP_OS_ANDROID 0 +#endif +// Evaluating Cygwin as *nix OS +#if !ELPP_OS_UNIX && !ELPP_OS_WINDOWS && ELPP_CYGWIN +# undef ELPP_OS_UNIX +# undef ELPP_OS_LINUX +# define ELPP_OS_UNIX 1 +# define ELPP_OS_LINUX 1 +#endif // !ELPP_OS_UNIX && !ELPP_OS_WINDOWS && ELPP_CYGWIN +#if !defined(ELPP_INTERNAL_DEBUGGING_OUT_INFO) +# define ELPP_INTERNAL_DEBUGGING_OUT_INFO std::cout +#endif // !defined(ELPP_INTERNAL_DEBUGGING_OUT) +#if !defined(ELPP_INTERNAL_DEBUGGING_OUT_ERROR) +# define ELPP_INTERNAL_DEBUGGING_OUT_ERROR std::cerr +#endif // !defined(ELPP_INTERNAL_DEBUGGING_OUT) +#if !defined(ELPP_INTERNAL_DEBUGGING_ENDL) +# define ELPP_INTERNAL_DEBUGGING_ENDL std::endl +#endif // !defined(ELPP_INTERNAL_DEBUGGING_OUT) +#if !defined(ELPP_INTERNAL_DEBUGGING_MSG) +# define ELPP_INTERNAL_DEBUGGING_MSG(msg) msg +#endif // !defined(ELPP_INTERNAL_DEBUGGING_OUT) +// Internal Assertions and errors +#if !defined(ELPP_DISABLE_ASSERT) +# if (defined(ELPP_DEBUG_ASSERT_FAILURE)) +# define ELPP_ASSERT(expr, msg) if (!(expr)) { \ +std::stringstream internalInfoStream; internalInfoStream << msg; \ +ELPP_INTERNAL_DEBUGGING_OUT_ERROR \ +<< "EASYLOGGING++ ASSERTION FAILED (LINE: " << __LINE__ << ") [" #expr << "] WITH MESSAGE \"" \ +<< ELPP_INTERNAL_DEBUGGING_MSG(internalInfoStream.str()) << "\"" << ELPP_INTERNAL_DEBUGGING_ENDL; base::utils::abort(1, \ +"ELPP Assertion failure, please define ELPP_DEBUG_ASSERT_FAILURE"); } +# else +# define ELPP_ASSERT(expr, msg) if (!(expr)) { \ +std::stringstream internalInfoStream; internalInfoStream << msg; \ +ELPP_INTERNAL_DEBUGGING_OUT_ERROR\ +<< "ASSERTION FAILURE FROM EASYLOGGING++ (LINE: " \ +<< __LINE__ << ") [" #expr << "] WITH MESSAGE \"" << ELPP_INTERNAL_DEBUGGING_MSG(internalInfoStream.str()) << "\"" \ +<< ELPP_INTERNAL_DEBUGGING_ENDL; } +# endif // (defined(ELPP_DEBUG_ASSERT_FAILURE)) +#else +# define ELPP_ASSERT(x, y) +#endif //(!defined(ELPP_DISABLE_ASSERT) +#if ELPP_COMPILER_MSVC +# define ELPP_INTERNAL_DEBUGGING_WRITE_PERROR \ +{ char buff[256]; strerror_s(buff, 256, errno); \ +ELPP_INTERNAL_DEBUGGING_OUT_ERROR << ": " << buff << " [" << errno << "]";} (void)0 +#else +# define ELPP_INTERNAL_DEBUGGING_WRITE_PERROR \ +ELPP_INTERNAL_DEBUGGING_OUT_ERROR << ": " << strerror(errno) << " [" << errno << "]"; (void)0 +#endif // ELPP_COMPILER_MSVC +#if defined(ELPP_DEBUG_ERRORS) +# if !defined(ELPP_INTERNAL_ERROR) +# define ELPP_INTERNAL_ERROR(msg, pe) { \ +std::stringstream internalInfoStream; internalInfoStream << " " << msg; \ +ELPP_INTERNAL_DEBUGGING_OUT_ERROR \ +<< "ERROR FROM EASYLOGGING++ (LINE: " << __LINE__ << ") " \ +<< ELPP_INTERNAL_DEBUGGING_MSG(internalInfoStream.str()) << ELPP_INTERNAL_DEBUGGING_ENDL; \ +if (pe) { ELPP_INTERNAL_DEBUGGING_OUT_ERROR << " "; ELPP_INTERNAL_DEBUGGING_WRITE_PERROR; }} (void)0 +# endif +#else +# undef ELPP_INTERNAL_INFO +# define ELPP_INTERNAL_ERROR(msg, pe) +#endif // defined(ELPP_DEBUG_ERRORS) +#if (defined(ELPP_DEBUG_INFO)) +# if !(defined(ELPP_INTERNAL_INFO_LEVEL)) +# define ELPP_INTERNAL_INFO_LEVEL 9 +# endif // !(defined(ELPP_INTERNAL_INFO_LEVEL)) +# if !defined(ELPP_INTERNAL_INFO) +# define ELPP_INTERNAL_INFO(lvl, msg) { if (lvl <= ELPP_INTERNAL_INFO_LEVEL) { \ +std::stringstream internalInfoStream; internalInfoStream << " " << msg; \ +ELPP_INTERNAL_DEBUGGING_OUT_INFO << ELPP_INTERNAL_DEBUGGING_MSG(internalInfoStream.str()) \ +<< ELPP_INTERNAL_DEBUGGING_ENDL; }} +# endif +#else +# undef ELPP_INTERNAL_INFO +# define ELPP_INTERNAL_INFO(lvl, msg) +#endif // (defined(ELPP_DEBUG_INFO)) +#if (defined(ELPP_FEATURE_ALL)) || (defined(ELPP_FEATURE_CRASH_LOG)) +# if (ELPP_COMPILER_GCC && !ELPP_MINGW && !ELPP_CYGWIN && !ELPP_OS_ANDROID && !ELPP_OS_EMSCRIPTEN && !ELPP_OS_QNX) +# define ELPP_STACKTRACE 1 +# else +# if ELPP_COMPILER_MSVC +# pragma message("Stack trace not available for this compiler") +# else +# warning "Stack trace not available for this compiler"; +# endif // ELPP_COMPILER_MSVC +# define ELPP_STACKTRACE 0 +# endif // ELPP_COMPILER_GCC +#else +# define ELPP_STACKTRACE 0 +#endif // (defined(ELPP_FEATURE_ALL)) || (defined(ELPP_FEATURE_CRASH_LOG)) +// Miscellaneous macros +#define ELPP_UNUSED(x) (void)x +#if ELPP_OS_UNIX +// Log file permissions for unix-based systems +# define ELPP_LOG_PERMS S_IRUSR | S_IWUSR | S_IXUSR | S_IWGRP | S_IRGRP | S_IXGRP | S_IWOTH | S_IXOTH +#endif // ELPP_OS_UNIX +#if defined(ELPP_AS_DLL) && ELPP_COMPILER_MSVC +# if defined(ELPP_EXPORT_SYMBOLS) +# define ELPP_EXPORT __declspec(dllexport) +# else +# define ELPP_EXPORT __declspec(dllimport) +# endif // defined(ELPP_EXPORT_SYMBOLS) +#else +# define ELPP_EXPORT +#endif // defined(ELPP_AS_DLL) && ELPP_COMPILER_MSVC +// Some special functions that are VC++ specific +#undef STRTOK +#undef STRERROR +#undef STRCAT +#undef STRCPY +#if ELPP_CRT_DBG_WARNINGS +# define STRTOK(a, b, c) strtok_s(a, b, c) +# define STRERROR(a, b, c) strerror_s(a, b, c) +# define STRCAT(a, b, len) strcat_s(a, len, b) +# define STRCPY(a, b, len) strcpy_s(a, len, b) +#else +# define STRTOK(a, b, c) strtok(a, b) +# define STRERROR(a, b, c) strerror(c) +# define STRCAT(a, b, len) strcat(a, b) +# define STRCPY(a, b, len) strcpy(a, b) +#endif +// Compiler specific support evaluations +#if (ELPP_MINGW && !defined(ELPP_FORCE_USE_STD_THREAD)) +# define ELPP_USE_STD_THREADING 0 +#else +# if ((ELPP_COMPILER_CLANG && defined(ELPP_CLANG_SUPPORTS_THREAD)) || \ + (!ELPP_COMPILER_CLANG && defined(ELPP_CXX11)) || \ + defined(ELPP_FORCE_USE_STD_THREAD)) +# define ELPP_USE_STD_THREADING 1 +# else +# define ELPP_USE_STD_THREADING 0 +# endif +#endif +#undef ELPP_FINAL +#if ELPP_COMPILER_INTEL || (ELPP_GCC_VERSION < 40702) +# define ELPP_FINAL +#else +# define ELPP_FINAL final +#endif // ELPP_COMPILER_INTEL || (ELPP_GCC_VERSION < 40702) +#if defined(ELPP_EXPERIMENTAL_ASYNC) +# define ELPP_ASYNC_LOGGING 1 +#else +# define ELPP_ASYNC_LOGGING 0 +#endif // defined(ELPP_EXPERIMENTAL_ASYNC) +#if defined(ELPP_THREAD_SAFE) || ELPP_ASYNC_LOGGING +# define ELPP_THREADING_ENABLED 1 +#else +# define ELPP_THREADING_ENABLED 0 +#endif // defined(ELPP_THREAD_SAFE) || ELPP_ASYNC_LOGGING +// Function macro ELPP_FUNC +#undef ELPP_FUNC +#if ELPP_COMPILER_MSVC // Visual C++ +# define ELPP_FUNC __FUNCSIG__ +#elif ELPP_COMPILER_GCC // GCC +# define ELPP_FUNC __PRETTY_FUNCTION__ +#elif ELPP_COMPILER_INTEL // Intel C++ +# define ELPP_FUNC __PRETTY_FUNCTION__ +#elif ELPP_COMPILER_CLANG // Clang++ +# define ELPP_FUNC __PRETTY_FUNCTION__ +#else +# if defined(__func__) +# define ELPP_FUNC __func__ +# else +# define ELPP_FUNC "" +# endif // defined(__func__) +#endif // defined(_MSC_VER) +#undef ELPP_VARIADIC_TEMPLATES_SUPPORTED +// Keep following line commented until features are fixed +#define ELPP_VARIADIC_TEMPLATES_SUPPORTED \ +(ELPP_COMPILER_GCC || ELPP_COMPILER_CLANG || ELPP_COMPILER_INTEL || (ELPP_COMPILER_MSVC && _MSC_VER >= 1800)) +// Logging Enable/Disable macros +#if defined(ELPP_DISABLE_LOGS) +#define ELPP_LOGGING_ENABLED 0 +#else +#define ELPP_LOGGING_ENABLED 1 +#endif +#if (!defined(ELPP_DISABLE_DEBUG_LOGS) && (ELPP_LOGGING_ENABLED)) +# define ELPP_DEBUG_LOG 1 +#else +# define ELPP_DEBUG_LOG 0 +#endif // (!defined(ELPP_DISABLE_DEBUG_LOGS) && (ELPP_LOGGING_ENABLED)) +#if (!defined(ELPP_DISABLE_INFO_LOGS) && (ELPP_LOGGING_ENABLED)) +# define ELPP_INFO_LOG 1 +#else +# define ELPP_INFO_LOG 0 +#endif // (!defined(ELPP_DISABLE_INFO_LOGS) && (ELPP_LOGGING_ENABLED)) +#if (!defined(ELPP_DISABLE_WARNING_LOGS) && (ELPP_LOGGING_ENABLED)) +# define ELPP_WARNING_LOG 1 +#else +# define ELPP_WARNING_LOG 0 +#endif // (!defined(ELPP_DISABLE_WARNING_LOGS) && (ELPP_LOGGING_ENABLED)) +#if (!defined(ELPP_DISABLE_ERROR_LOGS) && (ELPP_LOGGING_ENABLED)) +# define ELPP_ERROR_LOG 1 +#else +# define ELPP_ERROR_LOG 0 +#endif // (!defined(ELPP_DISABLE_ERROR_LOGS) && (ELPP_LOGGING_ENABLED)) +#if (!defined(ELPP_DISABLE_FATAL_LOGS) && (ELPP_LOGGING_ENABLED)) +# define ELPP_FATAL_LOG 1 +#else +# define ELPP_FATAL_LOG 0 +#endif // (!defined(ELPP_DISABLE_FATAL_LOGS) && (ELPP_LOGGING_ENABLED)) +#if (!defined(ELPP_DISABLE_TRACE_LOGS) && (ELPP_LOGGING_ENABLED)) +# define ELPP_TRACE_LOG 1 +#else +# define ELPP_TRACE_LOG 0 +#endif // (!defined(ELPP_DISABLE_TRACE_LOGS) && (ELPP_LOGGING_ENABLED)) +#if (!defined(ELPP_DISABLE_VERBOSE_LOGS) && (ELPP_LOGGING_ENABLED)) +# define ELPP_VERBOSE_LOG 1 +#else +# define ELPP_VERBOSE_LOG 0 +#endif // (!defined(ELPP_DISABLE_VERBOSE_LOGS) && (ELPP_LOGGING_ENABLED)) +#if (!(ELPP_CXX0X || ELPP_CXX11)) +# error "C++0x (or higher) support not detected! (Is `-std=c++11' missing?)" +#endif // (!(ELPP_CXX0X || ELPP_CXX11)) +// Headers +#if defined(ELPP_SYSLOG) +# include +#endif // defined(ELPP_SYSLOG) +#include +#include +#include +#include +#include +#include +#include +#include +#if defined(ELPP_UNICODE) +# include +# if ELPP_OS_WINDOWS +# include +# endif // ELPP_OS_WINDOWS +#endif // defined(ELPP_UNICODE) +#ifdef HAVE_EXECINFO +# include +# include +#endif // ENABLE_EXECINFO +#if ELPP_OS_ANDROID +# include +#endif // ELPP_OS_ANDROID +#if ELPP_OS_UNIX +# include +# include +#elif ELPP_OS_WINDOWS +# include +# include +# if defined(WIN32_LEAN_AND_MEAN) +# if defined(ELPP_WINSOCK2) +# include +# else +# include +# endif // defined(ELPP_WINSOCK2) +# endif // defined(WIN32_LEAN_AND_MEAN) +#endif // ELPP_OS_UNIX +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#if ELPP_THREADING_ENABLED +# if ELPP_USE_STD_THREADING +# include +# include +# else +# if ELPP_OS_UNIX +# include +# endif // ELPP_OS_UNIX +# endif // ELPP_USE_STD_THREADING +#endif // ELPP_THREADING_ENABLED +#if ELPP_ASYNC_LOGGING +# if defined(ELPP_NO_SLEEP_FOR) +# include +# endif // defined(ELPP_NO_SLEEP_FOR) +# include +# include +# include +#endif // ELPP_ASYNC_LOGGING +#if defined(ELPP_STL_LOGGING) +// For logging STL based templates +# include +# include +# include +# include +# include +# include +# if defined(ELPP_LOG_STD_ARRAY) +# include +# endif // defined(ELPP_LOG_STD_ARRAY) +# if defined(ELPP_LOG_UNORDERED_SET) +# include +# endif // defined(ELPP_UNORDERED_SET) +#endif // defined(ELPP_STL_LOGGING) +#if defined(ELPP_QT_LOGGING) +// For logging Qt based classes & templates +# include +# include +# include +# include +# include +# include +# include +# include +# include +# include +# include +# include +#endif // defined(ELPP_QT_LOGGING) +#if defined(ELPP_BOOST_LOGGING) +// For logging boost based classes & templates +# include +# include +# include +# include +# include +# include +# include +# include +#endif // defined(ELPP_BOOST_LOGGING) +#if defined(ELPP_WXWIDGETS_LOGGING) +// For logging wxWidgets based classes & templates +# include +#endif // defined(ELPP_WXWIDGETS_LOGGING) +#if defined(ELPP_UTC_DATETIME) +# define elpptime_r gmtime_r +# define elpptime_s gmtime_s +# define elpptime gmtime +#else +# define elpptime_r localtime_r +# define elpptime_s localtime_s +# define elpptime localtime +#endif // defined(ELPP_UTC_DATETIME) +// Forward declarations +namespace el { +class Logger; +class LogMessage; +class PerformanceTrackingData; +class Loggers; +class Helpers; +template class Callback; +class LogDispatchCallback; +class PerformanceTrackingCallback; +class LoggerRegistrationCallback; +class LogDispatchData; +namespace base { +class Storage; +class RegisteredLoggers; +class PerformanceTracker; +class MessageBuilder; +class Writer; +class PErrorWriter; +class LogDispatcher; +class DefaultLogBuilder; +class DefaultLogDispatchCallback; +#if ELPP_ASYNC_LOGGING +class AsyncLogDispatchCallback; +class AsyncDispatchWorker; +#endif // ELPP_ASYNC_LOGGING +class DefaultPerformanceTrackingCallback; +} // namespace base +} // namespace el +/// @brief Easylogging++ entry namespace +namespace el { +/// @brief Namespace containing base/internal functionality used by Easylogging++ +namespace base { +/// @brief Data types used by Easylogging++ +namespace type { +#undef ELPP_LITERAL +#undef ELPP_STRLEN +#undef ELPP_COUT +#if defined(ELPP_UNICODE) +# define ELPP_LITERAL(txt) L##txt +# define ELPP_STRLEN wcslen +# if defined ELPP_CUSTOM_COUT +# define ELPP_COUT ELPP_CUSTOM_COUT +# else +# define ELPP_COUT std::wcout +# endif // defined ELPP_CUSTOM_COUT +typedef wchar_t char_t; +typedef std::wstring string_t; +typedef std::wstringstream stringstream_t; +typedef std::wfstream fstream_t; +typedef std::wostream ostream_t; +#else +# define ELPP_LITERAL(txt) txt +# define ELPP_STRLEN strlen +# if defined ELPP_CUSTOM_COUT +# define ELPP_COUT ELPP_CUSTOM_COUT +# else +# define ELPP_COUT std::cout +# endif // defined ELPP_CUSTOM_COUT +typedef char char_t; +typedef std::string string_t; +typedef std::stringstream stringstream_t; +typedef std::fstream fstream_t; +typedef std::ostream ostream_t; +#endif // defined(ELPP_UNICODE) +#if defined(ELPP_CUSTOM_COUT_LINE) +# define ELPP_COUT_LINE(logLine) ELPP_CUSTOM_COUT_LINE(logLine) +#else +# define ELPP_COUT_LINE(logLine) logLine << std::flush +#endif // defined(ELPP_CUSTOM_COUT_LINE) +typedef unsigned int EnumType; +typedef unsigned short VerboseLevel; +typedef unsigned long int LineNumber; +typedef std::shared_ptr StoragePointer; +typedef std::shared_ptr LogDispatchCallbackPtr; +typedef std::shared_ptr PerformanceTrackingCallbackPtr; +typedef std::shared_ptr LoggerRegistrationCallbackPtr; +typedef std::unique_ptr PerformanceTrackerPtr; +} // namespace type +/// @brief Internal helper class that prevent copy constructor for class +/// +/// @detail When using this class simply inherit it privately +class NoCopy { + protected: + NoCopy(void) {} + private: + NoCopy(const NoCopy&); + NoCopy& operator=(const NoCopy&); +}; +/// @brief Internal helper class that makes all default constructors private. +/// +/// @detail This prevents initializing class making it static unless an explicit constructor is declared. +/// When using this class simply inherit it privately +class StaticClass { + private: + StaticClass(void); + StaticClass(const StaticClass&); + StaticClass& operator=(const StaticClass&); +}; +} // namespace base +/// @brief Represents enumeration for severity level used to determine level of logging +/// +/// @detail With Easylogging++, developers may disable or enable any level regardless of +/// what the severity is. Or they can choose to log using hierarchical logging flag +enum class Level : base::type::EnumType { + /// @brief Generic level that represents all the levels. Useful when setting global configuration for all levels + Global = 1, + /// @brief Information that can be useful to back-trace certain events - mostly useful than debug logs. + Trace = 2, + /// @brief Informational events most useful for developers to debug application + Debug = 4, + /// @brief Severe error information that will presumably abort application + Fatal = 8, + /// @brief Information representing errors in application but application will keep running + Error = 16, + /// @brief Useful when application has potentially harmful situations + Warning = 32, + /// @brief Information that can be highly useful and vary with verbose logging level. + Verbose = 64, + /// @brief Mainly useful to represent current progress of application + Info = 128, + /// @brief Represents unknown level + Unknown = 1010 +}; +} // namespace el +namespace std { +template<> struct hash { + public: + std::size_t operator()(const el::Level& l) const { + return hash {}(static_cast(l)); + } +}; +} +namespace el { +/// @brief Static class that contains helper functions for el::Level +class LevelHelper : base::StaticClass { + public: + /// @brief Represents minimum valid level. Useful when iterating through enum. + static const base::type::EnumType kMinValid = static_cast(Level::Trace); + /// @brief Represents maximum valid level. This is used internally and you should not need it. + static const base::type::EnumType kMaxValid = static_cast(Level::Info); + /// @brief Casts level to int, useful for iterating through enum. + static base::type::EnumType castToInt(Level level) { + return static_cast(level); + } + /// @brief Casts int(ushort) to level, useful for iterating through enum. + static Level castFromInt(base::type::EnumType l) { + return static_cast(l); + } + /// @brief Converts level to associated const char* + /// @return Upper case string based level. + static const char* convertToString(Level level); + /// @brief Converts from levelStr to Level + /// @param levelStr Upper case string based level. + /// Lower case is also valid but providing upper case is recommended. + static Level convertFromString(const char* levelStr); + /// @brief Applies specified function to each level starting from startIndex + /// @param startIndex initial value to start the iteration from. This is passed as pointer and + /// is left-shifted so this can be used inside function (fn) to represent current level. + /// @param fn function to apply with each level. This bool represent whether or not to stop iterating through levels. + static void forEachLevel(base::type::EnumType* startIndex, const std::function& fn); +}; +/// @brief Represents enumeration of ConfigurationType used to configure or access certain aspect +/// of logging +enum class ConfigurationType : base::type::EnumType { + /// @brief Determines whether or not corresponding level and logger of logging is enabled + /// You may disable all logs by using el::Level::Global + Enabled = 1, + /// @brief Whether or not to write corresponding log to log file + ToFile = 2, + /// @brief Whether or not to write corresponding level and logger log to standard output. + /// By standard output meaning termnal, command prompt etc + ToStandardOutput = 4, + /// @brief Determines format of logging corresponding level and logger. + Format = 8, + /// @brief Determines log file (full path) to write logs to for corresponding level and logger + Filename = 16, + /// @brief Specifies precision of the subsecond part. It should be within range (1-6). + SubsecondPrecision = 32, + /// @brief Alias of SubsecondPrecision (for backward compatibility) + MillisecondsWidth = SubsecondPrecision, + /// @brief Determines whether or not performance tracking is enabled. + /// + /// @detail This does not depend on logger or level. Performance tracking always uses 'performance' logger + PerformanceTracking = 64, + /// @brief Specifies log file max size. + /// + /// @detail If file size of corresponding log file (for corresponding level) is >= specified size, log file will + /// be truncated and re-initiated. + MaxLogFileSize = 128, + /// @brief Specifies number of log entries to hold until we flush pending log data + LogFlushThreshold = 256, + /// @brief Represents unknown configuration + Unknown = 1010 +}; +/// @brief Static class that contains helper functions for el::ConfigurationType +class ConfigurationTypeHelper : base::StaticClass { + public: + /// @brief Represents minimum valid configuration type. Useful when iterating through enum. + static const base::type::EnumType kMinValid = static_cast(ConfigurationType::Enabled); + /// @brief Represents maximum valid configuration type. This is used internally and you should not need it. + static const base::type::EnumType kMaxValid = static_cast(ConfigurationType::MaxLogFileSize); + /// @brief Casts configuration type to int, useful for iterating through enum. + static base::type::EnumType castToInt(ConfigurationType configurationType) { + return static_cast(configurationType); + } + /// @brief Casts int(ushort) to configuration type, useful for iterating through enum. + static ConfigurationType castFromInt(base::type::EnumType c) { + return static_cast(c); + } + /// @brief Converts configuration type to associated const char* + /// @returns Upper case string based configuration type. + static const char* convertToString(ConfigurationType configurationType); + /// @brief Converts from configStr to ConfigurationType + /// @param configStr Upper case string based configuration type. + /// Lower case is also valid but providing upper case is recommended. + static ConfigurationType convertFromString(const char* configStr); + /// @brief Applies specified function to each configuration type starting from startIndex + /// @param startIndex initial value to start the iteration from. This is passed by pointer and is left-shifted + /// so this can be used inside function (fn) to represent current configuration type. + /// @param fn function to apply with each configuration type. + /// This bool represent whether or not to stop iterating through configurations. + static inline void forEachConfigType(base::type::EnumType* startIndex, const std::function& fn); +}; +/// @brief Flags used while writing logs. This flags are set by user +enum class LoggingFlag : base::type::EnumType { + /// @brief Makes sure we have new line for each container log entry + NewLineForContainer = 1, + /// @brief Makes sure if -vmodule is used and does not specifies a module, then verbose + /// logging is allowed via that module. + AllowVerboseIfModuleNotSpecified = 2, + /// @brief When handling crashes by default, detailed crash reason will be logged as well + LogDetailedCrashReason = 4, + /// @brief Allows to disable application abortion when logged using FATAL level + DisableApplicationAbortOnFatalLog = 8, + /// @brief Flushes log with every log-entry (performance sensitive) - Disabled by default + ImmediateFlush = 16, + /// @brief Enables strict file rolling + StrictLogFileSizeCheck = 32, + /// @brief Make terminal output colorful for supported terminals + ColoredTerminalOutput = 64, + /// @brief Supports use of multiple logging in same macro, e.g, CLOG(INFO, "default", "network") + MultiLoggerSupport = 128, + /// @brief Disables comparing performance tracker's checkpoints + DisablePerformanceTrackingCheckpointComparison = 256, + /// @brief Disable VModules + DisableVModules = 512, + /// @brief Disable VModules extensions + DisableVModulesExtensions = 1024, + /// @brief Enables hierarchical logging + HierarchicalLogging = 2048, + /// @brief Creates logger automatically when not available + CreateLoggerAutomatically = 4096, + /// @brief Adds spaces b/w logs that separated by left-shift operator + AutoSpacing = 8192, + /// @brief Preserves time format and does not convert it to sec, hour etc (performance tracking only) + FixedTimeFormat = 16384, + // @brief Ignore SIGINT or crash + IgnoreSigInt = 32768, +}; +namespace base { +/// @brief Namespace containing constants used internally. +namespace consts { +static const char kFormatSpecifierCharValue = 'v'; +static const char kFormatSpecifierChar = '%'; +static const unsigned int kMaxLogPerCounter = 100000; +static const unsigned int kMaxLogPerContainer = 100; +static const unsigned int kDefaultSubsecondPrecision = 3; + +#ifdef ELPP_DEFAULT_LOGGER +static const char* kDefaultLoggerId = ELPP_DEFAULT_LOGGER; +#else +static const char* kDefaultLoggerId = "default"; +#endif + +#if defined(ELPP_FEATURE_ALL) || defined(ELPP_FEATURE_PERFORMANCE_TRACKING) +#ifdef ELPP_DEFAULT_PERFORMANCE_LOGGER +static const char* kPerformanceLoggerId = ELPP_DEFAULT_PERFORMANCE_LOGGER; +#else +static const char* kPerformanceLoggerId = "performance"; +#endif // ELPP_DEFAULT_PERFORMANCE_LOGGER +#endif + +#if defined(ELPP_SYSLOG) +static const char* kSysLogLoggerId = "syslog"; +#endif // defined(ELPP_SYSLOG) + +#if ELPP_OS_WINDOWS +static const char* kFilePathSeparator = "\\"; +#else +static const char* kFilePathSeparator = "/"; +#endif // ELPP_OS_WINDOWS + +static const std::size_t kSourceFilenameMaxLength = 100; +static const std::size_t kSourceLineMaxLength = 10; +static const Level kPerformanceTrackerDefaultLevel = Level::Info; +const struct { + double value; + const base::type::char_t* unit; +} kTimeFormats[] = { + { 1000.0f, ELPP_LITERAL("us") }, + { 1000.0f, ELPP_LITERAL("ms") }, + { 60.0f, ELPP_LITERAL("seconds") }, + { 60.0f, ELPP_LITERAL("minutes") }, + { 24.0f, ELPP_LITERAL("hours") }, + { 7.0f, ELPP_LITERAL("days") } +}; +static const int kTimeFormatsCount = sizeof(kTimeFormats) / sizeof(kTimeFormats[0]); +const struct { + int numb; + const char* name; + const char* brief; + const char* detail; +} kCrashSignals[] = { + // NOTE: Do not re-order, if you do please check CrashHandler(bool) constructor and CrashHandler::setHandler(..) + { + SIGABRT, "SIGABRT", "Abnormal termination", + "Program was abnormally terminated." + }, + { + SIGFPE, "SIGFPE", "Erroneous arithmetic operation", + "Arithmetic operation issue such as division by zero or operation resulting in overflow." + }, + { + SIGILL, "SIGILL", "Illegal instruction", + "Generally due to a corruption in the code or to an attempt to execute data." + }, + { + SIGSEGV, "SIGSEGV", "Invalid access to memory", + "Program is trying to read an invalid (unallocated, deleted or corrupted) or inaccessible memory." + }, + { + SIGINT, "SIGINT", "Interactive attention signal", + "Interruption generated (generally) by user or operating system." + }, +}; +static const int kCrashSignalsCount = sizeof(kCrashSignals) / sizeof(kCrashSignals[0]); +} // namespace consts +} // namespace base +typedef std::function PreRollOutCallback; +namespace base { +static inline void defaultPreRollOutCallback(const char*, std::size_t) {} +/// @brief Enum to represent timestamp unit +enum class TimestampUnit : base::type::EnumType { + Microsecond = 0, Millisecond = 1, Second = 2, Minute = 3, Hour = 4, Day = 5 +}; +/// @brief Format flags used to determine specifiers that are active for performance improvements. +enum class FormatFlags : base::type::EnumType { + DateTime = 1 << 1, + LoggerId = 1 << 2, + File = 1 << 3, + Line = 1 << 4, + Location = 1 << 5, + Function = 1 << 6, + User = 1 << 7, + Host = 1 << 8, + LogMessage = 1 << 9, + VerboseLevel = 1 << 10, + AppName = 1 << 11, + ThreadId = 1 << 12, + Level = 1 << 13, + FileBase = 1 << 14, + LevelShort = 1 << 15 +}; +/// @brief A subsecond precision class containing actual width and offset of the subsecond part +class SubsecondPrecision { + public: + SubsecondPrecision(void) { + init(base::consts::kDefaultSubsecondPrecision); + } + explicit SubsecondPrecision(int width) { + init(width); + } + bool operator==(const SubsecondPrecision& ssPrec) { + return m_width == ssPrec.m_width && m_offset == ssPrec.m_offset; + } + int m_width; + unsigned int m_offset; + private: + void init(int width); +}; +/// @brief Type alias of SubsecondPrecision +typedef SubsecondPrecision MillisecondsWidth; +/// @brief Namespace containing utility functions/static classes used internally +namespace utils { +/// @brief Deletes memory safely and points to null +template +static +typename std::enable_if::value, void>::type +safeDelete(T*& pointer) { + if (pointer == nullptr) + return; + delete pointer; + pointer = nullptr; +} +/// @brief Bitwise operations for C++11 strong enum class. This casts e into Flag_T and returns value after bitwise operation +/// Use these function as
flag = bitwise::Or(MyEnum::val1, flag);
+namespace bitwise { +template +static inline base::type::EnumType And(Enum e, base::type::EnumType flag) { + return static_cast(flag) & static_cast(e); +} +template +static inline base::type::EnumType Not(Enum e, base::type::EnumType flag) { + return static_cast(flag) & ~(static_cast(e)); +} +template +static inline base::type::EnumType Or(Enum e, base::type::EnumType flag) { + return static_cast(flag) | static_cast(e); +} +} // namespace bitwise +template +static inline void addFlag(Enum e, base::type::EnumType* flag) { + *flag = base::utils::bitwise::Or(e, *flag); +} +template +static inline void removeFlag(Enum e, base::type::EnumType* flag) { + *flag = base::utils::bitwise::Not(e, *flag); +} +template +static inline bool hasFlag(Enum e, base::type::EnumType flag) { + return base::utils::bitwise::And(e, flag) > 0x0; +} +} // namespace utils +namespace threading { +#if ELPP_THREADING_ENABLED +# if !ELPP_USE_STD_THREADING +namespace internal { +/// @brief A mutex wrapper for compiler that dont yet support std::recursive_mutex +class Mutex : base::NoCopy { + public: + Mutex(void) { +# if ELPP_OS_UNIX + pthread_mutexattr_t attr; + pthread_mutexattr_init(&attr); + pthread_mutexattr_settype(&attr, PTHREAD_MUTEX_RECURSIVE); + pthread_mutex_init(&m_underlyingMutex, &attr); + pthread_mutexattr_destroy(&attr); +# elif ELPP_OS_WINDOWS + InitializeCriticalSection(&m_underlyingMutex); +# endif // ELPP_OS_UNIX + } + + virtual ~Mutex(void) { +# if ELPP_OS_UNIX + pthread_mutex_destroy(&m_underlyingMutex); +# elif ELPP_OS_WINDOWS + DeleteCriticalSection(&m_underlyingMutex); +# endif // ELPP_OS_UNIX + } + + inline void lock(void) { +# if ELPP_OS_UNIX + pthread_mutex_lock(&m_underlyingMutex); +# elif ELPP_OS_WINDOWS + EnterCriticalSection(&m_underlyingMutex); +# endif // ELPP_OS_UNIX + } + + inline bool try_lock(void) { +# if ELPP_OS_UNIX + return (pthread_mutex_trylock(&m_underlyingMutex) == 0); +# elif ELPP_OS_WINDOWS + return TryEnterCriticalSection(&m_underlyingMutex); +# endif // ELPP_OS_UNIX + } + + inline void unlock(void) { +# if ELPP_OS_UNIX + pthread_mutex_unlock(&m_underlyingMutex); +# elif ELPP_OS_WINDOWS + LeaveCriticalSection(&m_underlyingMutex); +# endif // ELPP_OS_UNIX + } + + private: +# if ELPP_OS_UNIX + pthread_mutex_t m_underlyingMutex; +# elif ELPP_OS_WINDOWS + CRITICAL_SECTION m_underlyingMutex; +# endif // ELPP_OS_UNIX +}; +/// @brief Scoped lock for compiler that dont yet support std::lock_guard +template +class ScopedLock : base::NoCopy { + public: + explicit ScopedLock(M& mutex) { + m_mutex = &mutex; + m_mutex->lock(); + } + + virtual ~ScopedLock(void) { + m_mutex->unlock(); + } + private: + M* m_mutex; + ScopedLock(void); +}; +} // namespace internal +typedef base::threading::internal::Mutex Mutex; +typedef base::threading::internal::ScopedLock ScopedLock; +# else +typedef std::recursive_mutex Mutex; +typedef std::lock_guard ScopedLock; +# endif // !ELPP_USE_STD_THREADING +#else +namespace internal { +/// @brief Mutex wrapper used when multi-threading is disabled. +class NoMutex : base::NoCopy { + public: + NoMutex(void) {} + inline void lock(void) {} + inline bool try_lock(void) { + return true; + } + inline void unlock(void) {} +}; +/// @brief Lock guard wrapper used when multi-threading is disabled. +template +class NoScopedLock : base::NoCopy { + public: + explicit NoScopedLock(Mutex&) { + } + virtual ~NoScopedLock(void) { + } + private: + NoScopedLock(void); +}; +} // namespace internal +typedef base::threading::internal::NoMutex Mutex; +typedef base::threading::internal::NoScopedLock ScopedLock; +#endif // ELPP_THREADING_ENABLED +/// @brief Base of thread safe class, this class is inheritable-only +class ThreadSafe { + public: + virtual inline void acquireLock(void) ELPP_FINAL { m_mutex.lock(); } + virtual inline void releaseLock(void) ELPP_FINAL { m_mutex.unlock(); } + virtual inline base::threading::Mutex& lock(void) ELPP_FINAL { return m_mutex; } + protected: + ThreadSafe(void) {} + virtual ~ThreadSafe(void) {} + private: + base::threading::Mutex m_mutex; +}; + +#if ELPP_THREADING_ENABLED +# if !ELPP_USE_STD_THREADING +/// @brief Gets ID of currently running threading in windows systems. On unix, nothing is returned. +static std::string getCurrentThreadId(void) { + std::stringstream ss; +# if (ELPP_OS_WINDOWS) + ss << GetCurrentThreadId(); +# endif // (ELPP_OS_WINDOWS) + return ss.str(); +} +# else +/// @brief Gets ID of currently running threading using std::this_thread::get_id() +static std::string getCurrentThreadId(void) { + std::stringstream ss; + ss << std::this_thread::get_id(); + return ss.str(); +} +# endif // !ELPP_USE_STD_THREADING +#else +static inline std::string getCurrentThreadId(void) { + return std::string(); +} +#endif // ELPP_THREADING_ENABLED +} // namespace threading +namespace utils { +class File : base::StaticClass { + public: + /// @brief Creates new out file stream for specified filename. + /// @return Pointer to newly created fstream or nullptr + static base::type::fstream_t* newFileStream(const std::string& filename); + + /// @brief Gets size of file provided in stream + static std::size_t getSizeOfFile(base::type::fstream_t* fs); + + /// @brief Determines whether or not provided path exist in current file system + static bool pathExists(const char* path, bool considerFile = false); + + /// @brief Creates specified path on file system + /// @param path Path to create. + static bool createPath(const std::string& path); + /// @brief Extracts path of filename with leading slash + static std::string extractPathFromFilename(const std::string& fullPath, + const char* separator = base::consts::kFilePathSeparator); + /// @brief builds stripped filename and puts it in buff + static void buildStrippedFilename(const char* filename, char buff[], + std::size_t limit = base::consts::kSourceFilenameMaxLength); + /// @brief builds base filename and puts it in buff + static void buildBaseFilename(const std::string& fullPath, char buff[], + std::size_t limit = base::consts::kSourceFilenameMaxLength, + const char* separator = base::consts::kFilePathSeparator); +}; +/// @brief String utilities helper class used internally. You should not use it. +class Str : base::StaticClass { + public: + /// @brief Checks if character is digit. Dont use libc implementation of it to prevent locale issues. + static inline bool isDigit(char c) { + return c >= '0' && c <= '9'; + } + + /// @brief Matches wildcards, '*' and '?' only supported. + static bool wildCardMatch(const char* str, const char* pattern); + + static std::string& ltrim(std::string& str); + static std::string& rtrim(std::string& str); + static std::string& trim(std::string& str); + + /// @brief Determines whether or not str starts with specified string + /// @param str String to check + /// @param start String to check against + /// @return Returns true if starts with specified string, false otherwise + static bool startsWith(const std::string& str, const std::string& start); + + /// @brief Determines whether or not str ends with specified string + /// @param str String to check + /// @param end String to check against + /// @return Returns true if ends with specified string, false otherwise + static bool endsWith(const std::string& str, const std::string& end); + + /// @brief Replaces all instances of replaceWhat with 'replaceWith'. Original variable is changed for performance. + /// @param [in,out] str String to replace from + /// @param replaceWhat Character to replace + /// @param replaceWith Character to replace with + /// @return Modified version of str + static std::string& replaceAll(std::string& str, char replaceWhat, char replaceWith); + + /// @brief Replaces all instances of 'replaceWhat' with 'replaceWith'. (String version) Replaces in place + /// @param str String to replace from + /// @param replaceWhat Character to replace + /// @param replaceWith Character to replace with + /// @return Modified (original) str + static std::string& replaceAll(std::string& str, const std::string& replaceWhat, + const std::string& replaceWith); + + static void replaceFirstWithEscape(base::type::string_t& str, const base::type::string_t& replaceWhat, + const base::type::string_t& replaceWith); +#if defined(ELPP_UNICODE) + static void replaceFirstWithEscape(base::type::string_t& str, const base::type::string_t& replaceWhat, + const std::string& replaceWith); +#endif // defined(ELPP_UNICODE) + /// @brief Converts string to uppercase + /// @param str String to convert + /// @return Uppercase string + static std::string& toUpper(std::string& str); + + /// @brief Compares cstring equality - uses strcmp + static bool cStringEq(const char* s1, const char* s2); + + /// @brief Compares cstring equality (case-insensitive) - uses toupper(char) + /// Dont use strcasecmp because of CRT (VC++) + static bool cStringCaseEq(const char* s1, const char* s2); + + /// @brief Returns true if c exist in str + static bool contains(const char* str, char c); + + static char* convertAndAddToBuff(std::size_t n, int len, char* buf, const char* bufLim, bool zeroPadded = true); + static char* addToBuff(const char* str, char* buf, const char* bufLim); + static char* clearBuff(char buff[], std::size_t lim); + + /// @brief Converts wchar* to char* + /// NOTE: Need to free return value after use! + static char* wcharPtrToCharPtr(const wchar_t* line); +}; +/// @brief Operating System helper static class used internally. You should not use it. +class OS : base::StaticClass { + public: +#if ELPP_OS_WINDOWS + /// @brief Gets environment variables for Windows based OS. + /// We are not using getenv(const char*) because of CRT deprecation + /// @param varname Variable name to get environment variable value for + /// @return If variable exist the value of it otherwise nullptr + static const char* getWindowsEnvironmentVariable(const char* varname); +#endif // ELPP_OS_WINDOWS +#if ELPP_OS_ANDROID + /// @brief Reads android property value + static std::string getProperty(const char* prop); + + /// @brief Reads android device name + static std::string getDeviceName(void); +#endif // ELPP_OS_ANDROID + + /// @brief Runs command on terminal and returns the output. + /// + /// @detail This is applicable only on unix based systems, for all other OS, an empty string is returned. + /// @param command Bash command + /// @return Result of bash output or empty string if no result found. + static const std::string getBashOutput(const char* command); + + /// @brief Gets environment variable. This is cross-platform and CRT safe (for VC++) + /// @param variableName Environment variable name + /// @param defaultVal If no environment variable or value found the value to return by default + /// @param alternativeBashCommand If environment variable not found what would be alternative bash command + /// in order to look for value user is looking for. E.g, for 'user' alternative command will 'whoami' + static std::string getEnvironmentVariable(const char* variableName, const char* defaultVal, + const char* alternativeBashCommand = nullptr); + /// @brief Gets current username. + static std::string currentUser(void); + + /// @brief Gets current host name or computer name. + /// + /// @detail For android systems this is device name with its manufacturer and model separated by hyphen + static std::string currentHost(void); + /// @brief Whether or not terminal supports colors + static bool termSupportsColor(void); +}; +/// @brief Contains utilities for cross-platform date/time. This class make use of el::base::utils::Str +class DateTime : base::StaticClass { + public: + /// @brief Cross platform gettimeofday for Windows and unix platform. This can be used to determine current microsecond. + /// + /// @detail For unix system it uses gettimeofday(timeval*, timezone*) and for Windows, a separate implementation is provided + /// @param [in,out] tv Pointer that gets updated + static void gettimeofday(struct timeval* tv); + + /// @brief Gets current date and time with a subsecond part. + /// @param format User provided date/time format + /// @param ssPrec A pointer to base::SubsecondPrecision from configuration (non-null) + /// @returns string based date time in specified format. + static std::string getDateTime(const char* format, const base::SubsecondPrecision* ssPrec); + + /// @brief Converts timeval (struct from ctime) to string using specified format and subsecond precision + static std::string timevalToString(struct timeval tval, const char* format, + const el::base::SubsecondPrecision* ssPrec); + + /// @brief Formats time to get unit accordingly, units like second if > 1000 or minutes if > 60000 etc + static base::type::string_t formatTime(unsigned long long time, base::TimestampUnit timestampUnit); + + /// @brief Gets time difference in milli/micro second depending on timestampUnit + static unsigned long long getTimeDifference(const struct timeval& endTime, const struct timeval& startTime, + base::TimestampUnit timestampUnit); + + + static struct ::tm* buildTimeInfo(struct timeval* currTime, struct ::tm* timeInfo); + private: + static char* parseFormat(char* buf, std::size_t bufSz, const char* format, const struct tm* tInfo, + std::size_t msec, const base::SubsecondPrecision* ssPrec); +}; +/// @brief Command line arguments for application if specified using el::Helpers::setArgs(..) or START_EASYLOGGINGPP(..) +class CommandLineArgs { + public: + CommandLineArgs(void) { + setArgs(0, static_cast(nullptr)); + } + CommandLineArgs(int argc, const char** argv) { + setArgs(argc, argv); + } + CommandLineArgs(int argc, char** argv) { + setArgs(argc, argv); + } + virtual ~CommandLineArgs(void) {} + /// @brief Sets arguments and parses them + inline void setArgs(int argc, const char** argv) { + setArgs(argc, const_cast(argv)); + } + /// @brief Sets arguments and parses them + void setArgs(int argc, char** argv); + /// @brief Returns true if arguments contain paramKey with a value (separated by '=') + bool hasParamWithValue(const char* paramKey) const; + /// @brief Returns value of arguments + /// @see hasParamWithValue(const char*) + const char* getParamValue(const char* paramKey) const; + /// @brief Return true if arguments has a param (not having a value) i,e without '=' + bool hasParam(const char* paramKey) const; + /// @brief Returns true if no params available. This exclude argv[0] + bool empty(void) const; + /// @brief Returns total number of arguments. This exclude argv[0] + std::size_t size(void) const; + friend base::type::ostream_t& operator<<(base::type::ostream_t& os, const CommandLineArgs& c); + + private: + int m_argc; + char** m_argv; + std::unordered_map m_paramsWithValue; + std::vector m_params; +}; +/// @brief Abstract registry (aka repository) that provides basic interface for pointer repository specified by T_Ptr type. +/// +/// @detail Most of the functions are virtual final methods but anything implementing this abstract class should implement +/// unregisterAll() and deepCopy(const AbstractRegistry&) and write registerNew() method according to container +/// and few more methods; get() to find element, unregister() to unregister single entry. +/// Please note that this is thread-unsafe and should also implement thread-safety mechanisms in implementation. +template +class AbstractRegistry : public base::threading::ThreadSafe { + public: + typedef typename Container::iterator iterator; + typedef typename Container::const_iterator const_iterator; + + /// @brief Default constructor + AbstractRegistry(void) {} + + /// @brief Move constructor that is useful for base classes + AbstractRegistry(AbstractRegistry&& sr) { + if (this == &sr) { + return; + } + unregisterAll(); + m_list = std::move(sr.m_list); + } + + bool operator==(const AbstractRegistry& other) { + if (size() != other.size()) { + return false; + } + for (std::size_t i = 0; i < m_list.size(); ++i) { + if (m_list.at(i) != other.m_list.at(i)) { + return false; + } + } + return true; + } + + bool operator!=(const AbstractRegistry& other) { + if (size() != other.size()) { + return true; + } + for (std::size_t i = 0; i < m_list.size(); ++i) { + if (m_list.at(i) != other.m_list.at(i)) { + return true; + } + } + return false; + } + + /// @brief Assignment move operator + AbstractRegistry& operator=(AbstractRegistry&& sr) { + if (this == &sr) { + return *this; + } + unregisterAll(); + m_list = std::move(sr.m_list); + return *this; + } + + virtual ~AbstractRegistry(void) { + } + + /// @return Iterator pointer from start of repository + virtual inline iterator begin(void) ELPP_FINAL { + return m_list.begin(); + } + + /// @return Iterator pointer from end of repository + virtual inline iterator end(void) ELPP_FINAL { + return m_list.end(); + } + + + /// @return Constant iterator pointer from start of repository + virtual inline const_iterator cbegin(void) const ELPP_FINAL { + return m_list.cbegin(); + } + + /// @return End of repository + virtual inline const_iterator cend(void) const ELPP_FINAL { + return m_list.cend(); + } + + /// @return Whether or not repository is empty + virtual inline bool empty(void) const ELPP_FINAL { + return m_list.empty(); + } + + /// @return Size of repository + virtual inline std::size_t size(void) const ELPP_FINAL { + return m_list.size(); + } + + /// @brief Returns underlying container by reference + virtual inline Container& list(void) ELPP_FINAL { + return m_list; + } + + /// @brief Returns underlying container by constant reference. + virtual inline const Container& list(void) const ELPP_FINAL { + return m_list; + } + + /// @brief Unregisters all the pointers from current repository. + virtual void unregisterAll(void) = 0; + + protected: + virtual void deepCopy(const AbstractRegistry&) = 0; + void reinitDeepCopy(const AbstractRegistry& sr) { + unregisterAll(); + deepCopy(sr); + } + + private: + Container m_list; +}; + +/// @brief A pointer registry mechanism to manage memory and provide search functionalities. (non-predicate version) +/// +/// @detail NOTE: This is thread-unsafe implementation (although it contains lock function, it does not use these functions) +/// of AbstractRegistry. Any implementation of this class should be +/// explicitly (by using lock functions) +template +class Registry : public AbstractRegistry> { + public: + typedef typename Registry::iterator iterator; + typedef typename Registry::const_iterator const_iterator; + + Registry(void) {} + + /// @brief Copy constructor that is useful for base classes. Try to avoid this constructor, use move constructor. + Registry(const Registry& sr) : AbstractRegistry>() { + if (this == &sr) { + return; + } + this->reinitDeepCopy(sr); + } + + /// @brief Assignment operator that unregisters all the existing registries and deeply copies each of repo element + /// @see unregisterAll() + /// @see deepCopy(const AbstractRegistry&) + Registry& operator=(const Registry& sr) { + if (this == &sr) { + return *this; + } + this->reinitDeepCopy(sr); + return *this; + } + + virtual ~Registry(void) { + unregisterAll(); + } + + protected: + virtual void unregisterAll(void) ELPP_FINAL { + if (!this->empty()) { + for (auto&& curr : this->list()) { + base::utils::safeDelete(curr.second); + } + this->list().clear(); + } + } + +/// @brief Registers new registry to repository. + virtual void registerNew(const T_Key& uniqKey, T_Ptr* ptr) ELPP_FINAL { + unregister(uniqKey); + this->list().insert(std::make_pair(uniqKey, ptr)); + } + +/// @brief Unregisters single entry mapped to specified unique key + void unregister(const T_Key& uniqKey) { + T_Ptr* existing = get(uniqKey); + if (existing != nullptr) { + this->list().erase(uniqKey); + base::utils::safeDelete(existing); + } + } + +/// @brief Gets pointer from repository. If none found, nullptr is returned. + T_Ptr* get(const T_Key& uniqKey) { + iterator it = this->list().find(uniqKey); + return it == this->list().end() + ? nullptr + : it->second; + } + + private: + virtual void deepCopy(const AbstractRegistry>& sr) ELPP_FINAL { + for (const_iterator it = sr.cbegin(); it != sr.cend(); ++it) { + registerNew(it->first, new T_Ptr(*it->second)); + } + } +}; + +/// @brief A pointer registry mechanism to manage memory and provide search functionalities. (predicate version) +/// +/// @detail NOTE: This is thread-unsafe implementation of AbstractRegistry. Any implementation of this class +/// should be made thread-safe explicitly +template +class RegistryWithPred : public AbstractRegistry> { + public: + typedef typename RegistryWithPred::iterator iterator; + typedef typename RegistryWithPred::const_iterator const_iterator; + + RegistryWithPred(void) { + } + + virtual ~RegistryWithPred(void) { + unregisterAll(); + } + + /// @brief Copy constructor that is useful for base classes. Try to avoid this constructor, use move constructor. + RegistryWithPred(const RegistryWithPred& sr) : AbstractRegistry>() { + if (this == &sr) { + return; + } + this->reinitDeepCopy(sr); + } + + /// @brief Assignment operator that unregisters all the existing registries and deeply copies each of repo element + /// @see unregisterAll() + /// @see deepCopy(const AbstractRegistry&) + RegistryWithPred& operator=(const RegistryWithPred& sr) { + if (this == &sr) { + return *this; + } + this->reinitDeepCopy(sr); + return *this; + } + + friend base::type::ostream_t& operator<<(base::type::ostream_t& os, const RegistryWithPred& sr) { + for (const_iterator it = sr.list().begin(); it != sr.list().end(); ++it) { + os << ELPP_LITERAL(" ") << **it << ELPP_LITERAL("\n"); + } + return os; + } + + protected: + virtual void unregisterAll(void) ELPP_FINAL { + if (!this->empty()) { + for (auto&& curr : this->list()) { + base::utils::safeDelete(curr); + } + this->list().clear(); + } + } + + virtual void unregister(T_Ptr*& ptr) ELPP_FINAL { + if (ptr) { + iterator iter = this->begin(); + for (; iter != this->end(); ++iter) { + if (ptr == *iter) { + break; + } + } + if (iter != this->end() && *iter != nullptr) { + this->list().erase(iter); + base::utils::safeDelete(*iter); + } + } + } + + virtual inline void registerNew(T_Ptr* ptr) ELPP_FINAL { + this->list().push_back(ptr); + } + +/// @brief Gets pointer from repository with specified arguments. Arguments are passed to predicate +/// in order to validate pointer. + template + T_Ptr* get(const T& arg1, const T2 arg2) { + iterator iter = std::find_if(this->list().begin(), this->list().end(), Pred(arg1, arg2)); + if (iter != this->list().end() && *iter != nullptr) { + return *iter; + } + return nullptr; + } + + private: + virtual void deepCopy(const AbstractRegistry>& sr) { + for (const_iterator it = sr.list().begin(); it != sr.list().end(); ++it) { + registerNew(new T_Ptr(**it)); + } + } +}; +class Utils { + public: + template + static bool installCallback(const std::string& id, std::unordered_map* mapT) { + if (mapT->find(id) == mapT->end()) { + mapT->insert(std::make_pair(id, TPtr(new T()))); + return true; + } + return false; + } + + template + static void uninstallCallback(const std::string& id, std::unordered_map* mapT) { + if (mapT->find(id) != mapT->end()) { + mapT->erase(id); + } + } + + template + static T* callback(const std::string& id, std::unordered_map* mapT) { + typename std::unordered_map::iterator iter = mapT->find(id); + if (iter != mapT->end()) { + return static_cast(iter->second.get()); + } + return nullptr; + } +}; +} // namespace utils +} // namespace base +/// @brief Base of Easylogging++ friendly class +/// +/// @detail After inheriting this class publicly, implement pure-virtual function `void log(std::ostream&) const` +class Loggable { + public: + virtual ~Loggable(void) {} + virtual void log(el::base::type::ostream_t&) const = 0; + private: + friend inline el::base::type::ostream_t& operator<<(el::base::type::ostream_t& os, const Loggable& loggable) { + loggable.log(os); + return os; + } +}; +namespace base { +/// @brief Represents log format containing flags and date format. This is used internally to start initial log +class LogFormat : public Loggable { + public: + LogFormat(void); + LogFormat(Level level, const base::type::string_t& format); + LogFormat(const LogFormat& logFormat); + LogFormat(LogFormat&& logFormat); + LogFormat& operator=(const LogFormat& logFormat); + virtual ~LogFormat(void) {} + bool operator==(const LogFormat& other); + + /// @brief Updates format to be used while logging. + /// @param userFormat User provided format + void parseFromFormat(const base::type::string_t& userFormat); + + inline Level level(void) const { + return m_level; + } + + inline const base::type::string_t& userFormat(void) const { + return m_userFormat; + } + + inline const base::type::string_t& format(void) const { + return m_format; + } + + inline const std::string& dateTimeFormat(void) const { + return m_dateTimeFormat; + } + + inline base::type::EnumType flags(void) const { + return m_flags; + } + + inline bool hasFlag(base::FormatFlags flag) const { + return base::utils::hasFlag(flag, m_flags); + } + + virtual void log(el::base::type::ostream_t& os) const { + os << m_format; + } + + protected: + /// @brief Updates date time format if available in currFormat. + /// @param index Index where %datetime, %date or %time was found + /// @param [in,out] currFormat current format that is being used to format + virtual void updateDateFormat(std::size_t index, base::type::string_t& currFormat) ELPP_FINAL; + + /// @brief Updates %level from format. This is so that we dont have to do it at log-writing-time. It uses m_format and m_level + virtual void updateFormatSpec(void) ELPP_FINAL; + + inline void addFlag(base::FormatFlags flag) { + base::utils::addFlag(flag, &m_flags); + } + + private: + Level m_level; + base::type::string_t m_userFormat; + base::type::string_t m_format; + std::string m_dateTimeFormat; + base::type::EnumType m_flags; + std::string m_currentUser; + std::string m_currentHost; + friend class el::Logger; // To resolve loggerId format specifier easily +}; +} // namespace base +/// @brief Resolving function for format specifier +typedef std::function FormatSpecifierValueResolver; +/// @brief User-provided custom format specifier +/// @see el::Helpers::installCustomFormatSpecifier +/// @see FormatSpecifierValueResolver +class CustomFormatSpecifier { + public: + CustomFormatSpecifier(const char* formatSpecifier, const FormatSpecifierValueResolver& resolver) : + m_formatSpecifier(formatSpecifier), m_resolver(resolver) {} + inline const char* formatSpecifier(void) const { + return m_formatSpecifier; + } + inline const FormatSpecifierValueResolver& resolver(void) const { + return m_resolver; + } + inline bool operator==(const char* formatSpecifier) { + return strcmp(m_formatSpecifier, formatSpecifier) == 0; + } + + private: + const char* m_formatSpecifier; + FormatSpecifierValueResolver m_resolver; +}; +/// @brief Represents single configuration that has representing level, configuration type and a string based value. +/// +/// @detail String based value means any value either its boolean, integer or string itself, it will be embedded inside quotes +/// and will be parsed later. +/// +/// Consider some examples below: +/// * el::Configuration confEnabledInfo(el::Level::Info, el::ConfigurationType::Enabled, "true"); +/// * el::Configuration confMaxLogFileSizeInfo(el::Level::Info, el::ConfigurationType::MaxLogFileSize, "2048"); +/// * el::Configuration confFilenameInfo(el::Level::Info, el::ConfigurationType::Filename, "/var/log/my.log"); +class Configuration : public Loggable { + public: + Configuration(const Configuration& c); + Configuration& operator=(const Configuration& c); + + virtual ~Configuration(void) { + } + + /// @brief Full constructor used to sets value of configuration + Configuration(Level level, ConfigurationType configurationType, const std::string& value); + + /// @brief Gets level of current configuration + inline Level level(void) const { + return m_level; + } + + /// @brief Gets configuration type of current configuration + inline ConfigurationType configurationType(void) const { + return m_configurationType; + } + + /// @brief Gets string based configuration value + inline const std::string& value(void) const { + return m_value; + } + + /// @brief Set string based configuration value + /// @param value Value to set. Values have to be std::string; For boolean values use "true", "false", for any integral values + /// use them in quotes. They will be parsed when configuring + inline void setValue(const std::string& value) { + m_value = value; + } + + virtual void log(el::base::type::ostream_t& os) const; + + /// @brief Used to find configuration from configuration (pointers) repository. Avoid using it. + class Predicate { + public: + Predicate(Level level, ConfigurationType configurationType); + + bool operator()(const Configuration* conf) const; + + private: + Level m_level; + ConfigurationType m_configurationType; + }; + + private: + Level m_level; + ConfigurationType m_configurationType; + std::string m_value; +}; + +/// @brief Thread-safe Configuration repository +/// +/// @detail This repository represents configurations for all the levels and configuration type mapped to a value. +class Configurations : public base::utils::RegistryWithPred { + public: + /// @brief Default constructor with empty repository + Configurations(void); + + /// @brief Constructor used to set configurations using configuration file. + /// @param configurationFile Full path to configuration file + /// @param useDefaultsForRemaining Lets you set the remaining configurations to default. + /// @param base If provided, this configuration will be based off existing repository that this argument is pointing to. + /// @see parseFromFile(const std::string&, Configurations* base) + /// @see setRemainingToDefault() + Configurations(const std::string& configurationFile, bool useDefaultsForRemaining = true, + Configurations* base = nullptr); + + virtual ~Configurations(void) { + } + + /// @brief Parses configuration from file. + /// @param configurationFile Full path to configuration file + /// @param base Configurations to base new configuration repository off. This value is used when you want to use + /// existing Configurations to base all the values and then set rest of configuration via configuration file. + /// @return True if successfully parsed, false otherwise. You may define 'ELPP_DEBUG_ASSERT_FAILURE' to make sure you + /// do not proceed without successful parse. + bool parseFromFile(const std::string& configurationFile, Configurations* base = nullptr); + + /// @brief Parse configurations from configuration string. + /// + /// @detail This configuration string has same syntax as configuration file contents. Make sure all the necessary + /// new line characters are provided. + /// @param base Configurations to base new configuration repository off. This value is used when you want to use + /// existing Configurations to base all the values and then set rest of configuration via configuration text. + /// @return True if successfully parsed, false otherwise. You may define 'ELPP_DEBUG_ASSERT_FAILURE' to make sure you + /// do not proceed without successful parse. + bool parseFromText(const std::string& configurationsString, Configurations* base = nullptr); + + /// @brief Sets configuration based-off an existing configurations. + /// @param base Pointer to existing configurations. + void setFromBase(Configurations* base); + + /// @brief Determines whether or not specified configuration type exists in the repository. + /// + /// @detail Returns as soon as first level is found. + /// @param configurationType Type of configuration to check existence for. + bool hasConfiguration(ConfigurationType configurationType); + + /// @brief Determines whether or not specified configuration type exists for specified level + /// @param level Level to check + /// @param configurationType Type of configuration to check existence for. + bool hasConfiguration(Level level, ConfigurationType configurationType); + + /// @brief Sets value of configuration for specified level. + /// + /// @detail Any existing configuration for specified level will be replaced. Also note that configuration types + /// ConfigurationType::SubsecondPrecision and ConfigurationType::PerformanceTracking will be ignored if not set for + /// Level::Global because these configurations are not dependant on level. + /// @param level Level to set configuration for (el::Level). + /// @param configurationType Type of configuration (el::ConfigurationType) + /// @param value A string based value. Regardless of what the data type of configuration is, it will always be string + /// from users' point of view. This is then parsed later to be used internally. + /// @see Configuration::setValue(const std::string& value) + /// @see el::Level + /// @see el::ConfigurationType + void set(Level level, ConfigurationType configurationType, const std::string& value); + + /// @brief Sets single configuration based on other single configuration. + /// @see set(Level level, ConfigurationType configurationType, const std::string& value) + void set(Configuration* conf); + + inline Configuration* get(Level level, ConfigurationType configurationType) { + base::threading::ScopedLock scopedLock(lock()); + return RegistryWithPred::get(level, configurationType); + } + + /// @brief Sets configuration for all levels. + /// @param configurationType Type of configuration + /// @param value String based value + /// @see Configurations::set(Level level, ConfigurationType configurationType, const std::string& value) + inline void setGlobally(ConfigurationType configurationType, const std::string& value) { + setGlobally(configurationType, value, false); + } + + /// @brief Clears repository so that all the configurations are unset + inline void clear(void) { + base::threading::ScopedLock scopedLock(lock()); + unregisterAll(); + } + + /// @brief Gets configuration file used in parsing this configurations. + /// + /// @detail If this repository was set manually or by text this returns empty string. + inline const std::string& configurationFile(void) const { + return m_configurationFile; + } + + /// @brief Sets configurations to "factory based" configurations. + void setToDefault(void); + + /// @brief Lets you set the remaining configurations to default. + /// + /// @detail By remaining, it means that the level/type a configuration does not exist for. + /// This function is useful when you want to minimize chances of failures, e.g, if you have a configuration file that sets + /// configuration for all the configurations except for Enabled or not, we use this so that ENABLED is set to default i.e, + /// true. If you dont do this explicitly (either by calling this function or by using second param in Constructor + /// and try to access a value, an error is thrown + void setRemainingToDefault(void); + + /// @brief Parser used internally to parse configurations from file or text. + /// + /// @detail This class makes use of base::utils::Str. + /// You should not need this unless you are working on some tool for Easylogging++ + class Parser : base::StaticClass { + public: + /// @brief Parses configuration from file. + /// @param configurationFile Full path to configuration file + /// @param sender Sender configurations pointer. Usually 'this' is used from calling class + /// @param base Configurations to base new configuration repository off. This value is used when you want to use + /// existing Configurations to base all the values and then set rest of configuration via configuration file. + /// @return True if successfully parsed, false otherwise. You may define '_STOP_ON_FIRSTELPP_ASSERTION' to make sure you + /// do not proceed without successful parse. + static bool parseFromFile(const std::string& configurationFile, Configurations* sender, + Configurations* base = nullptr); + + /// @brief Parse configurations from configuration string. + /// + /// @detail This configuration string has same syntax as configuration file contents. Make sure all the necessary + /// new line characters are provided. You may define '_STOP_ON_FIRSTELPP_ASSERTION' to make sure you + /// do not proceed without successful parse (This is recommended) + /// @param configurationsString the configuration in plain text format + /// @param sender Sender configurations pointer. Usually 'this' is used from calling class + /// @param base Configurations to base new configuration repository off. This value is used when you want to use + /// existing Configurations to base all the values and then set rest of configuration via configuration text. + /// @return True if successfully parsed, false otherwise. + static bool parseFromText(const std::string& configurationsString, Configurations* sender, + Configurations* base = nullptr); + + private: + friend class el::Loggers; + static void ignoreComments(std::string* line); + static bool isLevel(const std::string& line); + static bool isComment(const std::string& line); + static inline bool isConfig(const std::string& line); + static bool parseLine(std::string* line, std::string* currConfigStr, std::string* currLevelStr, Level* currLevel, + Configurations* conf); + }; + + private: + std::string m_configurationFile; + bool m_isFromFile; + friend class el::Loggers; + + /// @brief Unsafely sets configuration if does not already exist + void unsafeSetIfNotExist(Level level, ConfigurationType configurationType, const std::string& value); + + /// @brief Thread unsafe set + void unsafeSet(Level level, ConfigurationType configurationType, const std::string& value); + + /// @brief Sets configurations for all levels including Level::Global if includeGlobalLevel is true + /// @see Configurations::setGlobally(ConfigurationType configurationType, const std::string& value) + void setGlobally(ConfigurationType configurationType, const std::string& value, bool includeGlobalLevel); + + /// @brief Sets configurations (Unsafely) for all levels including Level::Global if includeGlobalLevel is true + /// @see Configurations::setGlobally(ConfigurationType configurationType, const std::string& value) + void unsafeSetGlobally(ConfigurationType configurationType, const std::string& value, bool includeGlobalLevel); +}; + +namespace base { +typedef std::shared_ptr FileStreamPtr; +typedef std::unordered_map LogStreamsReferenceMap; +typedef std::shared_ptr LogStreamsReferenceMapPtr; +/// @brief Configurations with data types. +/// +/// @detail el::Configurations have string based values. This is whats used internally in order to read correct configurations. +/// This is to perform faster while writing logs using correct configurations. +/// +/// This is thread safe and final class containing non-virtual destructor (means nothing should inherit this class) +class TypedConfigurations : public base::threading::ThreadSafe { + public: + /// @brief Constructor to initialize (construct) the object off el::Configurations + /// @param configurations Configurations pointer/reference to base this typed configurations off. + /// @param logStreamsReference Use ELPP->registeredLoggers()->logStreamsReference() + TypedConfigurations(Configurations* configurations, LogStreamsReferenceMapPtr logStreamsReference); + + TypedConfigurations(const TypedConfigurations& other); + + virtual ~TypedConfigurations(void) { + } + + const Configurations* configurations(void) const { + return m_configurations; + } + + bool enabled(Level level); + bool toFile(Level level); + const std::string& filename(Level level); + bool toStandardOutput(Level level); + const base::LogFormat& logFormat(Level level); + const base::SubsecondPrecision& subsecondPrecision(Level level = Level::Global); + const base::MillisecondsWidth& millisecondsWidth(Level level = Level::Global); + bool performanceTracking(Level level = Level::Global); + base::type::fstream_t* fileStream(Level level); + std::size_t maxLogFileSize(Level level); + std::size_t logFlushThreshold(Level level); + + private: + Configurations* m_configurations; + std::unordered_map m_enabledMap; + std::unordered_map m_toFileMap; + std::unordered_map m_filenameMap; + std::unordered_map m_toStandardOutputMap; + std::unordered_map m_logFormatMap; + std::unordered_map m_subsecondPrecisionMap; + std::unordered_map m_performanceTrackingMap; + std::unordered_map m_fileStreamMap; + std::unordered_map m_maxLogFileSizeMap; + std::unordered_map m_logFlushThresholdMap; + LogStreamsReferenceMapPtr m_logStreamsReference = nullptr; + + friend class el::Helpers; + friend class el::base::MessageBuilder; + friend class el::base::Writer; + friend class el::base::DefaultLogDispatchCallback; + friend class el::base::LogDispatcher; + + template + inline Conf_T getConfigByVal(Level level, const std::unordered_map* confMap, const char* confName) { + base::threading::ScopedLock scopedLock(lock()); + return unsafeGetConfigByVal(level, confMap, confName); // This is not unsafe anymore - mutex locked in scope + } + + template + inline Conf_T& getConfigByRef(Level level, std::unordered_map* confMap, const char* confName) { + base::threading::ScopedLock scopedLock(lock()); + return unsafeGetConfigByRef(level, confMap, confName); // This is not unsafe anymore - mutex locked in scope + } + + template + Conf_T unsafeGetConfigByVal(Level level, const std::unordered_map* confMap, const char* confName) { + ELPP_UNUSED(confName); + typename std::unordered_map::const_iterator it = confMap->find(level); + if (it == confMap->end()) { + try { + return confMap->at(Level::Global); + } catch (...) { + ELPP_INTERNAL_ERROR("Unable to get configuration [" << confName << "] for level [" + << LevelHelper::convertToString(level) << "]" + << std::endl << "Please ensure you have properly configured logger.", false); + return Conf_T(); + } + } + return it->second; + } + + template + Conf_T& unsafeGetConfigByRef(Level level, std::unordered_map* confMap, const char* confName) { + ELPP_UNUSED(confName); + typename std::unordered_map::iterator it = confMap->find(level); + if (it == confMap->end()) { + try { + return confMap->at(Level::Global); + } catch (...) { + ELPP_INTERNAL_ERROR("Unable to get configuration [" << confName << "] for level [" + << LevelHelper::convertToString(level) << "]" + << std::endl << "Please ensure you have properly configured logger.", false); + } + } + return it->second; + } + + template + void setValue(Level level, const Conf_T& value, std::unordered_map* confMap, + bool includeGlobalLevel = true) { + // If map is empty and we are allowed to add into generic level (Level::Global), do it! + if (confMap->empty() && includeGlobalLevel) { + confMap->insert(std::make_pair(Level::Global, value)); + return; + } + // If same value exist in generic level already, dont add it to explicit level + typename std::unordered_map::iterator it = confMap->find(Level::Global); + if (it != confMap->end() && it->second == value) { + return; + } + // Now make sure we dont double up values if we really need to add it to explicit level + it = confMap->find(level); + if (it == confMap->end()) { + // Value not found for level, add new + confMap->insert(std::make_pair(level, value)); + } else { + // Value found, just update value + confMap->at(level) = value; + } + } + + void build(Configurations* configurations); + unsigned long getULong(std::string confVal); + std::string resolveFilename(const std::string& filename); + void insertFile(Level level, const std::string& fullFilename); + bool unsafeValidateFileRolling(Level level, const PreRollOutCallback& preRollOutCallback); + + inline bool validateFileRolling(Level level, const PreRollOutCallback& preRollOutCallback) { + base::threading::ScopedLock scopedLock(lock()); + return unsafeValidateFileRolling(level, preRollOutCallback); + } +}; +/// @brief Class that keeps record of current line hit for occasional logging +class HitCounter { + public: + HitCounter(void) : + m_filename(""), + m_lineNumber(0), + m_hitCounts(0) { + } + + HitCounter(const char* filename, base::type::LineNumber lineNumber) : + m_filename(filename), + m_lineNumber(lineNumber), + m_hitCounts(0) { + } + + HitCounter(const HitCounter& hitCounter) : + m_filename(hitCounter.m_filename), + m_lineNumber(hitCounter.m_lineNumber), + m_hitCounts(hitCounter.m_hitCounts) { + } + + HitCounter& operator=(const HitCounter& hitCounter) { + if (&hitCounter != this) { + m_filename = hitCounter.m_filename; + m_lineNumber = hitCounter.m_lineNumber; + m_hitCounts = hitCounter.m_hitCounts; + } + return *this; + } + + virtual ~HitCounter(void) { + } + + /// @brief Resets location of current hit counter + inline void resetLocation(const char* filename, base::type::LineNumber lineNumber) { + m_filename = filename; + m_lineNumber = lineNumber; + } + + /// @brief Validates hit counts and resets it if necessary + inline void validateHitCounts(std::size_t n) { + if (m_hitCounts >= base::consts::kMaxLogPerCounter) { + m_hitCounts = (n >= 1 ? base::consts::kMaxLogPerCounter % n : 0); + } + ++m_hitCounts; + } + + inline const char* filename(void) const { + return m_filename; + } + + inline base::type::LineNumber lineNumber(void) const { + return m_lineNumber; + } + + inline std::size_t hitCounts(void) const { + return m_hitCounts; + } + + inline void increment(void) { + ++m_hitCounts; + } + + class Predicate { + public: + Predicate(const char* filename, base::type::LineNumber lineNumber) + : m_filename(filename), + m_lineNumber(lineNumber) { + } + inline bool operator()(const HitCounter* counter) { + return ((counter != nullptr) && + (strcmp(counter->m_filename, m_filename) == 0) && + (counter->m_lineNumber == m_lineNumber)); + } + + private: + const char* m_filename; + base::type::LineNumber m_lineNumber; + }; + + private: + const char* m_filename; + base::type::LineNumber m_lineNumber; + std::size_t m_hitCounts; +}; +/// @brief Repository for hit counters used across the application +class RegisteredHitCounters : public base::utils::RegistryWithPred { + public: + /// @brief Validates counter for every N, i.e, registers new if does not exist otherwise updates original one + /// @return True if validation resulted in triggering hit. Meaning logs should be written everytime true is returned + bool validateEveryN(const char* filename, base::type::LineNumber lineNumber, std::size_t n); + + /// @brief Validates counter for hits >= N, i.e, registers new if does not exist otherwise updates original one + /// @return True if validation resulted in triggering hit. Meaning logs should be written everytime true is returned + bool validateAfterN(const char* filename, base::type::LineNumber lineNumber, std::size_t n); + + /// @brief Validates counter for hits are <= n, i.e, registers new if does not exist otherwise updates original one + /// @return True if validation resulted in triggering hit. Meaning logs should be written everytime true is returned + bool validateNTimes(const char* filename, base::type::LineNumber lineNumber, std::size_t n); + + /// @brief Gets hit counter registered at specified position + inline const base::HitCounter* getCounter(const char* filename, base::type::LineNumber lineNumber) { + base::threading::ScopedLock scopedLock(lock()); + return get(filename, lineNumber); + } +}; +/// @brief Action to be taken for dispatching +enum class DispatchAction : base::type::EnumType { + None = 1, NormalLog = 2, SysLog = 4 +}; +} // namespace base +template +class Callback : protected base::threading::ThreadSafe { + public: + Callback(void) : m_enabled(true) {} + inline bool enabled(void) const { + return m_enabled; + } + inline void setEnabled(bool enabled) { + base::threading::ScopedLock scopedLock(lock()); + m_enabled = enabled; + } + protected: + virtual void handle(const T* handlePtr) = 0; + private: + bool m_enabled; +}; +class LogDispatchData { + public: + LogDispatchData() : m_logMessage(nullptr), m_dispatchAction(base::DispatchAction::None) {} + inline const LogMessage* logMessage(void) const { + return m_logMessage; + } + inline base::DispatchAction dispatchAction(void) const { + return m_dispatchAction; + } + inline void setLogMessage(LogMessage* logMessage) { + m_logMessage = logMessage; + } + inline void setDispatchAction(base::DispatchAction dispatchAction) { + m_dispatchAction = dispatchAction; + } + private: + LogMessage* m_logMessage; + base::DispatchAction m_dispatchAction; + friend class base::LogDispatcher; + +}; +class LogDispatchCallback : public Callback { + protected: + virtual void handle(const LogDispatchData* data); + base::threading::Mutex& fileHandle(const LogDispatchData* data); + private: + friend class base::LogDispatcher; + std::unordered_map> m_fileLocks; + base::threading::Mutex m_fileLocksMapLock; +}; +class PerformanceTrackingCallback : public Callback { + private: + friend class base::PerformanceTracker; +}; +class LoggerRegistrationCallback : public Callback { + private: + friend class base::RegisteredLoggers; +}; +class LogBuilder : base::NoCopy { + public: + LogBuilder() : m_termSupportsColor(base::utils::OS::termSupportsColor()) {} + virtual ~LogBuilder(void) { + ELPP_INTERNAL_INFO(3, "Destroying log builder...") + } + virtual base::type::string_t build(const LogMessage* logMessage, bool appendNewLine) const = 0; + void convertToColoredOutput(base::type::string_t* logLine, Level level); + private: + bool m_termSupportsColor; + friend class el::base::DefaultLogDispatchCallback; +}; +typedef std::shared_ptr LogBuilderPtr; +/// @brief Represents a logger holding ID and configurations we need to write logs +/// +/// @detail This class does not write logs itself instead its used by writer to read configurations from. +class Logger : public base::threading::ThreadSafe, public Loggable { + public: + Logger(const std::string& id, base::LogStreamsReferenceMapPtr logStreamsReference); + Logger(const std::string& id, const Configurations& configurations, base::LogStreamsReferenceMapPtr logStreamsReference); + Logger(const Logger& logger); + Logger& operator=(const Logger& logger); + + virtual ~Logger(void) { + base::utils::safeDelete(m_typedConfigurations); + } + + virtual inline void log(el::base::type::ostream_t& os) const { + os << m_id.c_str(); + } + + /// @brief Configures the logger using specified configurations. + void configure(const Configurations& configurations); + + /// @brief Reconfigures logger using existing configurations + void reconfigure(void); + + inline const std::string& id(void) const { + return m_id; + } + + inline const std::string& parentApplicationName(void) const { + return m_parentApplicationName; + } + + inline void setParentApplicationName(const std::string& parentApplicationName) { + m_parentApplicationName = parentApplicationName; + } + + inline Configurations* configurations(void) { + return &m_configurations; + } + + inline base::TypedConfigurations* typedConfigurations(void) { + return m_typedConfigurations; + } + + static bool isValidId(const std::string& id); + + /// @brief Flushes logger to sync all log files for all levels + void flush(void); + + void flush(Level level, base::type::fstream_t* fs); + + inline bool isFlushNeeded(Level level) { + return ++m_unflushedCount.find(level)->second >= m_typedConfigurations->logFlushThreshold(level); + } + + inline LogBuilder* logBuilder(void) const { + return m_logBuilder.get(); + } + + inline void setLogBuilder(const LogBuilderPtr& logBuilder) { + m_logBuilder = logBuilder; + } + + inline bool enabled(Level level) const { + return m_typedConfigurations->enabled(level); + } + +#if ELPP_VARIADIC_TEMPLATES_SUPPORTED +# define LOGGER_LEVEL_WRITERS_SIGNATURES(FUNCTION_NAME)\ +template \ +inline void FUNCTION_NAME(const char*, const T&, const Args&...);\ +template \ +inline void FUNCTION_NAME(const T&); + + template + inline void verbose(int, const char*, const T&, const Args&...); + + template + inline void verbose(int, const T&); + + LOGGER_LEVEL_WRITERS_SIGNATURES(info) + LOGGER_LEVEL_WRITERS_SIGNATURES(debug) + LOGGER_LEVEL_WRITERS_SIGNATURES(warn) + LOGGER_LEVEL_WRITERS_SIGNATURES(error) + LOGGER_LEVEL_WRITERS_SIGNATURES(fatal) + LOGGER_LEVEL_WRITERS_SIGNATURES(trace) +# undef LOGGER_LEVEL_WRITERS_SIGNATURES +#endif // ELPP_VARIADIC_TEMPLATES_SUPPORTED + private: + std::string m_id; + base::TypedConfigurations* m_typedConfigurations; + base::type::stringstream_t m_stream; + std::string m_parentApplicationName; + bool m_isConfigured; + Configurations m_configurations; + std::unordered_map m_unflushedCount; + base::LogStreamsReferenceMapPtr m_logStreamsReference = nullptr; + LogBuilderPtr m_logBuilder; + + friend class el::LogMessage; + friend class el::Loggers; + friend class el::Helpers; + friend class el::base::RegisteredLoggers; + friend class el::base::DefaultLogDispatchCallback; + friend class el::base::MessageBuilder; + friend class el::base::Writer; + friend class el::base::PErrorWriter; + friend class el::base::Storage; + friend class el::base::PerformanceTracker; + friend class el::base::LogDispatcher; + + Logger(void); + +#if ELPP_VARIADIC_TEMPLATES_SUPPORTED + template + void log_(Level, int, const char*, const T&, const Args&...); + + template + inline void log_(Level, int, const T&); + + template + void log(Level, const char*, const T&, const Args&...); + + template + inline void log(Level, const T&); +#endif // ELPP_VARIADIC_TEMPLATES_SUPPORTED + + void initUnflushedCount(void); + + inline base::type::stringstream_t& stream(void) { + return m_stream; + } + + void resolveLoggerFormatSpec(void) const; +}; +namespace base { +/// @brief Loggers repository +class RegisteredLoggers : public base::utils::Registry { + public: + explicit RegisteredLoggers(const LogBuilderPtr& defaultLogBuilder); + + virtual ~RegisteredLoggers(void) { + unsafeFlushAll(); + } + + inline void setDefaultConfigurations(const Configurations& configurations) { + base::threading::ScopedLock scopedLock(lock()); + m_defaultConfigurations.setFromBase(const_cast(&configurations)); + } + + inline Configurations* defaultConfigurations(void) { + return &m_defaultConfigurations; + } + + Logger* get(const std::string& id, bool forceCreation = true); + + template + inline bool installLoggerRegistrationCallback(const std::string& id) { + return base::utils::Utils::installCallback(id, + &m_loggerRegistrationCallbacks); + } + + template + inline void uninstallLoggerRegistrationCallback(const std::string& id) { + base::utils::Utils::uninstallCallback(id, &m_loggerRegistrationCallbacks); + } + + template + inline T* loggerRegistrationCallback(const std::string& id) { + return base::utils::Utils::callback(id, &m_loggerRegistrationCallbacks); + } + + bool remove(const std::string& id); + + inline bool has(const std::string& id) { + return get(id, false) != nullptr; + } + + inline void unregister(Logger*& logger) { + base::threading::ScopedLock scopedLock(lock()); + base::utils::Registry::unregister(logger->id()); + } + + inline LogStreamsReferenceMapPtr logStreamsReference(void) { + return m_logStreamsReference; + } + + inline void flushAll(void) { + base::threading::ScopedLock scopedLock(lock()); + unsafeFlushAll(); + } + + inline void setDefaultLogBuilder(LogBuilderPtr& logBuilderPtr) { + base::threading::ScopedLock scopedLock(lock()); + m_defaultLogBuilder = logBuilderPtr; + } + + private: + LogBuilderPtr m_defaultLogBuilder; + Configurations m_defaultConfigurations; + base::LogStreamsReferenceMapPtr m_logStreamsReference = nullptr; + std::unordered_map m_loggerRegistrationCallbacks; + friend class el::base::Storage; + + void unsafeFlushAll(void); +}; +/// @brief Represents registries for verbose logging +class VRegistry : base::NoCopy, public base::threading::ThreadSafe { + public: + explicit VRegistry(base::type::VerboseLevel level, base::type::EnumType* pFlags); + + /// @brief Sets verbose level. Accepted range is 0-9 + void setLevel(base::type::VerboseLevel level); + + inline base::type::VerboseLevel level(void) const { + return m_level; + } + + inline void clearModules(void) { + base::threading::ScopedLock scopedLock(lock()); + m_modules.clear(); + } + + void setModules(const char* modules); + + bool allowed(base::type::VerboseLevel vlevel, const char* file); + + inline const std::unordered_map& modules(void) const { + return m_modules; + } + + void setFromArgs(const base::utils::CommandLineArgs* commandLineArgs); + + /// @brief Whether or not vModules enabled + inline bool vModulesEnabled(void) { + return !base::utils::hasFlag(LoggingFlag::DisableVModules, *m_pFlags); + } + + private: + base::type::VerboseLevel m_level; + base::type::EnumType* m_pFlags; + std::unordered_map m_modules; +}; +} // namespace base +class LogMessage { + public: + LogMessage(Level level, const std::string& file, base::type::LineNumber line, const std::string& func, + base::type::VerboseLevel verboseLevel, Logger* logger) : + m_level(level), m_file(file), m_line(line), m_func(func), + m_verboseLevel(verboseLevel), m_logger(logger), m_message(logger->stream().str()) { + } + inline Level level(void) const { + return m_level; + } + inline const std::string& file(void) const { + return m_file; + } + inline base::type::LineNumber line(void) const { + return m_line; + } + inline const std::string& func(void) const { + return m_func; + } + inline base::type::VerboseLevel verboseLevel(void) const { + return m_verboseLevel; + } + inline Logger* logger(void) const { + return m_logger; + } + inline const base::type::string_t& message(void) const { + return m_message; + } + private: + Level m_level; + std::string m_file; + base::type::LineNumber m_line; + std::string m_func; + base::type::VerboseLevel m_verboseLevel; + Logger* m_logger; + base::type::string_t m_message; +}; +namespace base { +#if ELPP_ASYNC_LOGGING +class AsyncLogItem { + public: + explicit AsyncLogItem(const LogMessage& logMessage, const LogDispatchData& data, const base::type::string_t& logLine) + : m_logMessage(logMessage), m_dispatchData(data), m_logLine(logLine) {} + virtual ~AsyncLogItem() {} + inline LogMessage* logMessage(void) { + return &m_logMessage; + } + inline LogDispatchData* data(void) { + return &m_dispatchData; + } + inline base::type::string_t logLine(void) { + return m_logLine; + } + private: + LogMessage m_logMessage; + LogDispatchData m_dispatchData; + base::type::string_t m_logLine; +}; +class AsyncLogQueue : public base::threading::ThreadSafe { + public: + virtual ~AsyncLogQueue() { + ELPP_INTERNAL_INFO(6, "~AsyncLogQueue"); + } + + inline AsyncLogItem next(void) { + base::threading::ScopedLock scopedLock(lock()); + AsyncLogItem result = m_queue.front(); + m_queue.pop(); + return result; + } + + inline void push(const AsyncLogItem& item) { + base::threading::ScopedLock scopedLock(lock()); + m_queue.push(item); + } + inline void pop(void) { + base::threading::ScopedLock scopedLock(lock()); + m_queue.pop(); + } + inline AsyncLogItem front(void) { + base::threading::ScopedLock scopedLock(lock()); + return m_queue.front(); + } + inline bool empty(void) { + base::threading::ScopedLock scopedLock(lock()); + return m_queue.empty(); + } + private: + std::queue m_queue; +}; +class IWorker { + public: + virtual ~IWorker() {} + virtual void start() = 0; +}; +#endif // ELPP_ASYNC_LOGGING +/// @brief Easylogging++ management storage +class Storage : base::NoCopy, public base::threading::ThreadSafe { + public: +#if ELPP_ASYNC_LOGGING + Storage(const LogBuilderPtr& defaultLogBuilder, base::IWorker* asyncDispatchWorker); +#else + explicit Storage(const LogBuilderPtr& defaultLogBuilder); +#endif // ELPP_ASYNC_LOGGING + + virtual ~Storage(void); + + inline bool validateEveryNCounter(const char* filename, base::type::LineNumber lineNumber, std::size_t occasion) { + return hitCounters()->validateEveryN(filename, lineNumber, occasion); + } + + inline bool validateAfterNCounter(const char* filename, base::type::LineNumber lineNumber, std::size_t n) { + return hitCounters()->validateAfterN(filename, lineNumber, n); + } + + inline bool validateNTimesCounter(const char* filename, base::type::LineNumber lineNumber, std::size_t n) { + return hitCounters()->validateNTimes(filename, lineNumber, n); + } + + inline base::RegisteredHitCounters* hitCounters(void) const { + return m_registeredHitCounters; + } + + inline base::RegisteredLoggers* registeredLoggers(void) const { + return m_registeredLoggers; + } + + inline base::VRegistry* vRegistry(void) const { + return m_vRegistry; + } + +#if ELPP_ASYNC_LOGGING + inline base::AsyncLogQueue* asyncLogQueue(void) const { + return m_asyncLogQueue; + } +#endif // ELPP_ASYNC_LOGGING + + inline const base::utils::CommandLineArgs* commandLineArgs(void) const { + return &m_commandLineArgs; + } + + inline void addFlag(LoggingFlag flag) { + base::utils::addFlag(flag, &m_flags); + } + + inline void removeFlag(LoggingFlag flag) { + base::utils::removeFlag(flag, &m_flags); + } + + inline bool hasFlag(LoggingFlag flag) const { + return base::utils::hasFlag(flag, m_flags); + } + + inline base::type::EnumType flags(void) const { + return m_flags; + } + + inline void setFlags(base::type::EnumType flags) { + m_flags = flags; + } + + inline void setPreRollOutCallback(const PreRollOutCallback& callback) { + m_preRollOutCallback = callback; + } + + inline void unsetPreRollOutCallback(void) { + m_preRollOutCallback = base::defaultPreRollOutCallback; + } + + inline PreRollOutCallback& preRollOutCallback(void) { + return m_preRollOutCallback; + } + + bool hasCustomFormatSpecifier(const char* formatSpecifier); + void installCustomFormatSpecifier(const CustomFormatSpecifier& customFormatSpecifier); + bool uninstallCustomFormatSpecifier(const char* formatSpecifier); + + const std::vector* customFormatSpecifiers(void) const { + return &m_customFormatSpecifiers; + } + + base::threading::Mutex& customFormatSpecifiersLock() { + return m_customFormatSpecifiersLock; + } + + inline void setLoggingLevel(Level level) { + m_loggingLevel = level; + } + + template + inline bool installLogDispatchCallback(const std::string& id) { + return base::utils::Utils::installCallback(id, &m_logDispatchCallbacks); + } + + template + inline void uninstallLogDispatchCallback(const std::string& id) { + base::utils::Utils::uninstallCallback(id, &m_logDispatchCallbacks); + } + template + inline T* logDispatchCallback(const std::string& id) { + return base::utils::Utils::callback(id, &m_logDispatchCallbacks); + } + +#if defined(ELPP_FEATURE_ALL) || defined(ELPP_FEATURE_PERFORMANCE_TRACKING) + template + inline bool installPerformanceTrackingCallback(const std::string& id) { + return base::utils::Utils::installCallback(id, + &m_performanceTrackingCallbacks); + } + + template + inline void uninstallPerformanceTrackingCallback(const std::string& id) { + base::utils::Utils::uninstallCallback(id, + &m_performanceTrackingCallbacks); + } + + template + inline T* performanceTrackingCallback(const std::string& id) { + return base::utils::Utils::callback(id, &m_performanceTrackingCallbacks); + } +#endif // defined(ELPP_FEATURE_ALL) || defined(ELPP_FEATURE_PERFORMANCE_TRACKING) + + /// @brief Sets thread name for current thread. Requires std::thread + inline void setThreadName(const std::string& name) { + if (name.empty()) return; + base::threading::ScopedLock scopedLock(m_threadNamesLock); + m_threadNames[base::threading::getCurrentThreadId()] = name; + } + + inline std::string getThreadName(const std::string& threadId) { + base::threading::ScopedLock scopedLock(m_threadNamesLock); + std::unordered_map::const_iterator it = m_threadNames.find(threadId); + if (it == m_threadNames.end()) { + return threadId; + } + return it->second; + } + private: + base::RegisteredHitCounters* m_registeredHitCounters; + base::RegisteredLoggers* m_registeredLoggers; + base::type::EnumType m_flags; + base::VRegistry* m_vRegistry; +#if ELPP_ASYNC_LOGGING + base::AsyncLogQueue* m_asyncLogQueue; + base::IWorker* m_asyncDispatchWorker; +#endif // ELPP_ASYNC_LOGGING + base::utils::CommandLineArgs m_commandLineArgs; + PreRollOutCallback m_preRollOutCallback; + std::unordered_map m_logDispatchCallbacks; + std::unordered_map m_performanceTrackingCallbacks; + std::unordered_map m_threadNames; + std::vector m_customFormatSpecifiers; + base::threading::Mutex m_customFormatSpecifiersLock; + base::threading::Mutex m_threadNamesLock; + Level m_loggingLevel; + + friend class el::Helpers; + friend class el::base::DefaultLogDispatchCallback; + friend class el::LogBuilder; + friend class el::base::MessageBuilder; + friend class el::base::Writer; + friend class el::base::PerformanceTracker; + friend class el::base::LogDispatcher; + + void setApplicationArguments(int argc, char** argv); + + inline void setApplicationArguments(int argc, const char** argv) { + setApplicationArguments(argc, const_cast(argv)); + } +}; +extern ELPP_EXPORT base::type::StoragePointer elStorage; +#define ELPP el::base::elStorage +class DefaultLogDispatchCallback : public LogDispatchCallback { + protected: + void handle(const LogDispatchData* data); + private: + const LogDispatchData* m_data; + void dispatch(base::type::string_t&& logLine); +}; +#if ELPP_ASYNC_LOGGING +class AsyncLogDispatchCallback : public LogDispatchCallback { + protected: + void handle(const LogDispatchData* data); +}; +class AsyncDispatchWorker : public base::IWorker, public base::threading::ThreadSafe { + public: + AsyncDispatchWorker(); + virtual ~AsyncDispatchWorker(); + + bool clean(void); + void emptyQueue(void); + virtual void start(void); + void handle(AsyncLogItem* logItem); + void run(void); + + void setContinueRunning(bool value) { + base::threading::ScopedLock scopedLock(m_continueRunningLock); + m_continueRunning = value; + } + + bool continueRunning(void) const { + return m_continueRunning; + } + private: + std::condition_variable cv; + bool m_continueRunning; + base::threading::Mutex m_continueRunningLock; +}; +#endif // ELPP_ASYNC_LOGGING +} // namespace base +namespace base { +class DefaultLogBuilder : public LogBuilder { + public: + base::type::string_t build(const LogMessage* logMessage, bool appendNewLine) const; +}; +/// @brief Dispatches log messages +class LogDispatcher : base::NoCopy { + public: + LogDispatcher(bool proceed, LogMessage* logMessage, base::DispatchAction dispatchAction) : + m_proceed(proceed), + m_logMessage(logMessage), + m_dispatchAction(std::move(dispatchAction)) { + } + + void dispatch(void); + + private: + bool m_proceed; + LogMessage* m_logMessage; + base::DispatchAction m_dispatchAction; +}; +#if defined(ELPP_STL_LOGGING) +/// @brief Workarounds to write some STL logs +/// +/// @detail There is workaround needed to loop through some stl containers. In order to do that, we need iterable containers +/// of same type and provide iterator interface and pass it on to writeIterator(). +/// Remember, this is passed by value in constructor so that we dont change original containers. +/// This operation is as expensive as Big-O(std::min(class_.size(), base::consts::kMaxLogPerContainer)) +namespace workarounds { +/// @brief Abstract IterableContainer template that provides interface for iterable classes of type T +template +class IterableContainer { + public: + typedef typename Container::iterator iterator; + typedef typename Container::const_iterator const_iterator; + IterableContainer(void) {} + virtual ~IterableContainer(void) {} + iterator begin(void) { + return getContainer().begin(); + } + iterator end(void) { + return getContainer().end(); + } + private: + virtual Container& getContainer(void) = 0; +}; +/// @brief Implements IterableContainer and provides iterable std::priority_queue class +template, typename Comparator = std::less> +class IterablePriorityQueue : public IterableContainer, + public std::priority_queue { + public: + IterablePriorityQueue(std::priority_queue queue_) { + std::size_t count_ = 0; + while (++count_ < base::consts::kMaxLogPerContainer && !queue_.empty()) { + this->push(queue_.top()); + queue_.pop(); + } + } + private: + inline Container& getContainer(void) { + return this->c; + } +}; +/// @brief Implements IterableContainer and provides iterable std::queue class +template> +class IterableQueue : public IterableContainer, public std::queue { + public: + IterableQueue(std::queue queue_) { + std::size_t count_ = 0; + while (++count_ < base::consts::kMaxLogPerContainer && !queue_.empty()) { + this->push(queue_.front()); + queue_.pop(); + } + } + private: + inline Container& getContainer(void) { + return this->c; + } +}; +/// @brief Implements IterableContainer and provides iterable std::stack class +template> +class IterableStack : public IterableContainer, public std::stack { + public: + IterableStack(std::stack stack_) { + std::size_t count_ = 0; + while (++count_ < base::consts::kMaxLogPerContainer && !stack_.empty()) { + this->push(stack_.top()); + stack_.pop(); + } + } + private: + inline Container& getContainer(void) { + return this->c; + } +}; +} // namespace workarounds +#endif // defined(ELPP_STL_LOGGING) +// Log message builder +class MessageBuilder { + public: + MessageBuilder(void) : m_logger(nullptr), m_containerLogSeparator(ELPP_LITERAL("")) {} + void initialize(Logger* logger); + +# define ELPP_SIMPLE_LOG(LOG_TYPE)\ +MessageBuilder& operator<<(LOG_TYPE msg) {\ +m_logger->stream() << msg;\ +if (ELPP->hasFlag(LoggingFlag::AutoSpacing)) {\ +m_logger->stream() << " ";\ +}\ +return *this;\ +} + + inline MessageBuilder& operator<<(const std::string& msg) { + return operator<<(msg.c_str()); + } + ELPP_SIMPLE_LOG(char) + ELPP_SIMPLE_LOG(bool) + ELPP_SIMPLE_LOG(signed short) + ELPP_SIMPLE_LOG(unsigned short) + ELPP_SIMPLE_LOG(signed int) + ELPP_SIMPLE_LOG(unsigned int) + ELPP_SIMPLE_LOG(signed long) + ELPP_SIMPLE_LOG(unsigned long) + ELPP_SIMPLE_LOG(float) + ELPP_SIMPLE_LOG(double) + ELPP_SIMPLE_LOG(char*) + ELPP_SIMPLE_LOG(const char*) + ELPP_SIMPLE_LOG(const void*) + ELPP_SIMPLE_LOG(long double) + inline MessageBuilder& operator<<(const std::wstring& msg) { + return operator<<(msg.c_str()); + } + MessageBuilder& operator<<(const wchar_t* msg); + // ostream manipulators + inline MessageBuilder& operator<<(std::ostream& (*OStreamMani)(std::ostream&)) { + m_logger->stream() << OStreamMani; + return *this; + } +#define ELPP_ITERATOR_CONTAINER_LOG_ONE_ARG(temp) \ +template \ +inline MessageBuilder& operator<<(const temp& template_inst) { \ +return writeIterator(template_inst.begin(), template_inst.end(), template_inst.size()); \ +} +#define ELPP_ITERATOR_CONTAINER_LOG_TWO_ARG(temp) \ +template \ +inline MessageBuilder& operator<<(const temp& template_inst) { \ +return writeIterator(template_inst.begin(), template_inst.end(), template_inst.size()); \ +} +#define ELPP_ITERATOR_CONTAINER_LOG_THREE_ARG(temp) \ +template \ +inline MessageBuilder& operator<<(const temp& template_inst) { \ +return writeIterator(template_inst.begin(), template_inst.end(), template_inst.size()); \ +} +#define ELPP_ITERATOR_CONTAINER_LOG_FOUR_ARG(temp) \ +template \ +inline MessageBuilder& operator<<(const temp& template_inst) { \ +return writeIterator(template_inst.begin(), template_inst.end(), template_inst.size()); \ +} +#define ELPP_ITERATOR_CONTAINER_LOG_FIVE_ARG(temp) \ +template \ +inline MessageBuilder& operator<<(const temp& template_inst) { \ +return writeIterator(template_inst.begin(), template_inst.end(), template_inst.size()); \ +} + +#if defined(ELPP_STL_LOGGING) + ELPP_ITERATOR_CONTAINER_LOG_TWO_ARG(std::vector) + ELPP_ITERATOR_CONTAINER_LOG_TWO_ARG(std::list) + ELPP_ITERATOR_CONTAINER_LOG_TWO_ARG(std::deque) + ELPP_ITERATOR_CONTAINER_LOG_THREE_ARG(std::set) + ELPP_ITERATOR_CONTAINER_LOG_THREE_ARG(std::multiset) + ELPP_ITERATOR_CONTAINER_LOG_FOUR_ARG(std::map) + ELPP_ITERATOR_CONTAINER_LOG_FOUR_ARG(std::multimap) + template + inline MessageBuilder& operator<<(const std::queue& queue_) { + base::workarounds::IterableQueue iterableQueue_ = + static_cast >(queue_); + return writeIterator(iterableQueue_.begin(), iterableQueue_.end(), iterableQueue_.size()); + } + template + inline MessageBuilder& operator<<(const std::stack& stack_) { + base::workarounds::IterableStack iterableStack_ = + static_cast >(stack_); + return writeIterator(iterableStack_.begin(), iterableStack_.end(), iterableStack_.size()); + } + template + inline MessageBuilder& operator<<(const std::priority_queue& priorityQueue_) { + base::workarounds::IterablePriorityQueue iterablePriorityQueue_ = + static_cast >(priorityQueue_); + return writeIterator(iterablePriorityQueue_.begin(), iterablePriorityQueue_.end(), iterablePriorityQueue_.size()); + } + template + MessageBuilder& operator<<(const std::pair& pair_) { + m_logger->stream() << ELPP_LITERAL("("); + operator << (static_cast(pair_.first)); + m_logger->stream() << ELPP_LITERAL(", "); + operator << (static_cast(pair_.second)); + m_logger->stream() << ELPP_LITERAL(")"); + return *this; + } + template + MessageBuilder& operator<<(const std::bitset& bitset_) { + m_logger->stream() << ELPP_LITERAL("["); + operator << (bitset_.to_string()); + m_logger->stream() << ELPP_LITERAL("]"); + return *this; + } +# if defined(ELPP_LOG_STD_ARRAY) + template + inline MessageBuilder& operator<<(const std::array& array) { + return writeIterator(array.begin(), array.end(), array.size()); + } +# endif // defined(ELPP_LOG_STD_ARRAY) +# if defined(ELPP_LOG_UNORDERED_MAP) + ELPP_ITERATOR_CONTAINER_LOG_FIVE_ARG(std::unordered_map) + ELPP_ITERATOR_CONTAINER_LOG_FIVE_ARG(std::unordered_multimap) +# endif // defined(ELPP_LOG_UNORDERED_MAP) +# if defined(ELPP_LOG_UNORDERED_SET) + ELPP_ITERATOR_CONTAINER_LOG_FOUR_ARG(std::unordered_set) + ELPP_ITERATOR_CONTAINER_LOG_FOUR_ARG(std::unordered_multiset) +# endif // defined(ELPP_LOG_UNORDERED_SET) +#endif // defined(ELPP_STL_LOGGING) +#if defined(ELPP_QT_LOGGING) + inline MessageBuilder& operator<<(const QString& msg) { +# if defined(ELPP_UNICODE) + m_logger->stream() << msg.toStdWString(); +# else + m_logger->stream() << msg.toStdString(); +# endif // defined(ELPP_UNICODE) + return *this; + } + inline MessageBuilder& operator<<(const QByteArray& msg) { + return operator << (QString(msg)); + } + inline MessageBuilder& operator<<(const QStringRef& msg) { + return operator<<(msg.toString()); + } + inline MessageBuilder& operator<<(qint64 msg) { +# if defined(ELPP_UNICODE) + m_logger->stream() << QString::number(msg).toStdWString(); +# else + m_logger->stream() << QString::number(msg).toStdString(); +# endif // defined(ELPP_UNICODE) + return *this; + } + inline MessageBuilder& operator<<(quint64 msg) { +# if defined(ELPP_UNICODE) + m_logger->stream() << QString::number(msg).toStdWString(); +# else + m_logger->stream() << QString::number(msg).toStdString(); +# endif // defined(ELPP_UNICODE) + return *this; + } + inline MessageBuilder& operator<<(QChar msg) { + m_logger->stream() << msg.toLatin1(); + return *this; + } + inline MessageBuilder& operator<<(const QLatin1String& msg) { + m_logger->stream() << msg.latin1(); + return *this; + } + ELPP_ITERATOR_CONTAINER_LOG_ONE_ARG(QList) + ELPP_ITERATOR_CONTAINER_LOG_ONE_ARG(QVector) + ELPP_ITERATOR_CONTAINER_LOG_ONE_ARG(QQueue) + ELPP_ITERATOR_CONTAINER_LOG_ONE_ARG(QSet) + ELPP_ITERATOR_CONTAINER_LOG_ONE_ARG(QLinkedList) + ELPP_ITERATOR_CONTAINER_LOG_ONE_ARG(QStack) + template + MessageBuilder& operator<<(const QPair& pair_) { + m_logger->stream() << ELPP_LITERAL("("); + operator << (static_cast(pair_.first)); + m_logger->stream() << ELPP_LITERAL(", "); + operator << (static_cast(pair_.second)); + m_logger->stream() << ELPP_LITERAL(")"); + return *this; + } + template + MessageBuilder& operator<<(const QMap& map_) { + m_logger->stream() << ELPP_LITERAL("["); + QList keys = map_.keys(); + typename QList::const_iterator begin = keys.begin(); + typename QList::const_iterator end = keys.end(); + int max_ = static_cast(base::consts::kMaxLogPerContainer); // to prevent warning + for (int index_ = 0; begin != end && index_ < max_; ++index_, ++begin) { + m_logger->stream() << ELPP_LITERAL("("); + operator << (static_cast(*begin)); + m_logger->stream() << ELPP_LITERAL(", "); + operator << (static_cast(map_.value(*begin))); + m_logger->stream() << ELPP_LITERAL(")"); + m_logger->stream() << ((index_ < keys.size() -1) ? m_containerLogSeparator : ELPP_LITERAL("")); + } + if (begin != end) { + m_logger->stream() << ELPP_LITERAL("..."); + } + m_logger->stream() << ELPP_LITERAL("]"); + return *this; + } + template + inline MessageBuilder& operator<<(const QMultiMap& map_) { + operator << (static_cast>(map_)); + return *this; + } + template + MessageBuilder& operator<<(const QHash& hash_) { + m_logger->stream() << ELPP_LITERAL("["); + QList keys = hash_.keys(); + typename QList::const_iterator begin = keys.begin(); + typename QList::const_iterator end = keys.end(); + int max_ = static_cast(base::consts::kMaxLogPerContainer); // prevent type warning + for (int index_ = 0; begin != end && index_ < max_; ++index_, ++begin) { + m_logger->stream() << ELPP_LITERAL("("); + operator << (static_cast(*begin)); + m_logger->stream() << ELPP_LITERAL(", "); + operator << (static_cast(hash_.value(*begin))); + m_logger->stream() << ELPP_LITERAL(")"); + m_logger->stream() << ((index_ < keys.size() -1) ? m_containerLogSeparator : ELPP_LITERAL("")); + } + if (begin != end) { + m_logger->stream() << ELPP_LITERAL("..."); + } + m_logger->stream() << ELPP_LITERAL("]"); + return *this; + } + template + inline MessageBuilder& operator<<(const QMultiHash& multiHash_) { + operator << (static_cast>(multiHash_)); + return *this; + } +#endif // defined(ELPP_QT_LOGGING) +#if defined(ELPP_BOOST_LOGGING) + ELPP_ITERATOR_CONTAINER_LOG_TWO_ARG(boost::container::vector) + ELPP_ITERATOR_CONTAINER_LOG_TWO_ARG(boost::container::stable_vector) + ELPP_ITERATOR_CONTAINER_LOG_TWO_ARG(boost::container::list) + ELPP_ITERATOR_CONTAINER_LOG_TWO_ARG(boost::container::deque) + ELPP_ITERATOR_CONTAINER_LOG_FOUR_ARG(boost::container::map) + ELPP_ITERATOR_CONTAINER_LOG_FOUR_ARG(boost::container::flat_map) + ELPP_ITERATOR_CONTAINER_LOG_THREE_ARG(boost::container::set) + ELPP_ITERATOR_CONTAINER_LOG_THREE_ARG(boost::container::flat_set) +#endif // defined(ELPP_BOOST_LOGGING) + + /// @brief Macro used internally that can be used externally to make containers easylogging++ friendly + /// + /// @detail This macro expands to write an ostream& operator<< for container. This container is expected to + /// have begin() and end() methods that return respective iterators + /// @param ContainerType Type of container e.g, MyList from WX_DECLARE_LIST(int, MyList); in wxwidgets + /// @param SizeMethod Method used to get size of container. + /// @param ElementInstance Instance of element to be fed out. Instance name is "elem". See WXELPP_ENABLED macro + /// for an example usage +#define MAKE_CONTAINERELPP_FRIENDLY(ContainerType, SizeMethod, ElementInstance) \ +el::base::type::ostream_t& operator<<(el::base::type::ostream_t& ss, const ContainerType& container) {\ +const el::base::type::char_t* sep = ELPP->hasFlag(el::LoggingFlag::NewLineForContainer) ? \ +ELPP_LITERAL("\n ") : ELPP_LITERAL(", ");\ +ContainerType::const_iterator elem = container.begin();\ +ContainerType::const_iterator endElem = container.end();\ +std::size_t size_ = container.SizeMethod; \ +ss << ELPP_LITERAL("[");\ +for (std::size_t i = 0; elem != endElem && i < el::base::consts::kMaxLogPerContainer; ++i, ++elem) { \ +ss << ElementInstance;\ +ss << ((i < size_ - 1) ? sep : ELPP_LITERAL(""));\ +}\ +if (elem != endElem) {\ +ss << ELPP_LITERAL("...");\ +}\ +ss << ELPP_LITERAL("]");\ +return ss;\ +} +#if defined(ELPP_WXWIDGETS_LOGGING) + ELPP_ITERATOR_CONTAINER_LOG_ONE_ARG(wxVector) +# define ELPP_WX_PTR_ENABLED(ContainerType) MAKE_CONTAINERELPP_FRIENDLY(ContainerType, size(), *(*elem)) +# define ELPP_WX_ENABLED(ContainerType) MAKE_CONTAINERELPP_FRIENDLY(ContainerType, size(), (*elem)) +# define ELPP_WX_HASH_MAP_ENABLED(ContainerType) MAKE_CONTAINERELPP_FRIENDLY(ContainerType, size(), \ +ELPP_LITERAL("(") << elem->first << ELPP_LITERAL(", ") << elem->second << ELPP_LITERAL(")") +#else +# define ELPP_WX_PTR_ENABLED(ContainerType) +# define ELPP_WX_ENABLED(ContainerType) +# define ELPP_WX_HASH_MAP_ENABLED(ContainerType) +#endif // defined(ELPP_WXWIDGETS_LOGGING) + // Other classes + template + ELPP_SIMPLE_LOG(const Class&) +#undef ELPP_SIMPLE_LOG +#undef ELPP_ITERATOR_CONTAINER_LOG_ONE_ARG +#undef ELPP_ITERATOR_CONTAINER_LOG_TWO_ARG +#undef ELPP_ITERATOR_CONTAINER_LOG_THREE_ARG +#undef ELPP_ITERATOR_CONTAINER_LOG_FOUR_ARG +#undef ELPP_ITERATOR_CONTAINER_LOG_FIVE_ARG + private: + Logger* m_logger; + const base::type::char_t* m_containerLogSeparator; + + template + MessageBuilder& writeIterator(Iterator begin_, Iterator end_, std::size_t size_) { + m_logger->stream() << ELPP_LITERAL("["); + for (std::size_t i = 0; begin_ != end_ && i < base::consts::kMaxLogPerContainer; ++i, ++begin_) { + operator << (*begin_); + m_logger->stream() << ((i < size_ - 1) ? m_containerLogSeparator : ELPP_LITERAL("")); + } + if (begin_ != end_) { + m_logger->stream() << ELPP_LITERAL("..."); + } + m_logger->stream() << ELPP_LITERAL("]"); + if (ELPP->hasFlag(LoggingFlag::AutoSpacing)) { + m_logger->stream() << " "; + } + return *this; + } +}; +/// @brief Writes nothing - Used when certain log is disabled +class NullWriter : base::NoCopy { + public: + NullWriter(void) {} + + // Null manipulator + inline NullWriter& operator<<(std::ostream& (*)(std::ostream&)) { + return *this; + } + + template + inline NullWriter& operator<<(const T&) { + return *this; + } + + inline operator bool() { + return true; + } +}; +/// @brief Main entry point of each logging +class Writer : base::NoCopy { + public: + Writer(Level level, const char* file, base::type::LineNumber line, + const char* func, base::DispatchAction dispatchAction = base::DispatchAction::NormalLog, + base::type::VerboseLevel verboseLevel = 0) : + m_msg(nullptr), m_level(level), m_file(file), m_line(line), m_func(func), m_verboseLevel(verboseLevel), + m_logger(nullptr), m_proceed(false), m_dispatchAction(dispatchAction) { + } + + Writer(LogMessage* msg, base::DispatchAction dispatchAction = base::DispatchAction::NormalLog) : + m_msg(msg), m_level(msg != nullptr ? msg->level() : Level::Unknown), + m_line(0), m_logger(nullptr), m_proceed(false), m_dispatchAction(dispatchAction) { + } + + virtual ~Writer(void) { + processDispatch(); + } + + template + inline Writer& operator<<(const T& log) { +#if ELPP_LOGGING_ENABLED + if (m_proceed) { + m_messageBuilder << log; + } +#endif // ELPP_LOGGING_ENABLED + return *this; + } + + inline Writer& operator<<(std::ostream& (*log)(std::ostream&)) { +#if ELPP_LOGGING_ENABLED + if (m_proceed) { + m_messageBuilder << log; + } +#endif // ELPP_LOGGING_ENABLED + return *this; + } + + inline operator bool() { + return true; + } + + Writer& construct(Logger* logger, bool needLock = true); + Writer& construct(int count, const char* loggerIds, ...); + protected: + LogMessage* m_msg; + Level m_level; + const char* m_file; + const base::type::LineNumber m_line; + const char* m_func; + base::type::VerboseLevel m_verboseLevel; + Logger* m_logger; + bool m_proceed; + base::MessageBuilder m_messageBuilder; + base::DispatchAction m_dispatchAction; + std::vector m_loggerIds; + friend class el::Helpers; + + void initializeLogger(const std::string& loggerId, bool lookup = true, bool needLock = true); + void processDispatch(); + void triggerDispatch(void); +}; +class PErrorWriter : public base::Writer { + public: + PErrorWriter(Level level, const char* file, base::type::LineNumber line, + const char* func, base::DispatchAction dispatchAction = base::DispatchAction::NormalLog, + base::type::VerboseLevel verboseLevel = 0) : + base::Writer(level, file, line, func, dispatchAction, verboseLevel) { + } + + virtual ~PErrorWriter(void); +}; +} // namespace base +// Logging from Logger class. Why this is here? Because we have Storage and Writer class available +#if ELPP_VARIADIC_TEMPLATES_SUPPORTED +template +void Logger::log_(Level level, int vlevel, const char* s, const T& value, const Args&... args) { + base::MessageBuilder b; + b.initialize(this); + while (*s) { + if (*s == base::consts::kFormatSpecifierChar) { + if (*(s + 1) == base::consts::kFormatSpecifierChar) { + ++s; + } else { + if (*(s + 1) == base::consts::kFormatSpecifierCharValue) { + ++s; + b << value; + log_(level, vlevel, ++s, args...); + return; + } + } + } + b << *s++; + } + ELPP_INTERNAL_ERROR("Too many arguments provided. Unable to handle. Please provide more format specifiers", false); +} +template +void Logger::log_(Level level, int vlevel, const T& log) { + if (level == Level::Verbose) { + if (ELPP->vRegistry()->allowed(vlevel, __FILE__)) { + base::Writer(Level::Verbose, "FILE", 0, "FUNCTION", + base::DispatchAction::NormalLog, vlevel).construct(this, false) << log; + } else { + stream().str(ELPP_LITERAL("")); + releaseLock(); + } + } else { + base::Writer(level, "FILE", 0, "FUNCTION").construct(this, false) << log; + } +} +template +inline void Logger::log(Level level, const char* s, const T& value, const Args&... args) { + acquireLock(); // released in Writer! + log_(level, 0, s, value, args...); +} +template +inline void Logger::log(Level level, const T& log) { + acquireLock(); // released in Writer! + log_(level, 0, log); +} +# if ELPP_VERBOSE_LOG +template +inline void Logger::verbose(int vlevel, const char* s, const T& value, const Args&... args) { + acquireLock(); // released in Writer! + log_(el::Level::Verbose, vlevel, s, value, args...); +} +template +inline void Logger::verbose(int vlevel, const T& log) { + acquireLock(); // released in Writer! + log_(el::Level::Verbose, vlevel, log); +} +# else +template +inline void Logger::verbose(int, const char*, const T&, const Args&...) { + return; +} +template +inline void Logger::verbose(int, const T&) { + return; +} +# endif // ELPP_VERBOSE_LOG +# define LOGGER_LEVEL_WRITERS(FUNCTION_NAME, LOG_LEVEL)\ +template \ +inline void Logger::FUNCTION_NAME(const char* s, const T& value, const Args&... args) {\ +log(LOG_LEVEL, s, value, args...);\ +}\ +template \ +inline void Logger::FUNCTION_NAME(const T& value) {\ +log(LOG_LEVEL, value);\ +} +# define LOGGER_LEVEL_WRITERS_DISABLED(FUNCTION_NAME, LOG_LEVEL)\ +template \ +inline void Logger::FUNCTION_NAME(const char*, const T&, const Args&...) {\ +return;\ +}\ +template \ +inline void Logger::FUNCTION_NAME(const T&) {\ +return;\ +} + +# if ELPP_INFO_LOG +LOGGER_LEVEL_WRITERS(info, Level::Info) +# else +LOGGER_LEVEL_WRITERS_DISABLED(info, Level::Info) +# endif // ELPP_INFO_LOG +# if ELPP_DEBUG_LOG +LOGGER_LEVEL_WRITERS(debug, Level::Debug) +# else +LOGGER_LEVEL_WRITERS_DISABLED(debug, Level::Debug) +# endif // ELPP_DEBUG_LOG +# if ELPP_WARNING_LOG +LOGGER_LEVEL_WRITERS(warn, Level::Warning) +# else +LOGGER_LEVEL_WRITERS_DISABLED(warn, Level::Warning) +# endif // ELPP_WARNING_LOG +# if ELPP_ERROR_LOG +LOGGER_LEVEL_WRITERS(error, Level::Error) +# else +LOGGER_LEVEL_WRITERS_DISABLED(error, Level::Error) +# endif // ELPP_ERROR_LOG +# if ELPP_FATAL_LOG +LOGGER_LEVEL_WRITERS(fatal, Level::Fatal) +# else +LOGGER_LEVEL_WRITERS_DISABLED(fatal, Level::Fatal) +# endif // ELPP_FATAL_LOG +# if ELPP_TRACE_LOG +LOGGER_LEVEL_WRITERS(trace, Level::Trace) +# else +LOGGER_LEVEL_WRITERS_DISABLED(trace, Level::Trace) +# endif // ELPP_TRACE_LOG +# undef LOGGER_LEVEL_WRITERS +# undef LOGGER_LEVEL_WRITERS_DISABLED +#endif // ELPP_VARIADIC_TEMPLATES_SUPPORTED +#if ELPP_COMPILER_MSVC +# define ELPP_VARIADIC_FUNC_MSVC(variadicFunction, variadicArgs) variadicFunction variadicArgs +# define ELPP_VARIADIC_FUNC_MSVC_RUN(variadicFunction, ...) ELPP_VARIADIC_FUNC_MSVC(variadicFunction, (__VA_ARGS__)) +# define el_getVALength(...) ELPP_VARIADIC_FUNC_MSVC_RUN(el_resolveVALength, 0, ## __VA_ARGS__,\ +10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0) +#else +# if ELPP_COMPILER_CLANG +# define el_getVALength(...) el_resolveVALength(0, __VA_ARGS__, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0) +# else +# define el_getVALength(...) el_resolveVALength(0, ## __VA_ARGS__, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0) +# endif // ELPP_COMPILER_CLANG +#endif // ELPP_COMPILER_MSVC +#define el_resolveVALength(_0, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, N, ...) N +#define ELPP_WRITE_LOG(writer, level, dispatchAction, ...) \ +writer(level, __FILE__, __LINE__, ELPP_FUNC, dispatchAction).construct(el_getVALength(__VA_ARGS__), __VA_ARGS__) +#define ELPP_WRITE_LOG_IF(writer, condition, level, dispatchAction, ...) if (condition) \ +writer(level, __FILE__, __LINE__, ELPP_FUNC, dispatchAction).construct(el_getVALength(__VA_ARGS__), __VA_ARGS__) +#define ELPP_WRITE_LOG_EVERY_N(writer, occasion, level, dispatchAction, ...) \ +ELPP->validateEveryNCounter(__FILE__, __LINE__, occasion) && \ +writer(level, __FILE__, __LINE__, ELPP_FUNC, dispatchAction).construct(el_getVALength(__VA_ARGS__), __VA_ARGS__) +#define ELPP_WRITE_LOG_AFTER_N(writer, n, level, dispatchAction, ...) \ +ELPP->validateAfterNCounter(__FILE__, __LINE__, n) && \ +writer(level, __FILE__, __LINE__, ELPP_FUNC, dispatchAction).construct(el_getVALength(__VA_ARGS__), __VA_ARGS__) +#define ELPP_WRITE_LOG_N_TIMES(writer, n, level, dispatchAction, ...) \ +ELPP->validateNTimesCounter(__FILE__, __LINE__, n) && \ +writer(level, __FILE__, __LINE__, ELPP_FUNC, dispatchAction).construct(el_getVALength(__VA_ARGS__), __VA_ARGS__) +#if defined(ELPP_FEATURE_ALL) || defined(ELPP_FEATURE_PERFORMANCE_TRACKING) +class PerformanceTrackingData { + public: + enum class DataType : base::type::EnumType { + Checkpoint = 1, Complete = 2 + }; + // Do not use constructor, will run into multiple definition error, use init(PerformanceTracker*) + explicit PerformanceTrackingData(DataType dataType) : m_performanceTracker(nullptr), + m_dataType(dataType), m_firstCheckpoint(false), m_file(""), m_line(0), m_func("") {} + inline const std::string* blockName(void) const; + inline const struct timeval* startTime(void) const; + inline const struct timeval* endTime(void) const; + inline const struct timeval* lastCheckpointTime(void) const; + inline const base::PerformanceTracker* performanceTracker(void) const { + return m_performanceTracker; + } + inline PerformanceTrackingData::DataType dataType(void) const { + return m_dataType; + } + inline bool firstCheckpoint(void) const { + return m_firstCheckpoint; + } + inline std::string checkpointId(void) const { + return m_checkpointId; + } + inline const char* file(void) const { + return m_file; + } + inline base::type::LineNumber line(void) const { + return m_line; + } + inline const char* func(void) const { + return m_func; + } + inline const base::type::string_t* formattedTimeTaken() const { + return &m_formattedTimeTaken; + } + inline const std::string& loggerId(void) const; + private: + base::PerformanceTracker* m_performanceTracker; + base::type::string_t m_formattedTimeTaken; + PerformanceTrackingData::DataType m_dataType; + bool m_firstCheckpoint; + std::string m_checkpointId; + const char* m_file; + base::type::LineNumber m_line; + const char* m_func; + inline void init(base::PerformanceTracker* performanceTracker, bool firstCheckpoint = false) { + m_performanceTracker = performanceTracker; + m_firstCheckpoint = firstCheckpoint; + } + + friend class el::base::PerformanceTracker; +}; +namespace base { +/// @brief Represents performanceTracker block of code that conditionally adds performance status to log +/// either when goes outside the scope of when checkpoint() is called +class PerformanceTracker : public base::threading::ThreadSafe, public Loggable { + public: + PerformanceTracker(const std::string& blockName, + base::TimestampUnit timestampUnit = base::TimestampUnit::Millisecond, + const std::string& loggerId = std::string(el::base::consts::kPerformanceLoggerId), + bool scopedLog = true, Level level = base::consts::kPerformanceTrackerDefaultLevel); + /// @brief Copy constructor + PerformanceTracker(const PerformanceTracker& t) : + m_blockName(t.m_blockName), m_timestampUnit(t.m_timestampUnit), m_loggerId(t.m_loggerId), m_scopedLog(t.m_scopedLog), + m_level(t.m_level), m_hasChecked(t.m_hasChecked), m_lastCheckpointId(t.m_lastCheckpointId), m_enabled(t.m_enabled), + m_startTime(t.m_startTime), m_endTime(t.m_endTime), m_lastCheckpointTime(t.m_lastCheckpointTime) { + } + virtual ~PerformanceTracker(void); + /// @brief A checkpoint for current performanceTracker block. + void checkpoint(const std::string& id = std::string(), const char* file = __FILE__, + base::type::LineNumber line = __LINE__, + const char* func = ""); + inline Level level(void) const { + return m_level; + } + private: + std::string m_blockName; + base::TimestampUnit m_timestampUnit; + std::string m_loggerId; + bool m_scopedLog; + Level m_level; + bool m_hasChecked; + std::string m_lastCheckpointId; + bool m_enabled; + struct timeval m_startTime, m_endTime, m_lastCheckpointTime; + + PerformanceTracker(void); + + friend class el::PerformanceTrackingData; + friend class base::DefaultPerformanceTrackingCallback; + + const inline base::type::string_t getFormattedTimeTaken() const { + return getFormattedTimeTaken(m_startTime); + } + + const base::type::string_t getFormattedTimeTaken(struct timeval startTime) const; + + virtual inline void log(el::base::type::ostream_t& os) const { + os << getFormattedTimeTaken(); + } +}; +class DefaultPerformanceTrackingCallback : public PerformanceTrackingCallback { + protected: + void handle(const PerformanceTrackingData* data) { + m_data = data; + base::type::stringstream_t ss; + if (m_data->dataType() == PerformanceTrackingData::DataType::Complete) { + ss << ELPP_LITERAL("Executed [") << m_data->blockName()->c_str() << ELPP_LITERAL("] in [") << + *m_data->formattedTimeTaken() << ELPP_LITERAL("]"); + } else { + ss << ELPP_LITERAL("Performance checkpoint"); + if (!m_data->checkpointId().empty()) { + ss << ELPP_LITERAL(" [") << m_data->checkpointId().c_str() << ELPP_LITERAL("]"); + } + ss << ELPP_LITERAL(" for block [") << m_data->blockName()->c_str() << ELPP_LITERAL("] : [") << + *m_data->performanceTracker(); + if (!ELPP->hasFlag(LoggingFlag::DisablePerformanceTrackingCheckpointComparison) + && m_data->performanceTracker()->m_hasChecked) { + ss << ELPP_LITERAL(" ([") << *m_data->formattedTimeTaken() << ELPP_LITERAL("] from "); + if (m_data->performanceTracker()->m_lastCheckpointId.empty()) { + ss << ELPP_LITERAL("last checkpoint"); + } else { + ss << ELPP_LITERAL("checkpoint '") << m_data->performanceTracker()->m_lastCheckpointId.c_str() << ELPP_LITERAL("'"); + } + ss << ELPP_LITERAL(")]"); + } else { + ss << ELPP_LITERAL("]"); + } + } + el::base::Writer(m_data->performanceTracker()->level(), m_data->file(), m_data->line(), m_data->func()).construct(1, + m_data->loggerId().c_str()) << ss.str(); + } + private: + const PerformanceTrackingData* m_data; +}; +} // namespace base +inline const std::string* PerformanceTrackingData::blockName() const { + return const_cast(&m_performanceTracker->m_blockName); +} +inline const struct timeval* PerformanceTrackingData::startTime() const { + return const_cast(&m_performanceTracker->m_startTime); +} +inline const struct timeval* PerformanceTrackingData::endTime() const { + return const_cast(&m_performanceTracker->m_endTime); +} +inline const struct timeval* PerformanceTrackingData::lastCheckpointTime() const { + return const_cast(&m_performanceTracker->m_lastCheckpointTime); +} +inline const std::string& PerformanceTrackingData::loggerId(void) const { + return m_performanceTracker->m_loggerId; +} +#endif // defined(ELPP_FEATURE_ALL) || defined(ELPP_FEATURE_PERFORMANCE_TRACKING) +namespace base { +/// @brief Contains some internal debugging tools like crash handler and stack tracer +namespace debug { +#if defined(ELPP_FEATURE_ALL) || defined(ELPP_FEATURE_CRASH_LOG) +class StackTrace : base::NoCopy { + public: + static const unsigned int kMaxStack = 64; + static const unsigned int kStackStart = 2; // We want to skip c'tor and StackTrace::generateNew() + class StackTraceEntry { + public: + StackTraceEntry(std::size_t index, const std::string& loc, const std::string& demang, const std::string& hex, + const std::string& addr); + StackTraceEntry(std::size_t index, const std::string& loc) : + m_index(index), + m_location(loc) { + } + std::size_t m_index; + std::string m_location; + std::string m_demangled; + std::string m_hex; + std::string m_addr; + friend std::ostream& operator<<(std::ostream& ss, const StackTraceEntry& si); + + private: + StackTraceEntry(void); + }; + + StackTrace(void) { + generateNew(); + } + + virtual ~StackTrace(void) { + } + + inline std::vector& getLatestStack(void) { + return m_stack; + } + + friend std::ostream& operator<<(std::ostream& os, const StackTrace& st); + + private: + std::vector m_stack; + + void generateNew(void); +}; +/// @brief Handles unexpected crashes +class CrashHandler : base::NoCopy { + public: + typedef void (*Handler)(int); + + explicit CrashHandler(bool useDefault); + explicit CrashHandler(const Handler& cHandler) { + setHandler(cHandler); + } + void setHandler(const Handler& cHandler); + + private: + Handler m_handler; +}; +#else +class CrashHandler { + public: + explicit CrashHandler(bool) {} +}; +#endif // defined(ELPP_FEATURE_ALL) || defined(ELPP_FEATURE_CRASH_LOG) +} // namespace debug +} // namespace base +extern base::debug::CrashHandler elCrashHandler; +#define MAKE_LOGGABLE(ClassType, ClassInstance, OutputStreamInstance) \ +el::base::type::ostream_t& operator<<(el::base::type::ostream_t& OutputStreamInstance, const ClassType& ClassInstance) +/// @brief Initializes syslog with process ID, options and facility. calls closelog() on d'tor +class SysLogInitializer { + public: + SysLogInitializer(const char* processIdent, int options = 0, int facility = 0) { +#if defined(ELPP_SYSLOG) + (void)base::consts::kSysLogLoggerId; + openlog(processIdent, options, facility); +#else + ELPP_UNUSED(processIdent); + ELPP_UNUSED(options); + ELPP_UNUSED(facility); +#endif // defined(ELPP_SYSLOG) + } + virtual ~SysLogInitializer(void) { +#if defined(ELPP_SYSLOG) + closelog(); +#endif // defined(ELPP_SYSLOG) + } +}; +#define ELPP_INITIALIZE_SYSLOG(id, opt, fac) el::SysLogInitializer elSyslogInit(id, opt, fac) +/// @brief Static helpers for developers +class Helpers : base::StaticClass { + public: + /// @brief Shares logging repository (base::Storage) + static inline void setStorage(base::type::StoragePointer storage) { + ELPP = storage; + } + /// @return Main storage repository + static inline base::type::StoragePointer storage() { + return ELPP; + } + /// @brief Sets application arguments and figures out whats active for logging and whats not. + static inline void setArgs(int argc, char** argv) { + ELPP->setApplicationArguments(argc, argv); + } + /// @copydoc setArgs(int argc, char** argv) + static inline void setArgs(int argc, const char** argv) { + ELPP->setApplicationArguments(argc, const_cast(argv)); + } + /// @brief Sets thread name for current thread. Requires std::thread + static inline void setThreadName(const std::string& name) { + ELPP->setThreadName(name); + } + static inline std::string getThreadName() { + return ELPP->getThreadName(base::threading::getCurrentThreadId()); + } +#if defined(ELPP_FEATURE_ALL) || defined(ELPP_FEATURE_CRASH_LOG) + /// @brief Overrides default crash handler and installs custom handler. + /// @param crashHandler A functor with no return type that takes single int argument. + /// Handler is a typedef with specification: void (*Handler)(int) + static inline void setCrashHandler(const el::base::debug::CrashHandler::Handler& crashHandler) { + el::elCrashHandler.setHandler(crashHandler); + } + /// @brief Abort due to crash with signal in parameter + /// @param sig Crash signal + static void crashAbort(int sig, const char* sourceFile = "", unsigned int long line = 0); + /// @brief Logs reason of crash as per sig + /// @param sig Crash signal + /// @param stackTraceIfAvailable Includes stack trace if available + /// @param level Logging level + /// @param logger Logger to use for logging + static void logCrashReason(int sig, bool stackTraceIfAvailable = false, + Level level = Level::Fatal, const char* logger = base::consts::kDefaultLoggerId); +#endif // defined(ELPP_FEATURE_ALL) || defined(ELPP_FEATURE_CRASH_LOG) + /// @brief Installs pre rollout callback, this callback is triggered when log file is about to be rolled out + /// (can be useful for backing up) + static inline void installPreRollOutCallback(const PreRollOutCallback& callback) { + ELPP->setPreRollOutCallback(callback); + } + /// @brief Uninstalls pre rollout callback + static inline void uninstallPreRollOutCallback(void) { + ELPP->unsetPreRollOutCallback(); + } + /// @brief Installs post log dispatch callback, this callback is triggered when log is dispatched + template + static inline bool installLogDispatchCallback(const std::string& id) { + return ELPP->installLogDispatchCallback(id); + } + /// @brief Uninstalls log dispatch callback + template + static inline void uninstallLogDispatchCallback(const std::string& id) { + ELPP->uninstallLogDispatchCallback(id); + } + template + static inline T* logDispatchCallback(const std::string& id) { + return ELPP->logDispatchCallback(id); + } +#if defined(ELPP_FEATURE_ALL) || defined(ELPP_FEATURE_PERFORMANCE_TRACKING) + /// @brief Installs post performance tracking callback, this callback is triggered when performance tracking is finished + template + static inline bool installPerformanceTrackingCallback(const std::string& id) { + return ELPP->installPerformanceTrackingCallback(id); + } + /// @brief Uninstalls post performance tracking handler + template + static inline void uninstallPerformanceTrackingCallback(const std::string& id) { + ELPP->uninstallPerformanceTrackingCallback(id); + } + template + static inline T* performanceTrackingCallback(const std::string& id) { + return ELPP->performanceTrackingCallback(id); + } +#endif // defined(ELPP_FEATURE_ALL) || defined(ELPP_FEATURE_PERFORMANCE_TRACKING) + /// @brief Converts template to std::string - useful for loggable classes to log containers within log(std::ostream&) const + template + static std::string convertTemplateToStdString(const T& templ) { + el::Logger* logger = + ELPP->registeredLoggers()->get(el::base::consts::kDefaultLoggerId); + if (logger == nullptr) { + return std::string(); + } + base::MessageBuilder b; + b.initialize(logger); + logger->acquireLock(); + b << templ; +#if defined(ELPP_UNICODE) + std::string s = std::string(logger->stream().str().begin(), logger->stream().str().end()); +#else + std::string s = logger->stream().str(); +#endif // defined(ELPP_UNICODE) + logger->stream().str(ELPP_LITERAL("")); + logger->releaseLock(); + return s; + } + /// @brief Returns command line arguments (pointer) provided to easylogging++ + static inline const el::base::utils::CommandLineArgs* commandLineArgs(void) { + return ELPP->commandLineArgs(); + } + /// @brief Reserve space for custom format specifiers for performance + /// @see std::vector::reserve + static inline void reserveCustomFormatSpecifiers(std::size_t size) { + ELPP->m_customFormatSpecifiers.reserve(size); + } + /// @brief Installs user defined format specifier and handler + static inline void installCustomFormatSpecifier(const CustomFormatSpecifier& customFormatSpecifier) { + ELPP->installCustomFormatSpecifier(customFormatSpecifier); + } + /// @brief Uninstalls user defined format specifier and handler + static inline bool uninstallCustomFormatSpecifier(const char* formatSpecifier) { + return ELPP->uninstallCustomFormatSpecifier(formatSpecifier); + } + /// @brief Returns true if custom format specifier is installed + static inline bool hasCustomFormatSpecifier(const char* formatSpecifier) { + return ELPP->hasCustomFormatSpecifier(formatSpecifier); + } + static inline void validateFileRolling(Logger* logger, Level level) { + if (ELPP == nullptr || logger == nullptr) return; + logger->m_typedConfigurations->validateFileRolling(level, ELPP->preRollOutCallback()); + } +}; +/// @brief Static helpers to deal with loggers and their configurations +class Loggers : base::StaticClass { + public: + /// @brief Gets existing or registers new logger + static Logger* getLogger(const std::string& identity, bool registerIfNotAvailable = true); + /// @brief Changes default log builder for future loggers + static void setDefaultLogBuilder(el::LogBuilderPtr& logBuilderPtr); + /// @brief Installs logger registration callback, this callback is triggered when new logger is registered + template + static inline bool installLoggerRegistrationCallback(const std::string& id) { + return ELPP->registeredLoggers()->installLoggerRegistrationCallback(id); + } + /// @brief Uninstalls log dispatch callback + template + static inline void uninstallLoggerRegistrationCallback(const std::string& id) { + ELPP->registeredLoggers()->uninstallLoggerRegistrationCallback(id); + } + template + static inline T* loggerRegistrationCallback(const std::string& id) { + return ELPP->registeredLoggers()->loggerRegistrationCallback(id); + } + /// @brief Unregisters logger - use it only when you know what you are doing, you may unregister + /// loggers initialized / used by third-party libs. + static bool unregisterLogger(const std::string& identity); + /// @brief Whether or not logger with id is registered + static bool hasLogger(const std::string& identity); + /// @brief Reconfigures specified logger with new configurations + static Logger* reconfigureLogger(Logger* logger, const Configurations& configurations); + /// @brief Reconfigures logger with new configurations after looking it up using identity + static Logger* reconfigureLogger(const std::string& identity, const Configurations& configurations); + /// @brief Reconfigures logger's single configuration + static Logger* reconfigureLogger(const std::string& identity, ConfigurationType configurationType, + const std::string& value); + /// @brief Reconfigures all the existing loggers with new configurations + static void reconfigureAllLoggers(const Configurations& configurations); + /// @brief Reconfigures single configuration for all the loggers + static inline void reconfigureAllLoggers(ConfigurationType configurationType, const std::string& value) { + reconfigureAllLoggers(Level::Global, configurationType, value); + } + /// @brief Reconfigures single configuration for all the loggers for specified level + static void reconfigureAllLoggers(Level level, ConfigurationType configurationType, + const std::string& value); + /// @brief Sets default configurations. This configuration is used for future (and conditionally for existing) loggers + static void setDefaultConfigurations(const Configurations& configurations, + bool reconfigureExistingLoggers = false); + /// @brief Returns current default + static const Configurations* defaultConfigurations(void); + /// @brief Returns log stream reference pointer if needed by user + static const base::LogStreamsReferenceMapPtr logStreamsReference(void); + /// @brief Default typed configuration based on existing defaultConf + static base::TypedConfigurations defaultTypedConfigurations(void); + /// @brief Populates all logger IDs in current repository. + /// @param [out] targetList List of fill up. + static std::vector* populateAllLoggerIds(std::vector* targetList); + /// @brief Sets configurations from global configuration file. + static void configureFromGlobal(const char* globalConfigurationFilePath); + /// @brief Configures loggers using command line arg. Ensure you have already set command line args, + /// @return False if invalid argument or argument with no value provided, true if attempted to configure logger. + /// If true is returned that does not mean it has been configured successfully, it only means that it + /// has attempted to configure logger using configuration file provided in argument + static bool configureFromArg(const char* argKey); + /// @brief Flushes all loggers for all levels - Be careful if you dont know how many loggers are registered + static void flushAll(void); + /// @brief Adds logging flag used internally. + static inline void addFlag(LoggingFlag flag) { + ELPP->addFlag(flag); + } + /// @brief Removes logging flag used internally. + static inline void removeFlag(LoggingFlag flag) { + ELPP->removeFlag(flag); + } + /// @brief Determines whether or not certain flag is active + static inline bool hasFlag(LoggingFlag flag) { + return ELPP->hasFlag(flag); + } + /// @brief Adds flag and removes it when scope goes out + class ScopedAddFlag { + public: + ScopedAddFlag(LoggingFlag flag) : m_flag(flag) { + Loggers::addFlag(m_flag); + } + ~ScopedAddFlag(void) { + Loggers::removeFlag(m_flag); + } + private: + LoggingFlag m_flag; + }; + /// @brief Removes flag and add it when scope goes out + class ScopedRemoveFlag { + public: + ScopedRemoveFlag(LoggingFlag flag) : m_flag(flag) { + Loggers::removeFlag(m_flag); + } + ~ScopedRemoveFlag(void) { + Loggers::addFlag(m_flag); + } + private: + LoggingFlag m_flag; + }; + /// @brief Sets hierarchy for logging. Needs to enable logging flag (HierarchicalLogging) + static void setLoggingLevel(Level level) { + ELPP->setLoggingLevel(level); + } + /// @brief Sets verbose level on the fly + static void setVerboseLevel(base::type::VerboseLevel level); + /// @brief Gets current verbose level + static base::type::VerboseLevel verboseLevel(void); + /// @brief Sets vmodules as specified (on the fly) + static void setVModules(const char* modules); + /// @brief Clears vmodules + static void clearVModules(void); +}; +class VersionInfo : base::StaticClass { + public: + /// @brief Current version number + static const std::string version(void); + + /// @brief Release date of current version + static const std::string releaseDate(void); +}; +} // namespace el +#undef VLOG_IS_ON +/// @brief Determines whether verbose logging is on for specified level current file. +#define VLOG_IS_ON(verboseLevel) (ELPP->vRegistry()->allowed(verboseLevel, __FILE__)) +#undef TIMED_BLOCK +#undef TIMED_SCOPE +#undef TIMED_SCOPE_IF +#undef TIMED_FUNC +#undef TIMED_FUNC_IF +#undef ELPP_MIN_UNIT +#if defined(ELPP_PERFORMANCE_MICROSECONDS) +# define ELPP_MIN_UNIT el::base::TimestampUnit::Microsecond +#else +# define ELPP_MIN_UNIT el::base::TimestampUnit::Millisecond +#endif // (defined(ELPP_PERFORMANCE_MICROSECONDS)) +/// @brief Performance tracked scope. Performance gets written when goes out of scope using +/// 'performance' logger. +/// +/// @detail Please note in order to check the performance at a certain time you can use obj->checkpoint(); +/// @see el::base::PerformanceTracker +/// @see el::base::PerformanceTracker::checkpoint +// Note: Do not surround this definition with null macro because of obj instance +#define TIMED_SCOPE_IF(obj, blockname, condition) el::base::type::PerformanceTrackerPtr obj( condition ? \ + new el::base::PerformanceTracker(blockname, ELPP_MIN_UNIT) : nullptr ) +#define TIMED_SCOPE(obj, blockname) TIMED_SCOPE_IF(obj, blockname, true) +#define TIMED_BLOCK(obj, blockName) for (struct { int i; el::base::type::PerformanceTrackerPtr timer; } obj = { 0, \ + el::base::type::PerformanceTrackerPtr(new el::base::PerformanceTracker(blockName, ELPP_MIN_UNIT)) }; obj.i < 1; ++obj.i) +/// @brief Performance tracked function. Performance gets written when goes out of scope using +/// 'performance' logger. +/// +/// @detail Please note in order to check the performance at a certain time you can use obj->checkpoint(); +/// @see el::base::PerformanceTracker +/// @see el::base::PerformanceTracker::checkpoint +#define TIMED_FUNC_IF(obj,condition) TIMED_SCOPE_IF(obj, ELPP_FUNC, condition) +#define TIMED_FUNC(obj) TIMED_SCOPE(obj, ELPP_FUNC) +#undef PERFORMANCE_CHECKPOINT +#undef PERFORMANCE_CHECKPOINT_WITH_ID +#define PERFORMANCE_CHECKPOINT(obj) obj->checkpoint(std::string(), __FILE__, __LINE__, ELPP_FUNC) +#define PERFORMANCE_CHECKPOINT_WITH_ID(obj, id) obj->checkpoint(id, __FILE__, __LINE__, ELPP_FUNC) +#undef ELPP_COUNTER +#undef ELPP_COUNTER_POS +/// @brief Gets hit counter for file/line +#define ELPP_COUNTER (ELPP->hitCounters()->getCounter(__FILE__, __LINE__)) +/// @brief Gets hit counter position for file/line, -1 if not registered yet +#define ELPP_COUNTER_POS (ELPP_COUNTER == nullptr ? -1 : ELPP_COUNTER->hitCounts()) +// Undef levels to support LOG(LEVEL) +#undef INFO +#undef WARNING +#undef DEBUG +#undef ERROR +#undef FATAL +#undef TRACE +#undef VERBOSE +// Undef existing +#undef CINFO +#undef CWARNING +#undef CDEBUG +#undef CFATAL +#undef CERROR +#undef CTRACE +#undef CVERBOSE +#undef CINFO_IF +#undef CWARNING_IF +#undef CDEBUG_IF +#undef CERROR_IF +#undef CFATAL_IF +#undef CTRACE_IF +#undef CVERBOSE_IF +#undef CINFO_EVERY_N +#undef CWARNING_EVERY_N +#undef CDEBUG_EVERY_N +#undef CERROR_EVERY_N +#undef CFATAL_EVERY_N +#undef CTRACE_EVERY_N +#undef CVERBOSE_EVERY_N +#undef CINFO_AFTER_N +#undef CWARNING_AFTER_N +#undef CDEBUG_AFTER_N +#undef CERROR_AFTER_N +#undef CFATAL_AFTER_N +#undef CTRACE_AFTER_N +#undef CVERBOSE_AFTER_N +#undef CINFO_N_TIMES +#undef CWARNING_N_TIMES +#undef CDEBUG_N_TIMES +#undef CERROR_N_TIMES +#undef CFATAL_N_TIMES +#undef CTRACE_N_TIMES +#undef CVERBOSE_N_TIMES +// Normal logs +#if ELPP_INFO_LOG +# define CINFO(writer, dispatchAction, ...) ELPP_WRITE_LOG(writer, el::Level::Info, dispatchAction, __VA_ARGS__) +#else +# define CINFO(writer, dispatchAction, ...) el::base::NullWriter() +#endif // ELPP_INFO_LOG +#if ELPP_WARNING_LOG +# define CWARNING(writer, dispatchAction, ...) ELPP_WRITE_LOG(writer, el::Level::Warning, dispatchAction, __VA_ARGS__) +#else +# define CWARNING(writer, dispatchAction, ...) el::base::NullWriter() +#endif // ELPP_WARNING_LOG +#if ELPP_DEBUG_LOG +# define CDEBUG(writer, dispatchAction, ...) ELPP_WRITE_LOG(writer, el::Level::Debug, dispatchAction, __VA_ARGS__) +#else +# define CDEBUG(writer, dispatchAction, ...) el::base::NullWriter() +#endif // ELPP_DEBUG_LOG +#if ELPP_ERROR_LOG +# define CERROR(writer, dispatchAction, ...) ELPP_WRITE_LOG(writer, el::Level::Error, dispatchAction, __VA_ARGS__) +#else +# define CERROR(writer, dispatchAction, ...) el::base::NullWriter() +#endif // ELPP_ERROR_LOG +#if ELPP_FATAL_LOG +# define CFATAL(writer, dispatchAction, ...) ELPP_WRITE_LOG(writer, el::Level::Fatal, dispatchAction, __VA_ARGS__) +#else +# define CFATAL(writer, dispatchAction, ...) el::base::NullWriter() +#endif // ELPP_FATAL_LOG +#if ELPP_TRACE_LOG +# define CTRACE(writer, dispatchAction, ...) ELPP_WRITE_LOG(writer, el::Level::Trace, dispatchAction, __VA_ARGS__) +#else +# define CTRACE(writer, dispatchAction, ...) el::base::NullWriter() +#endif // ELPP_TRACE_LOG +#if ELPP_VERBOSE_LOG +# define CVERBOSE(writer, vlevel, dispatchAction, ...) if (VLOG_IS_ON(vlevel)) writer(\ +el::Level::Verbose, __FILE__, __LINE__, ELPP_FUNC, dispatchAction, vlevel).construct(el_getVALength(__VA_ARGS__), __VA_ARGS__) +#else +# define CVERBOSE(writer, vlevel, dispatchAction, ...) el::base::NullWriter() +#endif // ELPP_VERBOSE_LOG +// Conditional logs +#if ELPP_INFO_LOG +# define CINFO_IF(writer, condition_, dispatchAction, ...) \ +ELPP_WRITE_LOG_IF(writer, (condition_), el::Level::Info, dispatchAction, __VA_ARGS__) +#else +# define CINFO_IF(writer, condition_, dispatchAction, ...) el::base::NullWriter() +#endif // ELPP_INFO_LOG +#if ELPP_WARNING_LOG +# define CWARNING_IF(writer, condition_, dispatchAction, ...)\ +ELPP_WRITE_LOG_IF(writer, (condition_), el::Level::Warning, dispatchAction, __VA_ARGS__) +#else +# define CWARNING_IF(writer, condition_, dispatchAction, ...) el::base::NullWriter() +#endif // ELPP_WARNING_LOG +#if ELPP_DEBUG_LOG +# define CDEBUG_IF(writer, condition_, dispatchAction, ...)\ +ELPP_WRITE_LOG_IF(writer, (condition_), el::Level::Debug, dispatchAction, __VA_ARGS__) +#else +# define CDEBUG_IF(writer, condition_, dispatchAction, ...) el::base::NullWriter() +#endif // ELPP_DEBUG_LOG +#if ELPP_ERROR_LOG +# define CERROR_IF(writer, condition_, dispatchAction, ...)\ +ELPP_WRITE_LOG_IF(writer, (condition_), el::Level::Error, dispatchAction, __VA_ARGS__) +#else +# define CERROR_IF(writer, condition_, dispatchAction, ...) el::base::NullWriter() +#endif // ELPP_ERROR_LOG +#if ELPP_FATAL_LOG +# define CFATAL_IF(writer, condition_, dispatchAction, ...)\ +ELPP_WRITE_LOG_IF(writer, (condition_), el::Level::Fatal, dispatchAction, __VA_ARGS__) +#else +# define CFATAL_IF(writer, condition_, dispatchAction, ...) el::base::NullWriter() +#endif // ELPP_FATAL_LOG +#if ELPP_TRACE_LOG +# define CTRACE_IF(writer, condition_, dispatchAction, ...)\ +ELPP_WRITE_LOG_IF(writer, (condition_), el::Level::Trace, dispatchAction, __VA_ARGS__) +#else +# define CTRACE_IF(writer, condition_, dispatchAction, ...) el::base::NullWriter() +#endif // ELPP_TRACE_LOG +#if ELPP_VERBOSE_LOG +# define CVERBOSE_IF(writer, condition_, vlevel, dispatchAction, ...) if (VLOG_IS_ON(vlevel) && (condition_)) writer( \ +el::Level::Verbose, __FILE__, __LINE__, ELPP_FUNC, dispatchAction, vlevel).construct(el_getVALength(__VA_ARGS__), __VA_ARGS__) +#else +# define CVERBOSE_IF(writer, condition_, vlevel, dispatchAction, ...) el::base::NullWriter() +#endif // ELPP_VERBOSE_LOG +// Occasional logs +#if ELPP_INFO_LOG +# define CINFO_EVERY_N(writer, occasion, dispatchAction, ...)\ +ELPP_WRITE_LOG_EVERY_N(writer, occasion, el::Level::Info, dispatchAction, __VA_ARGS__) +#else +# define CINFO_EVERY_N(writer, occasion, dispatchAction, ...) el::base::NullWriter() +#endif // ELPP_INFO_LOG +#if ELPP_WARNING_LOG +# define CWARNING_EVERY_N(writer, occasion, dispatchAction, ...)\ +ELPP_WRITE_LOG_EVERY_N(writer, occasion, el::Level::Warning, dispatchAction, __VA_ARGS__) +#else +# define CWARNING_EVERY_N(writer, occasion, dispatchAction, ...) el::base::NullWriter() +#endif // ELPP_WARNING_LOG +#if ELPP_DEBUG_LOG +# define CDEBUG_EVERY_N(writer, occasion, dispatchAction, ...)\ +ELPP_WRITE_LOG_EVERY_N(writer, occasion, el::Level::Debug, dispatchAction, __VA_ARGS__) +#else +# define CDEBUG_EVERY_N(writer, occasion, dispatchAction, ...) el::base::NullWriter() +#endif // ELPP_DEBUG_LOG +#if ELPP_ERROR_LOG +# define CERROR_EVERY_N(writer, occasion, dispatchAction, ...)\ +ELPP_WRITE_LOG_EVERY_N(writer, occasion, el::Level::Error, dispatchAction, __VA_ARGS__) +#else +# define CERROR_EVERY_N(writer, occasion, dispatchAction, ...) el::base::NullWriter() +#endif // ELPP_ERROR_LOG +#if ELPP_FATAL_LOG +# define CFATAL_EVERY_N(writer, occasion, dispatchAction, ...)\ +ELPP_WRITE_LOG_EVERY_N(writer, occasion, el::Level::Fatal, dispatchAction, __VA_ARGS__) +#else +# define CFATAL_EVERY_N(writer, occasion, dispatchAction, ...) el::base::NullWriter() +#endif // ELPP_FATAL_LOG +#if ELPP_TRACE_LOG +# define CTRACE_EVERY_N(writer, occasion, dispatchAction, ...)\ +ELPP_WRITE_LOG_EVERY_N(writer, occasion, el::Level::Trace, dispatchAction, __VA_ARGS__) +#else +# define CTRACE_EVERY_N(writer, occasion, dispatchAction, ...) el::base::NullWriter() +#endif // ELPP_TRACE_LOG +#if ELPP_VERBOSE_LOG +# define CVERBOSE_EVERY_N(writer, occasion, vlevel, dispatchAction, ...)\ +CVERBOSE_IF(writer, ELPP->validateEveryNCounter(__FILE__, __LINE__, occasion), vlevel, dispatchAction, __VA_ARGS__) +#else +# define CVERBOSE_EVERY_N(writer, occasion, vlevel, dispatchAction, ...) el::base::NullWriter() +#endif // ELPP_VERBOSE_LOG +// After N logs +#if ELPP_INFO_LOG +# define CINFO_AFTER_N(writer, n, dispatchAction, ...)\ +ELPP_WRITE_LOG_AFTER_N(writer, n, el::Level::Info, dispatchAction, __VA_ARGS__) +#else +# define CINFO_AFTER_N(writer, n, dispatchAction, ...) el::base::NullWriter() +#endif // ELPP_INFO_LOG +#if ELPP_WARNING_LOG +# define CWARNING_AFTER_N(writer, n, dispatchAction, ...)\ +ELPP_WRITE_LOG_AFTER_N(writer, n, el::Level::Warning, dispatchAction, __VA_ARGS__) +#else +# define CWARNING_AFTER_N(writer, n, dispatchAction, ...) el::base::NullWriter() +#endif // ELPP_WARNING_LOG +#if ELPP_DEBUG_LOG +# define CDEBUG_AFTER_N(writer, n, dispatchAction, ...)\ +ELPP_WRITE_LOG_AFTER_N(writer, n, el::Level::Debug, dispatchAction, __VA_ARGS__) +#else +# define CDEBUG_AFTER_N(writer, n, dispatchAction, ...) el::base::NullWriter() +#endif // ELPP_DEBUG_LOG +#if ELPP_ERROR_LOG +# define CERROR_AFTER_N(writer, n, dispatchAction, ...)\ +ELPP_WRITE_LOG_AFTER_N(writer, n, el::Level::Error, dispatchAction, __VA_ARGS__) +#else +# define CERROR_AFTER_N(writer, n, dispatchAction, ...) el::base::NullWriter() +#endif // ELPP_ERROR_LOG +#if ELPP_FATAL_LOG +# define CFATAL_AFTER_N(writer, n, dispatchAction, ...)\ +ELPP_WRITE_LOG_AFTER_N(writer, n, el::Level::Fatal, dispatchAction, __VA_ARGS__) +#else +# define CFATAL_AFTER_N(writer, n, dispatchAction, ...) el::base::NullWriter() +#endif // ELPP_FATAL_LOG +#if ELPP_TRACE_LOG +# define CTRACE_AFTER_N(writer, n, dispatchAction, ...)\ +ELPP_WRITE_LOG_AFTER_N(writer, n, el::Level::Trace, dispatchAction, __VA_ARGS__) +#else +# define CTRACE_AFTER_N(writer, n, dispatchAction, ...) el::base::NullWriter() +#endif // ELPP_TRACE_LOG +#if ELPP_VERBOSE_LOG +# define CVERBOSE_AFTER_N(writer, n, vlevel, dispatchAction, ...)\ +CVERBOSE_IF(writer, ELPP->validateAfterNCounter(__FILE__, __LINE__, n), vlevel, dispatchAction, __VA_ARGS__) +#else +# define CVERBOSE_AFTER_N(writer, n, vlevel, dispatchAction, ...) el::base::NullWriter() +#endif // ELPP_VERBOSE_LOG +// N Times logs +#if ELPP_INFO_LOG +# define CINFO_N_TIMES(writer, n, dispatchAction, ...)\ +ELPP_WRITE_LOG_N_TIMES(writer, n, el::Level::Info, dispatchAction, __VA_ARGS__) +#else +# define CINFO_N_TIMES(writer, n, dispatchAction, ...) el::base::NullWriter() +#endif // ELPP_INFO_LOG +#if ELPP_WARNING_LOG +# define CWARNING_N_TIMES(writer, n, dispatchAction, ...)\ +ELPP_WRITE_LOG_N_TIMES(writer, n, el::Level::Warning, dispatchAction, __VA_ARGS__) +#else +# define CWARNING_N_TIMES(writer, n, dispatchAction, ...) el::base::NullWriter() +#endif // ELPP_WARNING_LOG +#if ELPP_DEBUG_LOG +# define CDEBUG_N_TIMES(writer, n, dispatchAction, ...)\ +ELPP_WRITE_LOG_N_TIMES(writer, n, el::Level::Debug, dispatchAction, __VA_ARGS__) +#else +# define CDEBUG_N_TIMES(writer, n, dispatchAction, ...) el::base::NullWriter() +#endif // ELPP_DEBUG_LOG +#if ELPP_ERROR_LOG +# define CERROR_N_TIMES(writer, n, dispatchAction, ...)\ +ELPP_WRITE_LOG_N_TIMES(writer, n, el::Level::Error, dispatchAction, __VA_ARGS__) +#else +# define CERROR_N_TIMES(writer, n, dispatchAction, ...) el::base::NullWriter() +#endif // ELPP_ERROR_LOG +#if ELPP_FATAL_LOG +# define CFATAL_N_TIMES(writer, n, dispatchAction, ...)\ +ELPP_WRITE_LOG_N_TIMES(writer, n, el::Level::Fatal, dispatchAction, __VA_ARGS__) +#else +# define CFATAL_N_TIMES(writer, n, dispatchAction, ...) el::base::NullWriter() +#endif // ELPP_FATAL_LOG +#if ELPP_TRACE_LOG +# define CTRACE_N_TIMES(writer, n, dispatchAction, ...)\ +ELPP_WRITE_LOG_N_TIMES(writer, n, el::Level::Trace, dispatchAction, __VA_ARGS__) +#else +# define CTRACE_N_TIMES(writer, n, dispatchAction, ...) el::base::NullWriter() +#endif // ELPP_TRACE_LOG +#if ELPP_VERBOSE_LOG +# define CVERBOSE_N_TIMES(writer, n, vlevel, dispatchAction, ...)\ +CVERBOSE_IF(writer, ELPP->validateNTimesCounter(__FILE__, __LINE__, n), vlevel, dispatchAction, __VA_ARGS__) +#else +# define CVERBOSE_N_TIMES(writer, n, vlevel, dispatchAction, ...) el::base::NullWriter() +#endif // ELPP_VERBOSE_LOG +// +// Custom Loggers - Requires (level, dispatchAction, loggerId/s) +// +// undef existing +#undef CLOG +#undef CLOG_VERBOSE +#undef CVLOG +#undef CLOG_IF +#undef CLOG_VERBOSE_IF +#undef CVLOG_IF +#undef CLOG_EVERY_N +#undef CVLOG_EVERY_N +#undef CLOG_AFTER_N +#undef CVLOG_AFTER_N +#undef CLOG_N_TIMES +#undef CVLOG_N_TIMES +// Normal logs +#define CLOG(LEVEL, ...)\ +C##LEVEL(el::base::Writer, el::base::DispatchAction::NormalLog, __VA_ARGS__) +#define CVLOG(vlevel, ...) CVERBOSE(el::base::Writer, vlevel, el::base::DispatchAction::NormalLog, __VA_ARGS__) +// Conditional logs +#define CLOG_IF(condition, LEVEL, ...)\ +C##LEVEL##_IF(el::base::Writer, condition, el::base::DispatchAction::NormalLog, __VA_ARGS__) +#define CVLOG_IF(condition, vlevel, ...)\ +CVERBOSE_IF(el::base::Writer, condition, vlevel, el::base::DispatchAction::NormalLog, __VA_ARGS__) +// Hit counts based logs +#define CLOG_EVERY_N(n, LEVEL, ...)\ +C##LEVEL##_EVERY_N(el::base::Writer, n, el::base::DispatchAction::NormalLog, __VA_ARGS__) +#define CVLOG_EVERY_N(n, vlevel, ...)\ +CVERBOSE_EVERY_N(el::base::Writer, n, vlevel, el::base::DispatchAction::NormalLog, __VA_ARGS__) +#define CLOG_AFTER_N(n, LEVEL, ...)\ +C##LEVEL##_AFTER_N(el::base::Writer, n, el::base::DispatchAction::NormalLog, __VA_ARGS__) +#define CVLOG_AFTER_N(n, vlevel, ...)\ +CVERBOSE_AFTER_N(el::base::Writer, n, vlevel, el::base::DispatchAction::NormalLog, __VA_ARGS__) +#define CLOG_N_TIMES(n, LEVEL, ...)\ +C##LEVEL##_N_TIMES(el::base::Writer, n, el::base::DispatchAction::NormalLog, __VA_ARGS__) +#define CVLOG_N_TIMES(n, vlevel, ...)\ +CVERBOSE_N_TIMES(el::base::Writer, n, vlevel, el::base::DispatchAction::NormalLog, __VA_ARGS__) +// +// Default Loggers macro using CLOG(), CLOG_VERBOSE() and CVLOG() macros +// +// undef existing +#undef LOG +#undef VLOG +#undef LOG_IF +#undef VLOG_IF +#undef LOG_EVERY_N +#undef VLOG_EVERY_N +#undef LOG_AFTER_N +#undef VLOG_AFTER_N +#undef LOG_N_TIMES +#undef VLOG_N_TIMES +#undef ELPP_CURR_FILE_LOGGER_ID +#if defined(ELPP_DEFAULT_LOGGER) +# define ELPP_CURR_FILE_LOGGER_ID ELPP_DEFAULT_LOGGER +#else +# define ELPP_CURR_FILE_LOGGER_ID el::base::consts::kDefaultLoggerId +#endif +#undef ELPP_TRACE +#define ELPP_TRACE CLOG(TRACE, ELPP_CURR_FILE_LOGGER_ID) +// Normal logs +#define LOG(LEVEL) CLOG(LEVEL, ELPP_CURR_FILE_LOGGER_ID) +#define VLOG(vlevel) CVLOG(vlevel, ELPP_CURR_FILE_LOGGER_ID) +// Conditional logs +#define LOG_IF(condition, LEVEL) CLOG_IF(condition, LEVEL, ELPP_CURR_FILE_LOGGER_ID) +#define VLOG_IF(condition, vlevel) CVLOG_IF(condition, vlevel, ELPP_CURR_FILE_LOGGER_ID) +// Hit counts based logs +#define LOG_EVERY_N(n, LEVEL) CLOG_EVERY_N(n, LEVEL, ELPP_CURR_FILE_LOGGER_ID) +#define VLOG_EVERY_N(n, vlevel) CVLOG_EVERY_N(n, vlevel, ELPP_CURR_FILE_LOGGER_ID) +#define LOG_AFTER_N(n, LEVEL) CLOG_AFTER_N(n, LEVEL, ELPP_CURR_FILE_LOGGER_ID) +#define VLOG_AFTER_N(n, vlevel) CVLOG_AFTER_N(n, vlevel, ELPP_CURR_FILE_LOGGER_ID) +#define LOG_N_TIMES(n, LEVEL) CLOG_N_TIMES(n, LEVEL, ELPP_CURR_FILE_LOGGER_ID) +#define VLOG_N_TIMES(n, vlevel) CVLOG_N_TIMES(n, vlevel, ELPP_CURR_FILE_LOGGER_ID) +// Generic PLOG() +#undef CPLOG +#undef CPLOG_IF +#undef PLOG +#undef PLOG_IF +#undef DCPLOG +#undef DCPLOG_IF +#undef DPLOG +#undef DPLOG_IF +#define CPLOG(LEVEL, ...)\ +C##LEVEL(el::base::PErrorWriter, el::base::DispatchAction::NormalLog, __VA_ARGS__) +#define CPLOG_IF(condition, LEVEL, ...)\ +C##LEVEL##_IF(el::base::PErrorWriter, condition, el::base::DispatchAction::NormalLog, __VA_ARGS__) +#define DCPLOG(LEVEL, ...)\ +if (ELPP_DEBUG_LOG) C##LEVEL(el::base::PErrorWriter, el::base::DispatchAction::NormalLog, __VA_ARGS__) +#define DCPLOG_IF(condition, LEVEL, ...)\ +C##LEVEL##_IF(el::base::PErrorWriter, (ELPP_DEBUG_LOG) && (condition), el::base::DispatchAction::NormalLog, __VA_ARGS__) +#define PLOG(LEVEL) CPLOG(LEVEL, ELPP_CURR_FILE_LOGGER_ID) +#define PLOG_IF(condition, LEVEL) CPLOG_IF(condition, LEVEL, ELPP_CURR_FILE_LOGGER_ID) +#define DPLOG(LEVEL) DCPLOG(LEVEL, ELPP_CURR_FILE_LOGGER_ID) +#define DPLOG_IF(condition, LEVEL) DCPLOG_IF(condition, LEVEL, ELPP_CURR_FILE_LOGGER_ID) +// Generic SYSLOG() +#undef CSYSLOG +#undef CSYSLOG_IF +#undef CSYSLOG_EVERY_N +#undef CSYSLOG_AFTER_N +#undef CSYSLOG_N_TIMES +#undef SYSLOG +#undef SYSLOG_IF +#undef SYSLOG_EVERY_N +#undef SYSLOG_AFTER_N +#undef SYSLOG_N_TIMES +#undef DCSYSLOG +#undef DCSYSLOG_IF +#undef DCSYSLOG_EVERY_N +#undef DCSYSLOG_AFTER_N +#undef DCSYSLOG_N_TIMES +#undef DSYSLOG +#undef DSYSLOG_IF +#undef DSYSLOG_EVERY_N +#undef DSYSLOG_AFTER_N +#undef DSYSLOG_N_TIMES +#if defined(ELPP_SYSLOG) +# define CSYSLOG(LEVEL, ...)\ +C##LEVEL(el::base::Writer, el::base::DispatchAction::SysLog, __VA_ARGS__) +# define CSYSLOG_IF(condition, LEVEL, ...)\ +C##LEVEL##_IF(el::base::Writer, condition, el::base::DispatchAction::SysLog, __VA_ARGS__) +# define CSYSLOG_EVERY_N(n, LEVEL, ...) C##LEVEL##_EVERY_N(el::base::Writer, n, el::base::DispatchAction::SysLog, __VA_ARGS__) +# define CSYSLOG_AFTER_N(n, LEVEL, ...) C##LEVEL##_AFTER_N(el::base::Writer, n, el::base::DispatchAction::SysLog, __VA_ARGS__) +# define CSYSLOG_N_TIMES(n, LEVEL, ...) C##LEVEL##_N_TIMES(el::base::Writer, n, el::base::DispatchAction::SysLog, __VA_ARGS__) +# define SYSLOG(LEVEL) CSYSLOG(LEVEL, el::base::consts::kSysLogLoggerId) +# define SYSLOG_IF(condition, LEVEL) CSYSLOG_IF(condition, LEVEL, el::base::consts::kSysLogLoggerId) +# define SYSLOG_EVERY_N(n, LEVEL) CSYSLOG_EVERY_N(n, LEVEL, el::base::consts::kSysLogLoggerId) +# define SYSLOG_AFTER_N(n, LEVEL) CSYSLOG_AFTER_N(n, LEVEL, el::base::consts::kSysLogLoggerId) +# define SYSLOG_N_TIMES(n, LEVEL) CSYSLOG_N_TIMES(n, LEVEL, el::base::consts::kSysLogLoggerId) +# define DCSYSLOG(LEVEL, ...) if (ELPP_DEBUG_LOG) C##LEVEL(el::base::Writer, el::base::DispatchAction::SysLog, __VA_ARGS__) +# define DCSYSLOG_IF(condition, LEVEL, ...)\ +C##LEVEL##_IF(el::base::Writer, (ELPP_DEBUG_LOG) && (condition), el::base::DispatchAction::SysLog, __VA_ARGS__) +# define DCSYSLOG_EVERY_N(n, LEVEL, ...)\ +if (ELPP_DEBUG_LOG) C##LEVEL##_EVERY_N(el::base::Writer, n, el::base::DispatchAction::SysLog, __VA_ARGS__) +# define DCSYSLOG_AFTER_N(n, LEVEL, ...)\ +if (ELPP_DEBUG_LOG) C##LEVEL##_AFTER_N(el::base::Writer, n, el::base::DispatchAction::SysLog, __VA_ARGS__) +# define DCSYSLOG_N_TIMES(n, LEVEL, ...)\ +if (ELPP_DEBUG_LOG) C##LEVEL##_EVERY_N(el::base::Writer, n, el::base::DispatchAction::SysLog, __VA_ARGS__) +# define DSYSLOG(LEVEL) DCSYSLOG(LEVEL, el::base::consts::kSysLogLoggerId) +# define DSYSLOG_IF(condition, LEVEL) DCSYSLOG_IF(condition, LEVEL, el::base::consts::kSysLogLoggerId) +# define DSYSLOG_EVERY_N(n, LEVEL) DCSYSLOG_EVERY_N(n, LEVEL, el::base::consts::kSysLogLoggerId) +# define DSYSLOG_AFTER_N(n, LEVEL) DCSYSLOG_AFTER_N(n, LEVEL, el::base::consts::kSysLogLoggerId) +# define DSYSLOG_N_TIMES(n, LEVEL) DCSYSLOG_N_TIMES(n, LEVEL, el::base::consts::kSysLogLoggerId) +#else +# define CSYSLOG(LEVEL, ...) el::base::NullWriter() +# define CSYSLOG_IF(condition, LEVEL, ...) el::base::NullWriter() +# define CSYSLOG_EVERY_N(n, LEVEL, ...) el::base::NullWriter() +# define CSYSLOG_AFTER_N(n, LEVEL, ...) el::base::NullWriter() +# define CSYSLOG_N_TIMES(n, LEVEL, ...) el::base::NullWriter() +# define SYSLOG(LEVEL) el::base::NullWriter() +# define SYSLOG_IF(condition, LEVEL) el::base::NullWriter() +# define SYSLOG_EVERY_N(n, LEVEL) el::base::NullWriter() +# define SYSLOG_AFTER_N(n, LEVEL) el::base::NullWriter() +# define SYSLOG_N_TIMES(n, LEVEL) el::base::NullWriter() +# define DCSYSLOG(LEVEL, ...) el::base::NullWriter() +# define DCSYSLOG_IF(condition, LEVEL, ...) el::base::NullWriter() +# define DCSYSLOG_EVERY_N(n, LEVEL, ...) el::base::NullWriter() +# define DCSYSLOG_AFTER_N(n, LEVEL, ...) el::base::NullWriter() +# define DCSYSLOG_N_TIMES(n, LEVEL, ...) el::base::NullWriter() +# define DSYSLOG(LEVEL) el::base::NullWriter() +# define DSYSLOG_IF(condition, LEVEL) el::base::NullWriter() +# define DSYSLOG_EVERY_N(n, LEVEL) el::base::NullWriter() +# define DSYSLOG_AFTER_N(n, LEVEL) el::base::NullWriter() +# define DSYSLOG_N_TIMES(n, LEVEL) el::base::NullWriter() +#endif // defined(ELPP_SYSLOG) +// +// Custom Debug Only Loggers - Requires (level, loggerId/s) +// +// undef existing +#undef DCLOG +#undef DCVLOG +#undef DCLOG_IF +#undef DCVLOG_IF +#undef DCLOG_EVERY_N +#undef DCVLOG_EVERY_N +#undef DCLOG_AFTER_N +#undef DCVLOG_AFTER_N +#undef DCLOG_N_TIMES +#undef DCVLOG_N_TIMES +// Normal logs +#define DCLOG(LEVEL, ...) if (ELPP_DEBUG_LOG) CLOG(LEVEL, __VA_ARGS__) +#define DCLOG_VERBOSE(vlevel, ...) if (ELPP_DEBUG_LOG) CLOG_VERBOSE(vlevel, __VA_ARGS__) +#define DCVLOG(vlevel, ...) if (ELPP_DEBUG_LOG) CVLOG(vlevel, __VA_ARGS__) +// Conditional logs +#define DCLOG_IF(condition, LEVEL, ...) if (ELPP_DEBUG_LOG) CLOG_IF(condition, LEVEL, __VA_ARGS__) +#define DCVLOG_IF(condition, vlevel, ...) if (ELPP_DEBUG_LOG) CVLOG_IF(condition, vlevel, __VA_ARGS__) +// Hit counts based logs +#define DCLOG_EVERY_N(n, LEVEL, ...) if (ELPP_DEBUG_LOG) CLOG_EVERY_N(n, LEVEL, __VA_ARGS__) +#define DCVLOG_EVERY_N(n, vlevel, ...) if (ELPP_DEBUG_LOG) CVLOG_EVERY_N(n, vlevel, __VA_ARGS__) +#define DCLOG_AFTER_N(n, LEVEL, ...) if (ELPP_DEBUG_LOG) CLOG_AFTER_N(n, LEVEL, __VA_ARGS__) +#define DCVLOG_AFTER_N(n, vlevel, ...) if (ELPP_DEBUG_LOG) CVLOG_AFTER_N(n, vlevel, __VA_ARGS__) +#define DCLOG_N_TIMES(n, LEVEL, ...) if (ELPP_DEBUG_LOG) CLOG_N_TIMES(n, LEVEL, __VA_ARGS__) +#define DCVLOG_N_TIMES(n, vlevel, ...) if (ELPP_DEBUG_LOG) CVLOG_N_TIMES(n, vlevel, __VA_ARGS__) +// +// Default Debug Only Loggers macro using CLOG(), CLOG_VERBOSE() and CVLOG() macros +// +#if !defined(ELPP_NO_DEBUG_MACROS) +// undef existing +#undef DLOG +#undef DVLOG +#undef DLOG_IF +#undef DVLOG_IF +#undef DLOG_EVERY_N +#undef DVLOG_EVERY_N +#undef DLOG_AFTER_N +#undef DVLOG_AFTER_N +#undef DLOG_N_TIMES +#undef DVLOG_N_TIMES +// Normal logs +#define DLOG(LEVEL) DCLOG(LEVEL, ELPP_CURR_FILE_LOGGER_ID) +#define DVLOG(vlevel) DCVLOG(vlevel, ELPP_CURR_FILE_LOGGER_ID) +// Conditional logs +#define DLOG_IF(condition, LEVEL) DCLOG_IF(condition, LEVEL, ELPP_CURR_FILE_LOGGER_ID) +#define DVLOG_IF(condition, vlevel) DCVLOG_IF(condition, vlevel, ELPP_CURR_FILE_LOGGER_ID) +// Hit counts based logs +#define DLOG_EVERY_N(n, LEVEL) DCLOG_EVERY_N(n, LEVEL, ELPP_CURR_FILE_LOGGER_ID) +#define DVLOG_EVERY_N(n, vlevel) DCVLOG_EVERY_N(n, vlevel, ELPP_CURR_FILE_LOGGER_ID) +#define DLOG_AFTER_N(n, LEVEL) DCLOG_AFTER_N(n, LEVEL, ELPP_CURR_FILE_LOGGER_ID) +#define DVLOG_AFTER_N(n, vlevel) DCVLOG_AFTER_N(n, vlevel, ELPP_CURR_FILE_LOGGER_ID) +#define DLOG_N_TIMES(n, LEVEL) DCLOG_N_TIMES(n, LEVEL, ELPP_CURR_FILE_LOGGER_ID) +#define DVLOG_N_TIMES(n, vlevel) DCVLOG_N_TIMES(n, vlevel, ELPP_CURR_FILE_LOGGER_ID) +#endif // defined(ELPP_NO_DEBUG_MACROS) +#if !defined(ELPP_NO_CHECK_MACROS) +// Check macros +#undef CCHECK +#undef CPCHECK +#undef CCHECK_EQ +#undef CCHECK_NE +#undef CCHECK_LT +#undef CCHECK_GT +#undef CCHECK_LE +#undef CCHECK_GE +#undef CCHECK_BOUNDS +#undef CCHECK_NOTNULL +#undef CCHECK_STRCASEEQ +#undef CCHECK_STRCASENE +#undef CHECK +#undef PCHECK +#undef CHECK_EQ +#undef CHECK_NE +#undef CHECK_LT +#undef CHECK_GT +#undef CHECK_LE +#undef CHECK_GE +#undef CHECK_BOUNDS +#undef CHECK_NOTNULL +#undef CHECK_STRCASEEQ +#undef CHECK_STRCASENE +#define CCHECK(condition, ...) CLOG_IF(!(condition), FATAL, __VA_ARGS__) << "Check failed: [" << #condition << "] " +#define CPCHECK(condition, ...) CPLOG_IF(!(condition), FATAL, __VA_ARGS__) << "Check failed: [" << #condition << "] " +#define CHECK(condition) CCHECK(condition, ELPP_CURR_FILE_LOGGER_ID) +#define PCHECK(condition) CPCHECK(condition, ELPP_CURR_FILE_LOGGER_ID) +#define CCHECK_EQ(a, b, ...) CCHECK(a == b, __VA_ARGS__) +#define CCHECK_NE(a, b, ...) CCHECK(a != b, __VA_ARGS__) +#define CCHECK_LT(a, b, ...) CCHECK(a < b, __VA_ARGS__) +#define CCHECK_GT(a, b, ...) CCHECK(a > b, __VA_ARGS__) +#define CCHECK_LE(a, b, ...) CCHECK(a <= b, __VA_ARGS__) +#define CCHECK_GE(a, b, ...) CCHECK(a >= b, __VA_ARGS__) +#define CCHECK_BOUNDS(val, min, max, ...) CCHECK(val >= min && val <= max, __VA_ARGS__) +#define CHECK_EQ(a, b) CCHECK_EQ(a, b, ELPP_CURR_FILE_LOGGER_ID) +#define CHECK_NE(a, b) CCHECK_NE(a, b, ELPP_CURR_FILE_LOGGER_ID) +#define CHECK_LT(a, b) CCHECK_LT(a, b, ELPP_CURR_FILE_LOGGER_ID) +#define CHECK_GT(a, b) CCHECK_GT(a, b, ELPP_CURR_FILE_LOGGER_ID) +#define CHECK_LE(a, b) CCHECK_LE(a, b, ELPP_CURR_FILE_LOGGER_ID) +#define CHECK_GE(a, b) CCHECK_GE(a, b, ELPP_CURR_FILE_LOGGER_ID) +#define CHECK_BOUNDS(val, min, max) CCHECK_BOUNDS(val, min, max, ELPP_CURR_FILE_LOGGER_ID) +#define CCHECK_NOTNULL(ptr, ...) CCHECK((ptr) != nullptr, __VA_ARGS__) +#define CCHECK_STREQ(str1, str2, ...) CLOG_IF(!el::base::utils::Str::cStringEq(str1, str2), FATAL, __VA_ARGS__) \ +<< "Check failed: [" << #str1 << " == " << #str2 << "] " +#define CCHECK_STRNE(str1, str2, ...) CLOG_IF(el::base::utils::Str::cStringEq(str1, str2), FATAL, __VA_ARGS__) \ +<< "Check failed: [" << #str1 << " != " << #str2 << "] " +#define CCHECK_STRCASEEQ(str1, str2, ...) CLOG_IF(!el::base::utils::Str::cStringCaseEq(str1, str2), FATAL, __VA_ARGS__) \ +<< "Check failed: [" << #str1 << " == " << #str2 << "] " +#define CCHECK_STRCASENE(str1, str2, ...) CLOG_IF(el::base::utils::Str::cStringCaseEq(str1, str2), FATAL, __VA_ARGS__) \ +<< "Check failed: [" << #str1 << " != " << #str2 << "] " +#define CHECK_NOTNULL(ptr) CCHECK_NOTNULL((ptr), ELPP_CURR_FILE_LOGGER_ID) +#define CHECK_STREQ(str1, str2) CCHECK_STREQ(str1, str2, ELPP_CURR_FILE_LOGGER_ID) +#define CHECK_STRNE(str1, str2) CCHECK_STRNE(str1, str2, ELPP_CURR_FILE_LOGGER_ID) +#define CHECK_STRCASEEQ(str1, str2) CCHECK_STRCASEEQ(str1, str2, ELPP_CURR_FILE_LOGGER_ID) +#define CHECK_STRCASENE(str1, str2) CCHECK_STRCASENE(str1, str2, ELPP_CURR_FILE_LOGGER_ID) +#undef DCCHECK +#undef DCCHECK_EQ +#undef DCCHECK_NE +#undef DCCHECK_LT +#undef DCCHECK_GT +#undef DCCHECK_LE +#undef DCCHECK_GE +#undef DCCHECK_BOUNDS +#undef DCCHECK_NOTNULL +#undef DCCHECK_STRCASEEQ +#undef DCCHECK_STRCASENE +#undef DCPCHECK +#undef DCHECK +#undef DCHECK_EQ +#undef DCHECK_NE +#undef DCHECK_LT +#undef DCHECK_GT +#undef DCHECK_LE +#undef DCHECK_GE +#undef DCHECK_BOUNDS_ +#undef DCHECK_NOTNULL +#undef DCHECK_STRCASEEQ +#undef DCHECK_STRCASENE +#undef DPCHECK +#define DCCHECK(condition, ...) if (ELPP_DEBUG_LOG) CCHECK(condition, __VA_ARGS__) +#define DCCHECK_EQ(a, b, ...) if (ELPP_DEBUG_LOG) CCHECK_EQ(a, b, __VA_ARGS__) +#define DCCHECK_NE(a, b, ...) if (ELPP_DEBUG_LOG) CCHECK_NE(a, b, __VA_ARGS__) +#define DCCHECK_LT(a, b, ...) if (ELPP_DEBUG_LOG) CCHECK_LT(a, b, __VA_ARGS__) +#define DCCHECK_GT(a, b, ...) if (ELPP_DEBUG_LOG) CCHECK_GT(a, b, __VA_ARGS__) +#define DCCHECK_LE(a, b, ...) if (ELPP_DEBUG_LOG) CCHECK_LE(a, b, __VA_ARGS__) +#define DCCHECK_GE(a, b, ...) if (ELPP_DEBUG_LOG) CCHECK_GE(a, b, __VA_ARGS__) +#define DCCHECK_BOUNDS(val, min, max, ...) if (ELPP_DEBUG_LOG) CCHECK_BOUNDS(val, min, max, __VA_ARGS__) +#define DCCHECK_NOTNULL(ptr, ...) if (ELPP_DEBUG_LOG) CCHECK_NOTNULL((ptr), __VA_ARGS__) +#define DCCHECK_STREQ(str1, str2, ...) if (ELPP_DEBUG_LOG) CCHECK_STREQ(str1, str2, __VA_ARGS__) +#define DCCHECK_STRNE(str1, str2, ...) if (ELPP_DEBUG_LOG) CCHECK_STRNE(str1, str2, __VA_ARGS__) +#define DCCHECK_STRCASEEQ(str1, str2, ...) if (ELPP_DEBUG_LOG) CCHECK_STRCASEEQ(str1, str2, __VA_ARGS__) +#define DCCHECK_STRCASENE(str1, str2, ...) if (ELPP_DEBUG_LOG) CCHECK_STRCASENE(str1, str2, __VA_ARGS__) +#define DCPCHECK(condition, ...) if (ELPP_DEBUG_LOG) CPCHECK(condition, __VA_ARGS__) +#define DCHECK(condition) DCCHECK(condition, ELPP_CURR_FILE_LOGGER_ID) +#define DCHECK_EQ(a, b) DCCHECK_EQ(a, b, ELPP_CURR_FILE_LOGGER_ID) +#define DCHECK_NE(a, b) DCCHECK_NE(a, b, ELPP_CURR_FILE_LOGGER_ID) +#define DCHECK_LT(a, b) DCCHECK_LT(a, b, ELPP_CURR_FILE_LOGGER_ID) +#define DCHECK_GT(a, b) DCCHECK_GT(a, b, ELPP_CURR_FILE_LOGGER_ID) +#define DCHECK_LE(a, b) DCCHECK_LE(a, b, ELPP_CURR_FILE_LOGGER_ID) +#define DCHECK_GE(a, b) DCCHECK_GE(a, b, ELPP_CURR_FILE_LOGGER_ID) +#define DCHECK_BOUNDS(val, min, max) DCCHECK_BOUNDS(val, min, max, ELPP_CURR_FILE_LOGGER_ID) +#define DCHECK_NOTNULL(ptr) DCCHECK_NOTNULL((ptr), ELPP_CURR_FILE_LOGGER_ID) +#define DCHECK_STREQ(str1, str2) DCCHECK_STREQ(str1, str2, ELPP_CURR_FILE_LOGGER_ID) +#define DCHECK_STRNE(str1, str2) DCCHECK_STRNE(str1, str2, ELPP_CURR_FILE_LOGGER_ID) +#define DCHECK_STRCASEEQ(str1, str2) DCCHECK_STRCASEEQ(str1, str2, ELPP_CURR_FILE_LOGGER_ID) +#define DCHECK_STRCASENE(str1, str2) DCCHECK_STRCASENE(str1, str2, ELPP_CURR_FILE_LOGGER_ID) +#define DPCHECK(condition) DCPCHECK(condition, ELPP_CURR_FILE_LOGGER_ID) +#endif // defined(ELPP_NO_CHECK_MACROS) +#if defined(ELPP_DISABLE_DEFAULT_CRASH_HANDLING) +# define ELPP_USE_DEF_CRASH_HANDLER false +#else +# define ELPP_USE_DEF_CRASH_HANDLER true +#endif // defined(ELPP_DISABLE_DEFAULT_CRASH_HANDLING) +#define ELPP_CRASH_HANDLER_INIT +#define ELPP_INIT_EASYLOGGINGPP(val) \ +namespace el { \ +namespace base { \ +el::base::type::StoragePointer elStorage(val); \ +} \ +el::base::debug::CrashHandler elCrashHandler(ELPP_USE_DEF_CRASH_HANDLER); \ +} + +#if ELPP_ASYNC_LOGGING +# define INITIALIZE_EASYLOGGINGPP ELPP_INIT_EASYLOGGINGPP(new el::base::Storage(el::LogBuilderPtr(new el::base::DefaultLogBuilder()),\ +new el::base::AsyncDispatchWorker())) +#else +# define INITIALIZE_EASYLOGGINGPP ELPP_INIT_EASYLOGGINGPP(new el::base::Storage(el::LogBuilderPtr(new el::base::DefaultLogBuilder()))) +#endif // ELPP_ASYNC_LOGGING +#define INITIALIZE_NULL_EASYLOGGINGPP \ +namespace el {\ +namespace base {\ +el::base::type::StoragePointer elStorage;\ +}\ +el::base::debug::CrashHandler elCrashHandler(ELPP_USE_DEF_CRASH_HANDLER);\ +} +#define SHARE_EASYLOGGINGPP(initializedStorage)\ +namespace el {\ +namespace base {\ +el::base::type::StoragePointer elStorage(initializedStorage);\ +}\ +el::base::debug::CrashHandler elCrashHandler(ELPP_USE_DEF_CRASH_HANDLER);\ +} + +#if defined(ELPP_UNICODE) +# define START_EASYLOGGINGPP(argc, argv) el::Helpers::setArgs(argc, argv); std::locale::global(std::locale("")) +#else +# define START_EASYLOGGINGPP(argc, argv) el::Helpers::setArgs(argc, argv) +#endif // defined(ELPP_UNICODE) +#endif // EASYLOGGINGPP_H diff --git a/hexl_v2/include/hexl/eltwise/eltwise-add-mod.hpp b/hexl_v2/include/hexl/eltwise/eltwise-add-mod.hpp new file mode 100644 index 00000000..cb2df110 --- /dev/null +++ b/hexl_v2/include/hexl/eltwise/eltwise-add-mod.hpp @@ -0,0 +1,40 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +namespace intel { +namespace hexl { + +/// @brief Adds two vectors elementwise with modular reduction +/// @param[out] result Stores result +/// @param[in] operand1 Vector of elements to add. Each element must be less +/// than the modulus +/// @param[in] operand2 Vector of elements to add. Each element must be less +/// than the modulus +/// @param[in] n Number of elements in each vector +/// @param[in] modulus Modulus with which to perform modular reduction. Must be +/// in the range \f$[2, 2^{63} - 1]\f$ +/// @details Computes \f$ operand1[i] = (operand1[i] + operand2[i]) \mod modulus +/// \f$ for \f$ i=0, ..., n-1\f$. +void EltwiseAddMod(uint64_t* result, const uint64_t* operand1, + const uint64_t* operand2, uint64_t n, uint64_t modulus); + +/// @brief Adds a vector and scalar elementwise with modular reduction +/// @param[out] result Stores result +/// @param[in] operand1 Vector of elements to add. Each element must be less +/// than the modulus +/// @param[in] operand2 Scalar to add. Must be less +/// than the modulus +/// @param[in] n Number of elements in each vector +/// @param[in] modulus Modulus with which to perform modular reduction. Must be +/// in the range \f$[2, 2^{63} - 1]\f$ +/// @details Computes \f$ operand1[i] = (operand1[i] + operand2) \mod modulus +/// \f$ for \f$ i=0, ..., n-1\f$. +void EltwiseAddMod(uint64_t* result, const uint64_t* operand1, + uint64_t operand2, uint64_t n, uint64_t modulus); + +} // namespace hexl +} // namespace intel diff --git a/hexl_v2/include/hexl/eltwise/eltwise-cmp-add.hpp b/hexl_v2/include/hexl/eltwise/eltwise-cmp-add.hpp new file mode 100644 index 00000000..27e514ff --- /dev/null +++ b/hexl_v2/include/hexl/eltwise/eltwise-cmp-add.hpp @@ -0,0 +1,26 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +#include "hexl/util/util.hpp" + +namespace intel { +namespace hexl { + +/// @brief Computes element-wise conditional addition. +/// @param[out] result Stores the result +/// @param[in] operand1 Vector of elements to compare; stores result +/// @param[in] n Number of elements in \p operand1 +/// @param[in] cmp Comparison operation +/// @param[in] bound Scalar to compare against +/// @param[in] diff Scalar to conditionally add +/// @details Computes result[i] = cmp(operand1[i], bound) ? operand1[i] + +/// diff : operand1[i] for all \f$i=0, ..., n-1\f$. +void EltwiseCmpAdd(uint64_t* result, const uint64_t* operand1, uint64_t n, + CMPINT cmp, uint64_t bound, uint64_t diff); + +} // namespace hexl +} // namespace intel diff --git a/hexl_v2/include/hexl/eltwise/eltwise-cmp-sub-mod.hpp b/hexl_v2/include/hexl/eltwise/eltwise-cmp-sub-mod.hpp new file mode 100644 index 00000000..07ba3d23 --- /dev/null +++ b/hexl_v2/include/hexl/eltwise/eltwise-cmp-sub-mod.hpp @@ -0,0 +1,29 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +#include "hexl/util/util.hpp" + +namespace intel { +namespace hexl { + +/// @brief Computes element-wise conditional modular subtraction. +/// @param[out] result Stores the result +/// @param[in] operand1 Vector of elements to compare +/// @param[in] n Number of elements in \p operand1 +/// @param[in] modulus Modulus to reduce by +/// @param[in] cmp Comparison function +/// @param[in] bound Scalar to compare against +/// @param[in] diff Scalar to subtract by +/// @details Computes \p operand1[i] = (\p cmp(\p operand1, \p bound)) ? (\p +/// operand1 - \p diff) mod \p modulus : \p operand1 mod \p modulus for all i=0, +/// ..., n-1 +void EltwiseCmpSubMod(uint64_t* result, const uint64_t* operand1, uint64_t n, + uint64_t modulus, CMPINT cmp, uint64_t bound, + uint64_t diff); + +} // namespace hexl +} // namespace intel diff --git a/hexl_v2/include/hexl/eltwise/eltwise-fma-mod.hpp b/hexl_v2/include/hexl/eltwise/eltwise-fma-mod.hpp new file mode 100644 index 00000000..03651a42 --- /dev/null +++ b/hexl_v2/include/hexl/eltwise/eltwise-fma-mod.hpp @@ -0,0 +1,27 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +namespace intel { +namespace hexl { + +/// @brief Computes fused multiply-add (\p arg1 * \p arg2 + \p arg3) mod \p +/// modulus element-wise, broadcasting scalars to vectors. +/// @param[out] result Stores the result +/// @param[in] arg1 Vector to multiply +/// @param[in] arg2 Scalar to multiply +/// @param[in] arg3 Vector to add. Will not add if \p arg3 == nullptr +/// @param[in] n Number of elements in each vector +/// @param[in] modulus Modulus with which to perform modular reduction. Must be +/// in the range \f$ [2, 2^{61} - 1]\f$ +/// @param[in] input_mod_factor Assumes input elements are in [0, +/// input_mod_factor * modulus). Must be 1, 2, 4, or 8. +void EltwiseFMAMod(uint64_t* result, const uint64_t* arg1, uint64_t arg2, + const uint64_t* arg3, uint64_t n, uint64_t modulus, + uint64_t input_mod_factor); + +} // namespace hexl +} // namespace intel diff --git a/hexl_v2/include/hexl/eltwise/eltwise-mult-mod.hpp b/hexl_v2/include/hexl/eltwise/eltwise-mult-mod.hpp new file mode 100644 index 00000000..e4d2dbd7 --- /dev/null +++ b/hexl_v2/include/hexl/eltwise/eltwise-mult-mod.hpp @@ -0,0 +1,28 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +namespace intel { +namespace hexl { + +/// @brief Multiplies two vectors elementwise with modular reduction +/// @param[in] result Result of element-wise multiplication +/// @param[in] operand1 Vector of elements to multiply. Each element must be +/// less than the modulus. +/// @param[in] operand2 Vector of elements to multiply. Each element must be +/// less than the modulus. +/// @param[in] n Number of elements in each vector +/// @param[in] modulus Modulus with which to perform modular reduction +/// @param[in] input_mod_factor Assumes input elements are in [0, +/// input_mod_factor * p) Must be 1, 2 or 4. +/// @details Computes \p result[i] = (\p operand1[i] * \p operand2[i]) mod \p +/// modulus for i=0, ..., \p n - 1 +void EltwiseMultMod(uint64_t* result, const uint64_t* operand1, + const uint64_t* operand2, uint64_t n, uint64_t modulus, + uint64_t input_mod_factor); + +} // namespace hexl +} // namespace intel diff --git a/hexl_v2/include/hexl/eltwise/eltwise-reduce-mod.hpp b/hexl_v2/include/hexl/eltwise/eltwise-reduce-mod.hpp new file mode 100644 index 00000000..c23abde2 --- /dev/null +++ b/hexl_v2/include/hexl/eltwise/eltwise-reduce-mod.hpp @@ -0,0 +1,29 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +namespace intel { +namespace hexl { + +/// @brief Performs elementwise modular reduction +/// @param[out] result Stores the result +/// @param[in] operand Data on which to compute the elementwise modular +/// reduction +/// @param[in] n Number of elements in operand +/// @param[in] modulus Modulus with which to perform modular reduction +/// @param[in] input_mod_factor Assumes input elements are in [0, +/// input_mod_factor * p) Must be modulus, 1, 2 or 4. input_mod_factor=modulus +/// means, input range is [0, p * p]. Barrett reduction will be used in this +/// case. input_mod_factor > output_mod_factor +/// @param[in] output_mod_factor output elements will be in [0, +/// output_mod_factor * modulus) Must be 1 or 2. For input_mod_factor=0, +/// output_mod_factor will be set to 1. +void EltwiseReduceMod(uint64_t* result, const uint64_t* operand, uint64_t n, + uint64_t modulus, uint64_t input_mod_factor, + uint64_t output_mod_factor); + +} // namespace hexl +} // namespace intel diff --git a/hexl_v2/include/hexl/eltwise/eltwise-sub-mod.hpp b/hexl_v2/include/hexl/eltwise/eltwise-sub-mod.hpp new file mode 100644 index 00000000..bd286e47 --- /dev/null +++ b/hexl_v2/include/hexl/eltwise/eltwise-sub-mod.hpp @@ -0,0 +1,40 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +namespace intel { +namespace hexl { + +/// @brief Subtracts two vectors elementwise with modular reduction +/// @param[out] result Stores result +/// @param[in] operand1 Vector of elements to subtract from. Each element must +/// be less than the modulus +/// @param[in] operand2 Vector of elements to subtract. Each element must be +/// less than the modulus +/// @param[in] n Number of elements in each vector +/// @param[in] modulus Modulus with which to perform modular reduction. Must be +/// in the range \f$[2, 2^{63} - 1]\f$ +/// @details Computes \f$ operand1[i] = (operand1[i] - operand2[i]) \mod modulus +/// \f$ for \f$ i=0, ..., n-1\f$. +void EltwiseSubMod(uint64_t* result, const uint64_t* operand1, + const uint64_t* operand2, uint64_t n, uint64_t modulus); + +/// @brief Subtracts a scalar from a vector elementwise with modular reduction +/// @param[out] result Stores result +/// @param[in] operand1 Vector of elements to subtract from. Each element must +/// be less than the modulus +/// @param[in] operand2 Elements to subtract. Each element must be +/// less than the modulus +/// @param[in] n Number of elements in each vector +/// @param[in] modulus Modulus with which to perform modular reduction. Must be +/// in the range \f$[2, 2^{63} - 1]\f$ +/// @details Computes \f$ operand1[i] = (operand1[i] - operand2) \mod modulus +/// \f$ for \f$ i=0, ..., n-1\f$. +void EltwiseSubMod(uint64_t* result, const uint64_t* operand1, + uint64_t operand2, uint64_t n, uint64_t modulus); + +} // namespace hexl +} // namespace intel diff --git a/hexl_v2/include/hexl/experimental/fft-like/fft-like-avx512-util.hpp b/hexl_v2/include/hexl/experimental/fft-like/fft-like-avx512-util.hpp new file mode 100644 index 00000000..28a2dddf --- /dev/null +++ b/hexl_v2/include/hexl/experimental/fft-like/fft-like-avx512-util.hpp @@ -0,0 +1,402 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "util/avx512-util.hpp" + +namespace intel { +namespace hexl { + +#ifdef HEXL_HAS_AVX512DQ + +// ************************************ T1 ************************************ + +// ComplexLoadFwdInterleavedT1: +// Assumes ComplexLoadFwdInterleavedT2 was used before. +// Given input: 15, 14, 11, 10, 7, 6, 3, 2, 13, 12, 9, 8, 5, 4, 1, 0 +// Returns +// *out1 = (14, 12, 10, 8, 6, 4, 2, 0); +// *out2 = (15, 13, 11, 9, 7, 5, 3, 1); +// +// Given output: 15, 13, 11, 9, 7, 5, 3, 1, 14, 12, 10, 8, 6, 4, 2, 0 +inline void ComplexLoadFwdInterleavedT1(const double* arg, __m512d* out1, + __m512d* out2) { + const __m512i vperm_idx = _mm512_set_epi64(6, 7, 4, 5, 2, 3, 0, 1); + + const __m512d* arg_512 = reinterpret_cast(arg); + + // 13 12 9 8 5 4 1 0 + __m512d v_7to0 = _mm512_loadu_pd(arg_512); + arg_512 += 2; + // 15 14 11 10 7 6 3 2 + __m512d v_15to8 = _mm512_loadu_pd(arg_512); + + // 12, 13, 8, 9, 4, 5, 0, 1 + __m512d perm_1 = _mm512_permutexvar_pd(vperm_idx, v_7to0); + // 14, 15, 10, 11, 6, 7, 2, 3 + __m512d perm_2 = _mm512_permutexvar_pd(vperm_idx, v_15to8); + + // 14, 12, 10, 8, 6, 4, 2, 0 + *out1 = _mm512_mask_blend_pd(0xaa, v_7to0, perm_2); + // 15, 13, 11, 9, 7, 5, 3, 1 + *out2 = _mm512_mask_blend_pd(0x55, v_15to8, perm_1); +} + +// ComplexWriteFwdInterleavedT1: +// Assumes ComplexLoadFwdInterleavedT1 was used before. +// Given inputs: +// 15i, 13i, 11i, 9i, 7i, 5i, 3i, 1i, 15r, 13r, 11r, 9r, 7r, 5r, 3r, 1r, +// 14i, 12i, 10i, 8i, 6i, 4i, 2i, 0i, 14r, 12r, 10r, 8r, 6r, 4r, 2r, 0r +// As seen with internal indexes: +// @param arg_yr = (15r, 14r, 13r, 12r, 11r, 10r, 9r, 8r); +// @param arg_xr = ( 7r, 6r, 5r, 4r, 3r, 2r, 1r, 0r); +// @param arg_yi = (15i, 14i, 13i, 12i, 11i, 10i, 9i, 8i); +// @param arg_xi = ( 7i, 6i, 5i, 4i, 3i, 2i, 1i, 0i); +// Writes out = +// {15i, 15r, 7i, 7r, 14i, 14r, 6i, 6r, 13i, 13r, 5i, 5r, 12i, 12r, 4i, 4r, +// 11i, 11r, 3i, 3r, 10i, 10r, 2i, 2r, 9i, 9r, 1i, 1r, 8i, 8r, 0i, 0r} +// +// Given output: +// 15i, 15r, 14i, 14r, 13i, 13r, 12i, 12r, 11i, 11r, 10i, 10r, 9i, 9r, 8i, 8r, +// 7i, 7r, 6i, 6r, 5i, 5r, 4i, 4r, 3i, 3r, 2i, 2r, 1i, 1r, 0i, 0r +inline void ComplexWriteFwdInterleavedT1(__m512d arg_xr, __m512d arg_yr, + __m512d arg_xi, __m512d arg_yi, + __m512d* out) { + const __m512i vperm_4hi_4lo_idx = _mm512_set_epi64(3, 2, 1, 0, 7, 6, 5, 4); + const __m512i v_X_out_idx = _mm512_set_epi64(3, 1, 7, 5, 2, 0, 6, 4); + const __m512i v_Y_out_idx = _mm512_set_epi64(7, 5, 3, 1, 6, 4, 2, 0); + + // Real part + // in: 14r, 12r, 10r, 8r, 6r, 4r, 2r, 0r + // -> 6r, 4r, 2r, 0r, 14r, 12r, 10r, 8r + arg_xr = _mm512_permutexvar_pd(vperm_4hi_4lo_idx, arg_xr); + + // arg_yr: 15r, 13r, 11r, 9r, 7r, 5r, 3r, 1r + // -> 6r, 4r, 2r, 0r, 7r, 5r, 3r, 1r + __m512d perm_1 = _mm512_mask_blend_pd(0x0f, arg_xr, arg_yr); + // -> 15r, 13r, 11r, 9r, 14r, 12r, 10r, 8r + __m512d perm_2 = _mm512_mask_blend_pd(0xf0, arg_xr, arg_yr); + + // 7r, 3r, 6r, 2r, 5r, 1r, 4r, 0r + arg_xr = _mm512_permutexvar_pd(v_X_out_idx, perm_1); + // 15r, 11r, 14r, 10r, 13r, 9r, 12r, 8r + arg_yr = _mm512_permutexvar_pd(v_Y_out_idx, perm_2); + + // Imaginary part + // in: 14i, 12i, 10i, 8i, 6i, 4i, 2i, 0i + // -> 6i, 4i, 2i, 0i, 14i, 12i, 10i, 8i + arg_xi = _mm512_permutexvar_pd(vperm_4hi_4lo_idx, arg_xi); + + // arg_yr: 15i, 13i, 11i, 9i, 7i, 5i, 3i, 1i + // -> 6i, 4i, 2i, 0i, 7i, 5i, 3i, 1i + perm_1 = _mm512_mask_blend_pd(0x0f, arg_xi, arg_yi); + // -> 15i, 13i, 11i, 9i, 14i, 12i, 10i, 8i + perm_2 = _mm512_mask_blend_pd(0xf0, arg_xi, arg_yi); + + // 7i, 3i, 6i, 2i, 5i, 1i, 4i, 0i + arg_xi = _mm512_permutexvar_pd(v_X_out_idx, perm_1); + // 15i, 11i, 14i, 10i, 13i, 9i, 12i, 8i + arg_yi = _mm512_permutexvar_pd(v_Y_out_idx, perm_2); + + // Merge + // 00000000 > 3i 3r 2i 2r 1i 1r 0i 0r + __m512d out1 = _mm512_shuffle_pd(arg_xr, arg_xi, 0x00); + // 11111111 > 7i 7r 6i 6r 5i 5r 4i 4r + __m512d out2 = _mm512_shuffle_pd(arg_xr, arg_xi, 0xff); + + // 00000000 > 11i 11r 10i 10r 9i 9r 8i 8r + __m512d out3 = _mm512_shuffle_pd(arg_yr, arg_yi, 0x00); + // 11111111 > 15i 15r 14i 14r 13i 13r 12i 12r + __m512d out4 = _mm512_shuffle_pd(arg_yr, arg_yi, 0xff); + + _mm512_storeu_pd(out++, out1); + _mm512_storeu_pd(out++, out2); + _mm512_storeu_pd(out++, out3); + _mm512_storeu_pd(out++, out4); +} + +// ComplexLoadInvInterleavedT1: +// Given input: 15i 15r 14i 14r 13i 13r 12i 12r 11i 11r 10i 10r 9i 9r 8i 8r +// 7i 7r 6i 6r 5i 5r 4i 4r 3i 3r 2i 2r 1i 1r 0i 0r +// Returns +// *out1_r = (14r, 10r, 6r, 2r, 12r, 8r, 4r, 0r); +// *out1_i = (14i, 10i, 6i, 2i, 12i, 8i, 4i, 0i); +// *out2_r = (15r, 11r, 7r, 3r, 13r, 9r, 5r, 1r); +// *out2_i = (15i, 11i, 7i, 3i, 13i, 9i, 5i, 1i); +// +// Given output: +// 15i, 11i, 7i, 3i, 13i, 9i, 5i, 1i, 15r, 11r, 7r, 3r, 13r, 9r, 5r, 1r, +// 14i, 10i, 6i, 2i, 12i, 8i, 4i, 0i, 14r, 10r, 6r, 2r, 12r, 8r, 4r, 0r +inline void ComplexLoadInvInterleavedT1(const double* arg, __m512d* out1_r, + __m512d* out1_i, __m512d* out2_r, + __m512d* out2_i) { + const __m512d* arg_512 = reinterpret_cast(arg); + + // 3i 3r 2i 2r 1i 1r 0i 0r + __m512d v_3to0 = _mm512_loadu_pd(arg_512++); + // 7i 7r 6i 6r 5i 5r 4i 4r + __m512d v_7to4 = _mm512_loadu_pd(arg_512++); + // 11i 11r 10i 10r 9i 9r 8i 8r + __m512d v_11to8 = _mm512_loadu_pd(arg_512++); + // 15i 15r 14i 14r 13i 13r 12i 12r + __m512d v_15to12 = _mm512_loadu_pd(arg_512++); + + // 00000000 > 7r 3r 6r 2r 5r 1r 4r 0r + __m512d v_7to0_r = _mm512_shuffle_pd(v_3to0, v_7to4, 0x00); + // 11111111 > 7i 3i 6i 2i 5i 1i 4i 0i + __m512d v_7to0_i = _mm512_shuffle_pd(v_3to0, v_7to4, 0xff); + // 00000000 > 15r 11r 14r 10r 13r 9r 12r 8r + __m512d v_15to8_r = _mm512_shuffle_pd(v_11to8, v_15to12, 0x00); + // 11111111 > 15i 11i 14i 10i 13i 9i 12i 8i + __m512d v_15to8_i = _mm512_shuffle_pd(v_11to8, v_15to12, 0xff); + + // real + const __m512i v1_perm_idx = _mm512_set_epi64(5, 4, 7, 6, 1, 0, 3, 2); + // 6 2 7 3 4 0 5 1 + __m512d v1r = _mm512_permutexvar_pd(v1_perm_idx, v_7to0_r); + // 14 10 15 11 12 8 13 9 + __m512d v2r = _mm512_permutexvar_pd(v1_perm_idx, v_15to8_r); + // 11001100 > 14 10 6 2 12 8 4 0 + *out1_r = _mm512_mask_blend_pd(0xcc, v_7to0_r, v2r); + // 11001100 > 15 11 7 3 13 9 5 1 + *out2_r = _mm512_mask_blend_pd(0xcc, v1r, v_15to8_r); + + // imag + // 6 2 7 3 4 0 5 1 + __m512d v1i = _mm512_permutexvar_pd(v1_perm_idx, v_7to0_i); + // 14 10 15 11 12 8 13 9 + __m512d v2i = _mm512_permutexvar_pd(v1_perm_idx, v_15to8_i); + // 11001100 > 14 10 6 2 12 8 4 0 + *out1_i = _mm512_mask_blend_pd(0xcc, v_7to0_i, v2i); + // 11001100 > 15 11 7 3 13 9 5 1 + *out2_i = _mm512_mask_blend_pd(0xcc, v1i, v_15to8_i); +} + +// ************************************ T2 ************************************ + +// ComplexLoadFwdInterleavedT2: +// Assumes ComplexLoadFwdInterleavedT4 was used before. +// Given input: 15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0 +// Returns +// *out1 = (13, 12, 9, 8, 5, 4, 1, 0) +// *out2 = (15, 14, 11, 10, 7, 6, 3, 2) +// +// Given output: 15, 14, 11, 10, 7, 6, 3, 2, 13, 12, 9, 8, 5, 4, 1, 0 +inline void ComplexLoadFwdInterleavedT2(const double* arg, __m512d* out1, + __m512d* out2) { + const __m512d* arg_512 = reinterpret_cast(arg); + + // Values were swapped in T4 + // 11, 10, 9, 8, 3, 2, 1, 0 + __m512d v1 = _mm512_loadu_pd(arg_512); + arg_512 += 2; + // 15, 14, 13, 12, 7, 6, 5, 4 + __m512d v2 = _mm512_loadu_pd(arg_512); + + const __m512i v1_perm_idx = _mm512_set_epi64(5, 4, 7, 6, 1, 0, 3, 2); + + __m512d v1_perm = _mm512_permutexvar_pd(v1_perm_idx, v1); + __m512d v2_perm = _mm512_permutexvar_pd(v1_perm_idx, v2); + + *out1 = _mm512_mask_blend_pd(0xcc, v1, v2_perm); + *out2 = _mm512_mask_blend_pd(0xcc, v1_perm, v2); +} + +// ComplexLoadInvInterleavedT2: +// Assumes ComplexLoadInvInterleavedT1 was used before. +// Given input: 15, 11, 7, 3, 13, 9, 5, 1, 14, 10, 6, 2, 12, 8, 4, 0 +// Returns +// *out1 = (13, 9, 5, 1, 12, 8, 4, 0) +// *out2 = (15, 11, 7, 3, 14, 10, 6, 2) +// +// Given output: 15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0 +inline void ComplexLoadInvInterleavedT2(const double* arg, __m512d* out1, + __m512d* out2) { + const __m512d* arg_512 = reinterpret_cast(arg); + + // 14 10 6 2 12 8 4 0 + __m512d v1 = _mm512_loadu_pd(arg_512); + arg_512 += 2; + // 15 11 7 3 13 9 5 1 + __m512d v2 = _mm512_loadu_pd(arg_512); + + const __m512i v1_perm_idx = _mm512_set_epi64(3, 2, 1, 0, 7, 6, 5, 4); + + // 12 8 4 0 14 10 6 2 + __m512d v1_perm = _mm512_permutexvar_pd(v1_perm_idx, v1); + // 13 9 5 1 15 11 7 3 + __m512d v2_perm = _mm512_permutexvar_pd(v1_perm_idx, v2); + + // 11110000 > 13 9 5 1 12 8 4 0 + *out1 = _mm512_mask_blend_pd(0xf0, v1, v2_perm); + // 11110000 > 15 11 7 3 14 10 6 2 + *out2 = _mm512_mask_blend_pd(0xf0, v1_perm, v2); +} + +// ************************************ T4 ************************************ + +// Complex LoadFwdInterleavedT4: +// Given input: 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0 +// Returns +// *out1 = (11, 10, 9, 8, 3, 2, 1, 0) +// *out2 = (15, 14, 13, 12, 7, 6, 5, 4) +// +// Given output: 15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0 +inline void ComplexLoadFwdInterleavedT4(const double* arg, __m512d* out1, + __m512d* out2) { + const __m512d* arg_512 = reinterpret_cast(arg); + + const __m512i vperm2_idx = _mm512_set_epi64(3, 2, 1, 0, 7, 6, 5, 4); + __m512d v_7to0 = _mm512_loadu_pd(arg_512); + arg_512 += 2; + __m512d v_15to8 = _mm512_loadu_pd(arg_512); + __m512d perm_hi = _mm512_permutexvar_pd(vperm2_idx, v_15to8); + *out1 = _mm512_mask_blend_pd(0x0f, perm_hi, v_7to0); + *out2 = _mm512_mask_blend_pd(0xf0, perm_hi, v_7to0); + *out2 = _mm512_permutexvar_pd(vperm2_idx, *out2); +} + +// ComplexLoadInvInterleavedT4: +// Assumes ComplexLoadInvInterleavedT2 was used before. +// Given input: 15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0 +// Returns +// *out1 = (11, 9, 3, 1, 10, 8, 2, 0) +// *out2 = (15, 13, 7, 5, 14, 12, 6, 4) +// +// Given output: 15, 13, 7, 5, 14, 12, 6, 4, 11, 9, 3, 1, 10, 8, 2, 0 + +inline void ComplexLoadInvInterleavedT4(const double* arg, __m512d* out1, + __m512d* out2) { + const __m512d* arg_512 = reinterpret_cast(arg); + + // 13, 9, 5, 1, 12, 8, 4, 0 + __m512d v1 = _mm512_loadu_pd(arg_512); + arg_512 += 2; + // 15, 11, 7, 3, 14, 10, 6, 2 + __m512d v2 = _mm512_loadu_pd(arg_512); + + // 00000000 > 11 9 3 1 10 8 2 0 + *out1 = _mm512_shuffle_pd(v1, v2, 0x00); + // 11111111 > 15 13 7 5 14 12 6 4 + *out2 = _mm512_shuffle_pd(v1, v2, 0xff); +} + +// ComplexWriteInvInterleavedT4: +// Assuming ComplexLoadInvInterleavedT4 was used before. +// Given inputs: 15, 13, 7, 5, 14, 12, 6, 4, 11, 9, 3, 1, 10, 8, 2, 0 +// Seen Internally: +// @param arg1 = ( 7, 6, 5, 4, 3, 2, 1, 0); +// @param arg2 = (15, 14, 13, 12, 11, 10, 9, 8); +// Writes out = {15, 11, 14, 10, 7, 3, 6, 2, +// 13, 9, 12, 8, 5, 1, 4, 0} +// +// Given output: 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0 +inline void ComplexWriteInvInterleavedT4(__m512d arg1, __m512d arg2, + __m512d* out) { + const __m512i vperm_4hi_4lo_idx = _mm512_set_epi64(3, 2, 1, 0, 7, 6, 5, 4); + const __m512i vperm1 = _mm512_set_epi64(7, 3, 6, 2, 5, 1, 4, 0); + const __m512i vperm2 = _mm512_set_epi64(5, 1, 4, 0, 7, 3, 6, 2); + + // in: 11 9 3 1 10 8 2 0 + // -> 11 10 9 8 3 2 1 0 + arg1 = _mm512_permutexvar_pd(vperm1, arg1); + // in: 15 13 7 5 14 12 6 4 + // -> 7 6 5 4 15 14 13 12 + arg2 = _mm512_permutexvar_pd(vperm2, arg2); + + // 7 6 5 4 3 2 1 0 + __m512d out1 = _mm512_mask_blend_pd(0xf0, arg1, arg2); + // 11 10 9 8 15 14 13 12 + __m512d out2 = _mm512_mask_blend_pd(0x0f, arg1, arg2); + // 15 14 13 12 11 10 9 8 + out2 = _mm512_permutexvar_pd(vperm_4hi_4lo_idx, out2); + + _mm512_storeu_pd(out, out1); + out += 2; + _mm512_storeu_pd(out, out2); +} + +// ************************************ T8 ************************************ + +// ComplexLoadFwdInterleavedT8: +// Given inputs: 7i, 7r, 6i, 6r, 5i, 5r, 4i, 4r, 3i, 3r, 2i, 2r, 1i, 1r, 0i, 0r +// Seen Internally: +// v_X1 = ( 7, 6, 5, 4, 3, 2, 1, 0); +// v_X2 = (15, 14, 13, 12, 11, 10, 9, 8); +// Writes out = {15, 13, 11, 9, 7, 5, 3, 1, +// 14, 12, 10, 8, 6, 4, 2, 0} +// +// Given output: 7i, 6i, 5i, 4i, 3i, 2i, 1i, 0i, 7r, 6r, 5r, 4r, 3r, 2r, 1r, 0r +inline void ComplexLoadFwdInterleavedT8(const __m512d* arg_x, + const __m512d* arg_y, __m512d* out1_r, + __m512d* out1_i, __m512d* out2_r, + __m512d* out2_i) { + const __m512i v_perm_idx = _mm512_set_epi64(7, 5, 3, 1, 6, 4, 2, 0); + + // 3i, 3r, 2i, 2r, 1i, 1r, 0i, 0r + __m512d v_X1 = _mm512_loadu_pd(arg_x++); + // 7i, 7r, 6i, 6r, 5i, 5r, 4i, 4r + __m512d v_X2 = _mm512_loadu_pd(arg_x); + // 7r, 3r, 6r, 2r, 5r, 1r, 4r, 0r + *out1_r = _mm512_shuffle_pd(v_X1, v_X2, 0x00); + // 7i, 3i, 6i, 2i, 5i, 1i, 4i, 0i + *out1_i = _mm512_shuffle_pd(v_X1, v_X2, 0xff); + // 7r, 6r, 5r, 4r, 3r, 2r, 1r, 0r + *out1_r = _mm512_permutexvar_pd(v_perm_idx, *out1_r); + // 7i, 6i, 5i, 4i, 3i, 2i, 1i, 0i + *out1_i = _mm512_permutexvar_pd(v_perm_idx, *out1_i); + + __m512d v_Y1 = _mm512_loadu_pd(arg_y++); + __m512d v_Y2 = _mm512_loadu_pd(arg_y); + *out2_r = _mm512_shuffle_pd(v_Y1, v_Y2, 0x00); + *out2_i = _mm512_shuffle_pd(v_Y1, v_Y2, 0xff); + *out2_r = _mm512_permutexvar_pd(v_perm_idx, *out2_r); + *out2_i = _mm512_permutexvar_pd(v_perm_idx, *out2_i); +} + +// ComplexWriteInvInterleavedT8: +// Assuming ComplexLoadInvInterleavedT4 was used before. +// Given inputs: 7i, 6i, 5i, 4i, 3i, 2i, 1i, 0i, 7r, 6r, 5r, 4r, 3r, 2r, 1r, 0r +// Seen Internally: +// @param arg1 = ( 7, 6, 5, 4, 3, 2, 1, 0); +// @param arg2 = (15, 14, 13, 12, 11, 10, 9, 8); +// Writes out = {15, 7, 14, 6, 13, 5, 12, 4, +// 11, 3, 10, 2, 9, 1, 8, 0} +// +// Given output: 7i, 7r, 6i, 6r, 5i, 5r, 4i, 4r, 3i, 3r, 2i, 2r, 1i, 1r, 0i, 0r +inline void ComplexWriteInvInterleavedT8(__m512d* v_X_real, __m512d* v_X_imag, + __m512d* v_Y_real, __m512d* v_Y_imag, + __m512d* v_X_pt, __m512d* v_Y_pt) { + const __m512i vperm = _mm512_set_epi64(7, 3, 6, 2, 5, 1, 4, 0); + // in: 7r 6r 5r 4r 3r 2r 1r 0r + // -> 7r 3r 6r 2r 5r 1r 4r 0r + *v_X_real = _mm512_permutexvar_pd(vperm, *v_X_real); + // in: 7i 6i 5i 4i 3i 2i 1i 0i + // -> 7i 3i 6i 2i 5i 1i 4i 0i + *v_X_imag = _mm512_permutexvar_pd(vperm, *v_X_imag); + // in: 15r 14r 13r 12r 11r 10r 9r 8r + // -> 15r 11r 14r 10r 13r 9r 12r 8r + *v_Y_real = _mm512_permutexvar_pd(vperm, *v_Y_real); + // in: 15i 14i 13i 12i 11i 10i 9i 8i + // -> 15i 11i 14i 10i 13i 9i 12i 8i + *v_Y_imag = _mm512_permutexvar_pd(vperm, *v_Y_imag); + + // 00000000 > 3i 3r 2i 2r 1i 1r 0i 0r + __m512d v_X1 = _mm512_shuffle_pd(*v_X_real, *v_X_imag, 0x00); + // 11111111 > 7i 7r 6i 6r 5i 5r 4i 4r + __m512d v_X2 = _mm512_shuffle_pd(*v_X_real, *v_X_imag, 0xff); + // 00000000 > 11i 11r 10i 10r 9i 9r 8i 8r + __m512d v_Y1 = _mm512_shuffle_pd(*v_Y_real, *v_Y_imag, 0x00); + // 11111111 > 15i 15r 14i 14r 13i 13r 12i 12r + __m512d v_Y2 = _mm512_shuffle_pd(*v_Y_real, *v_Y_imag, 0xff); + + _mm512_storeu_pd(v_X_pt++, v_X1); + _mm512_storeu_pd(v_X_pt, v_X2); + _mm512_storeu_pd(v_Y_pt++, v_Y1); + _mm512_storeu_pd(v_Y_pt, v_Y2); +} +#endif // HEXL_HAS_AVX512DQ + +} // namespace hexl +} // namespace intel diff --git a/hexl_v2/include/hexl/experimental/fft-like/fft-like-native.hpp b/hexl_v2/include/hexl/experimental/fft-like/fft-like-native.hpp new file mode 100644 index 00000000..7e02492d --- /dev/null +++ b/hexl_v2/include/hexl/experimental/fft-like/fft-like-native.hpp @@ -0,0 +1,38 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +namespace intel { +namespace hexl { + +/// @brief Radix-2 native C++ FFT like implementation of the forward FFT like +/// @param[out] result Output data. Overwritten with FFT like output +/// @param[in] operand Input data. +/// @param[in] n Size of the transform, i.e. the polynomial degree. Must be a +/// power of two. +/// @param[in] root_of_unity_powers Powers of 2n'th root of unity. In +/// bit-reversed order +/// @param[in] scale Scale applied to output data +void Forward_FFTLike_ToBitReverseRadix2( + std::complex* result, const std::complex* operand, + const std::complex* root_of_unity_powers, const uint64_t n, + const double* scale = nullptr); + +/// @brief Radix-2 native C++ FFT like implementation of the inverse FFT like +/// @param[out] result Output data. Overwritten with FFT like output +/// @param[in] operand Input data. +/// @param[in] n Size of the transform, i.e. the polynomial degree. Must be a +/// power of two. +/// @param[in] inv_root_of_unity_powers Powers of inverse 2n'th root of unity. +/// In bit-reversed order. +/// @param[in] scale Scale applied to output data +void Inverse_FFTLike_FromBitReverseRadix2( + std::complex* result, const std::complex* operand, + const std::complex* inv_root_of_unity_powers, const uint64_t n, + const double* scale = nullptr); + +} // namespace hexl +} // namespace intel diff --git a/hexl_v2/include/hexl/experimental/fft-like/fft-like.hpp b/hexl_v2/include/hexl/experimental/fft-like/fft-like.hpp new file mode 100644 index 00000000..334de246 --- /dev/null +++ b/hexl_v2/include/hexl/experimental/fft-like/fft-like.hpp @@ -0,0 +1,147 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +#include "hexl/experimental/fft-like/fwd-fft-like-avx512.hpp" +#include "hexl/experimental/fft-like/inv-fft-like-avx512.hpp" +#include "hexl/util/aligned-allocator.hpp" +#include "hexl/util/allocator.hpp" + +namespace intel { +namespace hexl { + +/// @brief Performs linear forward and inverse FFT like transform +/// for CKKS encoding and decoding. +class FFTLike { + public: + /// @brief Helper class for custom memory allocation + template + struct AllocatorAdapter + : public AllocatorInterface> { + explicit AllocatorAdapter(Adaptee&& _a, Args&&... args); + AllocatorAdapter(const Adaptee& _a, Args&... args); + + // interface implementation + void* allocate_impl(size_t bytes_count); + void deallocate_impl(void* p, size_t n); + + private: + Adaptee alloc; + }; + + /// @brief Initializes an empty CKKS_FTT object + FFTLike() = default; + + /// @brief Destructs the CKKS_FTT object + ~FFTLike() = default; + + /// @brief Initializes an FFTLike object with degree \p degree and scalar + /// \p in_scalar. + /// @param[in] degree also known as N. Size of the FFT like transform. Must be + /// a power of 2 + /// @param[in] in_scalar Scalar value to calculate scale and inv scale + /// @param[in] alloc_ptr Custom memory allocator used for intermediate + /// calculations + /// @details Performs pre-computation necessary for forward and inverse + /// transforms + FFTLike(uint64_t degree, double* in_scalar, + std::shared_ptr alloc_ptr = {}); + + template + FFTLike(uint64_t degree, double* in_scalar, Allocator&& a, + AllocatorArgs&&... args) + : FFTLike( + degree, in_scalar, + std::static_pointer_cast( + std::make_shared>( + std::move(a), std::forward(args)...))) {} + + /// @brief Compute forward FFT like. Results are bit-reversed. + /// @param[out] result Stores the result + /// @param[in] operand Data on which to compute the FFT like + /// @param[in] in_scale Scale applied to output values + void ComputeForwardFFTLike(std::complex* result, + const std::complex* operand, + const double* in_scale = nullptr); + + /// @brief Compute inverse FFT like. Results are bit-reversed. + /// @param[out] result Stores the result + /// @param[in] operand Data on which to compute the FFT like + /// @param[in] in_scale Scale applied to output values + void ComputeInverseFFTLike(std::complex* result, + const std::complex* operand, + const double* in_scale = nullptr); + + /// @brief Construct floating-point values from CRT-composed polynomial with + /// integer coefficients. + /// @param[out] res Stores the result + /// @param[in] plain Plaintext + /// @param[in] threshold Upper half threshold with respect to the total + /// coefficient modulus + /// @param[in] decryption_modulus Product of all primes in the coefficient + /// modulus + /// @param[in] inv_scale Scale applied to output values + /// @param[in] mod_size Size of coefficient modulus parameter + /// @param[in] coeff_count Degree of the polynomial modulus parameter + void BuildFloatingPoints(std::complex* res, const uint64_t* plain, + const uint64_t* threshold, + const uint64_t* decryption_modulus, + const double inv_scale, size_t mod_size, + size_t coeff_count); + + /// @brief Returns the root of unity power at bit-reversed index i. + /// @param[in] i Index + std::complex GetComplexRootOfUnity(size_t i) { + return GetComplexRootsOfUnity()[i]; + } + + /// @brief Returns the root of unity in bit-reversed order + const AlignedVector64>& GetComplexRootsOfUnity() const { + return m_complex_roots_of_unity; + } + + /// @brief Returns the root of unity power at bit-reversed index i. + /// @param[in] i Index + std::complex GetInvComplexRootOfUnity(size_t i) { + return GetInvComplexRootsOfUnity()[i]; + } + + /// @brief Returns the inverse root of unity in bit-reversed order + const AlignedVector64>& GetInvComplexRootsOfUnity() + const { + return m_inv_complex_roots_of_unity; + } + + /// @brief Returns the degree N + uint64_t GetDegree() const { return m_degree; } + + private: + // Computes 1~(n-1)-th powers and inv powers of the primitive 2n-th root + void ComputeComplexRootsOfUnity(); + + uint64_t m_degree; // N: size of FFT like transform, should be power of 2 + + double* scalar; // Pointer to scalar used for scale/inv_scale calculation + + double scale; // Scale value use for encoding (inv fft-like) + + double inv_scale; // Scale value use in decoding (fwd fft-like) + + std::shared_ptr m_alloc; + + AlignedAllocator m_aligned_alloc; + + uint64_t m_degree_bits; // log_2(m_degree) + + // Contains 0~(n-1)-th powers of the 2n-th primitive root. + AlignedVector64> m_complex_roots_of_unity; + + // Contains 0~(n-1)-th inv powers of the 2n-th primitive inv root. + AlignedVector64> m_inv_complex_roots_of_unity; +}; + +} // namespace hexl +} // namespace intel diff --git a/hexl_v2/include/hexl/experimental/fft-like/fwd-fft-like-avx512.hpp b/hexl_v2/include/hexl/experimental/fft-like/fwd-fft-like-avx512.hpp new file mode 100644 index 00000000..aba4ca4d --- /dev/null +++ b/hexl_v2/include/hexl/experimental/fft-like/fwd-fft-like-avx512.hpp @@ -0,0 +1,59 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "hexl/experimental/fft-like/fft-like.hpp" + +namespace intel { +namespace hexl { + +#ifdef HEXL_HAS_AVX512DQ + +/// @brief AVX512 implementation of the forward FFT like +/// @param[out] result_cmplx_intrlvd Output data. Overwritten with FFT like +/// output. Result is a vector of double with interleaved real and imaginary +/// numbers. +/// @param[in] operand_cmplx_intrlvd Input data. A vector of double with +/// interleaved real and imaginary numbers. +/// @param[in] roots_of_unity_cmplx_intrlvd Powers of 2n'th root of unity. In +/// bit-reversed order. +/// @param[in] n Size of the transform, i.e. the polynomial degree. Must be a +/// power of two. +/// @param[in] scale Scale applied to output values +/// @param[in] recursion_depth Depth of recursive call +/// @param[in] recursion_half Helper for indexing roots of unity +/// @details The implementation is recursive. The base case is a breadth-first +/// FFT like, where all the butterflies in a given stage are processed before +/// any butterflies in the next stage. The base case is small enough to fit in +/// the smallest cache. Larger FFTs are processed recursively in a depth-first +/// manner, such that an entire subtransform is completed before moving to the +/// next subtransform. This reduces the number of cache misses, improving +/// performance on larger transform sizes. +void Forward_FFTLike_ToBitReverseAVX512( + double* result_cmplx_intrlvd, const double* operand_cmplx_intrlvd, + const double* roots_of_unity_cmplx_intrlvd, const uint64_t n, + const double* scale = nullptr, uint64_t recursion_depth = 0, + uint64_t recursion_half = 0); + +/// @brief Construct floating-point values from CRT-composed polynomial with +/// integer coefficients in AVX512. +/// @param[out] res_cmplx_intrlvd Stores the result +/// @param[in] plain Plaintext +/// @param[in] threshold Upper half threshold with respect to the total +/// coefficient modulus +/// @param[in] decryption_modulus Product of all primes in the coefficient +/// modulus +/// @param[in] inv_scale Scale applied to output values +/// @param[in] mod_size Size of coefficient modulus parameter +/// @param[in] coeff_count Degree of the polynomial modulus parameter +void BuildFloatingPointsAVX512(double* res_cmplx_intrlvd, const uint64_t* plain, + const uint64_t* threshold, + const uint64_t* decryption_modulus, + const double inv_scale, const size_t mod_size, + const size_t coeff_count); + +#endif // HEXL_HAS_AVX512DQ + +} // namespace hexl +} // namespace intel diff --git a/hexl_v2/include/hexl/experimental/fft-like/inv-fft-like-avx512.hpp b/hexl_v2/include/hexl/experimental/fft-like/inv-fft-like-avx512.hpp new file mode 100644 index 00000000..487e2828 --- /dev/null +++ b/hexl_v2/include/hexl/experimental/fft-like/inv-fft-like-avx512.hpp @@ -0,0 +1,42 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "hexl/experimental/fft-like/fft-like.hpp" + +namespace intel { +namespace hexl { + +#ifdef HEXL_HAS_AVX512DQ + +/// @brief AVX512 implementation of the inverse FFT like +/// @param[out] result_cmplx_intrlvd Output data. Overwritten with FFT like +/// output. Result is a vector of double with interleaved real and imaginary +/// numbers. +/// @param[in] operand_cmplx_intrlvd Input data. A vector of double with +/// interleaved real and imaginary numbers. +/// @param[in] inv_roots_of_unity_cmplx_intrlvd Powers of 2n'th root of unity. +/// In bit-reversed order. +/// @param[in] n Size of the transform, i.e. the polynomial degree. Must be a +/// power of two. +/// @param[in] scale Scale applied to output values +/// @param[in] recursion_depth Depth of recursive call +/// @param[in] recursion_half Helper for indexing roots of unity +/// @details The implementation is recursive. The base case is a breadth-first +/// FFT like, where all the butterflies in a given stage are processed before +/// any butterflies in the next stage. The base case is small enough to fit in +/// the smallest cache. Larger FFTs are processed recursively in a depth-first +/// manner, such that an entire subtransform is completed before moving to the +/// next subtransform. This reduces the number of cache misses, improving +/// performance on larger transform sizes. +void Inverse_FFTLike_FromBitReverseAVX512( + double* result_cmplx_intrlvd, const double* operand_cmplx_intrlvd, + const double* inv_root_of_unity_cmplxintrlvd, const uint64_t n, + const double* scale = nullptr, uint64_t recursion_depth = 0, + uint64_t recursion_half = 0); + +#endif // HEXL_HAS_AVX512DQ + +} // namespace hexl +} // namespace intel diff --git a/hexl_v2/include/hexl/experimental/misc/lr-mat-vec-mult.hpp b/hexl_v2/include/hexl/experimental/misc/lr-mat-vec-mult.hpp new file mode 100644 index 00000000..df03df92 --- /dev/null +++ b/hexl_v2/include/hexl/experimental/misc/lr-mat-vec-mult.hpp @@ -0,0 +1,31 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +namespace intel { +namespace hexl { + +/// @brief Computes transposed linear regression +/// @param[in,out] result Ciphertext data. Will be over-written with result. Has +/// (3 * n * num_moduli) elements +/// @param[in] operand1 Vector of ciphertext representing a matrix that encodes +/// a transposed logistic regression model. Has (num_weights * 2 * n * +/// num_moduli) elements. +/// @param[in] operand2 Vector of ciphertext representing a matrix that encodes +/// at most n/2 input samples with feature size num_weights. Has (num_weights * +/// 2 * n * num_moduli) elements. +/// @param[in] n Number of coefficients in each polynomial +/// @param[in] moduli Pointer to contiguous array of num_moduli word-sized +/// coefficient moduli +/// @param[in] num_moduli Number of word-sized coefficient moduli +/// @param[in] num_weights Feature size of the linear/logistic regression model +void LinRegMatrixVectorMultiply(uint64_t* result, const uint64_t* operand1, + const uint64_t* operand2, uint64_t n, + const uint64_t* moduli, uint64_t num_moduli, + uint64_t num_weights); + +} // namespace hexl +} // namespace intel diff --git a/hexl_v2/include/hexl/experimental/seal/dyadic-multiply-internal.hpp b/hexl_v2/include/hexl/experimental/seal/dyadic-multiply-internal.hpp new file mode 100644 index 00000000..310a46b0 --- /dev/null +++ b/hexl_v2/include/hexl/experimental/seal/dyadic-multiply-internal.hpp @@ -0,0 +1,29 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +namespace intel { +namespace hexl { +namespace internal { + +/// @brief Computes dyadic multiplication +/// @param[in,out] result Ciphertext data. Will be over-written with result. Has +/// (2 * n * num_moduli) elements +/// @param[in] operand1 First ciphertext argument. Has (2 * n * num_moduli) +/// elements. +/// @param[in] operand2 Second ciphertext argument. Has (2 * n * num_moduli) +/// elements. +/// @param[in] n Number of coefficients in each polynomial +/// @param[in] moduli Pointer to contiguous array of num_moduli word-sized +/// coefficient moduli +/// @param[in] num_moduli Number of word-sized coefficient moduli +void DyadicMultiply(uint64_t* result, const uint64_t* operand1, + const uint64_t* operand2, uint64_t n, + const uint64_t* moduli, uint64_t num_moduli); + +} // namespace internal +} // namespace hexl +} // namespace intel diff --git a/hexl_v2/include/hexl/experimental/seal/dyadic-multiply.hpp b/hexl_v2/include/hexl/experimental/seal/dyadic-multiply.hpp new file mode 100644 index 00000000..f7eacfdf --- /dev/null +++ b/hexl_v2/include/hexl/experimental/seal/dyadic-multiply.hpp @@ -0,0 +1,27 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +namespace intel { +namespace hexl { + +/// @brief Computes dyadic multiplication +/// @param[in,out] result Ciphertext data. Will be over-written with result. Has +/// (2 * n * num_moduli) elements +/// @param[in] operand1 First ciphertext argument. Has (2 * n * num_moduli) +/// elements. +/// @param[in] operand2 Second ciphertext argument. Has (2 * n * num_moduli) +/// elements. +/// @param[in] n Number of coefficients in each polynomial +/// @param[in] moduli Pointer to contiguous array of num_moduli word-sized +/// coefficient moduli +/// @param[in] num_moduli Number of word-sized coefficient moduli +void DyadicMultiply(uint64_t* result, const uint64_t* operand1, + const uint64_t* operand2, uint64_t n, + const uint64_t* moduli, uint64_t num_moduli); + +} // namespace hexl +} // namespace intel diff --git a/hexl_v2/include/hexl/experimental/seal/key-switch-internal.hpp b/hexl_v2/include/hexl/experimental/seal/key-switch-internal.hpp new file mode 100644 index 00000000..8fc9d53e --- /dev/null +++ b/hexl_v2/include/hexl/experimental/seal/key-switch-internal.hpp @@ -0,0 +1,44 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +namespace intel { +namespace hexl { +namespace internal { + +/// @brief Computes key switching in-place +/// @param[in,out] result Ciphertext data. Will be over-written with result. Has +/// (n * decomp_modulus_size * key_component_count) elements +/// @param[in] t_target_iter_ptr Pointer to the last component of the input +/// ciphertext +/// @param[in] n Number of coefficients in each polynomial +/// @param[in] decomp_modulus_size Number of moduli in the ciphertext at its +/// current level, excluding one auxiliary prime. +/// @param[in] key_modulus_size Number of moduli in the ciphertext at its top +/// level, including one auxiliary prime. +/// @param[in] rns_modulus_size Number of moduli in the ciphertext at its +/// current level, including one auxiliary prime. rns_modulus_size == +/// decomp_modulus_size + 1 +/// @param[in] key_component_count Number of components in the resulting +/// ciphertext, e.g. key_component_count == 2. +/// @param[in] moduli Array of word-sized coefficient moduli. There must be +/// key_modulus_size moduli in the array +/// @param[in] k_switch_keys Array of evaluation key data. Has +/// decomp_modulus_size entries, each with +/// coeff_count * ((key_modulus_size - 1)+ (key_component_count - 1) * +/// (key_modulus_size) + 1) entries +/// @param[in] modswitch_factors Array of modulus switch factors +/// @param[in] root_of_unity_powers_ptr Array of root of unity powers +void KeySwitch(uint64_t* result, const uint64_t* t_target_iter_ptr, uint64_t n, + uint64_t decomp_modulus_size, uint64_t key_modulus_size, + uint64_t rns_modulus_size, uint64_t key_component_count, + const uint64_t* moduli, const uint64_t** k_switch_keys, + const uint64_t* modswitch_factors, + const uint64_t* root_of_unity_powers_ptr = nullptr); + +} // namespace internal +} // namespace hexl +} // namespace intel diff --git a/hexl_v2/include/hexl/experimental/seal/key-switch.hpp b/hexl_v2/include/hexl/experimental/seal/key-switch.hpp new file mode 100644 index 00000000..9eda159c --- /dev/null +++ b/hexl_v2/include/hexl/experimental/seal/key-switch.hpp @@ -0,0 +1,42 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +namespace intel { +namespace hexl { + +/// @brief Computes key switching in-place +/// @param[in,out] result Ciphertext data. Will be over-written with result. Has +/// (n * decomp_modulus_size * key_component_count) elements +/// @param[in] t_target_iter_ptr Pointer to the last component of the input +/// ciphertext +/// @param[in] n Number of coefficients in each polynomial +/// @param[in] decomp_modulus_size Number of moduli in the ciphertext at its +/// current level, excluding one auxiliary prime. +/// @param[in] key_modulus_size Number of moduli in the ciphertext at its top +/// level, including one auxiliary prime. +/// @param[in] rns_modulus_size Number of moduli in the ciphertext at its +/// current level, including one auxiliary prime. rns_modulus_size == +/// decomp_modulus_size + 1 +/// @param[in] key_component_count Number of components in the resulting +/// ciphertext, e.g. key_component_count == 2. +/// @param[in] moduli Array of word-sized coefficient moduli. There must be +/// key_modulus_size moduli in the array +/// @param[in] k_switch_keys Array of evaluation key data. Has +/// decomp_modulus_size entries, each with +/// coeff_count * ((key_modulus_size - 1)+ (key_component_count - 1) * +/// (key_modulus_size) + 1) entries +/// @param[in] modswitch_factors Array of modulus switch factors +/// @param[in] root_of_unity_powers_ptr Array of root of unity powers +void KeySwitch(uint64_t* result, const uint64_t* t_target_iter_ptr, uint64_t n, + uint64_t decomp_modulus_size, uint64_t key_modulus_size, + uint64_t rns_modulus_size, uint64_t key_component_count, + const uint64_t* moduli, const uint64_t** k_switch_keys, + const uint64_t* modswitch_factors, + const uint64_t* root_of_unity_powers_ptr = nullptr); + +} // namespace hexl +} // namespace intel diff --git a/hexl_v2/include/hexl/experimental/seal/locks.hpp b/hexl_v2/include/hexl/experimental/seal/locks.hpp new file mode 100644 index 00000000..4595f4e5 --- /dev/null +++ b/hexl_v2/include/hexl/experimental/seal/locks.hpp @@ -0,0 +1,35 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include + +namespace intel { +namespace hexl { + +using Lock = std::shared_mutex; +using WriteLock = std::unique_lock; +using ReadLock = std::shared_lock; + +class RWLock { + public: + RWLock() = default; + inline ReadLock AcquireRead() { return ReadLock(rw_mutex); } + inline WriteLock AcquireWrite() { return WriteLock(rw_mutex); } + inline ReadLock TryAcquireRead() noexcept { + return ReadLock(rw_mutex, std::try_to_lock); + } + inline WriteLock TryAcquireWrite() noexcept { + return WriteLock(rw_mutex, std::try_to_lock); + } + + private: + RWLock(const RWLock& copy) = delete; + RWLock& operator=(const RWLock& assign) = delete; + Lock rw_mutex{}; +}; + +} // namespace hexl +} // namespace intel diff --git a/hexl_v2/include/hexl/experimental/seal/ntt-cache.hpp b/hexl_v2/include/hexl/experimental/seal/ntt-cache.hpp new file mode 100644 index 00000000..8f6c1046 --- /dev/null +++ b/hexl_v2/include/hexl/experimental/seal/ntt-cache.hpp @@ -0,0 +1,56 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "hexl/experimental/seal/locks.hpp" +#include "ntt/ntt-internal.hpp" + +namespace intel { +namespace hexl { + +struct HashPair { + template + std::size_t operator()(const std::pair& p) const { + auto hash1 = std::hash{}(p.first); + auto hash2 = std::hash{}(p.second); + return hash_combine(hash1, hash2); + } + + // Golden Ratio Hashing with seeds + static std::size_t hash_combine(std::size_t lhs, std::size_t rhs) { + lhs ^= rhs + 0x9e3779b9 + (lhs << 6) + (lhs >> 2); + return lhs; + } +}; + +NTT& GetNTT(size_t N, uint64_t modulus) { + static std::unordered_map, NTT, HashPair> + ntt_cache; + static RWLock ntt_cache_locker; + + std::pair key{N, modulus}; + + // Enable shared access to NTT already present + { + ReadLock reader_lock(ntt_cache_locker.AcquireRead()); + auto ntt_it = ntt_cache.find(key); + if (ntt_it != ntt_cache.end()) { + return ntt_it->second; + } + } + + // Deal with NTT not yet present + WriteLock write_lock(ntt_cache_locker.AcquireWrite()); + + // Check ntt_cache for value (may be added by another thread) + auto ntt_it = ntt_cache.find(key); + if (ntt_it == ntt_cache.end()) { + NTT ntt(N, modulus); + ntt_it = ntt_cache.emplace(std::move(key), std::move(ntt)).first; + } + return ntt_it->second; +} + +} // namespace hexl +} // namespace intel diff --git a/hexl_v2/include/hexl/hexl.hpp b/hexl_v2/include/hexl/hexl.hpp new file mode 100644 index 00000000..6f07ae57 --- /dev/null +++ b/hexl_v2/include/hexl/hexl.hpp @@ -0,0 +1,26 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "hexl/eltwise/eltwise-add-mod.hpp" +#include "hexl/eltwise/eltwise-cmp-add.hpp" +#include "hexl/eltwise/eltwise-cmp-sub-mod.hpp" +#include "hexl/eltwise/eltwise-fma-mod.hpp" +#include "hexl/eltwise/eltwise-mult-mod.hpp" +#include "hexl/eltwise/eltwise-reduce-mod.hpp" +#include "hexl/eltwise/eltwise-sub-mod.hpp" +#include "hexl/experimental/fft-like/fft-like.hpp" +#include "hexl/experimental/misc/lr-mat-vec-mult.hpp" +#include "hexl/experimental/seal/dyadic-multiply-internal.hpp" +#include "hexl/experimental/seal/dyadic-multiply.hpp" +#include "hexl/experimental/seal/key-switch-internal.hpp" +#include "hexl/experimental/seal/key-switch.hpp" +#include "hexl/logging/logging.hpp" +#include "hexl/ntt/ntt.hpp" +#include "hexl/number-theory/number-theory.hpp" +#include "hexl/util/check.hpp" +#include "hexl/util/compiler.hpp" +#include "hexl/util/defines.hpp" +#include "hexl/util/types.hpp" +#include "hexl/util/util.hpp" diff --git a/hexl_v2/include/hexl/logging/logging.hpp b/hexl_v2/include/hexl/logging/logging.hpp new file mode 100644 index 00000000..af5bfcd8 --- /dev/null +++ b/hexl_v2/include/hexl/logging/logging.hpp @@ -0,0 +1,43 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include + +#include "hexl/util/defines.hpp" + +// Wrap HEXL_VLOG with HEXL_DEBUG; this ensures no logging overhead in +// release mode +#ifdef HEXL_DEBUG + +// TODO(fboemer) Enable if needed +// #define ELPP_THREAD_SAFE +#define ELPP_CUSTOM_COUT std::cerr +#define ELPP_STL_LOGGING +#define ELPP_LOG_STD_ARRAY +#define ELPP_LOG_UNORDERED_MAP +#define ELPP_LOG_UNORDERED_SET +#define ELPP_NO_LOG_TO_FILE +#define ELPP_DISABLE_DEFAULT_CRASH_HANDLING +#define ELPP_WINSOCK2 + +#include + +#define HEXL_VLOG(N, rest) \ + do { \ + if (VLOG_IS_ON(N)) { \ + VLOG(N) << rest; \ + } \ + } while (0); + +#else + +#define HEXL_VLOG(N, rest) \ + {} + +#define START_EASYLOGGINGPP(X, Y) \ + {} + +#endif diff --git a/hexl_v2/include/hexl/ntt/ntt.hpp b/hexl_v2/include/hexl/ntt/ntt.hpp new file mode 100644 index 00000000..93ccba72 --- /dev/null +++ b/hexl_v2/include/hexl/ntt/ntt.hpp @@ -0,0 +1,296 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +#include +#include + +#include "hexl/util/aligned-allocator.hpp" +#include "hexl/util/allocator.hpp" + +namespace intel { +namespace hexl { + +/// @brief Performs negacyclic forward and inverse number-theoretic transform +/// (NTT), commonly used in RLWE cryptography. +/// @details The number-theoretic transform (NTT) specializes the discrete +/// Fourier transform (DFT) to the finite field \f$ \mathbb{Z}_q[X] / (X^N + 1) +/// \f$. +class NTT { + public: + /// @brief Helper class for custom memory allocation + template + struct AllocatorAdapter + : public AllocatorInterface> { + explicit AllocatorAdapter(Adaptee&& _a, Args&&... args); + AllocatorAdapter(const Adaptee& _a, Args&... args); + + // interface implementation + void* allocate_impl(size_t bytes_count); + void deallocate_impl(void* p, size_t n); + + private: + Adaptee alloc; + }; + + /// @brief Initializes an empty NTT object + NTT() = default; + + /// @brief Destructs the NTT object + ~NTT() = default; + + /// @brief Initializes an NTT object with degree \p degree and modulus \p q. + /// @param[in] degree also known as N. Size of the NTT transform. Must be a + /// power of + /// 2 + /// @param[in] q Prime modulus. Must satisfy \f$ q == 1 \mod 2N \f$ + /// @param[in] alloc_ptr Custom memory allocator used for intermediate + /// calculations + /// @brief Performs pre-computation necessary for forward and inverse + /// transforms + NTT(uint64_t degree, uint64_t q, + std::shared_ptr alloc_ptr = {}); + + template + NTT(uint64_t degree, uint64_t q, Allocator&& a, AllocatorArgs&&... args) + : NTT(degree, q, + std::static_pointer_cast( + std::make_shared>( + std::move(a), std::forward(args)...))) {} + + /// @brief Initializes an NTT object with degree \p degree and modulus + /// \p q. + /// @param[in] degree also known as N. Size of the NTT transform. Must be a + /// power of + /// 2 + /// @param[in] q Prime modulus. Must satisfy \f$ q == 1 \mod 2N \f$ + /// @param[in] root_of_unity 2N'th root of unity in \f$ \mathbb{Z_q} \f$. + /// @param[in] alloc_ptr Custom memory allocator used for intermediate + /// calculations + /// @details Performs pre-computation necessary for forward and inverse + /// transforms + NTT(uint64_t degree, uint64_t q, uint64_t root_of_unity, + std::shared_ptr alloc_ptr = {}); + + template + NTT(uint64_t degree, uint64_t q, uint64_t root_of_unity, Allocator&& a, + AllocatorArgs&&... args) + : NTT(degree, q, root_of_unity, + std::static_pointer_cast( + std::make_shared>( + std::move(a), std::forward(args)...))) {} + + /// @brief Returns true if arguments satisfy constraints for negacyclic NTT + /// @param[in] degree N. Size of the transform, i.e. the polynomial degree. + /// Must be a power of two. + /// @param[in] modulus Prime modulus q. Must satisfy q mod 2N = 1 + static bool CheckArguments(uint64_t degree, uint64_t modulus); + + /// @brief Compute forward NTT. Results are bit-reversed. + /// @param[out] result Stores the result + /// @param[in] operand Data on which to compute the NTT + /// @param[in] input_mod_factor Assume input \p operand are in [0, + /// input_mod_factor * q). Must be 1, 2 or 4. + /// @param[in] output_mod_factor Returns output \p result in [0, + /// output_mod_factor * q). Must be 1 or 4. + void ComputeForward(uint64_t* result, const uint64_t* operand, + uint64_t input_mod_factor, uint64_t output_mod_factor); + + /// Compute inverse NTT. Results are bit-reversed. + /// @param[out] result Stores the result + /// @param[in] operand Data on which to compute the NTT + /// @param[in] input_mod_factor Assume input \p operand are in [0, + /// input_mod_factor * q). Must be 1 or 2. + /// @param[in] output_mod_factor Returns output \p result in [0, + /// output_mod_factor * q). Must be 1 or 2. + void ComputeInverse(uint64_t* result, const uint64_t* operand, + uint64_t input_mod_factor, uint64_t output_mod_factor); + + /// @brief Returns the minimal 2N'th root of unity + uint64_t GetMinimalRootOfUnity() const { return m_w; } + + /// @brief Returns the degree N + uint64_t GetDegree() const { return m_degree; } + + /// @brief Returns the word-sized prime modulus + uint64_t GetModulus() const { return m_q; } + + /// @brief Returns the root of unity powers in bit-reversed order + const AlignedVector64& GetRootOfUnityPowers() const { + return m_root_of_unity_powers; + } + + /// @brief Returns the root of unity power at bit-reversed index i. + uint64_t GetRootOfUnityPower(size_t i) { return GetRootOfUnityPowers()[i]; } + + /// @brief Returns 32-bit pre-conditioned root of unity powers in + /// bit-reversed order + const AlignedVector64& GetPrecon32RootOfUnityPowers() const { + return m_precon32_root_of_unity_powers; + } + + /// @brief Returns 64-bit pre-conditioned root of unity powers in + /// bit-reversed order + const AlignedVector64& GetPrecon64RootOfUnityPowers() const { + return m_precon64_root_of_unity_powers; + } + + /// @brief Returns the root of unity powers in bit-reversed order with + /// modifications for use by AVX512 implementation + const AlignedVector64& GetAVX512RootOfUnityPowers() const { + return m_avx512_root_of_unity_powers; + } + + /// @brief Returns 32-bit pre-conditioned AVX512 root of unity powers in + /// bit-reversed order + const AlignedVector64& GetAVX512Precon32RootOfUnityPowers() const { + return m_avx512_precon32_root_of_unity_powers; + } + + /// @brief Returns 52-bit pre-conditioned AVX512 root of unity powers in + /// bit-reversed order + const AlignedVector64& GetAVX512Precon52RootOfUnityPowers() const { + return m_avx512_precon52_root_of_unity_powers; + } + + /// @brief Returns 64-bit pre-conditioned AVX512 root of unity powers in + /// bit-reversed order + const AlignedVector64& GetAVX512Precon64RootOfUnityPowers() const { + return m_avx512_precon64_root_of_unity_powers; + } + + /// @brief Returns the inverse root of unity powers in bit-reversed order + const AlignedVector64& GetInvRootOfUnityPowers() const { + return m_inv_root_of_unity_powers; + } + + /// @brief Returns the inverse root of unity power at bit-reversed index i. + uint64_t GetInvRootOfUnityPower(size_t i) { + return GetInvRootOfUnityPowers()[i]; + } + + /// @brief Returns the vector of 32-bit pre-conditioned pre-computed root of + /// unity + // powers for the modulus and root of unity. + const AlignedVector64& GetPrecon32InvRootOfUnityPowers() const { + return m_precon32_inv_root_of_unity_powers; + } + + /// @brief Returns the vector of 52-bit pre-conditioned pre-computed root of + /// unity + // powers for the modulus and root of unity. + const AlignedVector64& GetPrecon52InvRootOfUnityPowers() const { + return m_precon52_inv_root_of_unity_powers; + } + + /// @brief Returns the vector of 64-bit pre-conditioned pre-computed root of + /// unity + // powers for the modulus and root of unity. + const AlignedVector64& GetPrecon64InvRootOfUnityPowers() const { + return m_precon64_inv_root_of_unity_powers; + } + + /// @brief Maximum power of 2 in degree + static size_t MaxDegreeBits() { return 20; } + + /// @brief Maximum number of bits in modulus; + static size_t MaxModulusBits() { return 62; } + + /// @brief Default bit shift used in Barrett precomputation + static const size_t s_default_shift_bits{64}; + + /// @brief Bit shift used in Barrett precomputation when AVX512-IFMA + /// acceleration is enabled + static const size_t s_ifma_shift_bits{52}; + + /// @brief Maximum modulus to use 32-bit AVX512-DQ acceleration for the + /// forward transform + static const size_t s_max_fwd_32_modulus{1ULL << (32 - 2)}; + + /// @brief Maximum modulus to use 32-bit AVX512-DQ acceleration for the + /// inverse transform + static const size_t s_max_inv_32_modulus{1ULL << (32 - 2)}; + + /// @brief Maximum modulus to use AVX512-IFMA acceleration for the forward + /// transform + static const size_t s_max_fwd_ifma_modulus{1ULL << (s_ifma_shift_bits - 2)}; + + /// @brief Maximum modulus to use AVX512-IFMA acceleration for the inverse + /// transform + static const size_t s_max_inv_ifma_modulus{1ULL << (s_ifma_shift_bits - 2)}; + + /// @brief Maximum modulus to use AVX512-DQ acceleration for the inverse + /// transform + static const size_t s_max_inv_dq_modulus{1ULL << (s_default_shift_bits - 2)}; + + static size_t s_max_fwd_modulus(int bit_shift) { + if (bit_shift == 32) { + return s_max_fwd_32_modulus; + } else if (bit_shift == 52) { + return s_max_fwd_ifma_modulus; + } else if (bit_shift == 64) { + return 1ULL << MaxModulusBits(); + } + HEXL_CHECK(false, "Invalid bit_shift " << bit_shift); + return 0; + } + + static size_t s_max_inv_modulus(int bit_shift) { + if (bit_shift == 32) { + return s_max_inv_32_modulus; + } else if (bit_shift == 52) { + return s_max_inv_ifma_modulus; + } else if (bit_shift == 64) { + return 1ULL << MaxModulusBits(); + } + HEXL_CHECK(false, "Invalid bit_shift " << bit_shift); + return 0; + } + + private: + void ComputeRootOfUnityPowers(); + + uint64_t m_degree; // N: size of NTT transform, should be power of 2 + uint64_t m_q; // prime modulus. Must satisfy q == 1 mod 2n + + uint64_t m_degree_bits; // log_2(m_degree) + + uint64_t m_w_inv; // Inverse of minimal root of unity + uint64_t m_w; // A 2N'th root of unity + + std::shared_ptr m_alloc; + + AlignedAllocator m_aligned_alloc; + + // powers of the minimal root of unity + AlignedVector64 m_root_of_unity_powers; + // vector of floor(W * 2**32 / m_q), with W the root of unity powers + AlignedVector64 m_precon32_root_of_unity_powers; + // vector of floor(W * 2**64 / m_q), with W the root of unity powers + AlignedVector64 m_precon64_root_of_unity_powers; + + // powers of the minimal root of unity adjusted for use in AVX512 + // implementations + AlignedVector64 m_avx512_root_of_unity_powers; + // vector of floor(W * 2**32 / m_q), with W the AVX512 root of unity powers + AlignedVector64 m_avx512_precon32_root_of_unity_powers; + // vector of floor(W * 2**52 / m_q), with W the AVX512 root of unity powers + AlignedVector64 m_avx512_precon52_root_of_unity_powers; + // vector of floor(W * 2**64 / m_q), with W the AVX512 root of unity powers + AlignedVector64 m_avx512_precon64_root_of_unity_powers; + + // vector of floor(W * 2**32 / m_q), with W the inverse root of unity powers + AlignedVector64 m_precon32_inv_root_of_unity_powers; + // vector of floor(W * 2**52 / m_q), with W the inverse root of unity powers + AlignedVector64 m_precon52_inv_root_of_unity_powers; + // vector of floor(W * 2**64 / m_q), with W the inverse root of unity powers + AlignedVector64 m_precon64_inv_root_of_unity_powers; + + AlignedVector64 m_inv_root_of_unity_powers; +}; + +} // namespace hexl +} // namespace intel diff --git a/hexl_v2/include/hexl/number-theory/number-theory.hpp b/hexl_v2/include/hexl/number-theory/number-theory.hpp new file mode 100644 index 00000000..da8d1d2a --- /dev/null +++ b/hexl_v2/include/hexl/number-theory/number-theory.hpp @@ -0,0 +1,342 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +#include +#include + +#include "hexl/util/check.hpp" +#include "hexl/util/compiler.hpp" + +namespace intel { +namespace hexl { + +/// @brief Pre-computes a Barrett factor with which modular multiplication can +/// be performed more efficiently +class MultiplyFactor { + public: + MultiplyFactor() = default; + + /// @brief Computes and stores the Barrett factor floor((operand << bit_shift) + /// / modulus). This is useful when modular multiplication of the form + /// (x * operand) mod modulus is performed with same modulus and operand + /// several times. Note, passing operand=1 can be used to pre-compute a + /// Barrett factor for multiplications of the form (x * y) mod modulus, where + /// only the modulus is re-used across calls to modular multiplication. + MultiplyFactor(uint64_t operand, uint64_t bit_shift, uint64_t modulus) + : m_operand(operand) { + HEXL_CHECK(operand <= modulus, "operand " << operand + << " must be less than modulus " + << modulus); + HEXL_CHECK(bit_shift == 32 || bit_shift == 52 || bit_shift == 64, + "Unsupported BitShift " << bit_shift); + uint64_t op_hi = operand >> (64 - bit_shift); + uint64_t op_lo = (bit_shift == 64) ? 0 : (operand << bit_shift); + + m_barrett_factor = DivideUInt128UInt64Lo(op_hi, op_lo, modulus); + } + + /// @brief Returns the pre-computed Barrett factor + inline uint64_t BarrettFactor() const { return m_barrett_factor; } + + /// @brief Returns the operand corresponding to the Barrett factor + inline uint64_t Operand() const { return m_operand; } + + private: + uint64_t m_operand; + uint64_t m_barrett_factor; +}; + +/// @brief Returns whether or not num is a power of two +inline bool IsPowerOfTwo(uint64_t num) { return num && !(num & (num - 1)); } + +/// @brief Returns floor(log2(x)) +inline uint64_t Log2(uint64_t x) { return MSB(x); } + +inline bool IsPowerOfFour(uint64_t num) { + return IsPowerOfTwo(num) && (Log2(num) % 2 == 0); +} + +/// @brief Returns the maximum value that can be represented using \p bits bits +inline uint64_t MaximumValue(uint64_t bits) { + HEXL_CHECK(bits <= 64, "MaximumValue requires bits <= 64; got " << bits); + if (bits == 64) { + return (std::numeric_limits::max)(); + } + return (1ULL << bits) - 1; +} + +/// @brief Reverses the bits +/// @param[in] x Input to reverse +/// @param[in] bit_width Number of bits in the input; must be >= MSB(x) +/// @return The bit-reversed representation of \p x using \p bit_width bits +uint64_t ReverseBits(uint64_t x, uint64_t bit_width); + +/// @brief Returns x^{-1} mod modulus +/// @details Requires x % modulus != 0 +uint64_t InverseMod(uint64_t x, uint64_t modulus); + +/// @brief Returns (x * y) mod modulus +/// @details Assumes x, y < modulus +uint64_t MultiplyMod(uint64_t x, uint64_t y, uint64_t modulus); + +/// @brief Returns (x * y) mod modulus +/// @param[in] x +/// @param[in] y +/// @param[in] y_precon 64-bit precondition factor floor(2**64 / modulus) +/// @param[in] modulus +uint64_t MultiplyMod(uint64_t x, uint64_t y, uint64_t y_precon, + uint64_t modulus); + +/// @brief Returns (x + y) mod modulus +/// @details Assumes x, y < modulus +uint64_t AddUIntMod(uint64_t x, uint64_t y, uint64_t modulus); + +/// @brief Returns (x - y) mod modulus +/// @details Assumes x, y < modulus +uint64_t SubUIntMod(uint64_t x, uint64_t y, uint64_t modulus); + +/// @brief Returns base^exp mod modulus +uint64_t PowMod(uint64_t base, uint64_t exp, uint64_t modulus); + +/// @brief Returns whether or not root is a degree-th root of unity mod modulus +/// @param[in] root Root of unity to check +/// @param[in] degree Degree of root of unity; must be a power of two +/// @param[in] modulus Modulus of finite field +bool IsPrimitiveRoot(uint64_t root, uint64_t degree, uint64_t modulus); + +/// @brief Tries to return a primitive degree-th root of unity +/// @details Returns 0 or throws an error if no root is found +uint64_t GeneratePrimitiveRoot(uint64_t degree, uint64_t modulus); + +/// @brief Returns whether or not root is a degree-th root of unity +/// @param[in] degree Must be a power of two +/// @param[in] modulus Modulus of finite field +uint64_t MinimalPrimitiveRoot(uint64_t degree, uint64_t modulus); + +/// @brief Computes (x * y) mod modulus, except that the output is in [0, 2 * +/// modulus] +/// @param[in] x +/// @param[in] y_operand also denoted y +/// @param[in] modulus +/// @param[in] y_barrett_factor Pre-computed Barrett reduction factor floor((y +/// << BitShift) / modulus) +template +inline uint64_t MultiplyModLazy(uint64_t x, uint64_t y_operand, + uint64_t y_barrett_factor, uint64_t modulus) { + HEXL_CHECK(y_operand < modulus, "y_operand " << y_operand + << " must be less than modulus " + << modulus); + HEXL_CHECK( + modulus <= MaximumValue(BitShift), + "Modulus " << modulus << " exceeds bound " << MaximumValue(BitShift)); + HEXL_CHECK(x <= MaximumValue(BitShift), + "Operand " << x << " exceeds bound " << MaximumValue(BitShift)); + + uint64_t Q = MultiplyUInt64Hi(x, y_barrett_factor); + return y_operand * x - Q * modulus; +} + +/// @brief Computes (x * y) mod modulus, except that the output is in [0, 2 * +/// modulus] +/// @param[in] x +/// @param[in] y +/// @param[in] modulus +template +inline uint64_t MultiplyModLazy(uint64_t x, uint64_t y, uint64_t modulus) { + HEXL_CHECK(BitShift == 64 || BitShift == 52, + "Unsupported BitShift " << BitShift); + HEXL_CHECK(x <= MaximumValue(BitShift), + "Operand " << x << " exceeds bound " << MaximumValue(BitShift)); + HEXL_CHECK(y < modulus, + "y " << y << " must be less than modulus " << modulus); + HEXL_CHECK( + modulus <= MaximumValue(BitShift), + "Modulus " << modulus << " exceeds bound " << MaximumValue(BitShift)); + + uint64_t y_barrett = MultiplyFactor(y, BitShift, modulus).BarrettFactor(); + return MultiplyModLazy(x, y, y_barrett, modulus); +} + +/// @brief Adds two unsigned 64-bit integers +/// @param operand1 Number to add +/// @param operand2 Number to add +/// @param result Stores the sum +/// @return The carry bit +inline unsigned char AddUInt64(uint64_t operand1, uint64_t operand2, + uint64_t* result) { + *result = operand1 + operand2; + return static_cast(*result < operand1); +} + +/// @brief Returns whether or not the input is prime +bool IsPrime(uint64_t n); + +/// @brief Generates a list of num_primes primes in the range [2^(bit_size), +// 2^(bit_size+1)]. Ensures each prime q satisfies +// q % (2*ntt_size+1)) == 1 +/// @param[in] num_primes Number of primes to generate +/// @param[in] bit_size Bit size of each prime +/// @param[in] prefer_small_primes When true, returns primes starting from +/// 2^(bit_size); when false, returns primes starting from 2^(bit_size+1) +/// @param[in] ntt_size N such that each prime q satisfies q % (2N) == 1. N must +/// be a power of two less than 2^bit_size. +std::vector GeneratePrimes(size_t num_primes, size_t bit_size, + bool prefer_small_primes, + size_t ntt_size = 1); + +/// @brief Returns input mod modulus, computed via 64-bit Barrett reduction +/// @param[in] input +/// @param[in] modulus +/// @param[in] q_barr floor(2^64 / modulus) +template +uint64_t BarrettReduce64(uint64_t input, uint64_t modulus, uint64_t q_barr) { + HEXL_CHECK(modulus != 0, "modulus == 0"); + uint64_t q = MultiplyUInt64Hi<64>(input, q_barr); + uint64_t q_times_input = input - q * modulus; + if (OutputModFactor == 2) { + return q_times_input; + } else { + return (q_times_input >= modulus) ? q_times_input - modulus : q_times_input; + } +} + +/// @brief Returns x mod modulus, assuming x < InputModFactor * modulus +/// @param[in] x +/// @param[in] modulus also denoted q +/// @param[in] twice_modulus 2 * q; must not be nullptr if InputModFactor == 4 +/// or 8 +/// @param[in] four_times_modulus 4 * q; must not be nullptr if InputModFactor +/// == 8 +template +uint64_t ReduceMod(uint64_t x, uint64_t modulus, + const uint64_t* twice_modulus = nullptr, + const uint64_t* four_times_modulus = nullptr) { + HEXL_CHECK(InputModFactor == 1 || InputModFactor == 2 || + InputModFactor == 4 || InputModFactor == 8, + "InputModFactor should be 1, 2, 4, or 8"); + if (InputModFactor == 1) { + return x; + } + if (InputModFactor == 2) { + if (x >= modulus) { + x -= modulus; + } + return x; + } + if (InputModFactor == 4) { + HEXL_CHECK(twice_modulus != nullptr, "twice_modulus should not be nullptr"); + if (x >= *twice_modulus) { + x -= *twice_modulus; + } + if (x >= modulus) { + x -= modulus; + } + return x; + } + if (InputModFactor == 8) { + HEXL_CHECK(twice_modulus != nullptr, "twice_modulus should not be nullptr"); + HEXL_CHECK(four_times_modulus != nullptr, + "four_times_modulus should not be nullptr"); + + if (x >= *four_times_modulus) { + x -= *four_times_modulus; + } + if (x >= *twice_modulus) { + x -= *twice_modulus; + } + if (x >= modulus) { + x -= modulus; + } + return x; + } + HEXL_CHECK(false, "Should be unreachable"); + return x; +} + +/// @brief Returns Montgomery form of ab mod q, computed via the REDC algorithm, +/// also known as Montgomery reduction. +/// @param[in] r +/// @param[in] q with R = 2^r such that gcd(R, q) = 1. R > q. +/// @param[in] inv_mod in [0, R − 1] such that q*v_inv_mod ≡ −1 mod R. +/// @param[in] mod_R_msk take r last bits to apply mod R. +/// @param[in] T_hi of T = ab in the range [0, Rq − 1]. +/// @param[in] T_lo of T. +/// @return Unsigned long int in the range [0, q − 1] such that S ≡ TR^−1 mod q +template +inline uint64_t MontgomeryReduce(uint64_t T_hi, uint64_t T_lo, uint64_t q, + int r, uint64_t mod_R_msk, uint64_t inv_mod) { + HEXL_CHECK(BitShift == 64 || BitShift == 52, + "Unsupported BitShift " << BitShift); + HEXL_CHECK((1ULL << r) > static_cast(q), + "R value should be greater than q = " << static_cast(q)); + + uint64_t mq_hi; + uint64_t mq_lo; + + uint64_t m = ((T_lo & mod_R_msk) * inv_mod) & mod_R_msk; + MultiplyUInt64(m, q, &mq_hi, &mq_lo); + + if (BitShift == 52) { + mq_hi = (mq_hi << 12) | (mq_lo >> 52); + mq_lo &= (1ULL << 52) - 1; + } + + uint64_t t_hi; + uint64_t t_lo; + + // first 64bit block + t_lo = T_lo + mq_lo; + unsigned int carry = static_cast(t_lo < T_lo); + t_hi = T_hi + mq_hi + carry; + + t_hi = t_hi << (BitShift - r); + t_lo = t_lo >> r; + t_lo = t_hi + t_lo; + + return (t_lo >= q) ? (t_lo - q) : t_lo; +} + +/// @brief Hensel's Lemma for 2-adic numbers +/// Find solution for qX + 1 = 0 mod 2^r +/// @param[in] r +/// @param[in] q such that gcd(2, q) = 1 +/// @return Unsigned long int in [0, 2^r − 1] such that q*x ≡ −1 mod 2^r +inline uint64_t HenselLemma2adicRoot(uint32_t r, uint64_t q) { + uint64_t a_prev = 1; + uint64_t c = 2; + uint64_t mod_mask = 3; + + // Root: + // f(x) = qX + 1 and a_(0) = 1 then f(1) ≡ 0 mod 2 + // General Case: + // - a_(n) ≡ a_(n-1) mod 2^(n) + // => a_(n) = a_(n-1) + 2^(n)*t + // - Find 't' such that f(a_(n)) = 0 mod 2^(n+1) + // First case in for: + // - a_(1) ≡ 1 mod 2 or a_(1) = 1 + 2t + // - Find 't' so f(a_(1)) ≡ 0 mod 4 => q(1 + 2t) + 1 ≡ 0 mod 4 + for (uint64_t k = 2; k <= r; k++) { + uint64_t f = 0; + uint64_t t = 0; + uint64_t a = 0; + + do { + a = a_prev + c * t++; + f = q * a + 1ULL; + } while (f & mod_mask); // f(a) ≡ 0 mod 2^(k) + + // Update vars + mod_mask = mod_mask * 2 + 1ULL; + c *= 2; + a_prev = a; + } + + return a_prev; +} + +} // namespace hexl +} // namespace intel diff --git a/hexl_v2/include/hexl/util/aligned-allocator.hpp b/hexl_v2/include/hexl/util/aligned-allocator.hpp new file mode 100644 index 00000000..d175c734 --- /dev/null +++ b/hexl_v2/include/hexl/util/aligned-allocator.hpp @@ -0,0 +1,110 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include +#include + +#include "hexl/number-theory/number-theory.hpp" +#include "hexl/util/allocator.hpp" +#include "hexl/util/defines.hpp" + +namespace intel { +namespace hexl { + +/// @brief Allocater implementation using malloc and free +struct MallocStrategy : AllocatorBase { + void* allocate(size_t bytes_count) final { return std::malloc(bytes_count); } + + void deallocate(void* p, size_t n) final { + HEXL_UNUSED(n); + std::free(p); + } +}; + +using AllocatorStrategyPtr = std::shared_ptr; +extern AllocatorStrategyPtr mallocStrategy; + +/// @brief Allocates memory aligned to Alignment-byte sized boundaries +/// @details Alignment must be a power of two +template +class AlignedAllocator { + public: + template + friend class AlignedAllocator; + + using value_type = T; + + explicit AlignedAllocator(AllocatorStrategyPtr strategy = nullptr) noexcept + : m_alloc_impl((strategy != nullptr) ? strategy : mallocStrategy) {} + + AlignedAllocator(const AlignedAllocator& src) = default; + AlignedAllocator& operator=(const AlignedAllocator& src) = default; + + template + AlignedAllocator(const AlignedAllocator& src) + : m_alloc_impl(src.m_alloc_impl) {} + + ~AlignedAllocator() {} + + template + struct rebind { + using other = AlignedAllocator; + }; + + bool operator==(const AlignedAllocator&) { return true; } + + bool operator!=(const AlignedAllocator&) { return false; } + + /// @brief Allocates \p n elements aligned to Alignment-byte boundaries + /// @return Pointer to the aligned allocated memory + T* allocate(size_t n) { + if (!IsPowerOfTwo(Alignment)) { + return nullptr; + } + // Allocate enough space to ensure the alignment can be satisfied + size_t buffer_size = sizeof(T) * n + Alignment; + // Additionally, allocate a prefix to store the memory location of the + // unaligned buffer + size_t alloc_size = buffer_size + sizeof(void*); + void* buffer = m_alloc_impl->allocate(alloc_size); + if (!buffer) { + return nullptr; + } + + // Reserve first location for pointer to originally-allocated space + void* aligned_buffer = static_cast(buffer) + sizeof(void*); + std::align(Alignment, sizeof(T) * n, aligned_buffer, buffer_size); + if (!aligned_buffer) { + return nullptr; + } + + // Store allocated buffer address at aligned_buffer - sizeof(void*). + void* store_buffer_addr = + static_cast(aligned_buffer) - sizeof(void*); + *(static_cast(store_buffer_addr)) = buffer; + + return static_cast(aligned_buffer); + } + + void deallocate(T* p, size_t n) { + if (!p) { + return; + } + void* store_buffer_addr = (reinterpret_cast(p) - sizeof(void*)); + void* free_address = *(static_cast(store_buffer_addr)); + m_alloc_impl->deallocate(free_address, n); + } + + private: + AllocatorStrategyPtr m_alloc_impl; +}; + +/// @brief 64-byte aligned memory allocator +template +using AlignedVector64 = std::vector >; + +} // namespace hexl +} // namespace intel diff --git a/hexl_v2/include/hexl/util/allocator.hpp b/hexl_v2/include/hexl/util/allocator.hpp new file mode 100644 index 00000000..5f4a7a31 --- /dev/null +++ b/hexl_v2/include/hexl/util/allocator.hpp @@ -0,0 +1,53 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +namespace intel { +namespace hexl { + +/// @brief Base class for custom memory allocator +struct AllocatorBase { + virtual ~AllocatorBase() noexcept {} + + /// @brief Allocates byte_count bytes of memory + /// @param[in] bytes_count Number of bytes to allocate + /// @return A pointer to the allocated memory + virtual void* allocate(size_t bytes_count) = 0; + + /// @brief Deallocate memory + /// @param[in] p Pointer to memory to deallocate + /// @param[in] n Number of bytes to deallocate + virtual void deallocate(void* p, size_t n) = 0; +}; + +/// @brief Helper memory allocation struct which delegates implementation to +/// AllocatorImpl +template +struct AllocatorInterface : public AllocatorBase { + /// @brief Override interface and delegate implementation to AllocatorImpl + void* allocate(size_t bytes_count) override { + return static_cast(this)->allocate_impl(bytes_count); + } + + /// @brief Override interface and delegate implementation to AllocatorImpl + void deallocate(void* p, size_t n) override { + static_cast(this)->deallocate_impl(p, n); + } + + private: + // in case AllocatorImpl doesn't provide implementations, use default null + // behavior + void* allocate_impl(size_t bytes_count) { + HEXL_UNUSED(bytes_count); + return nullptr; + } + void deallocate_impl(void* p, size_t n) { + HEXL_UNUSED(p); + HEXL_UNUSED(n); + } +}; +} // namespace hexl +} // namespace intel diff --git a/hexl_v2/include/hexl/util/check.hpp b/hexl_v2/include/hexl/util/check.hpp new file mode 100644 index 00000000..386eba89 --- /dev/null +++ b/hexl_v2/include/hexl/util/check.hpp @@ -0,0 +1,44 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +#include "hexl/util/types.hpp" + +// Create logging/debug macros with no run-time overhead unless HEXL_DEBUG is +// enabled +#ifdef HEXL_DEBUG +#include "hexl/logging/logging.hpp" + +/// @brief If input condition is not true, logs the expression and throws an +/// error +/// @param[in] cond A boolean indication the condition +/// @param[in] expr The expression to be logged +#define HEXL_CHECK(cond, expr) \ + if (!(cond)) { \ + LOG(ERROR) << expr << " in function: " << __FUNCTION__ \ + << " in file: " __FILE__ << ":" << __LINE__; \ + throw std::runtime_error("Error. Check log output"); \ + } + +/// @brief If input has an element >= bound, logs the expression and throws an +/// error +/// @param[in] arg Input container which supports the [] operator. +/// @param[in] n Size of input +/// @param[in] bound Upper bound on the input +/// @param[in] expr The expression to be logged +#define HEXL_CHECK_BOUNDS(arg, n, bound, expr) \ + for (size_t hexl_check_idx = 0; hexl_check_idx < n; ++hexl_check_idx) { \ + HEXL_CHECK((arg)[hexl_check_idx] < bound, expr); \ + } + +#else // HEXL_DEBUG=OFF + +#define HEXL_CHECK(cond, expr) \ + {} +#define HEXL_CHECK_BOUNDS(...) \ + {} + +#endif // HEXL_DEBUG diff --git a/hexl_v2/include/hexl/util/clang.hpp b/hexl_v2/include/hexl/util/clang.hpp new file mode 100644 index 00000000..958bea7b --- /dev/null +++ b/hexl_v2/include/hexl/util/clang.hpp @@ -0,0 +1,67 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once +#include + +#include "hexl/util/check.hpp" +#include "hexl/util/types.hpp" + +namespace intel { +namespace hexl { + +#ifdef HEXL_USE_CLANG +// Return x * y as 128-bit integer +// Correctness if x * y < 128 bits +inline uint128_t MultiplyUInt64(uint64_t x, uint64_t y) { + return uint128_t(x) * uint128_t(y); +} + +inline uint64_t BarrettReduce128(uint64_t input_hi, uint64_t input_lo, + uint64_t modulus) { + HEXL_CHECK(modulus != 0, "modulus == 0") + uint128_t n = (static_cast(input_hi) << 64) | + (static_cast(input_lo)); + + return n % modulus; + // TODO(fboemer): actually use barrett reduction if performance-critical +} + +// Returns low 64bit of 128b/64b where x1=high 64b, x0=low 64b +inline uint64_t DivideUInt128UInt64Lo(uint64_t x1, uint64_t x0, uint64_t y) { + uint128_t n = + (static_cast(x1) << 64) | (static_cast(x0)); + uint128_t q = n / y; + + return static_cast(q); +} + +// Multiplies x * y as 128-bit integer. +// @param prod_hi Stores high 64 bits of product +// @param prod_lo Stores low 64 bits of product +inline void MultiplyUInt64(uint64_t x, uint64_t y, uint64_t* prod_hi, + uint64_t* prod_lo) { + uint128_t prod = MultiplyUInt64(x, y); + *prod_hi = static_cast(prod >> 64); + *prod_lo = static_cast(prod); +} + +// Return the high 128 minus BitShift bits of the 128-bit product x * y +template +inline uint64_t MultiplyUInt64Hi(uint64_t x, uint64_t y) { + uint128_t product = static_cast(x) * y; + return static_cast(product >> BitShift); +} + +// Returns most-significant bit of the input +inline uint64_t MSB(uint64_t input) { + return static_cast(std::log2l(input)); +} + +#define HEXL_LOOP_UNROLL_4 _Pragma("clang loop unroll_count(4)") +#define HEXL_LOOP_UNROLL_8 _Pragma("clang loop unroll_count(8)") + +#endif + +} // namespace hexl +} // namespace intel diff --git a/hexl_v2/include/hexl/util/compiler.hpp b/hexl_v2/include/hexl/util/compiler.hpp new file mode 100644 index 00000000..7dd077df --- /dev/null +++ b/hexl_v2/include/hexl/util/compiler.hpp @@ -0,0 +1,14 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "hexl/util/defines.hpp" + +#ifdef HEXL_USE_MSVC +#include "hexl/util/msvc.hpp" +#elif defined HEXL_USE_GNU +#include "hexl/util/gcc.hpp" +#elif defined HEXL_USE_CLANG +#include "hexl/util/clang.hpp" +#endif diff --git a/hexl_v2/include/hexl/util/defines.hpp b/hexl_v2/include/hexl/util/defines.hpp new file mode 100644 index 00000000..93db376e --- /dev/null +++ b/hexl_v2/include/hexl/util/defines.hpp @@ -0,0 +1,13 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +/* #undef HEXL_USE_MSVC */ +#define HEXL_USE_GNU +/* #undef HEXL_USE_CLANG */ + +/* #undef HEXL_DEBUG */ + +// Avoid unused variable warnings +#define HEXL_UNUSED(x) (void)(x) diff --git a/hexl_v2/include/hexl/util/gcc.hpp b/hexl_v2/include/hexl/util/gcc.hpp new file mode 100644 index 00000000..828e3836 --- /dev/null +++ b/hexl_v2/include/hexl/util/gcc.hpp @@ -0,0 +1,67 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once +#include + +#include "hexl/util/check.hpp" +#include "hexl/util/types.hpp" + +namespace intel { +namespace hexl { + +#ifdef HEXL_USE_GNU +// Return x * y as 128-bit integer +// Correctness if x * y < 128 bits +inline uint128_t MultiplyUInt64(uint64_t x, uint64_t y) { + return uint128_t(x) * uint128_t(y); +} + +inline uint64_t BarrettReduce128(uint64_t input_hi, uint64_t input_lo, + uint64_t modulus) { + HEXL_CHECK(modulus != 0, "modulus == 0") + uint128_t n = (static_cast(input_hi) << 64) | + (static_cast(input_lo)); + + return static_cast(n % modulus); + // TODO(fboemer): actually use barrett reduction if performance-critical +} + +// Returns low 64bit of 128b/64b where x1=high 64b, x0=low 64b +inline uint64_t DivideUInt128UInt64Lo(uint64_t x1, uint64_t x0, uint64_t y) { + uint128_t n = + (static_cast(x1) << 64) | (static_cast(x0)); + uint128_t q = n / y; + + return static_cast(q); +} + +// Multiplies x * y as 128-bit integer. +// @param prod_hi Stores high 64 bits of product +// @param prod_lo Stores low 64 bits of product +inline void MultiplyUInt64(uint64_t x, uint64_t y, uint64_t* prod_hi, + uint64_t* prod_lo) { + uint128_t prod = MultiplyUInt64(x, y); + *prod_hi = static_cast(prod >> 64); + *prod_lo = static_cast(prod); +} + +// Return the high 128 minus BitShift bits of the 128-bit product x * y +template +inline uint64_t MultiplyUInt64Hi(uint64_t x, uint64_t y) { + uint128_t product = MultiplyUInt64(x, y); + return static_cast(product >> BitShift); +} + +// Returns most-significant bit of the input +inline uint64_t MSB(uint64_t input) { + return static_cast(std::log2l(input)); +} + +#define HEXL_LOOP_UNROLL_4 _Pragma("GCC unroll 4") +#define HEXL_LOOP_UNROLL_8 _Pragma("GCC unroll 8") + +#endif + +} // namespace hexl +} // namespace intel diff --git a/hexl_v2/include/hexl/util/msvc.hpp b/hexl_v2/include/hexl/util/msvc.hpp new file mode 100644 index 00000000..0ada2d45 --- /dev/null +++ b/hexl_v2/include/hexl/util/msvc.hpp @@ -0,0 +1,289 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#ifdef HEXL_USE_MSVC + +#define NOMINMAX // Avoid errors with std::min/std::max +#undef min +#undef max + +#include +#include +#include + +#include + +#include "hexl/util/check.hpp" + +#pragma intrinsic(_addcarry_u64, _BitScanReverse64, _subborrow_u64, _udiv128, \ + _umul128) + +#undef TRUE +#undef FALSE + +namespace intel { +namespace hexl { + +inline uint64_t BarrettReduce128(uint64_t input_hi, uint64_t input_lo, + uint64_t modulus) { + HEXL_CHECK(modulus != 0, "modulus == 0") + uint64_t remainder; + _udiv128(input_hi, input_lo, modulus, &remainder); + + return remainder; +} + +// Multiplies x * y as 128-bit integer. +// @param prod_hi Stores high 64 bits of product +// @param prod_lo Stores low 64 bits of product +inline void MultiplyUInt64(uint64_t x, uint64_t y, uint64_t* prod_hi, + uint64_t* prod_lo) { + *prod_lo = _umul128(x, y, prod_hi); +} + +// Return the high 128 minus BitShift bits of the 128-bit product x * y +template +inline uint64_t MultiplyUInt64Hi(uint64_t x, uint64_t y) { + HEXL_CHECK(BitShift == 52 || BitShift == 64, + "Invalid BitShift " << BitShift << "; expected 52 or 64"); + uint64_t prod_hi; + uint64_t prod_lo = _umul128(x, y, &prod_hi); + uint64_t result_hi; + uint64_t result_lo; + RightShift128(&result_hi, &result_lo, prod_hi, prod_lo, BitShift); + return result_lo; +} + +/// @brief Computes Left Shift op as 128-bit unsigned integer +/// @param[out] result_hi Stores high 64 bits of result +/// @param[out] result_lo Stores low 64 bits of result +/// @param[in] op_hi Stores high 64 bits of input +/// @param[in] op_lo Stores low 64 bits of input +inline void LeftShift128(uint64_t* result_hi, uint64_t* result_lo, + const uint64_t op_hi, const uint64_t op_lo, + const uint64_t shift_value) { + HEXL_CHECK(result_hi != nullptr, "Require result_hi != nullptr"); + HEXL_CHECK(result_lo != nullptr, "Require result_lo != nullptr"); + HEXL_CHECK(shift_value <= 128, + "shift_value cannot be greater than 128 " << shift_value); + + if (shift_value == 0) { + *result_hi = op_hi; + *result_lo = op_lo; + } else if (shift_value == 64) { + *result_hi = op_lo; + *result_lo = 0ULL; + } else if (shift_value == 128) { + *result_hi = 0ULL; + *result_lo = 0ULL; + } else if (shift_value >= 1 && shift_value <= 63) { + *result_hi = (op_hi << shift_value) | (op_lo >> (64 - shift_value)); + *result_lo = op_lo << shift_value; + } else if (shift_value >= 65 && shift_value < 128) { + *result_hi = op_lo << (shift_value - 64); + *result_lo = 0ULL; + } +} + +/// @brief Computes Right Shift op as 128-bit unsigned integer +/// @param[out] result_hi Stores high 64 bits of result +/// @param[out] result_lo Stores low 64 bits of result +/// @param[in] op_hi Stores high 64 bits of input +/// @param[in] op_lo Stores low 64 bits of input +inline void RightShift128(uint64_t* result_hi, uint64_t* result_lo, + const uint64_t op_hi, const uint64_t op_lo, + const uint64_t shift_value) { + HEXL_CHECK(result_hi != nullptr, "Require result_hi != nullptr"); + HEXL_CHECK(result_lo != nullptr, "Require result_lo != nullptr"); + HEXL_CHECK(shift_value <= 128, + "shift_value cannot be greater than 128 " << shift_value); + + if (shift_value == 0) { + *result_hi = op_hi; + *result_lo = op_lo; + } else if (shift_value == 64) { + *result_hi = 0ULL; + *result_lo = op_hi; + } else if (shift_value == 128) { + *result_hi = 0ULL; + *result_lo = 0ULL; + } else if (shift_value >= 1 && shift_value <= 63) { + *result_hi = op_hi >> shift_value; + *result_lo = (op_hi << (64 - shift_value)) | (op_lo >> shift_value); + } else if (shift_value >= 65 && shift_value < 128) { + *result_hi = 0ULL; + *result_lo = op_hi >> (shift_value - 64); + } +} + +/// @brief Adds op1 + op2 as 128-bit integer +/// @param[out] result_hi Stores high 64 bits of result +/// @param[out] result_lo Stores low 64 bits of result +inline void AddWithCarry128(uint64_t* result_hi, uint64_t* result_lo, + const uint64_t op1_hi, const uint64_t op1_lo, + const uint64_t op2_hi, const uint64_t op2_lo) { + HEXL_CHECK(result_hi != nullptr, "Require result_hi != nullptr"); + HEXL_CHECK(result_lo != nullptr, "Require result_lo != nullptr"); + + // first 64bit block + *result_lo = op1_lo + op2_lo; + unsigned char carry = static_cast(*result_lo < op1_lo); + + // second 64bit block + _addcarry_u64(carry, op1_hi, op2_hi, result_hi); +} + +/// @brief Subtracts op1 - op2 as 128-bit integer +/// @param[out] result_hi Stores high 64 bits of result +/// @param[out] result_lo Stores low 64 bits of result +inline void SubWithCarry128(uint64_t* result_hi, uint64_t* result_lo, + const uint64_t op1_hi, const uint64_t op1_lo, + const uint64_t op2_hi, const uint64_t op2_lo) { + HEXL_CHECK(result_hi != nullptr, "Require result_hi != nullptr"); + HEXL_CHECK(result_lo != nullptr, "Require result_lo != nullptr"); + + unsigned char borrow; + + // first 64bit block + *result_lo = op1_lo - op2_lo; + borrow = static_cast(op2_lo > op1_lo); + + // second 64bit block + _subborrow_u64(borrow, op1_hi, op2_hi, result_hi); +} + +/// @brief Computes and returns significant bit count +/// @param[in] value Input element at most 128 bits long +inline uint64_t SignificantBitLength(const uint64_t* value) { + HEXL_CHECK(value != nullptr, "Require value != nullptr"); + + unsigned long count = 0; // NOLINT(runtime/int) + + // second 64bit block + _BitScanReverse64(&count, *(value + 1)); + if (count >= 0 && *(value + 1) > 0) { + return static_cast(count) + 1 + 64; + } + + // first 64bit block + _BitScanReverse64(&count, *value); + if (count >= 0 && *(value) > 0) { + return static_cast(count) + 1; + } + return 0; +} + +/// @brief Checks if input is negative number +/// @param[in] input Input element to check for sign +inline bool CheckSign(const uint64_t* input) { + HEXL_CHECK(input != nullptr, "Require input != nullptr"); + + uint64_t input_temp[2]{0, 0}; + RightShift128(&input_temp[1], &input_temp[0], input[1], input[0], 127); + return (input_temp[0] == 1); +} + +/// @brief Divides numerator by denominator +/// @param[out] quotient Stores quotient as two 64-bit blocks after division +/// @param[in] numerator +/// @param[in] denominator +inline void DivideUInt128UInt64(uint64_t* quotient, const uint64_t* numerator, + const uint64_t denominator) { + HEXL_CHECK(quotient != nullptr, "Require quotient != nullptr"); + HEXL_CHECK(numerator != nullptr, "Require numerator != nullptr"); + HEXL_CHECK(denominator != 0, "denominator cannot be 0 " << denominator); + + // get bit count of divisor + uint64_t numerator_bits = SignificantBitLength(numerator); + const uint64_t numerator_bits_const = numerator_bits; + const uint64_t uint_128_bit = 128ULL; + + uint64_t MASK[2]{0x0000000000000001, 0x0000000000000000}; + uint64_t remainder[2]{0, 0}; + uint64_t quotient_temp[2]{0, 0}; + uint64_t denominator_temp[2]{denominator, 0}; + + quotient[0] = numerator[0]; + quotient[1] = numerator[1]; + + // align numerator + LeftShift128("ient[1], "ient[0], quotient[1], quotient[0], + (uint_128_bit - numerator_bits_const)); + + while (numerator_bits) { + // if remainder is negative + if (CheckSign(remainder)) { + LeftShift128(&remainder[1], &remainder[0], remainder[1], remainder[0], 1); + RightShift128("ient_temp[1], "ient_temp[0], quotient[1], + quotient[0], (uint_128_bit - 1)); + remainder[0] = remainder[0] | quotient_temp[0]; + LeftShift128("ient[1], "ient[0], quotient[1], quotient[0], 1); + // remainder=remainder+denominator_temp + AddWithCarry128(&remainder[1], &remainder[0], remainder[1], remainder[0], + denominator_temp[1], denominator_temp[0]); + } else { // if remainder is positive + LeftShift128(&remainder[1], &remainder[0], remainder[1], remainder[0], 1); + RightShift128("ient_temp[1], "ient_temp[0], quotient[1], + quotient[0], (uint_128_bit - 1)); + remainder[0] = remainder[0] | quotient_temp[0]; + LeftShift128("ient[1], "ient[0], quotient[1], quotient[0], 1); + // remainder=remainder-denominator_temp + SubWithCarry128(&remainder[1], &remainder[0], remainder[1], remainder[0], + denominator_temp[1], denominator_temp[0]); + } + + // if remainder is positive set MSB of quotient[0]=1 + if (!CheckSign(remainder)) { + MASK[0] = 0x0000000000000001; + MASK[1] = 0x0000000000000000; + LeftShift128(&MASK[1], &MASK[0], MASK[1], MASK[0], + (uint_128_bit - numerator_bits_const)); + quotient[0] = quotient[0] | MASK[0]; + quotient[1] = quotient[1] | MASK[1]; + } + quotient_temp[0] = 0; + quotient_temp[1] = 0; + numerator_bits--; + } + + if (CheckSign(remainder)) { + // remainder=remainder+denominator_temp + AddWithCarry128(&remainder[1], &remainder[0], remainder[1], remainder[0], + denominator_temp[1], denominator_temp[0]); + } + RightShift128("ient[1], "ient[0], quotient[1], quotient[0], + (uint_128_bit - numerator_bits_const)); +} + +/// @brief Returns low of dividing numerator by denominator +/// @param[in] numerator_hi Stores high 64 bit of numerator +/// @param[in] numerator_lo Stores low 64 bit of numerator +/// @param[in] denominator Stores denominator +inline uint64_t DivideUInt128UInt64Lo(const uint64_t numerator_hi, + const uint64_t numerator_lo, + const uint64_t denominator) { + uint64_t numerator[2]{numerator_lo, numerator_hi}; + uint64_t quotient[2]{0, 0}; + + DivideUInt128UInt64(quotient, numerator, denominator); + return quotient[0]; +} + +// Returns most-significant bit of the input +inline uint64_t MSB(uint64_t input) { + unsigned long index{0}; // NOLINT(runtime/int) + _BitScanReverse64(&index, input); + return index; +} + +#define HEXL_LOOP_UNROLL_4 \ + {} +#define HEXL_LOOP_UNROLL_8 \ + {} + +#endif + +} // namespace hexl +} // namespace intel diff --git a/hexl_v2/include/hexl/util/types.hpp b/hexl_v2/include/hexl/util/types.hpp new file mode 100644 index 00000000..2d2d8551 --- /dev/null +++ b/hexl_v2/include/hexl/util/types.hpp @@ -0,0 +1,13 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +#include "hexl/util/defines.hpp" + +#if defined(HEXL_USE_GNU) || defined(HEXL_USE_CLANG) +__extension__ typedef __int128 int128_t; +__extension__ typedef unsigned __int128 uint128_t; +#endif diff --git a/hexl_v2/include/hexl/util/util.hpp b/hexl_v2/include/hexl/util/util.hpp new file mode 100644 index 00000000..bf878a98 --- /dev/null +++ b/hexl_v2/include/hexl/util/util.hpp @@ -0,0 +1,53 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +namespace intel { +namespace hexl { + +#undef TRUE // MSVC defines TRUE +#undef FALSE // MSVC defines FALSE + +/// @enum CMPINT +/// @brief Represents binary operations between two boolean values +enum class CMPINT { + EQ = 0, ///< Equal + LT = 1, ///< Less than + LE = 2, ///< Less than or equal + FALSE = 3, ///< False + NE = 4, ///< Not equal + NLT = 5, ///< Not less than + NLE = 6, ///< Not less than or equal + TRUE = 7 ///< True +}; + +/// @brief Returns the logical negation of a binary operation +/// @param[in] cmp The binary operation to negate +inline CMPINT Not(CMPINT cmp) { + switch (cmp) { + case CMPINT::EQ: + return CMPINT::NE; + case CMPINT::LT: + return CMPINT::NLT; + case CMPINT::LE: + return CMPINT::NLE; + case CMPINT::FALSE: + return CMPINT::TRUE; + case CMPINT::NE: + return CMPINT::EQ; + case CMPINT::NLT: + return CMPINT::LT; + case CMPINT::NLE: + return CMPINT::LE; + case CMPINT::TRUE: + return CMPINT::FALSE; + default: + return CMPINT::FALSE; + } +} + +} // namespace hexl +} // namespace intel diff --git a/hexl_v2/lib/cmake/hexl-1.2.5/HEXLConfig.cmake b/hexl_v2/lib/cmake/hexl-1.2.5/HEXLConfig.cmake new file mode 100644 index 00000000..d3c012b5 --- /dev/null +++ b/hexl_v2/lib/cmake/hexl-1.2.5/HEXLConfig.cmake @@ -0,0 +1,59 @@ +# Copyright (C) 2020 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +# This will define the following variables: +# +# HEXL_FOUND - True if the system has the Intel HEXL library +# HEXL_VERSION - The full major.minor.patch version number +# HEXL_VERSION_MAJOR - The major version number +# HEXL_VERSION_MINOR - The minor version number +# HEXL_VERSION_PATCH - The patch version number + + +####### Expanded from @PACKAGE_INIT@ by configure_package_config_file() ####### +####### Any changes to this file will be overwritten by the next CMake run #### +####### The input file was HEXLConfig.cmake.in ######## + +get_filename_component(PACKAGE_PREFIX_DIR "${CMAKE_CURRENT_LIST_DIR}/../../../" ABSOLUTE) + +macro(set_and_check _var _file) + set(${_var} "${_file}") + if(NOT EXISTS "${_file}") + message(FATAL_ERROR "File or directory ${_file} referenced by variable ${_var} does not exist !") + endif() +endmacro() + +macro(check_required_components _NAME) + foreach(comp ${${_NAME}_FIND_COMPONENTS}) + if(NOT ${_NAME}_${comp}_FOUND) + if(${_NAME}_FIND_REQUIRED_${comp}) + set(${_NAME}_FOUND FALSE) + endif() + endif() + endforeach() +endmacro() + +#################################################################################### + +include(CMakeFindDependencyMacro) +find_package(CpuFeatures CONFIG) +if(NOT CpuFeatures_FOUND) + message(WARNING "Could not find pre-installed CpuFeatures; using CpuFeatures packaged with HEXL") +endif() + +include(${CMAKE_CURRENT_LIST_DIR}/HEXLTargets.cmake) + +# Defines HEXL_FOUND: If Intel HEXL library was found +if(TARGET HEXL::hexl) + set(HEXL_FOUND TRUE) + message(STATUS "Intel HEXL found") +else() + message(STATUS "Intel HEXL not found") +endif() + +set(HEXL_VERSION "1.2.5") +set(HEXL_VERSION_MAJOR "1") +set(HEXL_VERSION_MINOR "2") +set(HEXL_VERSION_PATCH "5") + +set(HEXL_DEBUG "OFF") diff --git a/hexl_v2/lib/cmake/hexl-1.2.5/HEXLConfigVersion.cmake b/hexl_v2/lib/cmake/hexl-1.2.5/HEXLConfigVersion.cmake new file mode 100644 index 00000000..98b46110 --- /dev/null +++ b/hexl_v2/lib/cmake/hexl-1.2.5/HEXLConfigVersion.cmake @@ -0,0 +1,88 @@ +# This is a basic version file for the Config-mode of find_package(). +# It is used by write_basic_package_version_file() as input file for configure_file() +# to create a version-file which can be installed along a config.cmake file. +# +# The created file sets PACKAGE_VERSION_EXACT if the current version string and +# the requested version string are exactly the same and it sets +# PACKAGE_VERSION_COMPATIBLE if the current version is equal to the requested version. +# The tweak version component is ignored. +# The variable CVF_VERSION must be set before calling configure_file(). + + +if (PACKAGE_FIND_VERSION_RANGE) + message(AUTHOR_WARNING + "`find_package()` specify a version range but the version strategy " + "(ExactVersion) of the module `${PACKAGE_FIND_NAME}` is incompatible " + "with this request. Only the lower endpoint of the range will be used.") +endif() + +set(PACKAGE_VERSION "1.2.5") + +if("1.2.5" MATCHES "^([0-9]+)\\.([0-9]+)\\.([0-9]+)") # strip the tweak version + set(CVF_VERSION_MAJOR "${CMAKE_MATCH_1}") + set(CVF_VERSION_MINOR "${CMAKE_MATCH_2}") + set(CVF_VERSION_PATCH "${CMAKE_MATCH_3}") + + if(NOT CVF_VERSION_MAJOR VERSION_EQUAL 0) + string(REGEX REPLACE "^0+" "" CVF_VERSION_MAJOR "${CVF_VERSION_MAJOR}") + endif() + if(NOT CVF_VERSION_MINOR VERSION_EQUAL 0) + string(REGEX REPLACE "^0+" "" CVF_VERSION_MINOR "${CVF_VERSION_MINOR}") + endif() + if(NOT CVF_VERSION_PATCH VERSION_EQUAL 0) + string(REGEX REPLACE "^0+" "" CVF_VERSION_PATCH "${CVF_VERSION_PATCH}") + endif() + + set(CVF_VERSION_NO_TWEAK "${CVF_VERSION_MAJOR}.${CVF_VERSION_MINOR}.${CVF_VERSION_PATCH}") +else() + set(CVF_VERSION_NO_TWEAK "1.2.5") +endif() + +if(PACKAGE_FIND_VERSION MATCHES "^([0-9]+)\\.([0-9]+)\\.([0-9]+)") # strip the tweak version + set(REQUESTED_VERSION_MAJOR "${CMAKE_MATCH_1}") + set(REQUESTED_VERSION_MINOR "${CMAKE_MATCH_2}") + set(REQUESTED_VERSION_PATCH "${CMAKE_MATCH_3}") + + if(NOT REQUESTED_VERSION_MAJOR VERSION_EQUAL 0) + string(REGEX REPLACE "^0+" "" REQUESTED_VERSION_MAJOR "${REQUESTED_VERSION_MAJOR}") + endif() + if(NOT REQUESTED_VERSION_MINOR VERSION_EQUAL 0) + string(REGEX REPLACE "^0+" "" REQUESTED_VERSION_MINOR "${REQUESTED_VERSION_MINOR}") + endif() + if(NOT REQUESTED_VERSION_PATCH VERSION_EQUAL 0) + string(REGEX REPLACE "^0+" "" REQUESTED_VERSION_PATCH "${REQUESTED_VERSION_PATCH}") + endif() + + set(REQUESTED_VERSION_NO_TWEAK + "${REQUESTED_VERSION_MAJOR}.${REQUESTED_VERSION_MINOR}.${REQUESTED_VERSION_PATCH}") +else() + set(REQUESTED_VERSION_NO_TWEAK "${PACKAGE_FIND_VERSION}") +endif() + +if(REQUESTED_VERSION_NO_TWEAK STREQUAL CVF_VERSION_NO_TWEAK) + set(PACKAGE_VERSION_COMPATIBLE TRUE) +else() + set(PACKAGE_VERSION_COMPATIBLE FALSE) +endif() + +if(PACKAGE_FIND_VERSION STREQUAL PACKAGE_VERSION) + set(PACKAGE_VERSION_EXACT TRUE) +endif() + + +# if the installed project requested no architecture check, don't perform the check +if("FALSE") + return() +endif() + +# if the installed or the using project don't have CMAKE_SIZEOF_VOID_P set, ignore it: +if("${CMAKE_SIZEOF_VOID_P}" STREQUAL "" OR "8" STREQUAL "") + return() +endif() + +# check that the installed version has the same 32/64bit-ness as the one which is currently searching: +if(NOT CMAKE_SIZEOF_VOID_P STREQUAL "8") + math(EXPR installedBits "8 * 8") + set(PACKAGE_VERSION "${PACKAGE_VERSION} (${installedBits}bit)") + set(PACKAGE_VERSION_UNSUITABLE TRUE) +endif() diff --git a/hexl_v2/lib/cmake/hexl-1.2.5/HEXLTargets-release.cmake b/hexl_v2/lib/cmake/hexl-1.2.5/HEXLTargets-release.cmake new file mode 100644 index 00000000..c736f1c1 --- /dev/null +++ b/hexl_v2/lib/cmake/hexl-1.2.5/HEXLTargets-release.cmake @@ -0,0 +1,19 @@ +#---------------------------------------------------------------- +# Generated CMake target import file for configuration "Release". +#---------------------------------------------------------------- + +# Commands may need to know the format version. +set(CMAKE_IMPORT_FILE_VERSION 1) + +# Import target "HEXL::hexl" for configuration "Release" +set_property(TARGET HEXL::hexl APPEND PROPERTY IMPORTED_CONFIGURATIONS RELEASE) +set_target_properties(HEXL::hexl PROPERTIES + IMPORTED_LOCATION_RELEASE "${_IMPORT_PREFIX}/lib/libhexl.so.1.2.5" + IMPORTED_SONAME_RELEASE "libhexl.so.1.2.5" + ) + +list(APPEND _IMPORT_CHECK_TARGETS HEXL::hexl ) +list(APPEND _IMPORT_CHECK_FILES_FOR_HEXL::hexl "${_IMPORT_PREFIX}/lib/libhexl.so.1.2.5" ) + +# Commands beyond this point should not need to know the version. +set(CMAKE_IMPORT_FILE_VERSION) diff --git a/hexl_v2/lib/cmake/hexl-1.2.5/HEXLTargets.cmake b/hexl_v2/lib/cmake/hexl-1.2.5/HEXLTargets.cmake new file mode 100644 index 00000000..7132d40a --- /dev/null +++ b/hexl_v2/lib/cmake/hexl-1.2.5/HEXLTargets.cmake @@ -0,0 +1,100 @@ +# Generated by CMake + +if("${CMAKE_MAJOR_VERSION}.${CMAKE_MINOR_VERSION}" LESS 2.6) + message(FATAL_ERROR "CMake >= 2.6.0 required") +endif() +cmake_policy(PUSH) +cmake_policy(VERSION 2.6...3.20) +#---------------------------------------------------------------- +# Generated CMake target import file. +#---------------------------------------------------------------- + +# Commands may need to know the format version. +set(CMAKE_IMPORT_FILE_VERSION 1) + +# Protect against multiple inclusion, which would fail when already imported targets are added once more. +set(_targetsDefined) +set(_targetsNotDefined) +set(_expectedTargets) +foreach(_expectedTarget HEXL::hexl) + list(APPEND _expectedTargets ${_expectedTarget}) + if(NOT TARGET ${_expectedTarget}) + list(APPEND _targetsNotDefined ${_expectedTarget}) + endif() + if(TARGET ${_expectedTarget}) + list(APPEND _targetsDefined ${_expectedTarget}) + endif() +endforeach() +if("${_targetsDefined}" STREQUAL "${_expectedTargets}") + unset(_targetsDefined) + unset(_targetsNotDefined) + unset(_expectedTargets) + set(CMAKE_IMPORT_FILE_VERSION) + cmake_policy(POP) + return() +endif() +if(NOT "${_targetsDefined}" STREQUAL "") + message(FATAL_ERROR "Some (but not all) targets in this export set were already defined.\nTargets Defined: ${_targetsDefined}\nTargets not yet defined: ${_targetsNotDefined}\n") +endif() +unset(_targetsDefined) +unset(_targetsNotDefined) +unset(_expectedTargets) + + +# Compute the installation prefix relative to this file. +get_filename_component(_IMPORT_PREFIX "${CMAKE_CURRENT_LIST_FILE}" PATH) +get_filename_component(_IMPORT_PREFIX "${_IMPORT_PREFIX}" PATH) +get_filename_component(_IMPORT_PREFIX "${_IMPORT_PREFIX}" PATH) +get_filename_component(_IMPORT_PREFIX "${_IMPORT_PREFIX}" PATH) +if(_IMPORT_PREFIX STREQUAL "/") + set(_IMPORT_PREFIX "") +endif() + +# Create imported target HEXL::hexl +add_library(HEXL::hexl SHARED IMPORTED) + +set_target_properties(HEXL::hexl PROPERTIES + INTERFACE_COMPILE_OPTIONS "-Wno-unknown-warning;-Wno-unknown-warning-option" + INTERFACE_INCLUDE_DIRECTORIES "${_IMPORT_PREFIX}/include" + INTERFACE_LINK_LIBRARIES "OpenMP::OpenMP_CXX" +) + +if(CMAKE_VERSION VERSION_LESS 2.8.12) + message(FATAL_ERROR "This file relies on consumers using CMake 2.8.12 or greater.") +endif() + +# Load information for each installed configuration. +get_filename_component(_DIR "${CMAKE_CURRENT_LIST_FILE}" PATH) +file(GLOB CONFIG_FILES "${_DIR}/HEXLTargets-*.cmake") +foreach(f ${CONFIG_FILES}) + include(${f}) +endforeach() + +# Cleanup temporary variables. +set(_IMPORT_PREFIX) + +# Loop over all imported files and verify that they actually exist +foreach(target ${_IMPORT_CHECK_TARGETS} ) + foreach(file ${_IMPORT_CHECK_FILES_FOR_${target}} ) + if(NOT EXISTS "${file}" ) + message(FATAL_ERROR "The imported target \"${target}\" references the file + \"${file}\" +but this file does not exist. Possible reasons include: +* The file was deleted, renamed, or moved to another location. +* An install or uninstall procedure did not complete successfully. +* The installation package was faulty and contained + \"${CMAKE_CURRENT_LIST_FILE}\" +but not all the files it references. +") + endif() + endforeach() + unset(_IMPORT_CHECK_FILES_FOR_${target}) +endforeach() +unset(_IMPORT_CHECK_TARGETS) + +# This file does not depend on other imported targets which have +# been exported from the same project but in a separate export set. + +# Commands beyond this point should not need to know the version. +set(CMAKE_IMPORT_FILE_VERSION) +cmake_policy(POP) diff --git a/hexl_v2/lib/libeasyloggingpp.a b/hexl_v2/lib/libeasyloggingpp.a new file mode 100644 index 00000000..8570c76b Binary files /dev/null and b/hexl_v2/lib/libeasyloggingpp.a differ diff --git a/hexl_v2/lib/libhexl.so b/hexl_v2/lib/libhexl.so new file mode 120000 index 00000000..af5173f3 --- /dev/null +++ b/hexl_v2/lib/libhexl.so @@ -0,0 +1 @@ +libhexl.so.1.2.5 \ No newline at end of file diff --git a/hexl_v2/lib/libhexl.so.1.2.5 b/hexl_v2/lib/libhexl.so.1.2.5 new file mode 100644 index 00000000..e593bd6c Binary files /dev/null and b/hexl_v2/lib/libhexl.so.1.2.5 differ diff --git a/hexl_v2/lib/libhexl_debug.so b/hexl_v2/lib/libhexl_debug.so new file mode 120000 index 00000000..c54310f0 --- /dev/null +++ b/hexl_v2/lib/libhexl_debug.so @@ -0,0 +1 @@ +libhexl_debug.so.1.2.5 \ No newline at end of file diff --git a/hexl_v2/lib/libhexl_debug.so.1.2.5 b/hexl_v2/lib/libhexl_debug.so.1.2.5 new file mode 100644 index 00000000..26fe0bae Binary files /dev/null and b/hexl_v2/lib/libhexl_debug.so.1.2.5 differ diff --git a/hexl_v2/lib/pkgconfig/hexl.pc b/hexl_v2/lib/pkgconfig/hexl.pc new file mode 100644 index 00000000..c9a1e5d1 --- /dev/null +++ b/hexl_v2/lib/pkgconfig/hexl.pc @@ -0,0 +1,13 @@ +# Copyright (C) 2020 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +prefix=/home/eidf018/eidf018/s1820742psd/hexl/hexl_v2 +libdir=${prefix}/lib +includedir=${prefix}/include + +Name: Intel HEXL +Version: 1.2.5 +Description: Intel® HEXL is an open-source library which provides efficient implementations of integer arithmetic on Galois fields. + +Libs: -L${libdir} -lhexl +Cflags: -I${includedir} diff --git a/hexl_v3/include/hexl/eltwise/eltwise-add-mod.hpp b/hexl_v3/include/hexl/eltwise/eltwise-add-mod.hpp new file mode 100644 index 00000000..cb2df110 --- /dev/null +++ b/hexl_v3/include/hexl/eltwise/eltwise-add-mod.hpp @@ -0,0 +1,40 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +namespace intel { +namespace hexl { + +/// @brief Adds two vectors elementwise with modular reduction +/// @param[out] result Stores result +/// @param[in] operand1 Vector of elements to add. Each element must be less +/// than the modulus +/// @param[in] operand2 Vector of elements to add. Each element must be less +/// than the modulus +/// @param[in] n Number of elements in each vector +/// @param[in] modulus Modulus with which to perform modular reduction. Must be +/// in the range \f$[2, 2^{63} - 1]\f$ +/// @details Computes \f$ operand1[i] = (operand1[i] + operand2[i]) \mod modulus +/// \f$ for \f$ i=0, ..., n-1\f$. +void EltwiseAddMod(uint64_t* result, const uint64_t* operand1, + const uint64_t* operand2, uint64_t n, uint64_t modulus); + +/// @brief Adds a vector and scalar elementwise with modular reduction +/// @param[out] result Stores result +/// @param[in] operand1 Vector of elements to add. Each element must be less +/// than the modulus +/// @param[in] operand2 Scalar to add. Must be less +/// than the modulus +/// @param[in] n Number of elements in each vector +/// @param[in] modulus Modulus with which to perform modular reduction. Must be +/// in the range \f$[2, 2^{63} - 1]\f$ +/// @details Computes \f$ operand1[i] = (operand1[i] + operand2) \mod modulus +/// \f$ for \f$ i=0, ..., n-1\f$. +void EltwiseAddMod(uint64_t* result, const uint64_t* operand1, + uint64_t operand2, uint64_t n, uint64_t modulus); + +} // namespace hexl +} // namespace intel diff --git a/hexl_v3/include/hexl/eltwise/eltwise-cmp-add.hpp b/hexl_v3/include/hexl/eltwise/eltwise-cmp-add.hpp new file mode 100644 index 00000000..27e514ff --- /dev/null +++ b/hexl_v3/include/hexl/eltwise/eltwise-cmp-add.hpp @@ -0,0 +1,26 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +#include "hexl/util/util.hpp" + +namespace intel { +namespace hexl { + +/// @brief Computes element-wise conditional addition. +/// @param[out] result Stores the result +/// @param[in] operand1 Vector of elements to compare; stores result +/// @param[in] n Number of elements in \p operand1 +/// @param[in] cmp Comparison operation +/// @param[in] bound Scalar to compare against +/// @param[in] diff Scalar to conditionally add +/// @details Computes result[i] = cmp(operand1[i], bound) ? operand1[i] + +/// diff : operand1[i] for all \f$i=0, ..., n-1\f$. +void EltwiseCmpAdd(uint64_t* result, const uint64_t* operand1, uint64_t n, + CMPINT cmp, uint64_t bound, uint64_t diff); + +} // namespace hexl +} // namespace intel diff --git a/hexl_v3/include/hexl/eltwise/eltwise-cmp-sub-mod.hpp b/hexl_v3/include/hexl/eltwise/eltwise-cmp-sub-mod.hpp new file mode 100644 index 00000000..07ba3d23 --- /dev/null +++ b/hexl_v3/include/hexl/eltwise/eltwise-cmp-sub-mod.hpp @@ -0,0 +1,29 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +#include "hexl/util/util.hpp" + +namespace intel { +namespace hexl { + +/// @brief Computes element-wise conditional modular subtraction. +/// @param[out] result Stores the result +/// @param[in] operand1 Vector of elements to compare +/// @param[in] n Number of elements in \p operand1 +/// @param[in] modulus Modulus to reduce by +/// @param[in] cmp Comparison function +/// @param[in] bound Scalar to compare against +/// @param[in] diff Scalar to subtract by +/// @details Computes \p operand1[i] = (\p cmp(\p operand1, \p bound)) ? (\p +/// operand1 - \p diff) mod \p modulus : \p operand1 mod \p modulus for all i=0, +/// ..., n-1 +void EltwiseCmpSubMod(uint64_t* result, const uint64_t* operand1, uint64_t n, + uint64_t modulus, CMPINT cmp, uint64_t bound, + uint64_t diff); + +} // namespace hexl +} // namespace intel diff --git a/hexl_v3/include/hexl/eltwise/eltwise-fma-mod.hpp b/hexl_v3/include/hexl/eltwise/eltwise-fma-mod.hpp new file mode 100644 index 00000000..03651a42 --- /dev/null +++ b/hexl_v3/include/hexl/eltwise/eltwise-fma-mod.hpp @@ -0,0 +1,27 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +namespace intel { +namespace hexl { + +/// @brief Computes fused multiply-add (\p arg1 * \p arg2 + \p arg3) mod \p +/// modulus element-wise, broadcasting scalars to vectors. +/// @param[out] result Stores the result +/// @param[in] arg1 Vector to multiply +/// @param[in] arg2 Scalar to multiply +/// @param[in] arg3 Vector to add. Will not add if \p arg3 == nullptr +/// @param[in] n Number of elements in each vector +/// @param[in] modulus Modulus with which to perform modular reduction. Must be +/// in the range \f$ [2, 2^{61} - 1]\f$ +/// @param[in] input_mod_factor Assumes input elements are in [0, +/// input_mod_factor * modulus). Must be 1, 2, 4, or 8. +void EltwiseFMAMod(uint64_t* result, const uint64_t* arg1, uint64_t arg2, + const uint64_t* arg3, uint64_t n, uint64_t modulus, + uint64_t input_mod_factor); + +} // namespace hexl +} // namespace intel diff --git a/hexl_v3/include/hexl/eltwise/eltwise-mult-mod.hpp b/hexl_v3/include/hexl/eltwise/eltwise-mult-mod.hpp new file mode 100644 index 00000000..e4d2dbd7 --- /dev/null +++ b/hexl_v3/include/hexl/eltwise/eltwise-mult-mod.hpp @@ -0,0 +1,28 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +namespace intel { +namespace hexl { + +/// @brief Multiplies two vectors elementwise with modular reduction +/// @param[in] result Result of element-wise multiplication +/// @param[in] operand1 Vector of elements to multiply. Each element must be +/// less than the modulus. +/// @param[in] operand2 Vector of elements to multiply. Each element must be +/// less than the modulus. +/// @param[in] n Number of elements in each vector +/// @param[in] modulus Modulus with which to perform modular reduction +/// @param[in] input_mod_factor Assumes input elements are in [0, +/// input_mod_factor * p) Must be 1, 2 or 4. +/// @details Computes \p result[i] = (\p operand1[i] * \p operand2[i]) mod \p +/// modulus for i=0, ..., \p n - 1 +void EltwiseMultMod(uint64_t* result, const uint64_t* operand1, + const uint64_t* operand2, uint64_t n, uint64_t modulus, + uint64_t input_mod_factor); + +} // namespace hexl +} // namespace intel diff --git a/hexl_v3/include/hexl/eltwise/eltwise-reduce-mod.hpp b/hexl_v3/include/hexl/eltwise/eltwise-reduce-mod.hpp new file mode 100644 index 00000000..c23abde2 --- /dev/null +++ b/hexl_v3/include/hexl/eltwise/eltwise-reduce-mod.hpp @@ -0,0 +1,29 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +namespace intel { +namespace hexl { + +/// @brief Performs elementwise modular reduction +/// @param[out] result Stores the result +/// @param[in] operand Data on which to compute the elementwise modular +/// reduction +/// @param[in] n Number of elements in operand +/// @param[in] modulus Modulus with which to perform modular reduction +/// @param[in] input_mod_factor Assumes input elements are in [0, +/// input_mod_factor * p) Must be modulus, 1, 2 or 4. input_mod_factor=modulus +/// means, input range is [0, p * p]. Barrett reduction will be used in this +/// case. input_mod_factor > output_mod_factor +/// @param[in] output_mod_factor output elements will be in [0, +/// output_mod_factor * modulus) Must be 1 or 2. For input_mod_factor=0, +/// output_mod_factor will be set to 1. +void EltwiseReduceMod(uint64_t* result, const uint64_t* operand, uint64_t n, + uint64_t modulus, uint64_t input_mod_factor, + uint64_t output_mod_factor); + +} // namespace hexl +} // namespace intel diff --git a/hexl_v3/include/hexl/eltwise/eltwise-sub-mod.hpp b/hexl_v3/include/hexl/eltwise/eltwise-sub-mod.hpp new file mode 100644 index 00000000..bd286e47 --- /dev/null +++ b/hexl_v3/include/hexl/eltwise/eltwise-sub-mod.hpp @@ -0,0 +1,40 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +namespace intel { +namespace hexl { + +/// @brief Subtracts two vectors elementwise with modular reduction +/// @param[out] result Stores result +/// @param[in] operand1 Vector of elements to subtract from. Each element must +/// be less than the modulus +/// @param[in] operand2 Vector of elements to subtract. Each element must be +/// less than the modulus +/// @param[in] n Number of elements in each vector +/// @param[in] modulus Modulus with which to perform modular reduction. Must be +/// in the range \f$[2, 2^{63} - 1]\f$ +/// @details Computes \f$ operand1[i] = (operand1[i] - operand2[i]) \mod modulus +/// \f$ for \f$ i=0, ..., n-1\f$. +void EltwiseSubMod(uint64_t* result, const uint64_t* operand1, + const uint64_t* operand2, uint64_t n, uint64_t modulus); + +/// @brief Subtracts a scalar from a vector elementwise with modular reduction +/// @param[out] result Stores result +/// @param[in] operand1 Vector of elements to subtract from. Each element must +/// be less than the modulus +/// @param[in] operand2 Elements to subtract. Each element must be +/// less than the modulus +/// @param[in] n Number of elements in each vector +/// @param[in] modulus Modulus with which to perform modular reduction. Must be +/// in the range \f$[2, 2^{63} - 1]\f$ +/// @details Computes \f$ operand1[i] = (operand1[i] - operand2) \mod modulus +/// \f$ for \f$ i=0, ..., n-1\f$. +void EltwiseSubMod(uint64_t* result, const uint64_t* operand1, + uint64_t operand2, uint64_t n, uint64_t modulus); + +} // namespace hexl +} // namespace intel diff --git a/hexl_v3/include/hexl/experimental/fft-like/fft-like-avx512-util.hpp b/hexl_v3/include/hexl/experimental/fft-like/fft-like-avx512-util.hpp new file mode 100644 index 00000000..28a2dddf --- /dev/null +++ b/hexl_v3/include/hexl/experimental/fft-like/fft-like-avx512-util.hpp @@ -0,0 +1,402 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "util/avx512-util.hpp" + +namespace intel { +namespace hexl { + +#ifdef HEXL_HAS_AVX512DQ + +// ************************************ T1 ************************************ + +// ComplexLoadFwdInterleavedT1: +// Assumes ComplexLoadFwdInterleavedT2 was used before. +// Given input: 15, 14, 11, 10, 7, 6, 3, 2, 13, 12, 9, 8, 5, 4, 1, 0 +// Returns +// *out1 = (14, 12, 10, 8, 6, 4, 2, 0); +// *out2 = (15, 13, 11, 9, 7, 5, 3, 1); +// +// Given output: 15, 13, 11, 9, 7, 5, 3, 1, 14, 12, 10, 8, 6, 4, 2, 0 +inline void ComplexLoadFwdInterleavedT1(const double* arg, __m512d* out1, + __m512d* out2) { + const __m512i vperm_idx = _mm512_set_epi64(6, 7, 4, 5, 2, 3, 0, 1); + + const __m512d* arg_512 = reinterpret_cast(arg); + + // 13 12 9 8 5 4 1 0 + __m512d v_7to0 = _mm512_loadu_pd(arg_512); + arg_512 += 2; + // 15 14 11 10 7 6 3 2 + __m512d v_15to8 = _mm512_loadu_pd(arg_512); + + // 12, 13, 8, 9, 4, 5, 0, 1 + __m512d perm_1 = _mm512_permutexvar_pd(vperm_idx, v_7to0); + // 14, 15, 10, 11, 6, 7, 2, 3 + __m512d perm_2 = _mm512_permutexvar_pd(vperm_idx, v_15to8); + + // 14, 12, 10, 8, 6, 4, 2, 0 + *out1 = _mm512_mask_blend_pd(0xaa, v_7to0, perm_2); + // 15, 13, 11, 9, 7, 5, 3, 1 + *out2 = _mm512_mask_blend_pd(0x55, v_15to8, perm_1); +} + +// ComplexWriteFwdInterleavedT1: +// Assumes ComplexLoadFwdInterleavedT1 was used before. +// Given inputs: +// 15i, 13i, 11i, 9i, 7i, 5i, 3i, 1i, 15r, 13r, 11r, 9r, 7r, 5r, 3r, 1r, +// 14i, 12i, 10i, 8i, 6i, 4i, 2i, 0i, 14r, 12r, 10r, 8r, 6r, 4r, 2r, 0r +// As seen with internal indexes: +// @param arg_yr = (15r, 14r, 13r, 12r, 11r, 10r, 9r, 8r); +// @param arg_xr = ( 7r, 6r, 5r, 4r, 3r, 2r, 1r, 0r); +// @param arg_yi = (15i, 14i, 13i, 12i, 11i, 10i, 9i, 8i); +// @param arg_xi = ( 7i, 6i, 5i, 4i, 3i, 2i, 1i, 0i); +// Writes out = +// {15i, 15r, 7i, 7r, 14i, 14r, 6i, 6r, 13i, 13r, 5i, 5r, 12i, 12r, 4i, 4r, +// 11i, 11r, 3i, 3r, 10i, 10r, 2i, 2r, 9i, 9r, 1i, 1r, 8i, 8r, 0i, 0r} +// +// Given output: +// 15i, 15r, 14i, 14r, 13i, 13r, 12i, 12r, 11i, 11r, 10i, 10r, 9i, 9r, 8i, 8r, +// 7i, 7r, 6i, 6r, 5i, 5r, 4i, 4r, 3i, 3r, 2i, 2r, 1i, 1r, 0i, 0r +inline void ComplexWriteFwdInterleavedT1(__m512d arg_xr, __m512d arg_yr, + __m512d arg_xi, __m512d arg_yi, + __m512d* out) { + const __m512i vperm_4hi_4lo_idx = _mm512_set_epi64(3, 2, 1, 0, 7, 6, 5, 4); + const __m512i v_X_out_idx = _mm512_set_epi64(3, 1, 7, 5, 2, 0, 6, 4); + const __m512i v_Y_out_idx = _mm512_set_epi64(7, 5, 3, 1, 6, 4, 2, 0); + + // Real part + // in: 14r, 12r, 10r, 8r, 6r, 4r, 2r, 0r + // -> 6r, 4r, 2r, 0r, 14r, 12r, 10r, 8r + arg_xr = _mm512_permutexvar_pd(vperm_4hi_4lo_idx, arg_xr); + + // arg_yr: 15r, 13r, 11r, 9r, 7r, 5r, 3r, 1r + // -> 6r, 4r, 2r, 0r, 7r, 5r, 3r, 1r + __m512d perm_1 = _mm512_mask_blend_pd(0x0f, arg_xr, arg_yr); + // -> 15r, 13r, 11r, 9r, 14r, 12r, 10r, 8r + __m512d perm_2 = _mm512_mask_blend_pd(0xf0, arg_xr, arg_yr); + + // 7r, 3r, 6r, 2r, 5r, 1r, 4r, 0r + arg_xr = _mm512_permutexvar_pd(v_X_out_idx, perm_1); + // 15r, 11r, 14r, 10r, 13r, 9r, 12r, 8r + arg_yr = _mm512_permutexvar_pd(v_Y_out_idx, perm_2); + + // Imaginary part + // in: 14i, 12i, 10i, 8i, 6i, 4i, 2i, 0i + // -> 6i, 4i, 2i, 0i, 14i, 12i, 10i, 8i + arg_xi = _mm512_permutexvar_pd(vperm_4hi_4lo_idx, arg_xi); + + // arg_yr: 15i, 13i, 11i, 9i, 7i, 5i, 3i, 1i + // -> 6i, 4i, 2i, 0i, 7i, 5i, 3i, 1i + perm_1 = _mm512_mask_blend_pd(0x0f, arg_xi, arg_yi); + // -> 15i, 13i, 11i, 9i, 14i, 12i, 10i, 8i + perm_2 = _mm512_mask_blend_pd(0xf0, arg_xi, arg_yi); + + // 7i, 3i, 6i, 2i, 5i, 1i, 4i, 0i + arg_xi = _mm512_permutexvar_pd(v_X_out_idx, perm_1); + // 15i, 11i, 14i, 10i, 13i, 9i, 12i, 8i + arg_yi = _mm512_permutexvar_pd(v_Y_out_idx, perm_2); + + // Merge + // 00000000 > 3i 3r 2i 2r 1i 1r 0i 0r + __m512d out1 = _mm512_shuffle_pd(arg_xr, arg_xi, 0x00); + // 11111111 > 7i 7r 6i 6r 5i 5r 4i 4r + __m512d out2 = _mm512_shuffle_pd(arg_xr, arg_xi, 0xff); + + // 00000000 > 11i 11r 10i 10r 9i 9r 8i 8r + __m512d out3 = _mm512_shuffle_pd(arg_yr, arg_yi, 0x00); + // 11111111 > 15i 15r 14i 14r 13i 13r 12i 12r + __m512d out4 = _mm512_shuffle_pd(arg_yr, arg_yi, 0xff); + + _mm512_storeu_pd(out++, out1); + _mm512_storeu_pd(out++, out2); + _mm512_storeu_pd(out++, out3); + _mm512_storeu_pd(out++, out4); +} + +// ComplexLoadInvInterleavedT1: +// Given input: 15i 15r 14i 14r 13i 13r 12i 12r 11i 11r 10i 10r 9i 9r 8i 8r +// 7i 7r 6i 6r 5i 5r 4i 4r 3i 3r 2i 2r 1i 1r 0i 0r +// Returns +// *out1_r = (14r, 10r, 6r, 2r, 12r, 8r, 4r, 0r); +// *out1_i = (14i, 10i, 6i, 2i, 12i, 8i, 4i, 0i); +// *out2_r = (15r, 11r, 7r, 3r, 13r, 9r, 5r, 1r); +// *out2_i = (15i, 11i, 7i, 3i, 13i, 9i, 5i, 1i); +// +// Given output: +// 15i, 11i, 7i, 3i, 13i, 9i, 5i, 1i, 15r, 11r, 7r, 3r, 13r, 9r, 5r, 1r, +// 14i, 10i, 6i, 2i, 12i, 8i, 4i, 0i, 14r, 10r, 6r, 2r, 12r, 8r, 4r, 0r +inline void ComplexLoadInvInterleavedT1(const double* arg, __m512d* out1_r, + __m512d* out1_i, __m512d* out2_r, + __m512d* out2_i) { + const __m512d* arg_512 = reinterpret_cast(arg); + + // 3i 3r 2i 2r 1i 1r 0i 0r + __m512d v_3to0 = _mm512_loadu_pd(arg_512++); + // 7i 7r 6i 6r 5i 5r 4i 4r + __m512d v_7to4 = _mm512_loadu_pd(arg_512++); + // 11i 11r 10i 10r 9i 9r 8i 8r + __m512d v_11to8 = _mm512_loadu_pd(arg_512++); + // 15i 15r 14i 14r 13i 13r 12i 12r + __m512d v_15to12 = _mm512_loadu_pd(arg_512++); + + // 00000000 > 7r 3r 6r 2r 5r 1r 4r 0r + __m512d v_7to0_r = _mm512_shuffle_pd(v_3to0, v_7to4, 0x00); + // 11111111 > 7i 3i 6i 2i 5i 1i 4i 0i + __m512d v_7to0_i = _mm512_shuffle_pd(v_3to0, v_7to4, 0xff); + // 00000000 > 15r 11r 14r 10r 13r 9r 12r 8r + __m512d v_15to8_r = _mm512_shuffle_pd(v_11to8, v_15to12, 0x00); + // 11111111 > 15i 11i 14i 10i 13i 9i 12i 8i + __m512d v_15to8_i = _mm512_shuffle_pd(v_11to8, v_15to12, 0xff); + + // real + const __m512i v1_perm_idx = _mm512_set_epi64(5, 4, 7, 6, 1, 0, 3, 2); + // 6 2 7 3 4 0 5 1 + __m512d v1r = _mm512_permutexvar_pd(v1_perm_idx, v_7to0_r); + // 14 10 15 11 12 8 13 9 + __m512d v2r = _mm512_permutexvar_pd(v1_perm_idx, v_15to8_r); + // 11001100 > 14 10 6 2 12 8 4 0 + *out1_r = _mm512_mask_blend_pd(0xcc, v_7to0_r, v2r); + // 11001100 > 15 11 7 3 13 9 5 1 + *out2_r = _mm512_mask_blend_pd(0xcc, v1r, v_15to8_r); + + // imag + // 6 2 7 3 4 0 5 1 + __m512d v1i = _mm512_permutexvar_pd(v1_perm_idx, v_7to0_i); + // 14 10 15 11 12 8 13 9 + __m512d v2i = _mm512_permutexvar_pd(v1_perm_idx, v_15to8_i); + // 11001100 > 14 10 6 2 12 8 4 0 + *out1_i = _mm512_mask_blend_pd(0xcc, v_7to0_i, v2i); + // 11001100 > 15 11 7 3 13 9 5 1 + *out2_i = _mm512_mask_blend_pd(0xcc, v1i, v_15to8_i); +} + +// ************************************ T2 ************************************ + +// ComplexLoadFwdInterleavedT2: +// Assumes ComplexLoadFwdInterleavedT4 was used before. +// Given input: 15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0 +// Returns +// *out1 = (13, 12, 9, 8, 5, 4, 1, 0) +// *out2 = (15, 14, 11, 10, 7, 6, 3, 2) +// +// Given output: 15, 14, 11, 10, 7, 6, 3, 2, 13, 12, 9, 8, 5, 4, 1, 0 +inline void ComplexLoadFwdInterleavedT2(const double* arg, __m512d* out1, + __m512d* out2) { + const __m512d* arg_512 = reinterpret_cast(arg); + + // Values were swapped in T4 + // 11, 10, 9, 8, 3, 2, 1, 0 + __m512d v1 = _mm512_loadu_pd(arg_512); + arg_512 += 2; + // 15, 14, 13, 12, 7, 6, 5, 4 + __m512d v2 = _mm512_loadu_pd(arg_512); + + const __m512i v1_perm_idx = _mm512_set_epi64(5, 4, 7, 6, 1, 0, 3, 2); + + __m512d v1_perm = _mm512_permutexvar_pd(v1_perm_idx, v1); + __m512d v2_perm = _mm512_permutexvar_pd(v1_perm_idx, v2); + + *out1 = _mm512_mask_blend_pd(0xcc, v1, v2_perm); + *out2 = _mm512_mask_blend_pd(0xcc, v1_perm, v2); +} + +// ComplexLoadInvInterleavedT2: +// Assumes ComplexLoadInvInterleavedT1 was used before. +// Given input: 15, 11, 7, 3, 13, 9, 5, 1, 14, 10, 6, 2, 12, 8, 4, 0 +// Returns +// *out1 = (13, 9, 5, 1, 12, 8, 4, 0) +// *out2 = (15, 11, 7, 3, 14, 10, 6, 2) +// +// Given output: 15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0 +inline void ComplexLoadInvInterleavedT2(const double* arg, __m512d* out1, + __m512d* out2) { + const __m512d* arg_512 = reinterpret_cast(arg); + + // 14 10 6 2 12 8 4 0 + __m512d v1 = _mm512_loadu_pd(arg_512); + arg_512 += 2; + // 15 11 7 3 13 9 5 1 + __m512d v2 = _mm512_loadu_pd(arg_512); + + const __m512i v1_perm_idx = _mm512_set_epi64(3, 2, 1, 0, 7, 6, 5, 4); + + // 12 8 4 0 14 10 6 2 + __m512d v1_perm = _mm512_permutexvar_pd(v1_perm_idx, v1); + // 13 9 5 1 15 11 7 3 + __m512d v2_perm = _mm512_permutexvar_pd(v1_perm_idx, v2); + + // 11110000 > 13 9 5 1 12 8 4 0 + *out1 = _mm512_mask_blend_pd(0xf0, v1, v2_perm); + // 11110000 > 15 11 7 3 14 10 6 2 + *out2 = _mm512_mask_blend_pd(0xf0, v1_perm, v2); +} + +// ************************************ T4 ************************************ + +// Complex LoadFwdInterleavedT4: +// Given input: 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0 +// Returns +// *out1 = (11, 10, 9, 8, 3, 2, 1, 0) +// *out2 = (15, 14, 13, 12, 7, 6, 5, 4) +// +// Given output: 15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0 +inline void ComplexLoadFwdInterleavedT4(const double* arg, __m512d* out1, + __m512d* out2) { + const __m512d* arg_512 = reinterpret_cast(arg); + + const __m512i vperm2_idx = _mm512_set_epi64(3, 2, 1, 0, 7, 6, 5, 4); + __m512d v_7to0 = _mm512_loadu_pd(arg_512); + arg_512 += 2; + __m512d v_15to8 = _mm512_loadu_pd(arg_512); + __m512d perm_hi = _mm512_permutexvar_pd(vperm2_idx, v_15to8); + *out1 = _mm512_mask_blend_pd(0x0f, perm_hi, v_7to0); + *out2 = _mm512_mask_blend_pd(0xf0, perm_hi, v_7to0); + *out2 = _mm512_permutexvar_pd(vperm2_idx, *out2); +} + +// ComplexLoadInvInterleavedT4: +// Assumes ComplexLoadInvInterleavedT2 was used before. +// Given input: 15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0 +// Returns +// *out1 = (11, 9, 3, 1, 10, 8, 2, 0) +// *out2 = (15, 13, 7, 5, 14, 12, 6, 4) +// +// Given output: 15, 13, 7, 5, 14, 12, 6, 4, 11, 9, 3, 1, 10, 8, 2, 0 + +inline void ComplexLoadInvInterleavedT4(const double* arg, __m512d* out1, + __m512d* out2) { + const __m512d* arg_512 = reinterpret_cast(arg); + + // 13, 9, 5, 1, 12, 8, 4, 0 + __m512d v1 = _mm512_loadu_pd(arg_512); + arg_512 += 2; + // 15, 11, 7, 3, 14, 10, 6, 2 + __m512d v2 = _mm512_loadu_pd(arg_512); + + // 00000000 > 11 9 3 1 10 8 2 0 + *out1 = _mm512_shuffle_pd(v1, v2, 0x00); + // 11111111 > 15 13 7 5 14 12 6 4 + *out2 = _mm512_shuffle_pd(v1, v2, 0xff); +} + +// ComplexWriteInvInterleavedT4: +// Assuming ComplexLoadInvInterleavedT4 was used before. +// Given inputs: 15, 13, 7, 5, 14, 12, 6, 4, 11, 9, 3, 1, 10, 8, 2, 0 +// Seen Internally: +// @param arg1 = ( 7, 6, 5, 4, 3, 2, 1, 0); +// @param arg2 = (15, 14, 13, 12, 11, 10, 9, 8); +// Writes out = {15, 11, 14, 10, 7, 3, 6, 2, +// 13, 9, 12, 8, 5, 1, 4, 0} +// +// Given output: 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0 +inline void ComplexWriteInvInterleavedT4(__m512d arg1, __m512d arg2, + __m512d* out) { + const __m512i vperm_4hi_4lo_idx = _mm512_set_epi64(3, 2, 1, 0, 7, 6, 5, 4); + const __m512i vperm1 = _mm512_set_epi64(7, 3, 6, 2, 5, 1, 4, 0); + const __m512i vperm2 = _mm512_set_epi64(5, 1, 4, 0, 7, 3, 6, 2); + + // in: 11 9 3 1 10 8 2 0 + // -> 11 10 9 8 3 2 1 0 + arg1 = _mm512_permutexvar_pd(vperm1, arg1); + // in: 15 13 7 5 14 12 6 4 + // -> 7 6 5 4 15 14 13 12 + arg2 = _mm512_permutexvar_pd(vperm2, arg2); + + // 7 6 5 4 3 2 1 0 + __m512d out1 = _mm512_mask_blend_pd(0xf0, arg1, arg2); + // 11 10 9 8 15 14 13 12 + __m512d out2 = _mm512_mask_blend_pd(0x0f, arg1, arg2); + // 15 14 13 12 11 10 9 8 + out2 = _mm512_permutexvar_pd(vperm_4hi_4lo_idx, out2); + + _mm512_storeu_pd(out, out1); + out += 2; + _mm512_storeu_pd(out, out2); +} + +// ************************************ T8 ************************************ + +// ComplexLoadFwdInterleavedT8: +// Given inputs: 7i, 7r, 6i, 6r, 5i, 5r, 4i, 4r, 3i, 3r, 2i, 2r, 1i, 1r, 0i, 0r +// Seen Internally: +// v_X1 = ( 7, 6, 5, 4, 3, 2, 1, 0); +// v_X2 = (15, 14, 13, 12, 11, 10, 9, 8); +// Writes out = {15, 13, 11, 9, 7, 5, 3, 1, +// 14, 12, 10, 8, 6, 4, 2, 0} +// +// Given output: 7i, 6i, 5i, 4i, 3i, 2i, 1i, 0i, 7r, 6r, 5r, 4r, 3r, 2r, 1r, 0r +inline void ComplexLoadFwdInterleavedT8(const __m512d* arg_x, + const __m512d* arg_y, __m512d* out1_r, + __m512d* out1_i, __m512d* out2_r, + __m512d* out2_i) { + const __m512i v_perm_idx = _mm512_set_epi64(7, 5, 3, 1, 6, 4, 2, 0); + + // 3i, 3r, 2i, 2r, 1i, 1r, 0i, 0r + __m512d v_X1 = _mm512_loadu_pd(arg_x++); + // 7i, 7r, 6i, 6r, 5i, 5r, 4i, 4r + __m512d v_X2 = _mm512_loadu_pd(arg_x); + // 7r, 3r, 6r, 2r, 5r, 1r, 4r, 0r + *out1_r = _mm512_shuffle_pd(v_X1, v_X2, 0x00); + // 7i, 3i, 6i, 2i, 5i, 1i, 4i, 0i + *out1_i = _mm512_shuffle_pd(v_X1, v_X2, 0xff); + // 7r, 6r, 5r, 4r, 3r, 2r, 1r, 0r + *out1_r = _mm512_permutexvar_pd(v_perm_idx, *out1_r); + // 7i, 6i, 5i, 4i, 3i, 2i, 1i, 0i + *out1_i = _mm512_permutexvar_pd(v_perm_idx, *out1_i); + + __m512d v_Y1 = _mm512_loadu_pd(arg_y++); + __m512d v_Y2 = _mm512_loadu_pd(arg_y); + *out2_r = _mm512_shuffle_pd(v_Y1, v_Y2, 0x00); + *out2_i = _mm512_shuffle_pd(v_Y1, v_Y2, 0xff); + *out2_r = _mm512_permutexvar_pd(v_perm_idx, *out2_r); + *out2_i = _mm512_permutexvar_pd(v_perm_idx, *out2_i); +} + +// ComplexWriteInvInterleavedT8: +// Assuming ComplexLoadInvInterleavedT4 was used before. +// Given inputs: 7i, 6i, 5i, 4i, 3i, 2i, 1i, 0i, 7r, 6r, 5r, 4r, 3r, 2r, 1r, 0r +// Seen Internally: +// @param arg1 = ( 7, 6, 5, 4, 3, 2, 1, 0); +// @param arg2 = (15, 14, 13, 12, 11, 10, 9, 8); +// Writes out = {15, 7, 14, 6, 13, 5, 12, 4, +// 11, 3, 10, 2, 9, 1, 8, 0} +// +// Given output: 7i, 7r, 6i, 6r, 5i, 5r, 4i, 4r, 3i, 3r, 2i, 2r, 1i, 1r, 0i, 0r +inline void ComplexWriteInvInterleavedT8(__m512d* v_X_real, __m512d* v_X_imag, + __m512d* v_Y_real, __m512d* v_Y_imag, + __m512d* v_X_pt, __m512d* v_Y_pt) { + const __m512i vperm = _mm512_set_epi64(7, 3, 6, 2, 5, 1, 4, 0); + // in: 7r 6r 5r 4r 3r 2r 1r 0r + // -> 7r 3r 6r 2r 5r 1r 4r 0r + *v_X_real = _mm512_permutexvar_pd(vperm, *v_X_real); + // in: 7i 6i 5i 4i 3i 2i 1i 0i + // -> 7i 3i 6i 2i 5i 1i 4i 0i + *v_X_imag = _mm512_permutexvar_pd(vperm, *v_X_imag); + // in: 15r 14r 13r 12r 11r 10r 9r 8r + // -> 15r 11r 14r 10r 13r 9r 12r 8r + *v_Y_real = _mm512_permutexvar_pd(vperm, *v_Y_real); + // in: 15i 14i 13i 12i 11i 10i 9i 8i + // -> 15i 11i 14i 10i 13i 9i 12i 8i + *v_Y_imag = _mm512_permutexvar_pd(vperm, *v_Y_imag); + + // 00000000 > 3i 3r 2i 2r 1i 1r 0i 0r + __m512d v_X1 = _mm512_shuffle_pd(*v_X_real, *v_X_imag, 0x00); + // 11111111 > 7i 7r 6i 6r 5i 5r 4i 4r + __m512d v_X2 = _mm512_shuffle_pd(*v_X_real, *v_X_imag, 0xff); + // 00000000 > 11i 11r 10i 10r 9i 9r 8i 8r + __m512d v_Y1 = _mm512_shuffle_pd(*v_Y_real, *v_Y_imag, 0x00); + // 11111111 > 15i 15r 14i 14r 13i 13r 12i 12r + __m512d v_Y2 = _mm512_shuffle_pd(*v_Y_real, *v_Y_imag, 0xff); + + _mm512_storeu_pd(v_X_pt++, v_X1); + _mm512_storeu_pd(v_X_pt, v_X2); + _mm512_storeu_pd(v_Y_pt++, v_Y1); + _mm512_storeu_pd(v_Y_pt, v_Y2); +} +#endif // HEXL_HAS_AVX512DQ + +} // namespace hexl +} // namespace intel diff --git a/hexl_v3/include/hexl/experimental/fft-like/fft-like-native.hpp b/hexl_v3/include/hexl/experimental/fft-like/fft-like-native.hpp new file mode 100644 index 00000000..7e02492d --- /dev/null +++ b/hexl_v3/include/hexl/experimental/fft-like/fft-like-native.hpp @@ -0,0 +1,38 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +namespace intel { +namespace hexl { + +/// @brief Radix-2 native C++ FFT like implementation of the forward FFT like +/// @param[out] result Output data. Overwritten with FFT like output +/// @param[in] operand Input data. +/// @param[in] n Size of the transform, i.e. the polynomial degree. Must be a +/// power of two. +/// @param[in] root_of_unity_powers Powers of 2n'th root of unity. In +/// bit-reversed order +/// @param[in] scale Scale applied to output data +void Forward_FFTLike_ToBitReverseRadix2( + std::complex* result, const std::complex* operand, + const std::complex* root_of_unity_powers, const uint64_t n, + const double* scale = nullptr); + +/// @brief Radix-2 native C++ FFT like implementation of the inverse FFT like +/// @param[out] result Output data. Overwritten with FFT like output +/// @param[in] operand Input data. +/// @param[in] n Size of the transform, i.e. the polynomial degree. Must be a +/// power of two. +/// @param[in] inv_root_of_unity_powers Powers of inverse 2n'th root of unity. +/// In bit-reversed order. +/// @param[in] scale Scale applied to output data +void Inverse_FFTLike_FromBitReverseRadix2( + std::complex* result, const std::complex* operand, + const std::complex* inv_root_of_unity_powers, const uint64_t n, + const double* scale = nullptr); + +} // namespace hexl +} // namespace intel diff --git a/hexl_v3/include/hexl/experimental/fft-like/fft-like.hpp b/hexl_v3/include/hexl/experimental/fft-like/fft-like.hpp new file mode 100644 index 00000000..334de246 --- /dev/null +++ b/hexl_v3/include/hexl/experimental/fft-like/fft-like.hpp @@ -0,0 +1,147 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +#include "hexl/experimental/fft-like/fwd-fft-like-avx512.hpp" +#include "hexl/experimental/fft-like/inv-fft-like-avx512.hpp" +#include "hexl/util/aligned-allocator.hpp" +#include "hexl/util/allocator.hpp" + +namespace intel { +namespace hexl { + +/// @brief Performs linear forward and inverse FFT like transform +/// for CKKS encoding and decoding. +class FFTLike { + public: + /// @brief Helper class for custom memory allocation + template + struct AllocatorAdapter + : public AllocatorInterface> { + explicit AllocatorAdapter(Adaptee&& _a, Args&&... args); + AllocatorAdapter(const Adaptee& _a, Args&... args); + + // interface implementation + void* allocate_impl(size_t bytes_count); + void deallocate_impl(void* p, size_t n); + + private: + Adaptee alloc; + }; + + /// @brief Initializes an empty CKKS_FTT object + FFTLike() = default; + + /// @brief Destructs the CKKS_FTT object + ~FFTLike() = default; + + /// @brief Initializes an FFTLike object with degree \p degree and scalar + /// \p in_scalar. + /// @param[in] degree also known as N. Size of the FFT like transform. Must be + /// a power of 2 + /// @param[in] in_scalar Scalar value to calculate scale and inv scale + /// @param[in] alloc_ptr Custom memory allocator used for intermediate + /// calculations + /// @details Performs pre-computation necessary for forward and inverse + /// transforms + FFTLike(uint64_t degree, double* in_scalar, + std::shared_ptr alloc_ptr = {}); + + template + FFTLike(uint64_t degree, double* in_scalar, Allocator&& a, + AllocatorArgs&&... args) + : FFTLike( + degree, in_scalar, + std::static_pointer_cast( + std::make_shared>( + std::move(a), std::forward(args)...))) {} + + /// @brief Compute forward FFT like. Results are bit-reversed. + /// @param[out] result Stores the result + /// @param[in] operand Data on which to compute the FFT like + /// @param[in] in_scale Scale applied to output values + void ComputeForwardFFTLike(std::complex* result, + const std::complex* operand, + const double* in_scale = nullptr); + + /// @brief Compute inverse FFT like. Results are bit-reversed. + /// @param[out] result Stores the result + /// @param[in] operand Data on which to compute the FFT like + /// @param[in] in_scale Scale applied to output values + void ComputeInverseFFTLike(std::complex* result, + const std::complex* operand, + const double* in_scale = nullptr); + + /// @brief Construct floating-point values from CRT-composed polynomial with + /// integer coefficients. + /// @param[out] res Stores the result + /// @param[in] plain Plaintext + /// @param[in] threshold Upper half threshold with respect to the total + /// coefficient modulus + /// @param[in] decryption_modulus Product of all primes in the coefficient + /// modulus + /// @param[in] inv_scale Scale applied to output values + /// @param[in] mod_size Size of coefficient modulus parameter + /// @param[in] coeff_count Degree of the polynomial modulus parameter + void BuildFloatingPoints(std::complex* res, const uint64_t* plain, + const uint64_t* threshold, + const uint64_t* decryption_modulus, + const double inv_scale, size_t mod_size, + size_t coeff_count); + + /// @brief Returns the root of unity power at bit-reversed index i. + /// @param[in] i Index + std::complex GetComplexRootOfUnity(size_t i) { + return GetComplexRootsOfUnity()[i]; + } + + /// @brief Returns the root of unity in bit-reversed order + const AlignedVector64>& GetComplexRootsOfUnity() const { + return m_complex_roots_of_unity; + } + + /// @brief Returns the root of unity power at bit-reversed index i. + /// @param[in] i Index + std::complex GetInvComplexRootOfUnity(size_t i) { + return GetInvComplexRootsOfUnity()[i]; + } + + /// @brief Returns the inverse root of unity in bit-reversed order + const AlignedVector64>& GetInvComplexRootsOfUnity() + const { + return m_inv_complex_roots_of_unity; + } + + /// @brief Returns the degree N + uint64_t GetDegree() const { return m_degree; } + + private: + // Computes 1~(n-1)-th powers and inv powers of the primitive 2n-th root + void ComputeComplexRootsOfUnity(); + + uint64_t m_degree; // N: size of FFT like transform, should be power of 2 + + double* scalar; // Pointer to scalar used for scale/inv_scale calculation + + double scale; // Scale value use for encoding (inv fft-like) + + double inv_scale; // Scale value use in decoding (fwd fft-like) + + std::shared_ptr m_alloc; + + AlignedAllocator m_aligned_alloc; + + uint64_t m_degree_bits; // log_2(m_degree) + + // Contains 0~(n-1)-th powers of the 2n-th primitive root. + AlignedVector64> m_complex_roots_of_unity; + + // Contains 0~(n-1)-th inv powers of the 2n-th primitive inv root. + AlignedVector64> m_inv_complex_roots_of_unity; +}; + +} // namespace hexl +} // namespace intel diff --git a/hexl_v3/include/hexl/experimental/fft-like/fwd-fft-like-avx512.hpp b/hexl_v3/include/hexl/experimental/fft-like/fwd-fft-like-avx512.hpp new file mode 100644 index 00000000..aba4ca4d --- /dev/null +++ b/hexl_v3/include/hexl/experimental/fft-like/fwd-fft-like-avx512.hpp @@ -0,0 +1,59 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "hexl/experimental/fft-like/fft-like.hpp" + +namespace intel { +namespace hexl { + +#ifdef HEXL_HAS_AVX512DQ + +/// @brief AVX512 implementation of the forward FFT like +/// @param[out] result_cmplx_intrlvd Output data. Overwritten with FFT like +/// output. Result is a vector of double with interleaved real and imaginary +/// numbers. +/// @param[in] operand_cmplx_intrlvd Input data. A vector of double with +/// interleaved real and imaginary numbers. +/// @param[in] roots_of_unity_cmplx_intrlvd Powers of 2n'th root of unity. In +/// bit-reversed order. +/// @param[in] n Size of the transform, i.e. the polynomial degree. Must be a +/// power of two. +/// @param[in] scale Scale applied to output values +/// @param[in] recursion_depth Depth of recursive call +/// @param[in] recursion_half Helper for indexing roots of unity +/// @details The implementation is recursive. The base case is a breadth-first +/// FFT like, where all the butterflies in a given stage are processed before +/// any butterflies in the next stage. The base case is small enough to fit in +/// the smallest cache. Larger FFTs are processed recursively in a depth-first +/// manner, such that an entire subtransform is completed before moving to the +/// next subtransform. This reduces the number of cache misses, improving +/// performance on larger transform sizes. +void Forward_FFTLike_ToBitReverseAVX512( + double* result_cmplx_intrlvd, const double* operand_cmplx_intrlvd, + const double* roots_of_unity_cmplx_intrlvd, const uint64_t n, + const double* scale = nullptr, uint64_t recursion_depth = 0, + uint64_t recursion_half = 0); + +/// @brief Construct floating-point values from CRT-composed polynomial with +/// integer coefficients in AVX512. +/// @param[out] res_cmplx_intrlvd Stores the result +/// @param[in] plain Plaintext +/// @param[in] threshold Upper half threshold with respect to the total +/// coefficient modulus +/// @param[in] decryption_modulus Product of all primes in the coefficient +/// modulus +/// @param[in] inv_scale Scale applied to output values +/// @param[in] mod_size Size of coefficient modulus parameter +/// @param[in] coeff_count Degree of the polynomial modulus parameter +void BuildFloatingPointsAVX512(double* res_cmplx_intrlvd, const uint64_t* plain, + const uint64_t* threshold, + const uint64_t* decryption_modulus, + const double inv_scale, const size_t mod_size, + const size_t coeff_count); + +#endif // HEXL_HAS_AVX512DQ + +} // namespace hexl +} // namespace intel diff --git a/hexl_v3/include/hexl/experimental/fft-like/inv-fft-like-avx512.hpp b/hexl_v3/include/hexl/experimental/fft-like/inv-fft-like-avx512.hpp new file mode 100644 index 00000000..487e2828 --- /dev/null +++ b/hexl_v3/include/hexl/experimental/fft-like/inv-fft-like-avx512.hpp @@ -0,0 +1,42 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "hexl/experimental/fft-like/fft-like.hpp" + +namespace intel { +namespace hexl { + +#ifdef HEXL_HAS_AVX512DQ + +/// @brief AVX512 implementation of the inverse FFT like +/// @param[out] result_cmplx_intrlvd Output data. Overwritten with FFT like +/// output. Result is a vector of double with interleaved real and imaginary +/// numbers. +/// @param[in] operand_cmplx_intrlvd Input data. A vector of double with +/// interleaved real and imaginary numbers. +/// @param[in] inv_roots_of_unity_cmplx_intrlvd Powers of 2n'th root of unity. +/// In bit-reversed order. +/// @param[in] n Size of the transform, i.e. the polynomial degree. Must be a +/// power of two. +/// @param[in] scale Scale applied to output values +/// @param[in] recursion_depth Depth of recursive call +/// @param[in] recursion_half Helper for indexing roots of unity +/// @details The implementation is recursive. The base case is a breadth-first +/// FFT like, where all the butterflies in a given stage are processed before +/// any butterflies in the next stage. The base case is small enough to fit in +/// the smallest cache. Larger FFTs are processed recursively in a depth-first +/// manner, such that an entire subtransform is completed before moving to the +/// next subtransform. This reduces the number of cache misses, improving +/// performance on larger transform sizes. +void Inverse_FFTLike_FromBitReverseAVX512( + double* result_cmplx_intrlvd, const double* operand_cmplx_intrlvd, + const double* inv_root_of_unity_cmplxintrlvd, const uint64_t n, + const double* scale = nullptr, uint64_t recursion_depth = 0, + uint64_t recursion_half = 0); + +#endif // HEXL_HAS_AVX512DQ + +} // namespace hexl +} // namespace intel diff --git a/hexl_v3/include/hexl/experimental/misc/lr-mat-vec-mult.hpp b/hexl_v3/include/hexl/experimental/misc/lr-mat-vec-mult.hpp new file mode 100644 index 00000000..df03df92 --- /dev/null +++ b/hexl_v3/include/hexl/experimental/misc/lr-mat-vec-mult.hpp @@ -0,0 +1,31 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +namespace intel { +namespace hexl { + +/// @brief Computes transposed linear regression +/// @param[in,out] result Ciphertext data. Will be over-written with result. Has +/// (3 * n * num_moduli) elements +/// @param[in] operand1 Vector of ciphertext representing a matrix that encodes +/// a transposed logistic regression model. Has (num_weights * 2 * n * +/// num_moduli) elements. +/// @param[in] operand2 Vector of ciphertext representing a matrix that encodes +/// at most n/2 input samples with feature size num_weights. Has (num_weights * +/// 2 * n * num_moduli) elements. +/// @param[in] n Number of coefficients in each polynomial +/// @param[in] moduli Pointer to contiguous array of num_moduli word-sized +/// coefficient moduli +/// @param[in] num_moduli Number of word-sized coefficient moduli +/// @param[in] num_weights Feature size of the linear/logistic regression model +void LinRegMatrixVectorMultiply(uint64_t* result, const uint64_t* operand1, + const uint64_t* operand2, uint64_t n, + const uint64_t* moduli, uint64_t num_moduli, + uint64_t num_weights); + +} // namespace hexl +} // namespace intel diff --git a/hexl_v3/include/hexl/experimental/seal/dyadic-multiply-internal.hpp b/hexl_v3/include/hexl/experimental/seal/dyadic-multiply-internal.hpp new file mode 100644 index 00000000..310a46b0 --- /dev/null +++ b/hexl_v3/include/hexl/experimental/seal/dyadic-multiply-internal.hpp @@ -0,0 +1,29 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +namespace intel { +namespace hexl { +namespace internal { + +/// @brief Computes dyadic multiplication +/// @param[in,out] result Ciphertext data. Will be over-written with result. Has +/// (2 * n * num_moduli) elements +/// @param[in] operand1 First ciphertext argument. Has (2 * n * num_moduli) +/// elements. +/// @param[in] operand2 Second ciphertext argument. Has (2 * n * num_moduli) +/// elements. +/// @param[in] n Number of coefficients in each polynomial +/// @param[in] moduli Pointer to contiguous array of num_moduli word-sized +/// coefficient moduli +/// @param[in] num_moduli Number of word-sized coefficient moduli +void DyadicMultiply(uint64_t* result, const uint64_t* operand1, + const uint64_t* operand2, uint64_t n, + const uint64_t* moduli, uint64_t num_moduli); + +} // namespace internal +} // namespace hexl +} // namespace intel diff --git a/hexl_v3/include/hexl/experimental/seal/dyadic-multiply.hpp b/hexl_v3/include/hexl/experimental/seal/dyadic-multiply.hpp new file mode 100644 index 00000000..f7eacfdf --- /dev/null +++ b/hexl_v3/include/hexl/experimental/seal/dyadic-multiply.hpp @@ -0,0 +1,27 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +namespace intel { +namespace hexl { + +/// @brief Computes dyadic multiplication +/// @param[in,out] result Ciphertext data. Will be over-written with result. Has +/// (2 * n * num_moduli) elements +/// @param[in] operand1 First ciphertext argument. Has (2 * n * num_moduli) +/// elements. +/// @param[in] operand2 Second ciphertext argument. Has (2 * n * num_moduli) +/// elements. +/// @param[in] n Number of coefficients in each polynomial +/// @param[in] moduli Pointer to contiguous array of num_moduli word-sized +/// coefficient moduli +/// @param[in] num_moduli Number of word-sized coefficient moduli +void DyadicMultiply(uint64_t* result, const uint64_t* operand1, + const uint64_t* operand2, uint64_t n, + const uint64_t* moduli, uint64_t num_moduli); + +} // namespace hexl +} // namespace intel diff --git a/hexl_v3/include/hexl/experimental/seal/key-switch-internal.hpp b/hexl_v3/include/hexl/experimental/seal/key-switch-internal.hpp new file mode 100644 index 00000000..8fc9d53e --- /dev/null +++ b/hexl_v3/include/hexl/experimental/seal/key-switch-internal.hpp @@ -0,0 +1,44 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +namespace intel { +namespace hexl { +namespace internal { + +/// @brief Computes key switching in-place +/// @param[in,out] result Ciphertext data. Will be over-written with result. Has +/// (n * decomp_modulus_size * key_component_count) elements +/// @param[in] t_target_iter_ptr Pointer to the last component of the input +/// ciphertext +/// @param[in] n Number of coefficients in each polynomial +/// @param[in] decomp_modulus_size Number of moduli in the ciphertext at its +/// current level, excluding one auxiliary prime. +/// @param[in] key_modulus_size Number of moduli in the ciphertext at its top +/// level, including one auxiliary prime. +/// @param[in] rns_modulus_size Number of moduli in the ciphertext at its +/// current level, including one auxiliary prime. rns_modulus_size == +/// decomp_modulus_size + 1 +/// @param[in] key_component_count Number of components in the resulting +/// ciphertext, e.g. key_component_count == 2. +/// @param[in] moduli Array of word-sized coefficient moduli. There must be +/// key_modulus_size moduli in the array +/// @param[in] k_switch_keys Array of evaluation key data. Has +/// decomp_modulus_size entries, each with +/// coeff_count * ((key_modulus_size - 1)+ (key_component_count - 1) * +/// (key_modulus_size) + 1) entries +/// @param[in] modswitch_factors Array of modulus switch factors +/// @param[in] root_of_unity_powers_ptr Array of root of unity powers +void KeySwitch(uint64_t* result, const uint64_t* t_target_iter_ptr, uint64_t n, + uint64_t decomp_modulus_size, uint64_t key_modulus_size, + uint64_t rns_modulus_size, uint64_t key_component_count, + const uint64_t* moduli, const uint64_t** k_switch_keys, + const uint64_t* modswitch_factors, + const uint64_t* root_of_unity_powers_ptr = nullptr); + +} // namespace internal +} // namespace hexl +} // namespace intel diff --git a/hexl_v3/include/hexl/experimental/seal/key-switch.hpp b/hexl_v3/include/hexl/experimental/seal/key-switch.hpp new file mode 100644 index 00000000..9eda159c --- /dev/null +++ b/hexl_v3/include/hexl/experimental/seal/key-switch.hpp @@ -0,0 +1,42 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +namespace intel { +namespace hexl { + +/// @brief Computes key switching in-place +/// @param[in,out] result Ciphertext data. Will be over-written with result. Has +/// (n * decomp_modulus_size * key_component_count) elements +/// @param[in] t_target_iter_ptr Pointer to the last component of the input +/// ciphertext +/// @param[in] n Number of coefficients in each polynomial +/// @param[in] decomp_modulus_size Number of moduli in the ciphertext at its +/// current level, excluding one auxiliary prime. +/// @param[in] key_modulus_size Number of moduli in the ciphertext at its top +/// level, including one auxiliary prime. +/// @param[in] rns_modulus_size Number of moduli in the ciphertext at its +/// current level, including one auxiliary prime. rns_modulus_size == +/// decomp_modulus_size + 1 +/// @param[in] key_component_count Number of components in the resulting +/// ciphertext, e.g. key_component_count == 2. +/// @param[in] moduli Array of word-sized coefficient moduli. There must be +/// key_modulus_size moduli in the array +/// @param[in] k_switch_keys Array of evaluation key data. Has +/// decomp_modulus_size entries, each with +/// coeff_count * ((key_modulus_size - 1)+ (key_component_count - 1) * +/// (key_modulus_size) + 1) entries +/// @param[in] modswitch_factors Array of modulus switch factors +/// @param[in] root_of_unity_powers_ptr Array of root of unity powers +void KeySwitch(uint64_t* result, const uint64_t* t_target_iter_ptr, uint64_t n, + uint64_t decomp_modulus_size, uint64_t key_modulus_size, + uint64_t rns_modulus_size, uint64_t key_component_count, + const uint64_t* moduli, const uint64_t** k_switch_keys, + const uint64_t* modswitch_factors, + const uint64_t* root_of_unity_powers_ptr = nullptr); + +} // namespace hexl +} // namespace intel diff --git a/hexl_v3/include/hexl/experimental/seal/locks.hpp b/hexl_v3/include/hexl/experimental/seal/locks.hpp new file mode 100644 index 00000000..4595f4e5 --- /dev/null +++ b/hexl_v3/include/hexl/experimental/seal/locks.hpp @@ -0,0 +1,35 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include + +namespace intel { +namespace hexl { + +using Lock = std::shared_mutex; +using WriteLock = std::unique_lock; +using ReadLock = std::shared_lock; + +class RWLock { + public: + RWLock() = default; + inline ReadLock AcquireRead() { return ReadLock(rw_mutex); } + inline WriteLock AcquireWrite() { return WriteLock(rw_mutex); } + inline ReadLock TryAcquireRead() noexcept { + return ReadLock(rw_mutex, std::try_to_lock); + } + inline WriteLock TryAcquireWrite() noexcept { + return WriteLock(rw_mutex, std::try_to_lock); + } + + private: + RWLock(const RWLock& copy) = delete; + RWLock& operator=(const RWLock& assign) = delete; + Lock rw_mutex{}; +}; + +} // namespace hexl +} // namespace intel diff --git a/hexl_v3/include/hexl/experimental/seal/ntt-cache.hpp b/hexl_v3/include/hexl/experimental/seal/ntt-cache.hpp new file mode 100644 index 00000000..8f6c1046 --- /dev/null +++ b/hexl_v3/include/hexl/experimental/seal/ntt-cache.hpp @@ -0,0 +1,56 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "hexl/experimental/seal/locks.hpp" +#include "ntt/ntt-internal.hpp" + +namespace intel { +namespace hexl { + +struct HashPair { + template + std::size_t operator()(const std::pair& p) const { + auto hash1 = std::hash{}(p.first); + auto hash2 = std::hash{}(p.second); + return hash_combine(hash1, hash2); + } + + // Golden Ratio Hashing with seeds + static std::size_t hash_combine(std::size_t lhs, std::size_t rhs) { + lhs ^= rhs + 0x9e3779b9 + (lhs << 6) + (lhs >> 2); + return lhs; + } +}; + +NTT& GetNTT(size_t N, uint64_t modulus) { + static std::unordered_map, NTT, HashPair> + ntt_cache; + static RWLock ntt_cache_locker; + + std::pair key{N, modulus}; + + // Enable shared access to NTT already present + { + ReadLock reader_lock(ntt_cache_locker.AcquireRead()); + auto ntt_it = ntt_cache.find(key); + if (ntt_it != ntt_cache.end()) { + return ntt_it->second; + } + } + + // Deal with NTT not yet present + WriteLock write_lock(ntt_cache_locker.AcquireWrite()); + + // Check ntt_cache for value (may be added by another thread) + auto ntt_it = ntt_cache.find(key); + if (ntt_it == ntt_cache.end()) { + NTT ntt(N, modulus); + ntt_it = ntt_cache.emplace(std::move(key), std::move(ntt)).first; + } + return ntt_it->second; +} + +} // namespace hexl +} // namespace intel diff --git a/hexl_v3/include/hexl/hexl.hpp b/hexl_v3/include/hexl/hexl.hpp new file mode 100644 index 00000000..6f07ae57 --- /dev/null +++ b/hexl_v3/include/hexl/hexl.hpp @@ -0,0 +1,26 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "hexl/eltwise/eltwise-add-mod.hpp" +#include "hexl/eltwise/eltwise-cmp-add.hpp" +#include "hexl/eltwise/eltwise-cmp-sub-mod.hpp" +#include "hexl/eltwise/eltwise-fma-mod.hpp" +#include "hexl/eltwise/eltwise-mult-mod.hpp" +#include "hexl/eltwise/eltwise-reduce-mod.hpp" +#include "hexl/eltwise/eltwise-sub-mod.hpp" +#include "hexl/experimental/fft-like/fft-like.hpp" +#include "hexl/experimental/misc/lr-mat-vec-mult.hpp" +#include "hexl/experimental/seal/dyadic-multiply-internal.hpp" +#include "hexl/experimental/seal/dyadic-multiply.hpp" +#include "hexl/experimental/seal/key-switch-internal.hpp" +#include "hexl/experimental/seal/key-switch.hpp" +#include "hexl/logging/logging.hpp" +#include "hexl/ntt/ntt.hpp" +#include "hexl/number-theory/number-theory.hpp" +#include "hexl/util/check.hpp" +#include "hexl/util/compiler.hpp" +#include "hexl/util/defines.hpp" +#include "hexl/util/types.hpp" +#include "hexl/util/util.hpp" diff --git a/hexl_v3/include/hexl/logging/logging.hpp b/hexl_v3/include/hexl/logging/logging.hpp new file mode 100644 index 00000000..af5bfcd8 --- /dev/null +++ b/hexl_v3/include/hexl/logging/logging.hpp @@ -0,0 +1,43 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include + +#include "hexl/util/defines.hpp" + +// Wrap HEXL_VLOG with HEXL_DEBUG; this ensures no logging overhead in +// release mode +#ifdef HEXL_DEBUG + +// TODO(fboemer) Enable if needed +// #define ELPP_THREAD_SAFE +#define ELPP_CUSTOM_COUT std::cerr +#define ELPP_STL_LOGGING +#define ELPP_LOG_STD_ARRAY +#define ELPP_LOG_UNORDERED_MAP +#define ELPP_LOG_UNORDERED_SET +#define ELPP_NO_LOG_TO_FILE +#define ELPP_DISABLE_DEFAULT_CRASH_HANDLING +#define ELPP_WINSOCK2 + +#include + +#define HEXL_VLOG(N, rest) \ + do { \ + if (VLOG_IS_ON(N)) { \ + VLOG(N) << rest; \ + } \ + } while (0); + +#else + +#define HEXL_VLOG(N, rest) \ + {} + +#define START_EASYLOGGINGPP(X, Y) \ + {} + +#endif diff --git a/hexl_v3/include/hexl/ntt/ntt.hpp b/hexl_v3/include/hexl/ntt/ntt.hpp new file mode 100644 index 00000000..93ccba72 --- /dev/null +++ b/hexl_v3/include/hexl/ntt/ntt.hpp @@ -0,0 +1,296 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +#include +#include + +#include "hexl/util/aligned-allocator.hpp" +#include "hexl/util/allocator.hpp" + +namespace intel { +namespace hexl { + +/// @brief Performs negacyclic forward and inverse number-theoretic transform +/// (NTT), commonly used in RLWE cryptography. +/// @details The number-theoretic transform (NTT) specializes the discrete +/// Fourier transform (DFT) to the finite field \f$ \mathbb{Z}_q[X] / (X^N + 1) +/// \f$. +class NTT { + public: + /// @brief Helper class for custom memory allocation + template + struct AllocatorAdapter + : public AllocatorInterface> { + explicit AllocatorAdapter(Adaptee&& _a, Args&&... args); + AllocatorAdapter(const Adaptee& _a, Args&... args); + + // interface implementation + void* allocate_impl(size_t bytes_count); + void deallocate_impl(void* p, size_t n); + + private: + Adaptee alloc; + }; + + /// @brief Initializes an empty NTT object + NTT() = default; + + /// @brief Destructs the NTT object + ~NTT() = default; + + /// @brief Initializes an NTT object with degree \p degree and modulus \p q. + /// @param[in] degree also known as N. Size of the NTT transform. Must be a + /// power of + /// 2 + /// @param[in] q Prime modulus. Must satisfy \f$ q == 1 \mod 2N \f$ + /// @param[in] alloc_ptr Custom memory allocator used for intermediate + /// calculations + /// @brief Performs pre-computation necessary for forward and inverse + /// transforms + NTT(uint64_t degree, uint64_t q, + std::shared_ptr alloc_ptr = {}); + + template + NTT(uint64_t degree, uint64_t q, Allocator&& a, AllocatorArgs&&... args) + : NTT(degree, q, + std::static_pointer_cast( + std::make_shared>( + std::move(a), std::forward(args)...))) {} + + /// @brief Initializes an NTT object with degree \p degree and modulus + /// \p q. + /// @param[in] degree also known as N. Size of the NTT transform. Must be a + /// power of + /// 2 + /// @param[in] q Prime modulus. Must satisfy \f$ q == 1 \mod 2N \f$ + /// @param[in] root_of_unity 2N'th root of unity in \f$ \mathbb{Z_q} \f$. + /// @param[in] alloc_ptr Custom memory allocator used for intermediate + /// calculations + /// @details Performs pre-computation necessary for forward and inverse + /// transforms + NTT(uint64_t degree, uint64_t q, uint64_t root_of_unity, + std::shared_ptr alloc_ptr = {}); + + template + NTT(uint64_t degree, uint64_t q, uint64_t root_of_unity, Allocator&& a, + AllocatorArgs&&... args) + : NTT(degree, q, root_of_unity, + std::static_pointer_cast( + std::make_shared>( + std::move(a), std::forward(args)...))) {} + + /// @brief Returns true if arguments satisfy constraints for negacyclic NTT + /// @param[in] degree N. Size of the transform, i.e. the polynomial degree. + /// Must be a power of two. + /// @param[in] modulus Prime modulus q. Must satisfy q mod 2N = 1 + static bool CheckArguments(uint64_t degree, uint64_t modulus); + + /// @brief Compute forward NTT. Results are bit-reversed. + /// @param[out] result Stores the result + /// @param[in] operand Data on which to compute the NTT + /// @param[in] input_mod_factor Assume input \p operand are in [0, + /// input_mod_factor * q). Must be 1, 2 or 4. + /// @param[in] output_mod_factor Returns output \p result in [0, + /// output_mod_factor * q). Must be 1 or 4. + void ComputeForward(uint64_t* result, const uint64_t* operand, + uint64_t input_mod_factor, uint64_t output_mod_factor); + + /// Compute inverse NTT. Results are bit-reversed. + /// @param[out] result Stores the result + /// @param[in] operand Data on which to compute the NTT + /// @param[in] input_mod_factor Assume input \p operand are in [0, + /// input_mod_factor * q). Must be 1 or 2. + /// @param[in] output_mod_factor Returns output \p result in [0, + /// output_mod_factor * q). Must be 1 or 2. + void ComputeInverse(uint64_t* result, const uint64_t* operand, + uint64_t input_mod_factor, uint64_t output_mod_factor); + + /// @brief Returns the minimal 2N'th root of unity + uint64_t GetMinimalRootOfUnity() const { return m_w; } + + /// @brief Returns the degree N + uint64_t GetDegree() const { return m_degree; } + + /// @brief Returns the word-sized prime modulus + uint64_t GetModulus() const { return m_q; } + + /// @brief Returns the root of unity powers in bit-reversed order + const AlignedVector64& GetRootOfUnityPowers() const { + return m_root_of_unity_powers; + } + + /// @brief Returns the root of unity power at bit-reversed index i. + uint64_t GetRootOfUnityPower(size_t i) { return GetRootOfUnityPowers()[i]; } + + /// @brief Returns 32-bit pre-conditioned root of unity powers in + /// bit-reversed order + const AlignedVector64& GetPrecon32RootOfUnityPowers() const { + return m_precon32_root_of_unity_powers; + } + + /// @brief Returns 64-bit pre-conditioned root of unity powers in + /// bit-reversed order + const AlignedVector64& GetPrecon64RootOfUnityPowers() const { + return m_precon64_root_of_unity_powers; + } + + /// @brief Returns the root of unity powers in bit-reversed order with + /// modifications for use by AVX512 implementation + const AlignedVector64& GetAVX512RootOfUnityPowers() const { + return m_avx512_root_of_unity_powers; + } + + /// @brief Returns 32-bit pre-conditioned AVX512 root of unity powers in + /// bit-reversed order + const AlignedVector64& GetAVX512Precon32RootOfUnityPowers() const { + return m_avx512_precon32_root_of_unity_powers; + } + + /// @brief Returns 52-bit pre-conditioned AVX512 root of unity powers in + /// bit-reversed order + const AlignedVector64& GetAVX512Precon52RootOfUnityPowers() const { + return m_avx512_precon52_root_of_unity_powers; + } + + /// @brief Returns 64-bit pre-conditioned AVX512 root of unity powers in + /// bit-reversed order + const AlignedVector64& GetAVX512Precon64RootOfUnityPowers() const { + return m_avx512_precon64_root_of_unity_powers; + } + + /// @brief Returns the inverse root of unity powers in bit-reversed order + const AlignedVector64& GetInvRootOfUnityPowers() const { + return m_inv_root_of_unity_powers; + } + + /// @brief Returns the inverse root of unity power at bit-reversed index i. + uint64_t GetInvRootOfUnityPower(size_t i) { + return GetInvRootOfUnityPowers()[i]; + } + + /// @brief Returns the vector of 32-bit pre-conditioned pre-computed root of + /// unity + // powers for the modulus and root of unity. + const AlignedVector64& GetPrecon32InvRootOfUnityPowers() const { + return m_precon32_inv_root_of_unity_powers; + } + + /// @brief Returns the vector of 52-bit pre-conditioned pre-computed root of + /// unity + // powers for the modulus and root of unity. + const AlignedVector64& GetPrecon52InvRootOfUnityPowers() const { + return m_precon52_inv_root_of_unity_powers; + } + + /// @brief Returns the vector of 64-bit pre-conditioned pre-computed root of + /// unity + // powers for the modulus and root of unity. + const AlignedVector64& GetPrecon64InvRootOfUnityPowers() const { + return m_precon64_inv_root_of_unity_powers; + } + + /// @brief Maximum power of 2 in degree + static size_t MaxDegreeBits() { return 20; } + + /// @brief Maximum number of bits in modulus; + static size_t MaxModulusBits() { return 62; } + + /// @brief Default bit shift used in Barrett precomputation + static const size_t s_default_shift_bits{64}; + + /// @brief Bit shift used in Barrett precomputation when AVX512-IFMA + /// acceleration is enabled + static const size_t s_ifma_shift_bits{52}; + + /// @brief Maximum modulus to use 32-bit AVX512-DQ acceleration for the + /// forward transform + static const size_t s_max_fwd_32_modulus{1ULL << (32 - 2)}; + + /// @brief Maximum modulus to use 32-bit AVX512-DQ acceleration for the + /// inverse transform + static const size_t s_max_inv_32_modulus{1ULL << (32 - 2)}; + + /// @brief Maximum modulus to use AVX512-IFMA acceleration for the forward + /// transform + static const size_t s_max_fwd_ifma_modulus{1ULL << (s_ifma_shift_bits - 2)}; + + /// @brief Maximum modulus to use AVX512-IFMA acceleration for the inverse + /// transform + static const size_t s_max_inv_ifma_modulus{1ULL << (s_ifma_shift_bits - 2)}; + + /// @brief Maximum modulus to use AVX512-DQ acceleration for the inverse + /// transform + static const size_t s_max_inv_dq_modulus{1ULL << (s_default_shift_bits - 2)}; + + static size_t s_max_fwd_modulus(int bit_shift) { + if (bit_shift == 32) { + return s_max_fwd_32_modulus; + } else if (bit_shift == 52) { + return s_max_fwd_ifma_modulus; + } else if (bit_shift == 64) { + return 1ULL << MaxModulusBits(); + } + HEXL_CHECK(false, "Invalid bit_shift " << bit_shift); + return 0; + } + + static size_t s_max_inv_modulus(int bit_shift) { + if (bit_shift == 32) { + return s_max_inv_32_modulus; + } else if (bit_shift == 52) { + return s_max_inv_ifma_modulus; + } else if (bit_shift == 64) { + return 1ULL << MaxModulusBits(); + } + HEXL_CHECK(false, "Invalid bit_shift " << bit_shift); + return 0; + } + + private: + void ComputeRootOfUnityPowers(); + + uint64_t m_degree; // N: size of NTT transform, should be power of 2 + uint64_t m_q; // prime modulus. Must satisfy q == 1 mod 2n + + uint64_t m_degree_bits; // log_2(m_degree) + + uint64_t m_w_inv; // Inverse of minimal root of unity + uint64_t m_w; // A 2N'th root of unity + + std::shared_ptr m_alloc; + + AlignedAllocator m_aligned_alloc; + + // powers of the minimal root of unity + AlignedVector64 m_root_of_unity_powers; + // vector of floor(W * 2**32 / m_q), with W the root of unity powers + AlignedVector64 m_precon32_root_of_unity_powers; + // vector of floor(W * 2**64 / m_q), with W the root of unity powers + AlignedVector64 m_precon64_root_of_unity_powers; + + // powers of the minimal root of unity adjusted for use in AVX512 + // implementations + AlignedVector64 m_avx512_root_of_unity_powers; + // vector of floor(W * 2**32 / m_q), with W the AVX512 root of unity powers + AlignedVector64 m_avx512_precon32_root_of_unity_powers; + // vector of floor(W * 2**52 / m_q), with W the AVX512 root of unity powers + AlignedVector64 m_avx512_precon52_root_of_unity_powers; + // vector of floor(W * 2**64 / m_q), with W the AVX512 root of unity powers + AlignedVector64 m_avx512_precon64_root_of_unity_powers; + + // vector of floor(W * 2**32 / m_q), with W the inverse root of unity powers + AlignedVector64 m_precon32_inv_root_of_unity_powers; + // vector of floor(W * 2**52 / m_q), with W the inverse root of unity powers + AlignedVector64 m_precon52_inv_root_of_unity_powers; + // vector of floor(W * 2**64 / m_q), with W the inverse root of unity powers + AlignedVector64 m_precon64_inv_root_of_unity_powers; + + AlignedVector64 m_inv_root_of_unity_powers; +}; + +} // namespace hexl +} // namespace intel diff --git a/hexl_v3/include/hexl/number-theory/number-theory.hpp b/hexl_v3/include/hexl/number-theory/number-theory.hpp new file mode 100644 index 00000000..da8d1d2a --- /dev/null +++ b/hexl_v3/include/hexl/number-theory/number-theory.hpp @@ -0,0 +1,342 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +#include +#include + +#include "hexl/util/check.hpp" +#include "hexl/util/compiler.hpp" + +namespace intel { +namespace hexl { + +/// @brief Pre-computes a Barrett factor with which modular multiplication can +/// be performed more efficiently +class MultiplyFactor { + public: + MultiplyFactor() = default; + + /// @brief Computes and stores the Barrett factor floor((operand << bit_shift) + /// / modulus). This is useful when modular multiplication of the form + /// (x * operand) mod modulus is performed with same modulus and operand + /// several times. Note, passing operand=1 can be used to pre-compute a + /// Barrett factor for multiplications of the form (x * y) mod modulus, where + /// only the modulus is re-used across calls to modular multiplication. + MultiplyFactor(uint64_t operand, uint64_t bit_shift, uint64_t modulus) + : m_operand(operand) { + HEXL_CHECK(operand <= modulus, "operand " << operand + << " must be less than modulus " + << modulus); + HEXL_CHECK(bit_shift == 32 || bit_shift == 52 || bit_shift == 64, + "Unsupported BitShift " << bit_shift); + uint64_t op_hi = operand >> (64 - bit_shift); + uint64_t op_lo = (bit_shift == 64) ? 0 : (operand << bit_shift); + + m_barrett_factor = DivideUInt128UInt64Lo(op_hi, op_lo, modulus); + } + + /// @brief Returns the pre-computed Barrett factor + inline uint64_t BarrettFactor() const { return m_barrett_factor; } + + /// @brief Returns the operand corresponding to the Barrett factor + inline uint64_t Operand() const { return m_operand; } + + private: + uint64_t m_operand; + uint64_t m_barrett_factor; +}; + +/// @brief Returns whether or not num is a power of two +inline bool IsPowerOfTwo(uint64_t num) { return num && !(num & (num - 1)); } + +/// @brief Returns floor(log2(x)) +inline uint64_t Log2(uint64_t x) { return MSB(x); } + +inline bool IsPowerOfFour(uint64_t num) { + return IsPowerOfTwo(num) && (Log2(num) % 2 == 0); +} + +/// @brief Returns the maximum value that can be represented using \p bits bits +inline uint64_t MaximumValue(uint64_t bits) { + HEXL_CHECK(bits <= 64, "MaximumValue requires bits <= 64; got " << bits); + if (bits == 64) { + return (std::numeric_limits::max)(); + } + return (1ULL << bits) - 1; +} + +/// @brief Reverses the bits +/// @param[in] x Input to reverse +/// @param[in] bit_width Number of bits in the input; must be >= MSB(x) +/// @return The bit-reversed representation of \p x using \p bit_width bits +uint64_t ReverseBits(uint64_t x, uint64_t bit_width); + +/// @brief Returns x^{-1} mod modulus +/// @details Requires x % modulus != 0 +uint64_t InverseMod(uint64_t x, uint64_t modulus); + +/// @brief Returns (x * y) mod modulus +/// @details Assumes x, y < modulus +uint64_t MultiplyMod(uint64_t x, uint64_t y, uint64_t modulus); + +/// @brief Returns (x * y) mod modulus +/// @param[in] x +/// @param[in] y +/// @param[in] y_precon 64-bit precondition factor floor(2**64 / modulus) +/// @param[in] modulus +uint64_t MultiplyMod(uint64_t x, uint64_t y, uint64_t y_precon, + uint64_t modulus); + +/// @brief Returns (x + y) mod modulus +/// @details Assumes x, y < modulus +uint64_t AddUIntMod(uint64_t x, uint64_t y, uint64_t modulus); + +/// @brief Returns (x - y) mod modulus +/// @details Assumes x, y < modulus +uint64_t SubUIntMod(uint64_t x, uint64_t y, uint64_t modulus); + +/// @brief Returns base^exp mod modulus +uint64_t PowMod(uint64_t base, uint64_t exp, uint64_t modulus); + +/// @brief Returns whether or not root is a degree-th root of unity mod modulus +/// @param[in] root Root of unity to check +/// @param[in] degree Degree of root of unity; must be a power of two +/// @param[in] modulus Modulus of finite field +bool IsPrimitiveRoot(uint64_t root, uint64_t degree, uint64_t modulus); + +/// @brief Tries to return a primitive degree-th root of unity +/// @details Returns 0 or throws an error if no root is found +uint64_t GeneratePrimitiveRoot(uint64_t degree, uint64_t modulus); + +/// @brief Returns whether or not root is a degree-th root of unity +/// @param[in] degree Must be a power of two +/// @param[in] modulus Modulus of finite field +uint64_t MinimalPrimitiveRoot(uint64_t degree, uint64_t modulus); + +/// @brief Computes (x * y) mod modulus, except that the output is in [0, 2 * +/// modulus] +/// @param[in] x +/// @param[in] y_operand also denoted y +/// @param[in] modulus +/// @param[in] y_barrett_factor Pre-computed Barrett reduction factor floor((y +/// << BitShift) / modulus) +template +inline uint64_t MultiplyModLazy(uint64_t x, uint64_t y_operand, + uint64_t y_barrett_factor, uint64_t modulus) { + HEXL_CHECK(y_operand < modulus, "y_operand " << y_operand + << " must be less than modulus " + << modulus); + HEXL_CHECK( + modulus <= MaximumValue(BitShift), + "Modulus " << modulus << " exceeds bound " << MaximumValue(BitShift)); + HEXL_CHECK(x <= MaximumValue(BitShift), + "Operand " << x << " exceeds bound " << MaximumValue(BitShift)); + + uint64_t Q = MultiplyUInt64Hi(x, y_barrett_factor); + return y_operand * x - Q * modulus; +} + +/// @brief Computes (x * y) mod modulus, except that the output is in [0, 2 * +/// modulus] +/// @param[in] x +/// @param[in] y +/// @param[in] modulus +template +inline uint64_t MultiplyModLazy(uint64_t x, uint64_t y, uint64_t modulus) { + HEXL_CHECK(BitShift == 64 || BitShift == 52, + "Unsupported BitShift " << BitShift); + HEXL_CHECK(x <= MaximumValue(BitShift), + "Operand " << x << " exceeds bound " << MaximumValue(BitShift)); + HEXL_CHECK(y < modulus, + "y " << y << " must be less than modulus " << modulus); + HEXL_CHECK( + modulus <= MaximumValue(BitShift), + "Modulus " << modulus << " exceeds bound " << MaximumValue(BitShift)); + + uint64_t y_barrett = MultiplyFactor(y, BitShift, modulus).BarrettFactor(); + return MultiplyModLazy(x, y, y_barrett, modulus); +} + +/// @brief Adds two unsigned 64-bit integers +/// @param operand1 Number to add +/// @param operand2 Number to add +/// @param result Stores the sum +/// @return The carry bit +inline unsigned char AddUInt64(uint64_t operand1, uint64_t operand2, + uint64_t* result) { + *result = operand1 + operand2; + return static_cast(*result < operand1); +} + +/// @brief Returns whether or not the input is prime +bool IsPrime(uint64_t n); + +/// @brief Generates a list of num_primes primes in the range [2^(bit_size), +// 2^(bit_size+1)]. Ensures each prime q satisfies +// q % (2*ntt_size+1)) == 1 +/// @param[in] num_primes Number of primes to generate +/// @param[in] bit_size Bit size of each prime +/// @param[in] prefer_small_primes When true, returns primes starting from +/// 2^(bit_size); when false, returns primes starting from 2^(bit_size+1) +/// @param[in] ntt_size N such that each prime q satisfies q % (2N) == 1. N must +/// be a power of two less than 2^bit_size. +std::vector GeneratePrimes(size_t num_primes, size_t bit_size, + bool prefer_small_primes, + size_t ntt_size = 1); + +/// @brief Returns input mod modulus, computed via 64-bit Barrett reduction +/// @param[in] input +/// @param[in] modulus +/// @param[in] q_barr floor(2^64 / modulus) +template +uint64_t BarrettReduce64(uint64_t input, uint64_t modulus, uint64_t q_barr) { + HEXL_CHECK(modulus != 0, "modulus == 0"); + uint64_t q = MultiplyUInt64Hi<64>(input, q_barr); + uint64_t q_times_input = input - q * modulus; + if (OutputModFactor == 2) { + return q_times_input; + } else { + return (q_times_input >= modulus) ? q_times_input - modulus : q_times_input; + } +} + +/// @brief Returns x mod modulus, assuming x < InputModFactor * modulus +/// @param[in] x +/// @param[in] modulus also denoted q +/// @param[in] twice_modulus 2 * q; must not be nullptr if InputModFactor == 4 +/// or 8 +/// @param[in] four_times_modulus 4 * q; must not be nullptr if InputModFactor +/// == 8 +template +uint64_t ReduceMod(uint64_t x, uint64_t modulus, + const uint64_t* twice_modulus = nullptr, + const uint64_t* four_times_modulus = nullptr) { + HEXL_CHECK(InputModFactor == 1 || InputModFactor == 2 || + InputModFactor == 4 || InputModFactor == 8, + "InputModFactor should be 1, 2, 4, or 8"); + if (InputModFactor == 1) { + return x; + } + if (InputModFactor == 2) { + if (x >= modulus) { + x -= modulus; + } + return x; + } + if (InputModFactor == 4) { + HEXL_CHECK(twice_modulus != nullptr, "twice_modulus should not be nullptr"); + if (x >= *twice_modulus) { + x -= *twice_modulus; + } + if (x >= modulus) { + x -= modulus; + } + return x; + } + if (InputModFactor == 8) { + HEXL_CHECK(twice_modulus != nullptr, "twice_modulus should not be nullptr"); + HEXL_CHECK(four_times_modulus != nullptr, + "four_times_modulus should not be nullptr"); + + if (x >= *four_times_modulus) { + x -= *four_times_modulus; + } + if (x >= *twice_modulus) { + x -= *twice_modulus; + } + if (x >= modulus) { + x -= modulus; + } + return x; + } + HEXL_CHECK(false, "Should be unreachable"); + return x; +} + +/// @brief Returns Montgomery form of ab mod q, computed via the REDC algorithm, +/// also known as Montgomery reduction. +/// @param[in] r +/// @param[in] q with R = 2^r such that gcd(R, q) = 1. R > q. +/// @param[in] inv_mod in [0, R − 1] such that q*v_inv_mod ≡ −1 mod R. +/// @param[in] mod_R_msk take r last bits to apply mod R. +/// @param[in] T_hi of T = ab in the range [0, Rq − 1]. +/// @param[in] T_lo of T. +/// @return Unsigned long int in the range [0, q − 1] such that S ≡ TR^−1 mod q +template +inline uint64_t MontgomeryReduce(uint64_t T_hi, uint64_t T_lo, uint64_t q, + int r, uint64_t mod_R_msk, uint64_t inv_mod) { + HEXL_CHECK(BitShift == 64 || BitShift == 52, + "Unsupported BitShift " << BitShift); + HEXL_CHECK((1ULL << r) > static_cast(q), + "R value should be greater than q = " << static_cast(q)); + + uint64_t mq_hi; + uint64_t mq_lo; + + uint64_t m = ((T_lo & mod_R_msk) * inv_mod) & mod_R_msk; + MultiplyUInt64(m, q, &mq_hi, &mq_lo); + + if (BitShift == 52) { + mq_hi = (mq_hi << 12) | (mq_lo >> 52); + mq_lo &= (1ULL << 52) - 1; + } + + uint64_t t_hi; + uint64_t t_lo; + + // first 64bit block + t_lo = T_lo + mq_lo; + unsigned int carry = static_cast(t_lo < T_lo); + t_hi = T_hi + mq_hi + carry; + + t_hi = t_hi << (BitShift - r); + t_lo = t_lo >> r; + t_lo = t_hi + t_lo; + + return (t_lo >= q) ? (t_lo - q) : t_lo; +} + +/// @brief Hensel's Lemma for 2-adic numbers +/// Find solution for qX + 1 = 0 mod 2^r +/// @param[in] r +/// @param[in] q such that gcd(2, q) = 1 +/// @return Unsigned long int in [0, 2^r − 1] such that q*x ≡ −1 mod 2^r +inline uint64_t HenselLemma2adicRoot(uint32_t r, uint64_t q) { + uint64_t a_prev = 1; + uint64_t c = 2; + uint64_t mod_mask = 3; + + // Root: + // f(x) = qX + 1 and a_(0) = 1 then f(1) ≡ 0 mod 2 + // General Case: + // - a_(n) ≡ a_(n-1) mod 2^(n) + // => a_(n) = a_(n-1) + 2^(n)*t + // - Find 't' such that f(a_(n)) = 0 mod 2^(n+1) + // First case in for: + // - a_(1) ≡ 1 mod 2 or a_(1) = 1 + 2t + // - Find 't' so f(a_(1)) ≡ 0 mod 4 => q(1 + 2t) + 1 ≡ 0 mod 4 + for (uint64_t k = 2; k <= r; k++) { + uint64_t f = 0; + uint64_t t = 0; + uint64_t a = 0; + + do { + a = a_prev + c * t++; + f = q * a + 1ULL; + } while (f & mod_mask); // f(a) ≡ 0 mod 2^(k) + + // Update vars + mod_mask = mod_mask * 2 + 1ULL; + c *= 2; + a_prev = a; + } + + return a_prev; +} + +} // namespace hexl +} // namespace intel diff --git a/hexl_v3/include/hexl/util/aligned-allocator.hpp b/hexl_v3/include/hexl/util/aligned-allocator.hpp new file mode 100644 index 00000000..d175c734 --- /dev/null +++ b/hexl_v3/include/hexl/util/aligned-allocator.hpp @@ -0,0 +1,110 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include +#include + +#include "hexl/number-theory/number-theory.hpp" +#include "hexl/util/allocator.hpp" +#include "hexl/util/defines.hpp" + +namespace intel { +namespace hexl { + +/// @brief Allocater implementation using malloc and free +struct MallocStrategy : AllocatorBase { + void* allocate(size_t bytes_count) final { return std::malloc(bytes_count); } + + void deallocate(void* p, size_t n) final { + HEXL_UNUSED(n); + std::free(p); + } +}; + +using AllocatorStrategyPtr = std::shared_ptr; +extern AllocatorStrategyPtr mallocStrategy; + +/// @brief Allocates memory aligned to Alignment-byte sized boundaries +/// @details Alignment must be a power of two +template +class AlignedAllocator { + public: + template + friend class AlignedAllocator; + + using value_type = T; + + explicit AlignedAllocator(AllocatorStrategyPtr strategy = nullptr) noexcept + : m_alloc_impl((strategy != nullptr) ? strategy : mallocStrategy) {} + + AlignedAllocator(const AlignedAllocator& src) = default; + AlignedAllocator& operator=(const AlignedAllocator& src) = default; + + template + AlignedAllocator(const AlignedAllocator& src) + : m_alloc_impl(src.m_alloc_impl) {} + + ~AlignedAllocator() {} + + template + struct rebind { + using other = AlignedAllocator; + }; + + bool operator==(const AlignedAllocator&) { return true; } + + bool operator!=(const AlignedAllocator&) { return false; } + + /// @brief Allocates \p n elements aligned to Alignment-byte boundaries + /// @return Pointer to the aligned allocated memory + T* allocate(size_t n) { + if (!IsPowerOfTwo(Alignment)) { + return nullptr; + } + // Allocate enough space to ensure the alignment can be satisfied + size_t buffer_size = sizeof(T) * n + Alignment; + // Additionally, allocate a prefix to store the memory location of the + // unaligned buffer + size_t alloc_size = buffer_size + sizeof(void*); + void* buffer = m_alloc_impl->allocate(alloc_size); + if (!buffer) { + return nullptr; + } + + // Reserve first location for pointer to originally-allocated space + void* aligned_buffer = static_cast(buffer) + sizeof(void*); + std::align(Alignment, sizeof(T) * n, aligned_buffer, buffer_size); + if (!aligned_buffer) { + return nullptr; + } + + // Store allocated buffer address at aligned_buffer - sizeof(void*). + void* store_buffer_addr = + static_cast(aligned_buffer) - sizeof(void*); + *(static_cast(store_buffer_addr)) = buffer; + + return static_cast(aligned_buffer); + } + + void deallocate(T* p, size_t n) { + if (!p) { + return; + } + void* store_buffer_addr = (reinterpret_cast(p) - sizeof(void*)); + void* free_address = *(static_cast(store_buffer_addr)); + m_alloc_impl->deallocate(free_address, n); + } + + private: + AllocatorStrategyPtr m_alloc_impl; +}; + +/// @brief 64-byte aligned memory allocator +template +using AlignedVector64 = std::vector >; + +} // namespace hexl +} // namespace intel diff --git a/hexl_v3/include/hexl/util/allocator.hpp b/hexl_v3/include/hexl/util/allocator.hpp new file mode 100644 index 00000000..5f4a7a31 --- /dev/null +++ b/hexl_v3/include/hexl/util/allocator.hpp @@ -0,0 +1,53 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +namespace intel { +namespace hexl { + +/// @brief Base class for custom memory allocator +struct AllocatorBase { + virtual ~AllocatorBase() noexcept {} + + /// @brief Allocates byte_count bytes of memory + /// @param[in] bytes_count Number of bytes to allocate + /// @return A pointer to the allocated memory + virtual void* allocate(size_t bytes_count) = 0; + + /// @brief Deallocate memory + /// @param[in] p Pointer to memory to deallocate + /// @param[in] n Number of bytes to deallocate + virtual void deallocate(void* p, size_t n) = 0; +}; + +/// @brief Helper memory allocation struct which delegates implementation to +/// AllocatorImpl +template +struct AllocatorInterface : public AllocatorBase { + /// @brief Override interface and delegate implementation to AllocatorImpl + void* allocate(size_t bytes_count) override { + return static_cast(this)->allocate_impl(bytes_count); + } + + /// @brief Override interface and delegate implementation to AllocatorImpl + void deallocate(void* p, size_t n) override { + static_cast(this)->deallocate_impl(p, n); + } + + private: + // in case AllocatorImpl doesn't provide implementations, use default null + // behavior + void* allocate_impl(size_t bytes_count) { + HEXL_UNUSED(bytes_count); + return nullptr; + } + void deallocate_impl(void* p, size_t n) { + HEXL_UNUSED(p); + HEXL_UNUSED(n); + } +}; +} // namespace hexl +} // namespace intel diff --git a/hexl_v3/include/hexl/util/check.hpp b/hexl_v3/include/hexl/util/check.hpp new file mode 100644 index 00000000..386eba89 --- /dev/null +++ b/hexl_v3/include/hexl/util/check.hpp @@ -0,0 +1,44 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +#include "hexl/util/types.hpp" + +// Create logging/debug macros with no run-time overhead unless HEXL_DEBUG is +// enabled +#ifdef HEXL_DEBUG +#include "hexl/logging/logging.hpp" + +/// @brief If input condition is not true, logs the expression and throws an +/// error +/// @param[in] cond A boolean indication the condition +/// @param[in] expr The expression to be logged +#define HEXL_CHECK(cond, expr) \ + if (!(cond)) { \ + LOG(ERROR) << expr << " in function: " << __FUNCTION__ \ + << " in file: " __FILE__ << ":" << __LINE__; \ + throw std::runtime_error("Error. Check log output"); \ + } + +/// @brief If input has an element >= bound, logs the expression and throws an +/// error +/// @param[in] arg Input container which supports the [] operator. +/// @param[in] n Size of input +/// @param[in] bound Upper bound on the input +/// @param[in] expr The expression to be logged +#define HEXL_CHECK_BOUNDS(arg, n, bound, expr) \ + for (size_t hexl_check_idx = 0; hexl_check_idx < n; ++hexl_check_idx) { \ + HEXL_CHECK((arg)[hexl_check_idx] < bound, expr); \ + } + +#else // HEXL_DEBUG=OFF + +#define HEXL_CHECK(cond, expr) \ + {} +#define HEXL_CHECK_BOUNDS(...) \ + {} + +#endif // HEXL_DEBUG diff --git a/hexl_v3/include/hexl/util/clang.hpp b/hexl_v3/include/hexl/util/clang.hpp new file mode 100644 index 00000000..958bea7b --- /dev/null +++ b/hexl_v3/include/hexl/util/clang.hpp @@ -0,0 +1,67 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once +#include + +#include "hexl/util/check.hpp" +#include "hexl/util/types.hpp" + +namespace intel { +namespace hexl { + +#ifdef HEXL_USE_CLANG +// Return x * y as 128-bit integer +// Correctness if x * y < 128 bits +inline uint128_t MultiplyUInt64(uint64_t x, uint64_t y) { + return uint128_t(x) * uint128_t(y); +} + +inline uint64_t BarrettReduce128(uint64_t input_hi, uint64_t input_lo, + uint64_t modulus) { + HEXL_CHECK(modulus != 0, "modulus == 0") + uint128_t n = (static_cast(input_hi) << 64) | + (static_cast(input_lo)); + + return n % modulus; + // TODO(fboemer): actually use barrett reduction if performance-critical +} + +// Returns low 64bit of 128b/64b where x1=high 64b, x0=low 64b +inline uint64_t DivideUInt128UInt64Lo(uint64_t x1, uint64_t x0, uint64_t y) { + uint128_t n = + (static_cast(x1) << 64) | (static_cast(x0)); + uint128_t q = n / y; + + return static_cast(q); +} + +// Multiplies x * y as 128-bit integer. +// @param prod_hi Stores high 64 bits of product +// @param prod_lo Stores low 64 bits of product +inline void MultiplyUInt64(uint64_t x, uint64_t y, uint64_t* prod_hi, + uint64_t* prod_lo) { + uint128_t prod = MultiplyUInt64(x, y); + *prod_hi = static_cast(prod >> 64); + *prod_lo = static_cast(prod); +} + +// Return the high 128 minus BitShift bits of the 128-bit product x * y +template +inline uint64_t MultiplyUInt64Hi(uint64_t x, uint64_t y) { + uint128_t product = static_cast(x) * y; + return static_cast(product >> BitShift); +} + +// Returns most-significant bit of the input +inline uint64_t MSB(uint64_t input) { + return static_cast(std::log2l(input)); +} + +#define HEXL_LOOP_UNROLL_4 _Pragma("clang loop unroll_count(4)") +#define HEXL_LOOP_UNROLL_8 _Pragma("clang loop unroll_count(8)") + +#endif + +} // namespace hexl +} // namespace intel diff --git a/hexl_v3/include/hexl/util/compiler.hpp b/hexl_v3/include/hexl/util/compiler.hpp new file mode 100644 index 00000000..7dd077df --- /dev/null +++ b/hexl_v3/include/hexl/util/compiler.hpp @@ -0,0 +1,14 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "hexl/util/defines.hpp" + +#ifdef HEXL_USE_MSVC +#include "hexl/util/msvc.hpp" +#elif defined HEXL_USE_GNU +#include "hexl/util/gcc.hpp" +#elif defined HEXL_USE_CLANG +#include "hexl/util/clang.hpp" +#endif diff --git a/hexl_v3/include/hexl/util/defines.hpp b/hexl_v3/include/hexl/util/defines.hpp new file mode 100644 index 00000000..93db376e --- /dev/null +++ b/hexl_v3/include/hexl/util/defines.hpp @@ -0,0 +1,13 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +/* #undef HEXL_USE_MSVC */ +#define HEXL_USE_GNU +/* #undef HEXL_USE_CLANG */ + +/* #undef HEXL_DEBUG */ + +// Avoid unused variable warnings +#define HEXL_UNUSED(x) (void)(x) diff --git a/hexl_v3/include/hexl/util/gcc.hpp b/hexl_v3/include/hexl/util/gcc.hpp new file mode 100644 index 00000000..828e3836 --- /dev/null +++ b/hexl_v3/include/hexl/util/gcc.hpp @@ -0,0 +1,67 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once +#include + +#include "hexl/util/check.hpp" +#include "hexl/util/types.hpp" + +namespace intel { +namespace hexl { + +#ifdef HEXL_USE_GNU +// Return x * y as 128-bit integer +// Correctness if x * y < 128 bits +inline uint128_t MultiplyUInt64(uint64_t x, uint64_t y) { + return uint128_t(x) * uint128_t(y); +} + +inline uint64_t BarrettReduce128(uint64_t input_hi, uint64_t input_lo, + uint64_t modulus) { + HEXL_CHECK(modulus != 0, "modulus == 0") + uint128_t n = (static_cast(input_hi) << 64) | + (static_cast(input_lo)); + + return static_cast(n % modulus); + // TODO(fboemer): actually use barrett reduction if performance-critical +} + +// Returns low 64bit of 128b/64b where x1=high 64b, x0=low 64b +inline uint64_t DivideUInt128UInt64Lo(uint64_t x1, uint64_t x0, uint64_t y) { + uint128_t n = + (static_cast(x1) << 64) | (static_cast(x0)); + uint128_t q = n / y; + + return static_cast(q); +} + +// Multiplies x * y as 128-bit integer. +// @param prod_hi Stores high 64 bits of product +// @param prod_lo Stores low 64 bits of product +inline void MultiplyUInt64(uint64_t x, uint64_t y, uint64_t* prod_hi, + uint64_t* prod_lo) { + uint128_t prod = MultiplyUInt64(x, y); + *prod_hi = static_cast(prod >> 64); + *prod_lo = static_cast(prod); +} + +// Return the high 128 minus BitShift bits of the 128-bit product x * y +template +inline uint64_t MultiplyUInt64Hi(uint64_t x, uint64_t y) { + uint128_t product = MultiplyUInt64(x, y); + return static_cast(product >> BitShift); +} + +// Returns most-significant bit of the input +inline uint64_t MSB(uint64_t input) { + return static_cast(std::log2l(input)); +} + +#define HEXL_LOOP_UNROLL_4 _Pragma("GCC unroll 4") +#define HEXL_LOOP_UNROLL_8 _Pragma("GCC unroll 8") + +#endif + +} // namespace hexl +} // namespace intel diff --git a/hexl_v3/include/hexl/util/msvc.hpp b/hexl_v3/include/hexl/util/msvc.hpp new file mode 100644 index 00000000..0ada2d45 --- /dev/null +++ b/hexl_v3/include/hexl/util/msvc.hpp @@ -0,0 +1,289 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#ifdef HEXL_USE_MSVC + +#define NOMINMAX // Avoid errors with std::min/std::max +#undef min +#undef max + +#include +#include +#include + +#include + +#include "hexl/util/check.hpp" + +#pragma intrinsic(_addcarry_u64, _BitScanReverse64, _subborrow_u64, _udiv128, \ + _umul128) + +#undef TRUE +#undef FALSE + +namespace intel { +namespace hexl { + +inline uint64_t BarrettReduce128(uint64_t input_hi, uint64_t input_lo, + uint64_t modulus) { + HEXL_CHECK(modulus != 0, "modulus == 0") + uint64_t remainder; + _udiv128(input_hi, input_lo, modulus, &remainder); + + return remainder; +} + +// Multiplies x * y as 128-bit integer. +// @param prod_hi Stores high 64 bits of product +// @param prod_lo Stores low 64 bits of product +inline void MultiplyUInt64(uint64_t x, uint64_t y, uint64_t* prod_hi, + uint64_t* prod_lo) { + *prod_lo = _umul128(x, y, prod_hi); +} + +// Return the high 128 minus BitShift bits of the 128-bit product x * y +template +inline uint64_t MultiplyUInt64Hi(uint64_t x, uint64_t y) { + HEXL_CHECK(BitShift == 52 || BitShift == 64, + "Invalid BitShift " << BitShift << "; expected 52 or 64"); + uint64_t prod_hi; + uint64_t prod_lo = _umul128(x, y, &prod_hi); + uint64_t result_hi; + uint64_t result_lo; + RightShift128(&result_hi, &result_lo, prod_hi, prod_lo, BitShift); + return result_lo; +} + +/// @brief Computes Left Shift op as 128-bit unsigned integer +/// @param[out] result_hi Stores high 64 bits of result +/// @param[out] result_lo Stores low 64 bits of result +/// @param[in] op_hi Stores high 64 bits of input +/// @param[in] op_lo Stores low 64 bits of input +inline void LeftShift128(uint64_t* result_hi, uint64_t* result_lo, + const uint64_t op_hi, const uint64_t op_lo, + const uint64_t shift_value) { + HEXL_CHECK(result_hi != nullptr, "Require result_hi != nullptr"); + HEXL_CHECK(result_lo != nullptr, "Require result_lo != nullptr"); + HEXL_CHECK(shift_value <= 128, + "shift_value cannot be greater than 128 " << shift_value); + + if (shift_value == 0) { + *result_hi = op_hi; + *result_lo = op_lo; + } else if (shift_value == 64) { + *result_hi = op_lo; + *result_lo = 0ULL; + } else if (shift_value == 128) { + *result_hi = 0ULL; + *result_lo = 0ULL; + } else if (shift_value >= 1 && shift_value <= 63) { + *result_hi = (op_hi << shift_value) | (op_lo >> (64 - shift_value)); + *result_lo = op_lo << shift_value; + } else if (shift_value >= 65 && shift_value < 128) { + *result_hi = op_lo << (shift_value - 64); + *result_lo = 0ULL; + } +} + +/// @brief Computes Right Shift op as 128-bit unsigned integer +/// @param[out] result_hi Stores high 64 bits of result +/// @param[out] result_lo Stores low 64 bits of result +/// @param[in] op_hi Stores high 64 bits of input +/// @param[in] op_lo Stores low 64 bits of input +inline void RightShift128(uint64_t* result_hi, uint64_t* result_lo, + const uint64_t op_hi, const uint64_t op_lo, + const uint64_t shift_value) { + HEXL_CHECK(result_hi != nullptr, "Require result_hi != nullptr"); + HEXL_CHECK(result_lo != nullptr, "Require result_lo != nullptr"); + HEXL_CHECK(shift_value <= 128, + "shift_value cannot be greater than 128 " << shift_value); + + if (shift_value == 0) { + *result_hi = op_hi; + *result_lo = op_lo; + } else if (shift_value == 64) { + *result_hi = 0ULL; + *result_lo = op_hi; + } else if (shift_value == 128) { + *result_hi = 0ULL; + *result_lo = 0ULL; + } else if (shift_value >= 1 && shift_value <= 63) { + *result_hi = op_hi >> shift_value; + *result_lo = (op_hi << (64 - shift_value)) | (op_lo >> shift_value); + } else if (shift_value >= 65 && shift_value < 128) { + *result_hi = 0ULL; + *result_lo = op_hi >> (shift_value - 64); + } +} + +/// @brief Adds op1 + op2 as 128-bit integer +/// @param[out] result_hi Stores high 64 bits of result +/// @param[out] result_lo Stores low 64 bits of result +inline void AddWithCarry128(uint64_t* result_hi, uint64_t* result_lo, + const uint64_t op1_hi, const uint64_t op1_lo, + const uint64_t op2_hi, const uint64_t op2_lo) { + HEXL_CHECK(result_hi != nullptr, "Require result_hi != nullptr"); + HEXL_CHECK(result_lo != nullptr, "Require result_lo != nullptr"); + + // first 64bit block + *result_lo = op1_lo + op2_lo; + unsigned char carry = static_cast(*result_lo < op1_lo); + + // second 64bit block + _addcarry_u64(carry, op1_hi, op2_hi, result_hi); +} + +/// @brief Subtracts op1 - op2 as 128-bit integer +/// @param[out] result_hi Stores high 64 bits of result +/// @param[out] result_lo Stores low 64 bits of result +inline void SubWithCarry128(uint64_t* result_hi, uint64_t* result_lo, + const uint64_t op1_hi, const uint64_t op1_lo, + const uint64_t op2_hi, const uint64_t op2_lo) { + HEXL_CHECK(result_hi != nullptr, "Require result_hi != nullptr"); + HEXL_CHECK(result_lo != nullptr, "Require result_lo != nullptr"); + + unsigned char borrow; + + // first 64bit block + *result_lo = op1_lo - op2_lo; + borrow = static_cast(op2_lo > op1_lo); + + // second 64bit block + _subborrow_u64(borrow, op1_hi, op2_hi, result_hi); +} + +/// @brief Computes and returns significant bit count +/// @param[in] value Input element at most 128 bits long +inline uint64_t SignificantBitLength(const uint64_t* value) { + HEXL_CHECK(value != nullptr, "Require value != nullptr"); + + unsigned long count = 0; // NOLINT(runtime/int) + + // second 64bit block + _BitScanReverse64(&count, *(value + 1)); + if (count >= 0 && *(value + 1) > 0) { + return static_cast(count) + 1 + 64; + } + + // first 64bit block + _BitScanReverse64(&count, *value); + if (count >= 0 && *(value) > 0) { + return static_cast(count) + 1; + } + return 0; +} + +/// @brief Checks if input is negative number +/// @param[in] input Input element to check for sign +inline bool CheckSign(const uint64_t* input) { + HEXL_CHECK(input != nullptr, "Require input != nullptr"); + + uint64_t input_temp[2]{0, 0}; + RightShift128(&input_temp[1], &input_temp[0], input[1], input[0], 127); + return (input_temp[0] == 1); +} + +/// @brief Divides numerator by denominator +/// @param[out] quotient Stores quotient as two 64-bit blocks after division +/// @param[in] numerator +/// @param[in] denominator +inline void DivideUInt128UInt64(uint64_t* quotient, const uint64_t* numerator, + const uint64_t denominator) { + HEXL_CHECK(quotient != nullptr, "Require quotient != nullptr"); + HEXL_CHECK(numerator != nullptr, "Require numerator != nullptr"); + HEXL_CHECK(denominator != 0, "denominator cannot be 0 " << denominator); + + // get bit count of divisor + uint64_t numerator_bits = SignificantBitLength(numerator); + const uint64_t numerator_bits_const = numerator_bits; + const uint64_t uint_128_bit = 128ULL; + + uint64_t MASK[2]{0x0000000000000001, 0x0000000000000000}; + uint64_t remainder[2]{0, 0}; + uint64_t quotient_temp[2]{0, 0}; + uint64_t denominator_temp[2]{denominator, 0}; + + quotient[0] = numerator[0]; + quotient[1] = numerator[1]; + + // align numerator + LeftShift128("ient[1], "ient[0], quotient[1], quotient[0], + (uint_128_bit - numerator_bits_const)); + + while (numerator_bits) { + // if remainder is negative + if (CheckSign(remainder)) { + LeftShift128(&remainder[1], &remainder[0], remainder[1], remainder[0], 1); + RightShift128("ient_temp[1], "ient_temp[0], quotient[1], + quotient[0], (uint_128_bit - 1)); + remainder[0] = remainder[0] | quotient_temp[0]; + LeftShift128("ient[1], "ient[0], quotient[1], quotient[0], 1); + // remainder=remainder+denominator_temp + AddWithCarry128(&remainder[1], &remainder[0], remainder[1], remainder[0], + denominator_temp[1], denominator_temp[0]); + } else { // if remainder is positive + LeftShift128(&remainder[1], &remainder[0], remainder[1], remainder[0], 1); + RightShift128("ient_temp[1], "ient_temp[0], quotient[1], + quotient[0], (uint_128_bit - 1)); + remainder[0] = remainder[0] | quotient_temp[0]; + LeftShift128("ient[1], "ient[0], quotient[1], quotient[0], 1); + // remainder=remainder-denominator_temp + SubWithCarry128(&remainder[1], &remainder[0], remainder[1], remainder[0], + denominator_temp[1], denominator_temp[0]); + } + + // if remainder is positive set MSB of quotient[0]=1 + if (!CheckSign(remainder)) { + MASK[0] = 0x0000000000000001; + MASK[1] = 0x0000000000000000; + LeftShift128(&MASK[1], &MASK[0], MASK[1], MASK[0], + (uint_128_bit - numerator_bits_const)); + quotient[0] = quotient[0] | MASK[0]; + quotient[1] = quotient[1] | MASK[1]; + } + quotient_temp[0] = 0; + quotient_temp[1] = 0; + numerator_bits--; + } + + if (CheckSign(remainder)) { + // remainder=remainder+denominator_temp + AddWithCarry128(&remainder[1], &remainder[0], remainder[1], remainder[0], + denominator_temp[1], denominator_temp[0]); + } + RightShift128("ient[1], "ient[0], quotient[1], quotient[0], + (uint_128_bit - numerator_bits_const)); +} + +/// @brief Returns low of dividing numerator by denominator +/// @param[in] numerator_hi Stores high 64 bit of numerator +/// @param[in] numerator_lo Stores low 64 bit of numerator +/// @param[in] denominator Stores denominator +inline uint64_t DivideUInt128UInt64Lo(const uint64_t numerator_hi, + const uint64_t numerator_lo, + const uint64_t denominator) { + uint64_t numerator[2]{numerator_lo, numerator_hi}; + uint64_t quotient[2]{0, 0}; + + DivideUInt128UInt64(quotient, numerator, denominator); + return quotient[0]; +} + +// Returns most-significant bit of the input +inline uint64_t MSB(uint64_t input) { + unsigned long index{0}; // NOLINT(runtime/int) + _BitScanReverse64(&index, input); + return index; +} + +#define HEXL_LOOP_UNROLL_4 \ + {} +#define HEXL_LOOP_UNROLL_8 \ + {} + +#endif + +} // namespace hexl +} // namespace intel diff --git a/hexl_v3/include/hexl/util/types.hpp b/hexl_v3/include/hexl/util/types.hpp new file mode 100644 index 00000000..2d2d8551 --- /dev/null +++ b/hexl_v3/include/hexl/util/types.hpp @@ -0,0 +1,13 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +#include "hexl/util/defines.hpp" + +#if defined(HEXL_USE_GNU) || defined(HEXL_USE_CLANG) +__extension__ typedef __int128 int128_t; +__extension__ typedef unsigned __int128 uint128_t; +#endif diff --git a/hexl_v3/include/hexl/util/util.hpp b/hexl_v3/include/hexl/util/util.hpp new file mode 100644 index 00000000..bf878a98 --- /dev/null +++ b/hexl_v3/include/hexl/util/util.hpp @@ -0,0 +1,53 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include + +namespace intel { +namespace hexl { + +#undef TRUE // MSVC defines TRUE +#undef FALSE // MSVC defines FALSE + +/// @enum CMPINT +/// @brief Represents binary operations between two boolean values +enum class CMPINT { + EQ = 0, ///< Equal + LT = 1, ///< Less than + LE = 2, ///< Less than or equal + FALSE = 3, ///< False + NE = 4, ///< Not equal + NLT = 5, ///< Not less than + NLE = 6, ///< Not less than or equal + TRUE = 7 ///< True +}; + +/// @brief Returns the logical negation of a binary operation +/// @param[in] cmp The binary operation to negate +inline CMPINT Not(CMPINT cmp) { + switch (cmp) { + case CMPINT::EQ: + return CMPINT::NE; + case CMPINT::LT: + return CMPINT::NLT; + case CMPINT::LE: + return CMPINT::NLE; + case CMPINT::FALSE: + return CMPINT::TRUE; + case CMPINT::NE: + return CMPINT::EQ; + case CMPINT::NLT: + return CMPINT::LT; + case CMPINT::NLE: + return CMPINT::LE; + case CMPINT::TRUE: + return CMPINT::FALSE; + default: + return CMPINT::FALSE; + } +} + +} // namespace hexl +} // namespace intel diff --git a/hexl_v3/lib/cmake/hexl-1.2.5/HEXLConfig.cmake b/hexl_v3/lib/cmake/hexl-1.2.5/HEXLConfig.cmake new file mode 100644 index 00000000..d3c012b5 --- /dev/null +++ b/hexl_v3/lib/cmake/hexl-1.2.5/HEXLConfig.cmake @@ -0,0 +1,59 @@ +# Copyright (C) 2020 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +# This will define the following variables: +# +# HEXL_FOUND - True if the system has the Intel HEXL library +# HEXL_VERSION - The full major.minor.patch version number +# HEXL_VERSION_MAJOR - The major version number +# HEXL_VERSION_MINOR - The minor version number +# HEXL_VERSION_PATCH - The patch version number + + +####### Expanded from @PACKAGE_INIT@ by configure_package_config_file() ####### +####### Any changes to this file will be overwritten by the next CMake run #### +####### The input file was HEXLConfig.cmake.in ######## + +get_filename_component(PACKAGE_PREFIX_DIR "${CMAKE_CURRENT_LIST_DIR}/../../../" ABSOLUTE) + +macro(set_and_check _var _file) + set(${_var} "${_file}") + if(NOT EXISTS "${_file}") + message(FATAL_ERROR "File or directory ${_file} referenced by variable ${_var} does not exist !") + endif() +endmacro() + +macro(check_required_components _NAME) + foreach(comp ${${_NAME}_FIND_COMPONENTS}) + if(NOT ${_NAME}_${comp}_FOUND) + if(${_NAME}_FIND_REQUIRED_${comp}) + set(${_NAME}_FOUND FALSE) + endif() + endif() + endforeach() +endmacro() + +#################################################################################### + +include(CMakeFindDependencyMacro) +find_package(CpuFeatures CONFIG) +if(NOT CpuFeatures_FOUND) + message(WARNING "Could not find pre-installed CpuFeatures; using CpuFeatures packaged with HEXL") +endif() + +include(${CMAKE_CURRENT_LIST_DIR}/HEXLTargets.cmake) + +# Defines HEXL_FOUND: If Intel HEXL library was found +if(TARGET HEXL::hexl) + set(HEXL_FOUND TRUE) + message(STATUS "Intel HEXL found") +else() + message(STATUS "Intel HEXL not found") +endif() + +set(HEXL_VERSION "1.2.5") +set(HEXL_VERSION_MAJOR "1") +set(HEXL_VERSION_MINOR "2") +set(HEXL_VERSION_PATCH "5") + +set(HEXL_DEBUG "OFF") diff --git a/hexl_v3/lib/cmake/hexl-1.2.5/HEXLConfigVersion.cmake b/hexl_v3/lib/cmake/hexl-1.2.5/HEXLConfigVersion.cmake new file mode 100644 index 00000000..98b46110 --- /dev/null +++ b/hexl_v3/lib/cmake/hexl-1.2.5/HEXLConfigVersion.cmake @@ -0,0 +1,88 @@ +# This is a basic version file for the Config-mode of find_package(). +# It is used by write_basic_package_version_file() as input file for configure_file() +# to create a version-file which can be installed along a config.cmake file. +# +# The created file sets PACKAGE_VERSION_EXACT if the current version string and +# the requested version string are exactly the same and it sets +# PACKAGE_VERSION_COMPATIBLE if the current version is equal to the requested version. +# The tweak version component is ignored. +# The variable CVF_VERSION must be set before calling configure_file(). + + +if (PACKAGE_FIND_VERSION_RANGE) + message(AUTHOR_WARNING + "`find_package()` specify a version range but the version strategy " + "(ExactVersion) of the module `${PACKAGE_FIND_NAME}` is incompatible " + "with this request. Only the lower endpoint of the range will be used.") +endif() + +set(PACKAGE_VERSION "1.2.5") + +if("1.2.5" MATCHES "^([0-9]+)\\.([0-9]+)\\.([0-9]+)") # strip the tweak version + set(CVF_VERSION_MAJOR "${CMAKE_MATCH_1}") + set(CVF_VERSION_MINOR "${CMAKE_MATCH_2}") + set(CVF_VERSION_PATCH "${CMAKE_MATCH_3}") + + if(NOT CVF_VERSION_MAJOR VERSION_EQUAL 0) + string(REGEX REPLACE "^0+" "" CVF_VERSION_MAJOR "${CVF_VERSION_MAJOR}") + endif() + if(NOT CVF_VERSION_MINOR VERSION_EQUAL 0) + string(REGEX REPLACE "^0+" "" CVF_VERSION_MINOR "${CVF_VERSION_MINOR}") + endif() + if(NOT CVF_VERSION_PATCH VERSION_EQUAL 0) + string(REGEX REPLACE "^0+" "" CVF_VERSION_PATCH "${CVF_VERSION_PATCH}") + endif() + + set(CVF_VERSION_NO_TWEAK "${CVF_VERSION_MAJOR}.${CVF_VERSION_MINOR}.${CVF_VERSION_PATCH}") +else() + set(CVF_VERSION_NO_TWEAK "1.2.5") +endif() + +if(PACKAGE_FIND_VERSION MATCHES "^([0-9]+)\\.([0-9]+)\\.([0-9]+)") # strip the tweak version + set(REQUESTED_VERSION_MAJOR "${CMAKE_MATCH_1}") + set(REQUESTED_VERSION_MINOR "${CMAKE_MATCH_2}") + set(REQUESTED_VERSION_PATCH "${CMAKE_MATCH_3}") + + if(NOT REQUESTED_VERSION_MAJOR VERSION_EQUAL 0) + string(REGEX REPLACE "^0+" "" REQUESTED_VERSION_MAJOR "${REQUESTED_VERSION_MAJOR}") + endif() + if(NOT REQUESTED_VERSION_MINOR VERSION_EQUAL 0) + string(REGEX REPLACE "^0+" "" REQUESTED_VERSION_MINOR "${REQUESTED_VERSION_MINOR}") + endif() + if(NOT REQUESTED_VERSION_PATCH VERSION_EQUAL 0) + string(REGEX REPLACE "^0+" "" REQUESTED_VERSION_PATCH "${REQUESTED_VERSION_PATCH}") + endif() + + set(REQUESTED_VERSION_NO_TWEAK + "${REQUESTED_VERSION_MAJOR}.${REQUESTED_VERSION_MINOR}.${REQUESTED_VERSION_PATCH}") +else() + set(REQUESTED_VERSION_NO_TWEAK "${PACKAGE_FIND_VERSION}") +endif() + +if(REQUESTED_VERSION_NO_TWEAK STREQUAL CVF_VERSION_NO_TWEAK) + set(PACKAGE_VERSION_COMPATIBLE TRUE) +else() + set(PACKAGE_VERSION_COMPATIBLE FALSE) +endif() + +if(PACKAGE_FIND_VERSION STREQUAL PACKAGE_VERSION) + set(PACKAGE_VERSION_EXACT TRUE) +endif() + + +# if the installed project requested no architecture check, don't perform the check +if("FALSE") + return() +endif() + +# if the installed or the using project don't have CMAKE_SIZEOF_VOID_P set, ignore it: +if("${CMAKE_SIZEOF_VOID_P}" STREQUAL "" OR "8" STREQUAL "") + return() +endif() + +# check that the installed version has the same 32/64bit-ness as the one which is currently searching: +if(NOT CMAKE_SIZEOF_VOID_P STREQUAL "8") + math(EXPR installedBits "8 * 8") + set(PACKAGE_VERSION "${PACKAGE_VERSION} (${installedBits}bit)") + set(PACKAGE_VERSION_UNSUITABLE TRUE) +endif() diff --git a/hexl_v3/lib/cmake/hexl-1.2.5/HEXLTargets-release.cmake b/hexl_v3/lib/cmake/hexl-1.2.5/HEXLTargets-release.cmake new file mode 100644 index 00000000..c736f1c1 --- /dev/null +++ b/hexl_v3/lib/cmake/hexl-1.2.5/HEXLTargets-release.cmake @@ -0,0 +1,19 @@ +#---------------------------------------------------------------- +# Generated CMake target import file for configuration "Release". +#---------------------------------------------------------------- + +# Commands may need to know the format version. +set(CMAKE_IMPORT_FILE_VERSION 1) + +# Import target "HEXL::hexl" for configuration "Release" +set_property(TARGET HEXL::hexl APPEND PROPERTY IMPORTED_CONFIGURATIONS RELEASE) +set_target_properties(HEXL::hexl PROPERTIES + IMPORTED_LOCATION_RELEASE "${_IMPORT_PREFIX}/lib/libhexl.so.1.2.5" + IMPORTED_SONAME_RELEASE "libhexl.so.1.2.5" + ) + +list(APPEND _IMPORT_CHECK_TARGETS HEXL::hexl ) +list(APPEND _IMPORT_CHECK_FILES_FOR_HEXL::hexl "${_IMPORT_PREFIX}/lib/libhexl.so.1.2.5" ) + +# Commands beyond this point should not need to know the version. +set(CMAKE_IMPORT_FILE_VERSION) diff --git a/hexl_v3/lib/cmake/hexl-1.2.5/HEXLTargets.cmake b/hexl_v3/lib/cmake/hexl-1.2.5/HEXLTargets.cmake new file mode 100644 index 00000000..7132d40a --- /dev/null +++ b/hexl_v3/lib/cmake/hexl-1.2.5/HEXLTargets.cmake @@ -0,0 +1,100 @@ +# Generated by CMake + +if("${CMAKE_MAJOR_VERSION}.${CMAKE_MINOR_VERSION}" LESS 2.6) + message(FATAL_ERROR "CMake >= 2.6.0 required") +endif() +cmake_policy(PUSH) +cmake_policy(VERSION 2.6...3.20) +#---------------------------------------------------------------- +# Generated CMake target import file. +#---------------------------------------------------------------- + +# Commands may need to know the format version. +set(CMAKE_IMPORT_FILE_VERSION 1) + +# Protect against multiple inclusion, which would fail when already imported targets are added once more. +set(_targetsDefined) +set(_targetsNotDefined) +set(_expectedTargets) +foreach(_expectedTarget HEXL::hexl) + list(APPEND _expectedTargets ${_expectedTarget}) + if(NOT TARGET ${_expectedTarget}) + list(APPEND _targetsNotDefined ${_expectedTarget}) + endif() + if(TARGET ${_expectedTarget}) + list(APPEND _targetsDefined ${_expectedTarget}) + endif() +endforeach() +if("${_targetsDefined}" STREQUAL "${_expectedTargets}") + unset(_targetsDefined) + unset(_targetsNotDefined) + unset(_expectedTargets) + set(CMAKE_IMPORT_FILE_VERSION) + cmake_policy(POP) + return() +endif() +if(NOT "${_targetsDefined}" STREQUAL "") + message(FATAL_ERROR "Some (but not all) targets in this export set were already defined.\nTargets Defined: ${_targetsDefined}\nTargets not yet defined: ${_targetsNotDefined}\n") +endif() +unset(_targetsDefined) +unset(_targetsNotDefined) +unset(_expectedTargets) + + +# Compute the installation prefix relative to this file. +get_filename_component(_IMPORT_PREFIX "${CMAKE_CURRENT_LIST_FILE}" PATH) +get_filename_component(_IMPORT_PREFIX "${_IMPORT_PREFIX}" PATH) +get_filename_component(_IMPORT_PREFIX "${_IMPORT_PREFIX}" PATH) +get_filename_component(_IMPORT_PREFIX "${_IMPORT_PREFIX}" PATH) +if(_IMPORT_PREFIX STREQUAL "/") + set(_IMPORT_PREFIX "") +endif() + +# Create imported target HEXL::hexl +add_library(HEXL::hexl SHARED IMPORTED) + +set_target_properties(HEXL::hexl PROPERTIES + INTERFACE_COMPILE_OPTIONS "-Wno-unknown-warning;-Wno-unknown-warning-option" + INTERFACE_INCLUDE_DIRECTORIES "${_IMPORT_PREFIX}/include" + INTERFACE_LINK_LIBRARIES "OpenMP::OpenMP_CXX" +) + +if(CMAKE_VERSION VERSION_LESS 2.8.12) + message(FATAL_ERROR "This file relies on consumers using CMake 2.8.12 or greater.") +endif() + +# Load information for each installed configuration. +get_filename_component(_DIR "${CMAKE_CURRENT_LIST_FILE}" PATH) +file(GLOB CONFIG_FILES "${_DIR}/HEXLTargets-*.cmake") +foreach(f ${CONFIG_FILES}) + include(${f}) +endforeach() + +# Cleanup temporary variables. +set(_IMPORT_PREFIX) + +# Loop over all imported files and verify that they actually exist +foreach(target ${_IMPORT_CHECK_TARGETS} ) + foreach(file ${_IMPORT_CHECK_FILES_FOR_${target}} ) + if(NOT EXISTS "${file}" ) + message(FATAL_ERROR "The imported target \"${target}\" references the file + \"${file}\" +but this file does not exist. Possible reasons include: +* The file was deleted, renamed, or moved to another location. +* An install or uninstall procedure did not complete successfully. +* The installation package was faulty and contained + \"${CMAKE_CURRENT_LIST_FILE}\" +but not all the files it references. +") + endif() + endforeach() + unset(_IMPORT_CHECK_FILES_FOR_${target}) +endforeach() +unset(_IMPORT_CHECK_TARGETS) + +# This file does not depend on other imported targets which have +# been exported from the same project but in a separate export set. + +# Commands beyond this point should not need to know the version. +set(CMAKE_IMPORT_FILE_VERSION) +cmake_policy(POP) diff --git a/hexl_v3/lib/libhexl.so b/hexl_v3/lib/libhexl.so new file mode 120000 index 00000000..af5173f3 --- /dev/null +++ b/hexl_v3/lib/libhexl.so @@ -0,0 +1 @@ +libhexl.so.1.2.5 \ No newline at end of file diff --git a/hexl_v3/lib/libhexl.so.1.2.5 b/hexl_v3/lib/libhexl.so.1.2.5 new file mode 100644 index 00000000..11a60620 Binary files /dev/null and b/hexl_v3/lib/libhexl.so.1.2.5 differ diff --git a/hexl_v3/lib/pkgconfig/hexl.pc b/hexl_v3/lib/pkgconfig/hexl.pc new file mode 100644 index 00000000..63164b5d --- /dev/null +++ b/hexl_v3/lib/pkgconfig/hexl.pc @@ -0,0 +1,13 @@ +# Copyright (C) 2020 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +prefix=/home/eidf018/eidf018/s1820742psd/hexl/hexl_v3 +libdir=${prefix}/lib +includedir=${prefix}/include + +Name: Intel HEXL +Version: 1.2.5 +Description: Intel® HEXL is an open-source library which provides efficient implementations of integer arithmetic on Galois fields. + +Libs: -L${libdir} -lhexl +Cflags: -I${includedir} diff --git a/omp_example/cmake/CMakeLists.txt b/omp_example/cmake/CMakeLists.txt new file mode 100644 index 00000000..141e9cbd --- /dev/null +++ b/omp_example/cmake/CMakeLists.txt @@ -0,0 +1,34 @@ +# Copyright (C) 2020 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +project(hexl_omp_example LANGUAGES C CXX) +cmake_minimum_required(VERSION 3.13) +set(CMAKE_CXX_STANDARD 17) + +# Define the directory containing HEXLConfig.cmake +set(HEXL_HINT_DIR "/home/eidf018/eidf018/s1820742psd/hexl/hexl_v2") + +# Example using source +find_package(HEXL 1.2.5 + HINTS ${HEXL_HINT_DIR} + REQUIRED) +if (NOT TARGET HEXL::hexl) + message(FATAL_ERROR "TARGET HEXL::hexl not found") +endif() + +find_package(OpenMP) +if (OpenMP_FOUND) + message(STATUS "OpenMP_CXX_INCLUDE_DIRS: ${OpenMP_CXX_INCLUDE_DIRS}") + message(STATUS "OpenMP_CXX_LIBRARIES: ${OpenMP_CXX_LIBRARIES}") +endif() + +# Add the directory containing libeasyloggingpp.a to the link directories +link_directories(/home/eidf018/eidf018/s1820742psd/hexl/hexl_v2/lib) + + +add_executable(omp_example ../omp_example.cpp) +add_executable(function_test ../test_example.cpp) +target_link_libraries(omp_example PRIVATE HEXL::hexl) +target_link_libraries(function_test PRIVATE HEXL::hexl) + + diff --git a/omp_example/omp_example.cpp b/omp_example/omp_example.cpp new file mode 100644 index 00000000..623e1cf1 --- /dev/null +++ b/omp_example/omp_example.cpp @@ -0,0 +1,321 @@ +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "../hexl/include/hexl/hexl.hpp" +#include "../hexl/include/hexl/util/util.hpp" +#include "../hexl/util/util-internal.hpp" +// #include "../hexl/include/hexl/experimental/fft-like/fft-like.hpp" +// #include "../hexl/include/hexl/experimental/fft-like/fft-like-native.hpp" + + +template +double TimeFunction(Func&& f) { + auto start_time = std::chrono::high_resolution_clock::now(); + f(); + auto end_time = std::chrono::high_resolution_clock::now(); + + std::chrono::duration duration = end_time - start_time; + return duration.count(); +} + +std::vector split(const std::string& s, char delimiter) { + std::vector tokens; + std::string token; + std::istringstream tokenStream(s); + while (std::getline(tokenStream, token, delimiter)) { + tokens.push_back(std::stoi(token)); + } + return tokens; +} + +double BM_EltwiseVectorVectorAddMod(size_t input_size) { + uint64_t modulus = 0xffffffffffc0001ULL; + + auto input1 = intel::hexl::GenerateInsecureUniformIntRandomValues(input_size, + 0, modulus); + auto input2 = intel::hexl::GenerateInsecureUniformIntRandomValues(input_size, + 0, modulus); + intel::hexl::AlignedVector64 output(input_size, 0); + + double time_taken = TimeFunction([&]() { + intel::hexl::EltwiseAddMod(output.data(), input1.data(), input2.data(), + input_size, modulus); + }); + + return time_taken; +} + +double BM_EltwiseVectorScalarAddMod(size_t input_size) { + uint64_t modulus = 0xffffffffffc0001ULL; + + auto input1 = intel::hexl::GenerateInsecureUniformIntRandomValues(input_size, + 0, modulus); + uint64_t input2 = + intel::hexl::GenerateInsecureUniformIntRandomValue(0, modulus); + intel::hexl::AlignedVector64 output(input_size, 0); + + double time_taken = TimeFunction([&]() { + intel::hexl::EltwiseAddMod(output.data(), input1.data(), input2, input_size, + modulus); + }); + + return time_taken; +} + +double BM_EltwiseCmpAdd(size_t input_size, intel::hexl::CMPINT chosenCMP) { + uint64_t modulus = 100; + + uint64_t bound = + intel::hexl::GenerateInsecureUniformIntRandomValue(0, modulus); + uint64_t diff = + intel::hexl::GenerateInsecureUniformIntRandomValue(1, modulus - 1); + auto input1 = intel::hexl::GenerateInsecureUniformIntRandomValues(input_size, + 0, modulus); + // intel::hexl::AlignedVector64 output(input_size, 0); + + double time_taken = TimeFunction([&]() { + intel::hexl::EltwiseCmpAdd(input1.data(), input1.data(), input_size, + chosenCMP, bound, diff); + }); + + return time_taken; +} + +double BM_EltwiseCmpSubMod( + size_t input_size, intel::hexl::CMPINT chosenCMP) { + uint64_t modulus = 100; + + uint64_t bound = + intel::hexl::GenerateInsecureUniformIntRandomValue(1, modulus); + uint64_t diff = + intel::hexl::GenerateInsecureUniformIntRandomValue(1, modulus); + auto input1 = intel::hexl::GenerateInsecureUniformIntRandomValues(input_size, + 0, modulus); + + double time_taken = TimeFunction([&]() { + intel::hexl::EltwiseCmpSubMod(input1.data(), input1.data(), input_size, + modulus, chosenCMP, bound, diff); + }); + + return time_taken; +} + +double BM_EltwiseFMAModAdd(size_t input_size, bool add) { + uint64_t modulus = 0xffffffffffc0001ULL; + + auto input1 = intel::hexl::GenerateInsecureUniformIntRandomValues(input_size, + 0, modulus); + uint64_t input2 = + intel::hexl::GenerateInsecureUniformIntRandomValue(0, modulus); + intel::hexl::AlignedVector64 input3 = + intel::hexl::GenerateInsecureUniformIntRandomValues(input_size, 0, + modulus); + uint64_t* arg3 = add ? input3.data() : nullptr; + double time_taken = TimeFunction([&]() { + intel::hexl::EltwiseFMAMod(input1.data(), input1.data(), input2, arg3, + input1.size(), modulus, 1); + }); + return time_taken; +} + +double BM_EltwiseMultMod(size_t input_size, size_t bit_width, + size_t input_mod_factor) { + // uint64_t modulus = (1ULL << bit_width) + 7; + uint64_t modulus = 0xffffffffffc0001ULL; + + auto input1 = intel::hexl::GenerateInsecureUniformIntRandomValues(input_size, + 0, modulus); + auto input2 = intel::hexl::GenerateInsecureUniformIntRandomValues(input_size, + 0, modulus); + intel::hexl::AlignedVector64 output(input_size, 2); + + double time_taken = TimeFunction([&]() { + intel::hexl::EltwiseMultMod(output.data(), input1.data(), input2.data(), + input_size, modulus, input_mod_factor); + }); + return time_taken; +} + +// double BM_EltwiseReduceModInPlace(size_t input_size) { +// uint64_t modulus = 0xffffffffffc0001ULL; + +// auto input1 = intel::hexl::GenerateInsecureUniformIntRandomValues( +// input_size, 0, 100 * modulus); + +// const uint64_t input_mod_factor = modulus; +// const uint64_t output_mod_factor = 1; + +// double time_taken = TimeFunction([&]() { +// intel::hexl::EltwiseReduceMod(input1.data(), input1.data(), input_size, +// modulus, input_mod_factor, output_mod_factor); +// }); +// return time_taken; +// } +double BM_EltwiseReduceMod(size_t input_size) { + uint64_t modulus = 0xffffffffffc0001ULL; + + auto input1 = intel::hexl::GenerateInsecureUniformIntRandomValues( + input_size, 0, 100 * modulus); + intel::hexl::AlignedVector64 output(input_size, 0); + + const uint64_t input_mod_factor = modulus; + const uint64_t output_mod_factor = 1; + + double time_taken = TimeFunction([&]() { + intel::hexl::EltwiseReduceMod(output.data(), input1.data(), input_size, + modulus, input_mod_factor, output_mod_factor); + }); + return time_taken; +} + +double BM_EltwiseVectorVectorSubMod(size_t input_size) { + uint64_t modulus = 0xffffffffffc0001ULL; + + auto input1 = intel::hexl::GenerateInsecureUniformIntRandomValues(input_size, + 0, modulus); + auto input2 = intel::hexl::GenerateInsecureUniformIntRandomValues(input_size, + 0, modulus); + intel::hexl::AlignedVector64 output(input_size, 0); + + double time_taken = TimeFunction([&]() { + intel::hexl::EltwiseSubMod(output.data(), input1.data(), input2.data(), + input_size, modulus); + }); + return time_taken; +} + +double BM_NTTInPlace(size_t ntt_size) { + size_t modulus = intel::hexl::GeneratePrimes(1, 45, true, ntt_size)[0]; + + auto input = + intel::hexl::GenerateInsecureUniformIntRandomValues(ntt_size, 0, modulus); + intel::hexl::NTT ntt(ntt_size, modulus); + + double time_taken = TimeFunction([&]() { + ntt.ComputeForward(input.data(), input.data(), 1, 1); + }) + + TimeFunction([&]() { + ntt.ComputeInverse(input.data(), input.data(), 2, 1); + }); + return time_taken; +} + +int main(int argc, char** argv) { + if (argc != 4) { + std::cerr << "Usage: " << argv[0] + << " " + << std::endl; + return 1; + } + + int num_iterations = std::stoi(argv[1]); + std::vector thread_nums = split(argv[2], ','); + int input_size = std::stoi(argv[3]); + + + // Using a map to store the results + std::map> results; + + // Initialize the results map + results["BM_EltwiseVectorVectorAddMod"] = + std::vector(thread_nums.size(), 0.0); + results["BM_EltwiseVectorScalarAddMod"] = + std::vector(thread_nums.size(), 0.0); + results["BM_EltwiseCmpAdd"] = + std::vector(thread_nums.size(), 0.0); + results["BM_EltwiseCmpSubMod"] = std::vector(thread_nums.size(), 0.0); + results["BM_EltwiseFMAModAdd"] = + std::vector(thread_nums.size(), 0.0); + results["BM_EltwiseMultMod"] = std::vector(thread_nums.size(), 0.0); + results["BM_EltwiseReduceModInPlace"] = + std::vector(thread_nums.size(), 0.0); + results["BM_EltwiseVectorVectorSubMod"] = + std::vector(thread_nums.size(), 0.0); + results["BM_NTTInPlace"] = std::vector(thread_nums.size(), 0.0); + + // Execute each method for all thread numbers + for (size_t j = 0; j < thread_nums.size(); ++j) { + int num_threads = thread_nums[j]; + omp_set_num_threads(num_threads); + + results["BM_EltwiseVectorVectorAddMod"][j] = 0; + results["BM_EltwiseVectorScalarAddMod"][j] = 0; + results["BM_EltwiseCmpAdd"][j] = 0; + results["BM_EltwiseCmpSubMod"][j] = 0; + results["BM_EltwiseFMAModAdd"][j] = 0; + results["BM_EltwiseMultMod"][j] = 0; + results["BM_EltwiseReduceModInPlace"][j] = 0; + results["BM_EltwiseVectorVectorSubMod"][j] = 0; + results["BM_NTTInPlace"][j] = 0; + + bool add_choices[] = {false,true}; + int bit_width_choices[] = {48, 60}; + int mod_factor_choices[] = {1, 2, 4}; + + for (int i = 0; i < num_iterations; i++) { + // There are CMPINT possibilities, should be chosen randomly for testing + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution<> dis_cmpint(0, 7); // 8 enum values, from 0 to 7 + std::uniform_int_distribution<> dis_add(0, 1); + std::uniform_int_distribution<> dis_factor(0, 2); + + intel::hexl::CMPINT chosenCMP = + static_cast(dis_cmpint(gen)); + + bool add = add_choices[dis_add(gen)]; + + size_t bit_width = bit_width_choices[dis_add(gen)]; + + size_t input_mod_factor = mod_factor_choices[dis_factor(gen)]; + + results["BM_EltwiseVectorVectorAddMod"][j] += + BM_EltwiseVectorVectorAddMod(input_size); + results["BM_EltwiseVectorScalarAddMod"][j] += + BM_EltwiseVectorScalarAddMod(input_size); + results["BM_EltwiseCmpAdd"][j] += + BM_EltwiseCmpAdd(input_size, chosenCMP); + results["BM_EltwiseCmpSubMod"][j] += + BM_EltwiseCmpSubMod(input_size, chosenCMP); + results["BM_EltwiseFMAModAdd"][j] += + BM_EltwiseFMAModAdd(input_size, add); + results["BM_EltwiseMultMod"][j] += + BM_EltwiseMultMod(input_size, bit_width, input_mod_factor); + results["BM_EltwiseReduceModInPlace"][j] += + BM_EltwiseReduceMod(input_size); + results["BM_EltwiseVectorVectorSubMod"][j] += + BM_EltwiseVectorVectorSubMod(input_size); + results["BM_NTTInPlace"][j] += BM_NTTInPlace(input_size/4096); + } + } + + // Print the table + + // Print headers + std::cout << std::left << std::setw(40) << "Method"; + for (int num_threads : thread_nums) { + std::cout << std::setw(20) << ("Threads=" + std::to_string(num_threads)); + } + std::cout << std::endl; + + // Print results + for (auto& [method, times] : results) { + std::cout << std::left << std::setw(40) << method; + for (double time : times) { + std::cout << std::setw(20) << time/num_iterations; + } + std::cout << std::endl; + } + + return 0; +} diff --git a/omp_example/test_example.cpp b/omp_example/test_example.cpp new file mode 100644 index 00000000..89380764 --- /dev/null +++ b/omp_example/test_example.cpp @@ -0,0 +1,312 @@ +// Copyright (C) 2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 + +#include +#include +#include + +#include "hexl/hexl.hpp" +#include "../hexl/util/util-internal.hpp" + +#include +#include +#include +#include +#include + + +template +double TimeFunction(Func&& f) { + auto start_time = std::chrono::high_resolution_clock::now(); + f(); + auto end_time = std::chrono::high_resolution_clock::now(); + + std::chrono::duration duration = end_time - start_time; + return duration.count(); +} + +std::vector split(const std::string& s, char delimiter) { + std::vector tokens; + std::string token; + std::istringstream tokenStream(s); + while (std::getline(tokenStream, token, delimiter)) { + tokens.push_back(std::stoi(token)); + } + return tokens; +} + +bool CheckEqual(const std::vector& x, + const std::vector& y) { + if (x.size() != y.size()) { + std::cout << "Not equal in size\n"; + return false; + } + uint64_t N = x.size(); + bool is_match = true; + for (size_t i = 0; i < N; ++i) { + if (x[i] != y[i]) { + std::cout << "Not equal at index " << i << "\n"; + is_match = false; + } + } + return is_match; +} + +double ExampleEltwiseVectorVectorAddMod() { +// std::cout << "Running ExampleEltwiseVectorVectorAddMod...\n"; + + std::vector op1{1, 2, 3, 4, 5, 6, 7, 8}; + std::vector op2{1, 3, 5, 7, 2, 4, 6, 8}; + uint64_t modulus = 10; + std::vector exp_out{2, 5, 8, 1, 7, 0, 3, 6}; + + double time_taken = TimeFunction([&]() { + intel::hexl::EltwiseAddMod(op1.data(), op1.data(), op2.data(), op1.size(), modulus); + }); + + CheckEqual(op1, exp_out); +// std::cout << "Done running ExampleEltwiseVectorVectorAddMod\n"; + return time_taken; +} + +double ExampleEltwiseVectorScalarAddMod() { +// std::cout << "Running ExampleEltwiseVectorScalarAddMod...\n"; + + std::vector op1{1, 2, 3, 4, 5, 6, 7, 8}; + uint64_t op2{3}; + uint64_t modulus = 10; + std::vector exp_out{4, 5, 6, 7, 8, 9, 0, 1}; + + double time_taken = TimeFunction([&]() { + intel::hexl::EltwiseAddMod(op1.data(), op1.data(), op2, op1.size(), modulus); + }); + + CheckEqual(op1, exp_out); +// std::cout << "Done running ExampleEltwiseVectorScalarAddMod\n"; + return time_taken; +} + +double ExampleEltwiseCmpAdd() { +// std::cout << "Running ExampleEltwiseCmpAdd...\n"; + + std::vector op1{1, 2, 3, 4, 5, 6, 7, 8}; + uint64_t cmp = 3; + uint64_t diff = 5; + std::vector exp_out{1, 2, 3, 9, 10, 11, 12, 13}; + + double time_taken = TimeFunction([&]() { + intel::hexl::EltwiseCmpAdd(op1.data(), op1.data(), op1.size(), intel::hexl::CMPINT::NLE, cmp, diff); + }); + + + CheckEqual(op1, exp_out); +// std::cout << "Done running ExampleEltwiseCmpAdd\n"; + return time_taken; +} + +double ExampleEltwiseCmpSubMod() { +// std::cout << "Running ExampleEltwiseCmpSubMod...\n"; + + std::vector op1{1, 2, 3, 4, 5, 6, 7}; + uint64_t bound = 4; + uint64_t diff = 5; + std::vector exp_out{1, 2, 3, 4, 0, 1, 2}; + + uint64_t modulus = 10; + + double time_taken = TimeFunction([&]() { + intel::hexl::EltwiseCmpSubMod(op1.data(), op1.data(), op1.size(), modulus, intel::hexl::CMPINT::NLE, bound, diff); + }); + + CheckEqual(op1, exp_out); +// std::cout << "Done running ExampleEltwiseCmpSubMod\n"; + return time_taken; +} + +double ExampleEltwiseFMAMod() { +// std::cout << "Running ExampleEltwiseFMAMod...\n"; + + std::vector arg1{1, 2, 3, 4, 5, 6, 7, 8, 9}; + uint64_t arg2 = 1; + std::vector exp_out{1, 2, 3, 4, 5, 6, 7, 8, 9}; + uint64_t modulus = 769; + + double time_taken; + time_taken = TimeFunction([&]() { + intel::hexl::EltwiseFMAMod(arg1.data(), arg1.data(), arg2, nullptr, arg1.size(), modulus, 1); + }); + + CheckEqual(arg1, exp_out); +// std::cout << "Done running ExampleEltwiseFMAMod\n"; + return time_taken; +} + +double ExampleEltwiseMultMod() { +// std::cout << "Running ExampleEltwiseMultMod...\n"; + + std::vector op1{2, 4, 3, 2}; + std::vector op2{2, 1, 2, 0}; + std::vector exp_out{4, 4, 6, 0}; + + uint64_t modulus = 769; + + double time_taken = TimeFunction([&]() { + intel::hexl::EltwiseMultMod(op1.data(), op1.data(), op2.data(), op1.size(), modulus, 1); + }); + + CheckEqual(op1, exp_out); +// std::cout << "Done running ExampleEltwiseMultMod\n"; + return time_taken; +} + +double ExampleNTT() { +// std::cout << "Running ExampleNTT...\n"; + + uint64_t N = 8; + uint64_t modulus = 769; + std::vector arg{1, 2, 3, 4, 5, 6, 7, 8}; + auto exp_out = arg; + intel::hexl::NTT ntt(N, modulus); + + double time_taken = TimeFunction([&]() { + ntt.ComputeForward(arg.data(), arg.data(), 1, 1); + }); + time_taken += TimeFunction([&]() { + ntt.ComputeInverse(arg.data(), arg.data(), 1, 1); + }); + + CheckEqual(arg, exp_out); +// std::cout << "Done running ExampleNTT\n"; + return time_taken; +} + +double ExampleEltwiseReduceMod() { +// std::cout << "Running ExampleReduceMod...\n"; + + uint64_t modulus = 5; + std::vector arg{1, 2, 3, 4, 5, 6, 7, 8}; + std::vector exp_out{1, 2, 3, 4, 0, 1, 2, 3}; + std::vector result{0, 0, 0, 0, 0, 0, 0, 0}; + + double time_taken = TimeFunction([&]() { + intel::hexl::EltwiseReduceMod(result.data(), arg.data(), arg.size(), modulus,2, 1); + }); + + CheckEqual(result, exp_out); +// std::cout << "Done running ExampleReduceMod\n"; + return time_taken; +} + +// double BM_EltwiseVectorVectorAddMod(size_t input_size) { // NOLINT +// uint64_t modulus = 0xffffffffffc0001ULL; + +// auto input1 = intel::hexl::GenerateInsecureUniformIntRandomValues(input_size, +// 0, modulus); +// auto input2 = intel::hexl::GenerateInsecureUniformIntRandomValues(input_size, +// 0, modulus); +// intel::hexl::AlignedVector64 output(input_size, 0); + +// double time_taken = TimeFunction([&]() { +// intel::hexl::EltwiseAddMod(output.data(), input1.data(), input2.data(), input_size, modulus); +// }); + +// return time_taken; +// } + +// double BM_EltwiseVectorScalarAddMod(size_t input_size) { +// uint64_t modulus = 0xffffffffffc0001ULL; + +// auto input1 = intel::hexl::GenerateInsecureUniformIntRandomValues(input_size, 0, modulus); +// uint64_t input2 = intel::hexl::GenerateInsecureUniformIntRandomValue(0, modulus); +// intel::hexl::AlignedVector64 output(input_size, 0); + +// double time_taken = TimeFunction([&]() { +// intel::hexl::EltwiseAddMod(output.data(), input1.data(), input2, input_size, modulus); +// }); + +// return time_taken; +// } + + +int main(int argc, char** argv) { + if (argc != 4) { + std::cerr << "Usage: " << argv[0] << " " << std::endl; + return 1; + } + + int num_iterations = std::stoi(argv[1]); + std::vector thread_nums = split(argv[2], ','); + int input_size = std::stoi(argv[3]); + + for (size_t j = 0; j < thread_nums.size(); ++j) { + int num_threads = thread_nums[j]; + omp_set_num_threads(num_threads); + + double t_EltwiseVectorVectorAddMod = 0.0; + double t_EltwiseVectorScalarAddMod = 0.0; + double t_EltwiseCmpAdd = 0.0; + double t_EltwiseCmpSubMod = 0.0; + double t_EltwiseFMAMod = 0.0; + double t_EltwiseMultMod = 0.0; + double t_NTT = 0.0; + double t_EltwiseReduceMod = 0.0; + // double t_BM_EltwiseVectorVectorAddMod = 0.0; + // double t_BM_EltwiseVectorScalarAddMod = 0.0; + + // Rest of your code + for (int i = 0; i < num_iterations; i++) { + + // computation code for each iteration + t_EltwiseVectorVectorAddMod += ExampleEltwiseVectorVectorAddMod(); + t_EltwiseVectorScalarAddMod += ExampleEltwiseVectorScalarAddMod(); + t_EltwiseCmpAdd += ExampleEltwiseCmpAdd(); + t_EltwiseCmpSubMod += ExampleEltwiseCmpSubMod(); + t_EltwiseFMAMod += ExampleEltwiseFMAMod(); + t_EltwiseMultMod += ExampleEltwiseMultMod(); + t_NTT += ExampleNTT(); + t_EltwiseReduceMod += ExampleEltwiseReduceMod(); + // t_BM_EltwiseVectorVectorAddMod += + // BM_EltwiseVectorVectorAddMod(input_size); + // t_BM_EltwiseVectorScalarAddMod += + // BM_EltwiseVectorScalarAddMod(input_size); + } + + std::cout << "Thread Number: " << num_threads + << " after itertating: " << num_iterations << " times "<< std::endl; + + // Calculate the width for columns based on the longest label + int labelWidth = 40; + int timeWidth = 20; + + std::cout << std::left << std::setw(labelWidth) << "Operation" + << std::setw(timeWidth) << "Time (ms)" << std::endl; + std::cout << std::setfill('-') << std::setw(labelWidth + timeWidth) + << "-" << std::setfill(' ') << std::endl; + std::cout << std::left << std::setw(labelWidth) + << "EltwiseVectorVectorAddMod" << std::setw(timeWidth) + << t_EltwiseVectorVectorAddMod << std::endl; + std::cout << std::left << std::setw(labelWidth) + << "EltwiseVectorScalarAddMod" << std::setw(timeWidth) + << t_EltwiseVectorScalarAddMod << std::endl; + std::cout << std::left << std::setw(labelWidth) << "EltwiseCmpAdd" + << std::setw(timeWidth) << t_EltwiseCmpAdd << std::endl; + std::cout << std::left << std::setw(labelWidth) << "EltwiseCmpSubMod" + << std::setw(timeWidth) << t_EltwiseCmpSubMod << std::endl; + std::cout << std::left << std::setw(labelWidth) << "EltwiseFMAMod" + << std::setw(timeWidth) << t_EltwiseFMAMod << std::endl; + std::cout << std::left << std::setw(labelWidth) << "EltwiseMultMod" + << std::setw(timeWidth) << t_EltwiseMultMod << std::endl; + std::cout << std::left << std::setw(labelWidth) << "EltwiseReduceMod" + << std::setw(timeWidth) << t_EltwiseReduceMod << std::endl; + std::cout << std::left << std::setw(labelWidth) << "NTT" + << std::setw(timeWidth) << t_NTT << std::endl; + // std::cout << std::left << std::setw(labelWidth) << + // "BM_EltwiseVectorVectorAddMod" << std::setw(timeWidth) << + // t_BM_EltwiseVectorVectorAddMod << std::endl; std::cout << std::left + // << std::setw(labelWidth) << "BM_EltwiseVectorScalarAddMod" << + // std::setw(timeWidth) << t_BM_EltwiseVectorScalarAddMod << std::endl; + std::cout << "******************************************************" << std::endl; + } + + return 0; + } diff --git a/omp_time.csv b/omp_time.csv new file mode 100644 index 00000000..19bbffec --- /dev/null +++ b/omp_time.csv @@ -0,0 +1,21 @@ +Method_Number,Thread_Count,Input_Size,Average_Elapsed_Time +4,1,4096,.00002207000000000000 +4,2,4096,.00001539000000000000 +4,4,4096,.00006322000000000000 +4,6,4096,.00001319000000000000 +4,8,4096,.00001400000000000000 +4,1,65536,.00029930000000000000 +4,2,65536,.00022425000000000000 +4,4,65536,.00016551000000000000 +4,6,65536,.00018119000000000000 +4,8,65536,.00021523000000000000 +4,1,1048576,.00358277000000000000 +4,2,1048576,.00289902000000000000 +4,4,1048576,.00224447000000000000 +4,6,1048576,.00172955000000000000 +4,8,1048576,.00148815000000000000 +4,1,16777216,.05797288000000000000 +4,2,16777216,.05529537000000000000 +4,4,16777216,.03163143000000000000 +4,6,16777216,.02106166000000000000 +4,8,16777216,.01790957000000000000 diff --git a/time_example/cmake/CMakeLists.txt b/time_example/cmake/CMakeLists.txt new file mode 100644 index 00000000..4d5d9118 --- /dev/null +++ b/time_example/cmake/CMakeLists.txt @@ -0,0 +1,32 @@ +# Copyright (C) 2020 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +project(hexl_omp_example LANGUAGES C CXX) +cmake_minimum_required(VERSION 3.13) +set(CMAKE_CXX_STANDARD 17) + +# Define the directory containing HEXLConfig.cmake +set(HEXL_HINT_DIR "/home/eidf018/eidf018/s1820742psd/hexl/hexl_v3") + +# Example using source +find_package(HEXL 1.2.5 + HINTS ${HEXL_HINT_DIR} + REQUIRED) +if (NOT TARGET HEXL::hexl) + message(FATAL_ERROR "TARGET HEXL::hexl not found") +endif() + +find_package(OpenMP) +if (OpenMP_FOUND) + message(STATUS "OpenMP_CXX_INCLUDE_DIRS: ${OpenMP_CXX_INCLUDE_DIRS}") + message(STATUS "OpenMP_CXX_LIBRARIES: ${OpenMP_CXX_LIBRARIES}") +endif() + +# Add the directory containing libeasyloggingpp.a to the link directories +link_directories(/home/eidf018/eidf018/s1820742psd/hexl/hexl_v3/lib) + + +add_executable(time_example ../time_example.cpp) +target_link_libraries(time_example PRIVATE HEXL::hexl) + + diff --git a/time_example/time_example.cpp b/time_example/time_example.cpp new file mode 100644 index 00000000..24bf3465 --- /dev/null +++ b/time_example/time_example.cpp @@ -0,0 +1,392 @@ +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "../hexl/include/hexl/hexl.hpp" +#include "../hexl/include/hexl/util/util.hpp" +#include "../hexl/util/util-internal.hpp" +// #include "../hexl/include/hexl/experimental/fft-like/fft-like.hpp" +// #include "../hexl/include/hexl/experimental/fft-like/fft-like-native.hpp" + + +template +double TimeFunction(Func&& f) { + auto start_time = std::chrono::high_resolution_clock::now(); + f(); + auto end_time = std::chrono::high_resolution_clock::now(); + + std::chrono::duration duration = end_time - start_time; + return duration.count(); +} + +std::vector split(const std::string& s, char delimiter) { + std::vector tokens; + std::string token; + std::istringstream tokenStream(s); + while (std::getline(tokenStream, token, delimiter)) { + tokens.push_back(std::stoi(token)); + } + return tokens; +} + +double BM_EltwiseVectorVectorAddMod(size_t input_size) { + uint64_t modulus = 0xffffffffffc0001ULL; + + auto input1 = intel::hexl::GenerateInsecureUniformIntRandomValues(input_size, + 0, modulus); + auto input2 = intel::hexl::GenerateInsecureUniformIntRandomValues(input_size, + 0, modulus); + intel::hexl::AlignedVector64 output(input_size, 0); + + double time_taken = TimeFunction([&]() { + intel::hexl::EltwiseAddMod(output.data(), input1.data(), input2.data(), + input_size, modulus); + }); + + return time_taken; +} + +double BM_EltwiseVectorScalarAddMod(size_t input_size) { + uint64_t modulus = 0xffffffffffc0001ULL; + + auto input1 = intel::hexl::GenerateInsecureUniformIntRandomValues(input_size, + 0, modulus); + uint64_t input2 = + intel::hexl::GenerateInsecureUniformIntRandomValue(0, modulus); + intel::hexl::AlignedVector64 output(input_size, 0); + + double time_taken = TimeFunction([&]() { + intel::hexl::EltwiseAddMod(output.data(), input1.data(), input2, input_size, + modulus); + }); + + return time_taken; +} + +double BM_EltwiseCmpAdd(size_t input_size, intel::hexl::CMPINT chosenCMP) { + uint64_t modulus = 100; + + uint64_t bound = + intel::hexl::GenerateInsecureUniformIntRandomValue(0, modulus); + uint64_t diff = + intel::hexl::GenerateInsecureUniformIntRandomValue(1, modulus - 1); + auto input1 = intel::hexl::GenerateInsecureUniformIntRandomValues(input_size, + 0, modulus); + + double time_taken = TimeFunction([&]() { + intel::hexl::EltwiseCmpAdd(input1.data(), input1.data(), input_size, + chosenCMP, bound, diff); + }); + + return time_taken; +} + +double BM_EltwiseCmpSubMod( + size_t input_size, intel::hexl::CMPINT chosenCMP) { + uint64_t modulus = 100; + + uint64_t bound = + intel::hexl::GenerateInsecureUniformIntRandomValue(1, modulus); + uint64_t diff = + intel::hexl::GenerateInsecureUniformIntRandomValue(1, modulus); + auto input1 = intel::hexl::GenerateInsecureUniformIntRandomValues(input_size, + 0, modulus); + + double time_taken = TimeFunction([&]() { + intel::hexl::EltwiseCmpSubMod(input1.data(), input1.data(), input_size, + modulus, chosenCMP, bound, diff); + }); + + return time_taken; +} + +double BM_EltwiseFMAModAdd(size_t input_size, bool add) { + uint64_t modulus = 0xffffffffffc0001ULL; + + auto input1 = intel::hexl::GenerateInsecureUniformIntRandomValues(input_size, + 0, modulus); + uint64_t input2 = + intel::hexl::GenerateInsecureUniformIntRandomValue(0, modulus); + intel::hexl::AlignedVector64 input3 = + intel::hexl::GenerateInsecureUniformIntRandomValues(input_size, 0, + modulus); + uint64_t* arg3 = add ? input3.data() : nullptr; + double time_taken = TimeFunction([&]() { + intel::hexl::EltwiseFMAMod(input1.data(), input1.data(), input2, arg3, + input1.size(), modulus, 1); + }); + return time_taken; +} + +double BM_EltwiseMultMod(size_t input_size, size_t bit_width, + size_t input_mod_factor) { + uint64_t modulus = (1ULL << bit_width) + 7; + + auto input1 = intel::hexl::GenerateInsecureUniformIntRandomValues(input_size, + 0, modulus); + auto input2 = intel::hexl::GenerateInsecureUniformIntRandomValues(input_size, + 0, modulus); + intel::hexl::AlignedVector64 output(input_size, 2); + + double time_taken = TimeFunction([&]() { + intel::hexl::EltwiseMultMod(output.data(), input1.data(), input2.data(), + input_size, modulus, input_mod_factor); + }); + return time_taken; +} + +// double BM_EltwiseReduceModInPlace(size_t input_size) { +// uint64_t modulus = 0xffffffffffc0001ULL; + +// auto input1 = intel::hexl::GenerateInsecureUniformIntRandomValues( +// input_size, 0, 100 * modulus); + +// const uint64_t input_mod_factor = modulus; +// const uint64_t output_mod_factor = 1; + +// double time_taken = TimeFunction([&]() { +// intel::hexl::EltwiseReduceMod(input1.data(), input1.data(), input_size, +// modulus, input_mod_factor, output_mod_factor); +// }); +// return time_taken; +// } +double BM_EltwiseReduceModInPlace(size_t input_size) { + uint64_t modulus = 0xffffffffffc0001ULL; + + auto input1 = intel::hexl::GenerateInsecureUniformIntRandomValues( + input_size, 0, 100 * modulus); + intel::hexl::AlignedVector64 output(input_size, 0); + + const uint64_t input_mod_factor = modulus; + const uint64_t output_mod_factor = 1; + + double time_taken = TimeFunction([&]() { + intel::hexl::EltwiseReduceMod(output.data(), input1.data(), input_size, + modulus, input_mod_factor, output_mod_factor); + }); + return time_taken; +} + +double BM_EltwiseVectorVectorSubMod(size_t input_size) { + uint64_t modulus = 0xffffffffffc0001ULL; + + auto input1 = intel::hexl::GenerateInsecureUniformIntRandomValues(input_size, + 0, modulus); + auto input2 = intel::hexl::GenerateInsecureUniformIntRandomValues(input_size, + 0, modulus); + intel::hexl::AlignedVector64 output(input_size, 0); + + double time_taken = TimeFunction([&]() { + intel::hexl::EltwiseSubMod(output.data(), input1.data(), input2.data(), + input_size, modulus); + }); + return time_taken; +} + +double BM_NTTInPlace(size_t ntt_size) { + size_t modulus = intel::hexl::GeneratePrimes(1, 45, true, ntt_size)[0]; + + auto input = + intel::hexl::GenerateInsecureUniformIntRandomValues(ntt_size, 0, modulus); + intel::hexl::NTT ntt(ntt_size, modulus); + + double time_taken = TimeFunction([&]() { + ntt.ComputeForward(input.data(), input.data(), 1, 1); + }) + + TimeFunction([&]() { + ntt.ComputeInverse(input.data(), input.data(), 2, 1); + }); + return time_taken; +} + +int main(int argc, char** argv) { + if (argc != 5) { + std::cerr << "Usage: " << argv[0] + << " " + << std::endl; + return 1; + } + + int num_iterations = std::stoi(argv[1]); + std::vector thread_nums = split(argv[2], ','); + int input_size = std::stoi(argv[3]); + int method_number = std::stoi(argv[4]); + std::string method; + if(method_number>=9 || method_number<0){ + std::cerr + << "Method Number range from 0 to 8" << "0--EltwiseVectorVectorAddMod, " + << "1 --EltwiseVectorScalarAddMod, " << "2 --EltwiseCmpAdd, " + << "3 --EltwiseCmpSubMod, " << "4 --EltwiseFMAModAdd, " + << "5 --EltwiseMultMod, " << "6 --EltwiseReduceModInPlace, " + << "7 --EltwiseVectorVectorSubMod, " + << "8 --NTTInPlace, " + << std::endl; + return 1; + } + + // Using a map to store the results + std::map results; + // std::function operation; + // Initialize the results map + + switch (method_number) { + case 0: + method = "BM_EltwiseVectorVectorAddMod"; + break; + + case 1: + method = "BM_EltwiseVectorScalarAddMod"; + break; + + case 2: + method = "BM_EltwiseCmpAdd"; + break; + + case 3: + method = "BM_EltwiseCmpSubMod"; + break; + + case 4: + method = "BM_EltwiseFMAModAdd"; + break; + + case 5: + method = "BM_EltwiseMultMod"; + // std::cout << "BM_EltwiseMultMod" << std::endl; + break; + + case 6: + method = "BM_EltwiseReduceModInPlace"; + // std::cout << "BM_EltwiseReduceModInPlace" << std::endl; + break; + + case 7: + method = "BM_EltwiseVectorVectorSubMod"; + // std::cout << "BM_EltwiseVectorVectorSubMod" << std::endl; + break; + + case 8: + method = "BM_NTTInPlace"; + // std::cout << "BM_NTTInPlace" << std::endl; + break; + + default: + method = "BM_EltwiseVectorVectorAddMod"; + // std::cout << "Default" << std::endl; + break; + } + + // Execute each method for all thread numbers + for (size_t j = 0; j < thread_nums.size(); ++j) { + int num_threads = thread_nums[j]; + omp_set_num_threads(num_threads); + + // Initialize the results map + results[num_threads] = 0.0; + + bool add_choices[] = {false,true}; + int bit_width_choices[] = {48, 60}; + int mod_factor_choices[] = {1, 2, 4}; + + for (int i = 0; i < num_iterations; i++) { + // There are CMPINT possibilities, should be chosen randomly for testing + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution<> dis_cmpint(0, 7); // 8 enum values, from 0 to 7 + std::uniform_int_distribution<> dis_add(0, 1); + std::uniform_int_distribution<> dis_factor(0, 2); + + intel::hexl::CMPINT chosenCMP = + static_cast(dis_cmpint(gen)); + + bool add = add_choices[dis_add(gen)]; + + size_t bit_width = bit_width_choices[dis_add(gen)]; + + size_t input_mod_factor = mod_factor_choices[dis_factor(gen)]; + + if(method_number == 0 || method_number == 1 || method_number == 6 || method_number == 7 || method_number == 8){ + std::function operation; + switch (method_number) { + case 0: + operation = BM_EltwiseVectorVectorAddMod; + // std::cout << "BM_EltwiseVectorVectorAddMod" << std::endl; + break; + + case 1: + operation = BM_EltwiseVectorScalarAddMod; + // std::cout << "BM_EltwiseVectorScalarAddMod" << std::endl; + break; + + case 6: + operation = BM_EltwiseReduceModInPlace; + // std::cout << "BM_EltwiseReduceModInPlace" << std::endl; + break; + + case 7: + operation = BM_EltwiseVectorVectorSubMod; + // std::cout << "BM_EltwiseVectorVectorSubMod" << std::endl; + break; + + case 8: + operation = BM_NTTInPlace; + input_size = input_size/4096; + // std::cout << "BM_NTTInPlace" << std::endl; + break; + + default: + operation = BM_EltwiseVectorVectorAddMod; + // std::cout << "Default" << std::endl; + break; + } + results[num_threads] += operation(input_size); + } + + else if(method_number == 2 || method_number == 3 ){ + std::function operation; + switch (method_number) { + case 2: + operation = BM_EltwiseCmpAdd; + // std::cout << "BM_EltwiseCmpAdd" << std::endl; + break; + + case 3: + operation = BM_EltwiseCmpSubMod; + // std::cout << "BM_EltwiseCmpSubMod" << std::endl; + break; + + default: + operation = BM_EltwiseCmpAdd; + // std::cout << "Default" << std::endl; + break; + } + + results[num_threads] += operation(input_size, chosenCMP); + } + else if(method_number == 4){ + std::function operation = BM_EltwiseFMAModAdd; + // std::cout << "BM_EltwiseFMAModAdd" << std::endl; + results[num_threads] += operation(input_size, add); + } + else if (method_number == 5) { + std::function operation = BM_EltwiseMultMod; + // std::cout << "BM_EltwiseMultMod" << std::endl; + results[num_threads] += + operation(input_size, bit_width, input_mod_factor); + } + + } + } + + return 0; +} diff --git a/visualisor_omp.py b/visualisor_omp.py new file mode 100644 index 00000000..70db1aad --- /dev/null +++ b/visualisor_omp.py @@ -0,0 +1,150 @@ +import matplotlib.pyplot as plt +import csv +import numpy as np +import pandas as pd + +filename = "hexl_omp_out_0824_23_59.csv" + +data = {} + +with open(filename, 'r') as file: + reader = csv.reader(file) + + input_size = None + + for row_index, row in enumerate(reader): + if len(row) == 0: + continue + + print(f"Debug: Processing row {row_index}: {row}") # Debug print + + if row[0].startswith("Input Size"): + input_size = row[0].split('=')[1].strip() + data[input_size] = {} + next(reader) # Skip headers + else: + row_split = row[0].split() + method = row_split[0] + print(f"Debug: method is: {method}") # Debug print + + for i, threads in enumerate(["Threads=1", "Threads=2", "Threads=4", "Threads=6", "Threads=8"]): + + print(f"Debug: i = {i}, row length = {len(row_split)}") # Debug print + + if i + 1 >= len(row_split): + print(f"Warning: Skipping index {i + 1} as it's out of range for row {row}") + continue + + if not method in data[input_size]: + data[input_size][method] = {} + + print(f"row_split[i+1] = {row_split[i+1]}") # Debug print + data[input_size][method][threads] = float(row_split[i + 1]) + + +serial_times = {} + +# Read the file with serial execution times +with open("hexl_ser_out_0824_1431.csv", 'r') as file: + lines = file.readlines() + # Extract header to get input sizes + header = lines[0].strip().split() + input_sizes = [int(input_size.split('=')[1]) for input_size in header[1:]] + + # Loop through each row (skipping the header) + for line in lines[1:]: + elements = line.strip().split() + method = elements[0] + times = [float(time) for time in elements[1:]] + serial_times[method] = {} + for input_size, time in zip(input_sizes, times): + serial_times[method][input_size] = time + +# print(serial_times) + + +for method in data[list(data.keys())[0]].keys(): # Assuming each input size has the same methods + + # Create a figure and layout for subplots + fig, axs = plt.subplots(2, 3, figsize=(15, 10), gridspec_kw={'width_ratios': [1, 1, 1], 'height_ratios': [1, 1]}) + fig.suptitle(f'{method} - Execution Time vs Thread Count') + + axs = axs.flatten() + + # First set of graphs for each method: Execution time vs Thread number + for ax, (input_size, input_data) in zip(axs, data.items()): + thread_counts = [] + times = [] + for thread, time in input_data[method].items(): + thread_counts.append(int(thread.split("=")[1])) + times.append(time) + + ax.plot(thread_counts, times, label=f"Input Size {input_size}") + + ax.axhline(y=serial_times[method][int(input_size)], color='r', linestyle='--', label='Serial input size {input_size}') + + ax.set_xlabel('Thread Count') + ax.set_ylabel('Execution Time') + ax.set_title(f'Input Size = {input_size}') + ax.grid(True) + + axs[-1].set_visible(False) + + # Add a single legend for all subplots + fig.legend(loc="lower right", bbox_to_anchor=(0.9, 0.1)) + + # Add a text caption below the figure + # fig.text(0.5, -0.05, 'Caption: Description of the figure.', ha='center') + + plt.tight_layout(rect=[0, 0.05, 1, 0.95]) + + # Save the figure + plt.savefig(f"Dissertation/{method}_Execution_Time_vs_Thread_Count.png") + + # plt.show() + + +num_methods = len(data[list(data.keys())[0]].keys()) +fig, axs = plt.subplots(9, 1, figsize=(12, 20)) +# fig.suptitle('Execution Time vs Input Size') + +# Flatten axes array if it's multidimensional +if num_methods > 1: + axs = axs.flatten() + +# Loop through each method +for ax, method in zip(axs, data[list(data.keys())[0]].keys()): + for thread in ["Threads=1", "Threads=2", "Threads=4", "Threads=6", "Threads=8"]: + input_sizes = [] + times = [] + for input_size in data.keys(): + input_sizes.append(int(input_size)) + times.append(data[input_size][method][thread]) + + ax.plot(input_sizes, times, label=f"{thread}") + + ax.set_xlabel('Input Size') + ax.set_ylabel('Execution Time(s)') + ax.set_title(f'{method}') + if ax == 0: + ax.legend(loc="lower right") + + ax.set_xscale('log', base=2) + ax.set_xticks([2**12, 2**16, 2**20, 2**24, 2**28]) + ax.set_xticklabels(['2^12', '2^16', '2^20', '2^24', '2^28']) + ax.set_yscale('log', base=2) + ax.grid(True) + + + +# Tighten layout +plt.tight_layout(rect=[0, 0.03, 1, 0.95]) +# Save the figure +plt.savefig("Dissertation/Execution_Time_vs_Input_size.png") +# Show or save figure +# plt.show() + + + + + diff --git a/visualisor_ser.py b/visualisor_ser.py new file mode 100644 index 00000000..7a7cf8be --- /dev/null +++ b/visualisor_ser.py @@ -0,0 +1,153 @@ +# import pandas as pd +# import numpy as np +# import matplotlib.pyplot as plt +# from sklearn.linear_model import LinearRegression + +# # Read the CSV file into a pandas DataFrame +# df = pd.read_csv('hexl_ser_out_0824_1431.csv', delim_whitespace=True) + +# # Strip leading and trailing whitespaces from column names +# df.columns = df.columns.str.strip() + +# # Inspect the DataFrame +# print(df.head()) + +# # Check if 'Method' column exists +# try: +# print(df['Method']) +# except KeyError: +# print("Column 'Method' not found!") + +# # Input sizes from the DataFrame columns +# input_sizes = [int(col.split('=')[1]) for col in df.columns if 'Input_size' in col] +# x_vals = np.log2(input_sizes).reshape(-1, 1) # Log-transform the x-values + +# # Enlarge the figure size +# plt.figure(figsize=(15,21)) + +# # Loop over each method +# for idx, method in enumerate(df['Method']): +# y_vals = np.log(df.loc[idx, 'Input_size=4096':'Input_size=268435456'].to_numpy(dtype='float64')) # Log-transform the y-values + + +# # Linear Regression +# model = LinearRegression() +# model.fit(x_vals, y_vals) +# x_test = np.linspace(min(x_vals), max(x_vals), 300).reshape(-1, 1) +# y_pred = model.predict(x_test) + +# # Calculate residuals +# residuals = y_vals - model.predict(x_vals) + +# # Calculate the standard deviation of the residuals +# residual_std = np.std(residuals) + +# # Define the color for this particular method +# current_color = next(plt.gca()._get_lines.prop_cycler)['color'] + +# # Plotting +# plt.scatter(x_vals, y_vals, label=f"{method} (slope: {model.coef_[0]:.2f})", color=current_color) +# plt.plot(x_test, y_pred, color=current_color) + +# plt.errorbar(x_vals.flatten(), y_vals, yerr=residual_std, fmt='o', capsize=5, color=current_color) + + +# plt.xticks(x_vals.flatten(), [f'$2^{{{int(x)}}}$' for x in x_vals.flatten()], fontsize='20') + +# # Move the legend to the top left within the graph +# plt.legend(loc='upper left', fontsize='20') + +# # Add labels and grid +# plt.xlabel('Input Size',fontsize='20') +# plt.ylabel('Logarithm to base 2 of Execution Time (ms)',fontsize='20') +# plt.grid(True) + +# # Adjust layout and show the plot +# plt.tight_layout() +# # Save the figure as a high-resolution jpg file +# plt.savefig('hexl_ser_plot.jpg', format='jpg', dpi=300, bbox_inches='tight', pad_inches=0.1) + +# plt.show() + +import pandas as pd +import numpy as np +import matplotlib.pyplot as plt +from sklearn.linear_model import LinearRegression +from sklearn.preprocessing import PolynomialFeatures +from sklearn.metrics import mean_squared_error +import math + +# Read the CSV file into a pandas DataFrame +df = pd.read_csv('hexl_ser_out_0824_1431.csv', delim_whitespace=True) +df.columns = df.columns.str.strip() + +# Input sizes from the DataFrame columns +input_sizes = [int(col.split('=')[1]) for col in df.columns if 'Input_size' in col] +x_vals = np.log2(input_sizes).reshape(-1, 1) + +# Create polynomial features +poly = PolynomialFeatures(degree=2) +x_poly = poly.fit_transform(x_vals) + +# Enlarge the figure size +plt.figure(figsize=(15, 21)) + +# Loop over each method +for idx, method in enumerate(df['Method']): + y_vals = np.log(df.loc[idx, 'Input_size=4096':'Input_size=268435456'].to_numpy(dtype='float64')) + + color = next(plt.gca()._get_lines.prop_cycler)['color'] + + + if method == "BM_NTTInPlace": + # Polynomial Regression + model = LinearRegression() + model.fit(x_poly, y_vals) + poly_y_pred = model.predict(x_poly) + + # Calculate the standard error of the residuals + mse = mean_squared_error(y_vals, poly_y_pred) + std_error = math.sqrt(mse) + + # Generate test data and predict + x_test = np.linspace(min(x_vals), max(x_vals), 300).reshape(-1, 1) + x_test_poly = poly.transform(x_test) + y_pred_test = model.predict(x_test_poly) + + plt.scatter(x_vals, y_vals, color=color, label=f"{method} (coef: {model.coef_[1]:.2f}, {model.coef_[2]:.2f})") + plt.errorbar(x_vals, y_vals, yerr=std_error, fmt='o', color=color) + plt.plot(x_test, y_pred_test, color=color) + + else: + # Linear Regression + model = LinearRegression() + model.fit(x_vals, y_vals) + x_test = np.linspace(min(x_vals), max(x_vals), 300).reshape(-1, 1) + y_pred = model.predict(x_test) + + # Calculate residuals + residuals = y_vals - model.predict(x_vals) + + # Calculate the standard deviation of the residuals + residual_std = np.std(residuals) + + # Plotting + plt.scatter(x_vals, y_vals, label=f"{method} (slope: {model.coef_[0]:.2f})", color=color) + plt.plot(x_test, y_pred, color=color) + plt.errorbar(x_vals.flatten(), y_vals, yerr=residual_std, fmt='o', capsize=5, color=color) + +# Move the legend to the top left within the graph +plt.legend(loc='upper left', fontsize='20') + +# Set custom x-ticks +plt.xticks(x_vals.flatten(), [f'$2^{{{int(x)}}}$' for x in x_vals.flatten()], fontsize='20') + +# Add labels and grid +plt.xlabel('Input Size', fontsize='20') +plt.ylabel('Logarithm to base 2 of Execution Time (ms)', fontsize='20') +plt.grid(True) + +# Adjust layout and save the plot +plt.tight_layout() +plt.savefig('hexl_ser_plot_poly_with_error.jpg', format='jpg', dpi=300, bbox_inches='tight', pad_inches=0.1) +plt.show()