Skip to content

Commit

Permalink
Merge pull request #20 from iden3/fix-msm
Browse files Browse the repository at this point in the history
MSM optimizations
  • Loading branch information
olomix authored Sep 5, 2024
2 parents 35dfe13 + 76825e5 commit 74dd391
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 114 deletions.
14 changes: 7 additions & 7 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,25 +42,25 @@ jobs:
- name: Build prover Android ARM64
run: |
mkdir -p build_prover_android && cd build_prover_android
cmake .. -DTARGET_PLATFORM=ANDROID -DCMAKE_BUILD_TYPE=Release -DCMAKE_INSTALL_PREFIX=../package_android -DUSE_OPENMP=OFF
cmake .. -DTARGET_PLATFORM=ANDROID -DCMAKE_BUILD_TYPE=Release -DCMAKE_INSTALL_PREFIX=../package_android -DBUILD_TESTS=OFF -DUSE_OPENMP=OFF
make -j4 && make install
- name: Build prover Android ARM64 with OpenMP
run: |
mkdir -p build_prover_android_openmp && cd build_prover_android_openmp
cmake .. -DTARGET_PLATFORM=ANDROID -DCMAKE_BUILD_TYPE=Release -DCMAKE_INSTALL_PREFIX=../package_android_openmp -DUSE_OPENMP=ON
cmake .. -DTARGET_PLATFORM=ANDROID -DCMAKE_BUILD_TYPE=Release -DCMAKE_INSTALL_PREFIX=../package_android_openmp -DBUILD_TESTS=OFF -DUSE_OPENMP=ON
make -j4 && make install
- name: Build prover Android x86_64
run: |
mkdir -p build_prover_android_x86_64 && cd build_prover_android_x86_64
cmake .. -DTARGET_PLATFORM=ANDROID_x86_64 -DCMAKE_BUILD_TYPE=Release -DCMAKE_INSTALL_PREFIX=../package_android_x86_64 -DUSE_OPENMP=OFF
cmake .. -DTARGET_PLATFORM=ANDROID_x86_64 -DCMAKE_BUILD_TYPE=Release -DCMAKE_INSTALL_PREFIX=../package_android_x86_64 -DBUILD_TESTS=OFF -DUSE_OPENMP=OFF
make -j4 && make install
- name: Build prover Android x86_64 with OpenMP
run: |
mkdir -p build_prover_android_openmp_x86_64 && cd build_prover_android_openmp_x86_64
cmake .. -DTARGET_PLATFORM=ANDROID_x86_64 -DCMAKE_BUILD_TYPE=Release -DCMAKE_INSTALL_PREFIX=../package_android_openmp_x86_64 -DUSE_OPENMP=ON
cmake .. -DTARGET_PLATFORM=ANDROID_x86_64 -DCMAKE_BUILD_TYPE=Release -DCMAKE_INSTALL_PREFIX=../package_android_openmp_x86_64 -DBUILD_TESTS=OFF -DUSE_OPENMP=ON
make -j4 && make install
- name: Build prover Linux
Expand Down Expand Up @@ -184,13 +184,13 @@ jobs:
if [[ ! -d "depends/gmp/package_macos_arm64" ]]; then ./build_gmp.sh macos_arm64; fi
mkdir -p build_prover_ios && cd build_prover_ios
cmake .. -GXcode -DTARGET_PLATFORM=IOS -DCMAKE_INSTALL_PREFIX=../package_ios
cmake .. -GXcode -DTARGET_PLATFORM=IOS -DCMAKE_INSTALL_PREFIX=../package_ios -DBUILD_TESTS=OFF
xcodebuild -destination 'generic/platform=iOS' -scheme rapidsnarkStatic -project rapidsnark.xcodeproj -configuration Release
cp ../depends/gmp/package_ios_arm64/lib/libgmp.a src/Release-iphoneos
cd ../
mkdir -p build_prover_ios_simulator && cd build_prover_ios_simulator
cmake .. -GXcode -DTARGET_PLATFORM=IOS -DCMAKE_INSTALL_PREFIX=../package_ios_simulator -DUSE_ASM=NO
mkdir -p build_prover_ios_simulator && cd build_prover_ios_simulator
cmake .. -GXcode -DTARGET_PLATFORM=IOS -DCMAKE_INSTALL_PREFIX=../package_ios_simulator -DUSE_ASM=NO -DBUILD_TESTS=OFF
xcodebuild -destination 'generic/platform=iOS Simulator' -scheme rapidsnarkStatic -project rapidsnark.xcodeproj
cp ../depends/gmp/package_iphone_simulator/lib/libgmp.a src/Debug-iphonesimulator
cd ../
Expand Down
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ project(rapidsnark LANGUAGES CXX C ASM)
set(CMAKE_CXX_STANDARD 11)
set(CMAKE_CXX_STANDARD_REQUIRED ON)

