Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for wtns calc and prover options #517

Merged
merged 12 commits into from
Oct 18, 2024
272 changes: 185 additions & 87 deletions browser_tests/package-lock.json

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion browser_tests/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"test": "node test/launch-groth16.js"
},
"devDependencies": {
"puppeteer": "22.15.0",
"puppeteer": "^23.5.3",
"ffjavascript": "^0.3.0",
"st": "3.0.0"
}
Expand Down
7 changes: 5 additions & 2 deletions browser_tests/test/launch-groth16.js
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,15 @@ const server = http
.listen(1337);

const browser = await puppeteer.launch({
headless: "new",
headless: true,
args: [
// Necessary to have WebCrypto on localhost
"--allow-insecure-localhost",
// Necessary to download the PTAU file from AWS within the tests
"--disable-web-security"
"--disable-web-security",
// Disable the sandbox to run in GHA
"--no-sandbox",

],
});
const page = await browser.newPage();
Expand Down
94 changes: 51 additions & 43 deletions build/browser.esm.js
Original file line number Diff line number Diff line change
Expand Up @@ -1089,37 +1089,41 @@ const bn128r$1 = Scalar.e("21888242871839275222246405745257275088548364400416034
const bls12381q = Scalar.e("1a0111ea397fe69a4b1ba7b6434bacd764774b84f38512bf6730d2a0f6b0f6241eabfffeb153ffffb9feffffffffaaab", 16);
const bn128q = Scalar.e("21888242871839275222246405745257275088696311157297823662689037894645226208583");

async function getCurveFromR(r) {
async function getCurveFromR(r, options) {
let curve;
// check that options param is defined and that options.singleThread is defined
let singleThread = options && options.singleThread;
if (Scalar.eq(r, bn128r$1)) {
curve = await buildBn128();
curve = await buildBn128(singleThread);
} else if (Scalar.eq(r, bls12381r$1)) {
curve = await buildBls12381();
curve = await buildBls12381(singleThread);
} else {
throw new Error(`Curve not supported: ${Scalar.toString(r)}`);
}
return curve;
}

async function getCurveFromQ(q) {
async function getCurveFromQ(q, options) {
let curve;
let singleThread = options && options.singleThread;
if (Scalar.eq(q, bn128q)) {
curve = await buildBn128();
curve = await buildBn128(singleThread);
} else if (Scalar.eq(q, bls12381q)) {
curve = await buildBls12381();
curve = await buildBls12381(singleThread);
} else {
throw new Error(`Curve not supported: ${Scalar.toString(q)}`);
}
return curve;
}

async function getCurveFromName(name) {
async function getCurveFromName(name, options) {
let curve;
let singleThread = options && options.singleThread;
const normName = normalizeName(name);
if (["BN128", "BN254", "ALTBN128"].indexOf(normName) >= 0) {
curve = await buildBn128();
curve = await buildBn128(singleThread);
} else if (["BLS12381"].indexOf(normName) >= 0) {
curve = await buildBls12381();
curve = await buildBls12381(singleThread);
} else {
throw new Error(`Curve not supported: ${name}`);
}
Expand Down Expand Up @@ -2489,19 +2493,19 @@ async function readG2(fd, curve, toObject) {
}


async function readHeader$1(fd, sections, toObject) {
async function readHeader$1(fd, sections, toObject, options) {
// Read Header
/////////////////////
await startReadUniqueSection(fd, sections, 1);
const protocolId = await fd.readULE32();
await endReadSection(fd);

if (protocolId === GROTH16_PROTOCOL_ID) {
return await readHeaderGroth16(fd, sections, toObject);
return await readHeaderGroth16(fd, sections, toObject, options);
} else if (protocolId === PLONK_PROTOCOL_ID) {
return await readHeaderPlonk(fd, sections, toObject);
return await readHeaderPlonk(fd, sections, toObject, options);
} else if (protocolId === FFLONK_PROTOCOL_ID) {
return await readHeaderFFlonk(fd, sections, toObject);
return await readHeaderFFlonk(fd, sections, toObject, options);
} else {
throw new Error("Protocol not supported: ");
}
Expand All @@ -2510,7 +2514,7 @@ async function readHeader$1(fd, sections, toObject) {



async function readHeaderGroth16(fd, sections, toObject) {
async function readHeaderGroth16(fd, sections, toObject, options) {
const zkey = {};

zkey.protocol = "groth16";
Expand All @@ -2525,7 +2529,7 @@ async function readHeaderGroth16(fd, sections, toObject) {
const n8r = await fd.readULE32();
zkey.n8r = n8r;
zkey.r = await readBigInt(fd, n8r);
zkey.curve = await getCurveFromQ(zkey.q);
zkey.curve = await getCurveFromQ(zkey.q, options);
zkey.nVars = await fd.readULE32();
zkey.nPublic = await fd.readULE32();
zkey.domainSize = await fd.readULE32();
Expand All @@ -2542,7 +2546,7 @@ async function readHeaderGroth16(fd, sections, toObject) {

}

async function readHeaderPlonk(fd, sections, toObject) {
async function readHeaderPlonk(fd, sections, toObject, options) {
const zkey = {};

zkey.protocol = "plonk";
Expand All @@ -2557,7 +2561,7 @@ async function readHeaderPlonk(fd, sections, toObject) {
const n8r = await fd.readULE32();
zkey.n8r = n8r;
zkey.r = await readBigInt(fd, n8r);
zkey.curve = await getCurveFromQ(zkey.q);
zkey.curve = await getCurveFromQ(zkey.q, options);
zkey.nVars = await fd.readULE32();
zkey.nPublic = await fd.readULE32();
zkey.domainSize = await fd.readULE32();
Expand All @@ -2582,7 +2586,7 @@ async function readHeaderPlonk(fd, sections, toObject) {
return zkey;
}

async function readHeaderFFlonk(fd, sections, toObject) {
async function readHeaderFFlonk(fd, sections, toObject, options) {
const zkey = {};

zkey.protocol = "fflonk";
Expand All @@ -2592,7 +2596,7 @@ async function readHeaderFFlonk(fd, sections, toObject) {
const n8q = await fd.readULE32();
zkey.n8q = n8q;
zkey.q = await readBigInt(fd, n8q);
zkey.curve = await getCurveFromQ(zkey.q);
zkey.curve = await getCurveFromQ(zkey.q, options);

const n8r = await fd.readULE32();
zkey.n8r = n8r;
Expand Down Expand Up @@ -2955,14 +2959,14 @@ async function read(fileName) {
*/
const {stringifyBigInts: stringifyBigInts$4} = utils;

async function groth16Prove(zkeyFileName, witnessFileName, logger) {
async function groth16Prove(zkeyFileName, witnessFileName, logger, options) {
const {fd: fdWtns, sections: sectionsWtns} = await readBinFile(witnessFileName, "wtns", 2);

const wtns = await readHeader(fdWtns, sectionsWtns);

const {fd: fdZKey, sections: sectionsZKey} = await readBinFile(zkeyFileName, "zkey", 2);

const zkey = await readHeader$1(fdZKey, sectionsZKey);
const zkey = await readHeader$1(fdZKey, sectionsZKey, undefined, options);

if (zkey.protocol != "groth16") {
throw new Error("zkey file is not groth16");
Expand Down Expand Up @@ -3394,11 +3398,13 @@ async function builder(code, options) {
// If we can't look up the patch version, assume the lowest
let patchVersion = 0;

let codeIsWebAssemblyInstance = false;

// If code is already prepared WebAssembly.Instance, we use it directly
if (code instanceof WebAssembly.Instance) {
instance = code;
codeIsWebAssemblyInstance = true;
} else {

let memorySize = 32767;

if (options.memorySize) {
Expand Down Expand Up @@ -3558,9 +3564,13 @@ async function builder(code, options) {
// We explicitly check for major version 2 in case there's a circom v3 in the future
if (majorVersion === 2) {
wc = new WitnessCalculatorCircom2(instance, sanityCheck);
} else {
// TODO: Maybe we want to check for the explicit version 1 before choosing this?
} else if (majorVersion === 1) {
if (codeIsWebAssemblyInstance) {
throw new Error('Loading code from WebAssembly instance is not supported for circom version 1');
}
wc = new WitnessCalculatorCircom1(memory, instance, sanityCheck);
} else {
throw new Error(`Unsupported circom version: ${majorVersion}`);
}
return wc;

Expand Down Expand Up @@ -3938,7 +3948,7 @@ async function wtnsCalculate(_input, wasmFileName, wtnsFileName, options) {
await fdWasm.close();

const wc = await builder(wasm, options);
if (wc.circom_version() == 1) {
if (wc.circom_version() === 1) {
const w = await wc.calculateBinWitness(input);

const fdWtns = await createBinFile(wtnsFileName, "wtns", 2, 2);
Expand Down Expand Up @@ -3975,14 +3985,14 @@ async function wtnsCalculate(_input, wasmFileName, wtnsFileName, options) {
*/
const {unstringifyBigInts: unstringifyBigInts$a} = utils;

async function groth16FullProve(_input, wasmFile, zkeyFileName, logger) {
async function groth16FullProve(_input, wasmFile, zkeyFileName, logger, wtnsCalcOptions, proverOptions) {
const input = unstringifyBigInts$a(_input);

const wtns= {
type: "mem"
};
await wtnsCalculate(input, wasmFile, wtns);
return await groth16Prove(zkeyFileName, wtns, logger);
await wtnsCalculate(input, wasmFile, wtns, wtnsCalcOptions);
return await groth16Prove(zkeyFileName, wtns, logger, proverOptions);
}

/*
Expand Down Expand Up @@ -7043,10 +7053,7 @@ async function wtnsDebug(_input, wasmFileName, wtnsFileName, symName, options, l
const wasm = await fdWasm.read(fdWasm.totalSize);
await fdWasm.close();


let wcOps = {
sanityCheck: true
};
const wcOps = {...options, sanityCheck: true};
let sym = await loadSymbols(symName);
if (options.set) {
if (!sym) sym = await loadSymbols(symName);
Expand Down Expand Up @@ -7074,7 +7081,7 @@ async function wtnsDebug(_input, wasmFileName, wtnsFileName, symName, options, l
wcOps.sym = sym;

const wc = await builder(wasm, wcOps);
const w = await wc.calculateWitness(input);
const w = await wc.calculateWitness(input, true);

const fdWtns = await createBinFile(wtnsFileName, "wtns", 2, 2);

Expand Down Expand Up @@ -11937,7 +11944,7 @@ class Evaluations {
*/
const {stringifyBigInts: stringifyBigInts$1} = utils;

async function plonk16Prove(zkeyFileName, witnessFileName, logger) {
async function plonk16Prove(zkeyFileName, witnessFileName, logger, options) {
const {fd: fdWtns, sections: sectionsWtns} = await readBinFile(witnessFileName, "wtns", 2);

// Read witness file
Expand All @@ -11948,7 +11955,7 @@ async function plonk16Prove(zkeyFileName, witnessFileName, logger) {
if (logger) logger.debug("> Reading zkey file");
const {fd: fdZKey, sections: zkeySections} = await readBinFile(zkeyFileName, "zkey", 2);

const zkey = await readHeader$1(fdZKey, zkeySections);
const zkey = await readHeader$1(fdZKey, zkeySections, undefined, options);
if (zkey.protocol != "plonk") {
throw new Error("zkey file is not plonk");
}
Expand Down Expand Up @@ -12801,14 +12808,14 @@ async function plonk16Prove(zkeyFileName, witnessFileName, logger) {
*/
const {unstringifyBigInts: unstringifyBigInts$5} = utils;

async function plonkFullProve(_input, wasmFile, zkeyFileName, logger) {
async function plonkFullProve(_input, wasmFile, zkeyFileName, logger, wtnsCalcOptions, proverOptions) {
const input = unstringifyBigInts$5(_input);

const wtns= {
type: "mem"
};
await wtnsCalculate(input, wasmFile, wtns);
return await plonk16Prove(zkeyFileName, wtns, logger);
await wtnsCalculate(input, wasmFile, wtns, wtnsCalcOptions);
return await plonk16Prove(zkeyFileName, wtns, logger, proverOptions);
}

/*
Expand Down Expand Up @@ -14138,7 +14145,7 @@ async function fflonkSetup(r1csFilename, ptauFilename, zkeyFilename, logger) {
const { stringifyBigInts } = utils;


async function fflonkProve(zkeyFileName, witnessFileName, logger) {
async function fflonkProve(zkeyFileName, witnessFileName, logger, options) {
if (logger) logger.info("FFLONK PROVER STARTED");

// Read witness file
Expand All @@ -14155,7 +14162,8 @@ async function fflonkProve(zkeyFileName, witnessFileName, logger) {
fd: fdZKey,
sections: zkeySections
} = await readBinFile(zkeyFileName, "zkey", 2);
const zkey = await readHeader$1(fdZKey, zkeySections);

const zkey = await readHeader$1(fdZKey, zkeySections, undefined, options);

if (zkey.protocolId !== FFLONK_PROTOCOL_ID) {
throw new Error("zkey file is not fflonk");
Expand Down Expand Up @@ -15392,16 +15400,16 @@ async function fflonkProve(zkeyFileName, witnessFileName, logger) {
*/
const {unstringifyBigInts: unstringifyBigInts$2} = utils;

async function fflonkFullProve(_input, wasmFilename, zkeyFilename, logger) {
async function fflonkFullProve(_input, wasmFilename, zkeyFilename, logger, wtnsCalcOptions, proverOptions) {
const input = unstringifyBigInts$2(_input);

const wtns= {type: "mem"};

// Compute the witness
await wtnsCalculate(input, wasmFilename, wtns);
await wtnsCalculate(input, wasmFilename, wtns, wtnsCalcOptions);

// Compute the proof
return await fflonkProve(zkeyFilename, wtns, logger);
return await fflonkProve(zkeyFilename, wtns, logger, proverOptions);
}

/*
Expand Down
Loading
Loading