message("BITS_PER_CHUNK=" ${BITS_PER_CHUNK})
message("USE_ASM=" ${USE_ASM})
message("USE_OPENMP=" ${USE_OPENMP})
message("CMAKE_CROSSCOMPILING=" ${CMAKE_CROSSCOMPILING})
Expand Down
2 changes: 1 addition & 1 deletion depends/ffiasm
17 changes: 12 additions & 5 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ if(USE_ASM)
endif()
endif()

if(DEFINED BITS_PER_CHUNK)
add_definitions(-DMSM_BITS_PER_CHUNK=${BITS_PER_CHUNK})
endif()

if(USE_ASM AND ARCH MATCHES "x86_64")

if (CMAKE_HOST_SYSTEM_NAME MATCHES "Darwin")
Expand Down Expand Up @@ -131,12 +135,15 @@ if(USE_SODIUM)
target_link_libraries(prover sodium)
endif()

option(BUILD_TESTS "Build the tests" ON)

enable_testing()
add_executable(test_public_size test_public_size.c)
target_link_libraries(test_public_size rapidsnarkStaticFrFq)
add_test(NAME test_public_size COMMAND test_public_size circuit_final.zkey 86
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/testdata)
if(BUILD_TESTS)
enable_testing()
add_executable(test_public_size test_public_size.c)
target_link_libraries(test_public_size rapidsnarkStaticFrFq pthread)
add_test(NAME test_public_size COMMAND test_public_size circuit_final.zkey 86
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/testdata)
endif()

if(OpenMP_CXX_FOUND)

Expand Down
173 changes: 72 additions & 101 deletions src/groth16.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#include "random_generator.hpp"
#include "logging.hpp"
#include <future>
#include "misc.hpp"
#include <vector>
#include <mutex>

namespace Groth16 {

Expand Down Expand Up @@ -46,114 +48,84 @@ std::unique_ptr<Prover<Engine>> makeProver(
template <typename Engine>
std::unique_ptr<Proof<Engine>> Prover<Engine>::prove(typename Engine::FrElement *wtns) {

#ifdef USE_OPENMP
ThreadPool &threadPool = ThreadPool::defaultPool();

LOG_TRACE("Start Multiexp A");
uint32_t sW = sizeof(wtns[0]);
typename Engine::G1Point pi_a;
E.g1.multiMulByScalar(pi_a, pointsA, (uint8_t *)wtns, sW, nVars);
E.g1.multiMulByScalarMSM(pi_a, pointsA, (uint8_t *)wtns, sW, nVars);
std::ostringstream ss2;
ss2 << "pi_a: " << E.g1.toString(pi_a);
LOG_DEBUG(ss2);

LOG_TRACE("Start Multiexp B1");
typename Engine::G1Point pib1;
E.g1.multiMulByScalar(pib1, pointsB1, (uint8_t *)wtns, sW, nVars);
E.g1.multiMulByScalarMSM(pib1, pointsB1, (uint8_t *)wtns, sW, nVars);
std::ostringstream ss3;
ss3 << "pib1: " << E.g1.toString(pib1);
LOG_DEBUG(ss3);

LOG_TRACE("Start Multiexp B2");
typename Engine::G2Point pi_b;
E.g2.multiMulByScalar(pi_b, pointsB2, (uint8_t *)wtns, sW, nVars);
E.g2.multiMulByScalarMSM(pi_b, pointsB2, (uint8_t *)wtns, sW, nVars);
std::ostringstream ss4;
ss4 << "pi_b: " << E.g2.toString(pi_b);
LOG_DEBUG(ss4);

LOG_TRACE("Start Multiexp C");
typename Engine::G1Point pi_c;
E.g1.multiMulByScalar(pi_c, pointsC, (uint8_t *)((uint64_t)wtns + (nPublic +1)*sW), sW, nVars-nPublic-1);
E.g1.multiMulByScalarMSM(pi_c, pointsC, (uint8_t *)((uint64_t)wtns + (nPublic +1)*sW), sW, nVars-nPublic-1);
std::ostringstream ss5;
ss5 << "pi_c: " << E.g1.toString(pi_c);
LOG_DEBUG(ss5);
#else
LOG_TRACE("Start Multiexp A");
uint32_t sW = sizeof(wtns[0]);
typename Engine::G1Point pi_a;
auto pA_future = std::async([&]() {
E.g1.multiMulByScalar(pi_a, pointsA, (uint8_t *)wtns, sW, nVars);
});

LOG_TRACE("Start Multiexp B1");
typename Engine::G1Point pib1;
auto pB1_future = std::async([&]() {
E.g1.multiMulByScalar(pib1, pointsB1, (uint8_t *)wtns, sW, nVars);
});

LOG_TRACE("Start Multiexp B2");
typename Engine::G2Point pi_b;
auto pB2_future = std::async([&]() {
E.g2.multiMulByScalar(pi_b, pointsB2, (uint8_t *)wtns, sW, nVars);
});

LOG_TRACE("Start Multiexp C");
typename Engine::G1Point pi_c;
auto pC_future = std::async([&]() {
E.g1.multiMulByScalar(pi_c, pointsC, (uint8_t *)((uint64_t)wtns + (nPublic +1)*sW), sW, nVars-nPublic-1);
});
#endif

LOG_TRACE("Start Initializing a b c A");
auto a = new typename Engine::FrElement[domainSize];
auto b = new typename Engine::FrElement[domainSize];
auto c = new typename Engine::FrElement[domainSize];

#pragma omp parallel for
for (u_int32_t i=0; i<domainSize; i++) {
E.fr.copy(a[i], E.fr.zero());
E.fr.copy(b[i], E.fr.zero());
}
threadPool.parallelFor(0, domainSize, [&] (int begin, int end, int numThread) {
for (u_int32_t i=begin; i<end; i++) {
E.fr.copy(a[i], E.fr.zero());
E.fr.copy(b[i], E.fr.zero());
}
});

LOG_TRACE("Processing coefs");
#ifdef _OPENMP
#define NLOCKS 1024
omp_lock_t locks[NLOCKS];
for (int i=0; i<NLOCKS; i++) omp_init_lock(&locks[i]);
#pragma omp parallel for
#endif
for (u_int64_t i=0; i<nCoefs; i++) {
typename Engine::FrElement *ab = (coefs[i].m == 0) ? a : b;
typename Engine::FrElement aux;

E.fr.mul(
aux,
wtns[coefs[i].s],
coefs[i].coef
);
#ifdef _OPENMP
omp_set_lock(&locks[coefs[i].c % NLOCKS]);
#endif
E.fr.add(
ab[coefs[i].c],
ab[coefs[i].c],
aux
);
#ifdef _OPENMP
omp_unset_lock(&locks[coefs[i].c % NLOCKS]);
#endif
}
#ifdef _OPENMP
for (int i=0; i<NLOCKS; i++) omp_destroy_lock(&locks[i]);
#endif

#define NLOCKS 1024
std::vector<std::mutex> locks(NLOCKS);

threadPool.parallelFor(0, nCoefs, [&] (int begin, int end, int numThread) {
for (u_int64_t i=begin; i<end; i++) {
typename Engine::FrElement *ab = (coefs[i].m == 0) ? a : b;
typename Engine::FrElement aux;

E.fr.mul(
aux,
wtns[coefs[i].s],
coefs[i].coef
);

std::lock_guard<std::mutex> guard(locks[coefs[i].c % NLOCKS]);

E.fr.add(
ab[coefs[i].c],
ab[coefs[i].c],
aux
);
}
});
LOG_TRACE("Calculating c");
#pragma omp parallel for
for (u_int32_t i=0; i<domainSize; i++) {
E.fr.mul(
c[i],
a[i],
b[i]
);
}
threadPool.parallelFor(0, domainSize, [&] (int begin, int end, int numThread) {
for (u_int64_t i=begin; i<end; i++) {
E.fr.mul(
c[i],
a[i],
b[i]
);
}
});

LOG_TRACE("Initializing fft");
u_int32_t domainPower = fft->log2(domainSize);
Expand All @@ -164,10 +136,13 @@ std::unique_ptr<Proof<Engine>> Prover<Engine>::prove(typename Engine::FrElement
LOG_DEBUG(E.fr.toString(a[0]).c_str());
LOG_DEBUG(E.fr.toString(a[1]).c_str());
LOG_TRACE("Start Shift A");
#pragma omp parallel for
for (u_int64_t i=0; i<domainSize; i++) {
E.fr.mul(a[i], a[i], fft->root(domainPower+1, i));
}

threadPool.parallelFor(0, domainSize, [&] (int begin, int end, int numThread) {
for (u_int64_t i=begin; i<end; i++) {
E.fr.mul(a[i], a[i], fft->root(domainPower+1, i));
}
});

LOG_TRACE("a After shift:");
LOG_DEBUG(E.fr.toString(a[0]).c_str());
LOG_DEBUG(E.fr.toString(a[1]).c_str());
Expand All @@ -182,10 +157,11 @@ std::unique_ptr<Proof<Engine>> Prover<Engine>::prove(typename Engine::FrElement
LOG_DEBUG(E.fr.toString(b[0]).c_str());
LOG_DEBUG(E.fr.toString(b[1]).c_str());
LOG_TRACE("Start Shift B");
#pragma omp parallel for
for (u_int64_t i=0; i<domainSize; i++) {
E.fr.mul(b[i], b[i], fft->root(domainPower+1, i));
}
threadPool.parallelFor(0, domainSize, [&] (int begin, int end, int numThread) {
for (u_int64_t i=begin; i<end; i++) {
E.fr.mul(b[i], b[i], fft->root(domainPower+1, i));
}
});
LOG_TRACE("b After shift:");
LOG_DEBUG(E.fr.toString(b[0]).c_str());
LOG_DEBUG(E.fr.toString(b[1]).c_str());
Expand All @@ -201,10 +177,11 @@ std::unique_ptr<Proof<Engine>> Prover<Engine>::prove(typename Engine::FrElement
LOG_DEBUG(E.fr.toString(c[0]).c_str());
LOG_DEBUG(E.fr.toString(c[1]).c_str());
LOG_TRACE("Start Shift C");
#pragma omp parallel for
for (u_int64_t i=0; i<domainSize; i++) {
E.fr.mul(c[i], c[i], fft->root(domainPower+1, i));
}
threadPool.parallelFor(0, domainSize, [&] (int begin, int end, int numThread) {
for (u_int64_t i=begin; i<end; i++) {
E.fr.mul(c[i], c[i], fft->root(domainPower+1, i));
}
});
LOG_TRACE("c After shift:");
LOG_DEBUG(E.fr.toString(c[0]).c_str());
LOG_DEBUG(E.fr.toString(c[1]).c_str());
Expand All @@ -215,12 +192,13 @@ std::unique_ptr<Proof<Engine>> Prover<Engine>::prove(typename Engine::FrElement
LOG_DEBUG(E.fr.toString(c[1]).c_str());

LOG_TRACE("Start ABC");
#pragma omp parallel for
for (u_int64_t i=0; i<domainSize; i++) {
E.fr.mul(a[i], a[i], b[i]);
E.fr.sub(a[i], a[i], c[i]);
E.fr.fromMontgomery(a[i], a[i]);
}
threadPool.parallelFor(0, domainSize, [&] (int begin, int end, int numThread) {
for (u_int64_t i=begin; i<end; i++) {
E.fr.mul(a[i], a[i], b[i]);
E.fr.sub(a[i], a[i], c[i]);
E.fr.fromMontgomery(a[i], a[i]);
}
});
LOG_TRACE("abc:");
LOG_DEBUG(E.fr.toString(a[0]).c_str());
LOG_DEBUG(E.fr.toString(a[1]).c_str());
Expand All @@ -230,7 +208,7 @@ std::unique_ptr<Proof<Engine>> Prover<Engine>::prove(typename Engine::FrElement

LOG_TRACE("Start Multiexp H");
typename Engine::G1Point pih;
E.g1.multiMulByScalar(pih, pointsH, (uint8_t *)a, sizeof(a[0]), domainSize);
E.g1.multiMulByScalarMSM(pih, pointsH, (uint8_t *)a, sizeof(a[0]), domainSize);
std::ostringstream ss1;
ss1 << "pih: " << E.g1.toString(pih);
LOG_DEBUG(ss1);
Expand All @@ -247,13 +225,6 @@ std::unique_ptr<Proof<Engine>> Prover<Engine>::prove(typename Engine::FrElement
randombytes_buf((void *)&(r.v[0]), sizeof(r)-1);
randombytes_buf((void *)&(s.v[0]), sizeof(s)-1);

#ifndef USE_OPENMP
pA_future.get();
pB1_future.get();
pB2_future.get();
pC_future.get();
#endif

typename Engine::G1Point p1;
typename Engine::G2Point p2;

Expand Down

0 comments on commit 74dd391

Please sign in to comment.