From 9c54cc415056c9d5f65068fdebfaa64ab73b0294 Mon Sep 17 00:00:00 2001 From: 0xmad <0xmad@users.noreply.github.com> Date: Thu, 25 Jul 2024 11:40:20 -0500 Subject: [PATCH] chore: optimize message processor and tally - [x] Remove sha256 hashing - [x] Support multiple params for verifier - [x] Move circuit input generation to contract utils --- circuits/circom/circuits.json | 32 +++- .../circom/core/non-qv/processMessages.circom | 63 ++----- circuits/circom/core/non-qv/tallyVotes.circom | 80 +++----- .../circom/core/qv/processMessages.circom | 72 +++---- circuits/circom/core/qv/tallyVotes.circom | 89 +++------ .../test/ProcessMessages_10-2-1-2_test.circom | 2 +- .../circom/test/TallyVotes_10-1-2_test.circom | 2 +- circuits/circom/utils/hashers.circom | 40 ---- .../utils/processMessagesInputHasher.circom | 61 ------ .../circom/utils/tallyVotesInputHasher.circom | 36 ---- circuits/ts/__tests__/CeremonyParams.test.ts | 48 +---- circuits/ts/__tests__/Hasher.test.ts | 164 +--------------- circuits/ts/__tests__/ProcessMessages.test.ts | 43 +---- circuits/ts/__tests__/TallyVotes.test.ts | 5 +- circuits/ts/types.ts | 11 +- cli/ts/commands/proveOnChain.ts | 80 +++----- contracts/contracts/MessageProcessor.sol | 176 +++++------------- contracts/contracts/Tally.sol | 83 +++------ contracts/contracts/crypto/MockVerifier.sol | 2 +- contracts/contracts/crypto/Verifier.sol | 26 ++- contracts/contracts/interfaces/IVerifier.sol | 4 +- contracts/contracts/trees/LazyIMT.sol | 4 +- contracts/tasks/helpers/Prover.ts | 83 +++------ contracts/tests/MessageProcessor.test.ts | 51 ++--- contracts/tests/Tally.test.ts | 153 ++++++++++++--- contracts/tests/TallyNonQv.test.ts | 80 +++++--- contracts/tests/Verifier.test.ts | 12 +- contracts/ts/circuitInputs.ts | 81 ++++++++ contracts/ts/index.ts | 12 +- contracts/ts/types.ts | 149 +++++++++++++++ core/ts/Poll.ts | 45 ++--- core/ts/__tests__/e2e.test.ts | 15 -- core/ts/__tests__/utils.test.ts | 37 +--- core/ts/index.ts | 9 +- core/ts/utils/types.ts | 11 +- core/ts/utils/utils.ts | 87 --------- 36 files changed, 764 insertions(+), 1184 deletions(-) delete mode 100644 circuits/circom/utils/processMessagesInputHasher.circom delete mode 100644 circuits/circom/utils/tallyVotesInputHasher.circom create mode 100644 contracts/ts/circuitInputs.ts diff --git a/circuits/circom/circuits.json b/circuits/circom/circuits.json index b6dbd207b..2201950ba 100644 --- a/circuits/circom/circuits.json +++ b/circuits/circom/circuits.json @@ -3,24 +3,48 @@ "file": "./core/qv/processMessages", "template": "ProcessMessages", "params": [10, 2, 1, 2], - "pubs": ["inputHash"] + "pubs": [ + "maxVoteOptions", + "numSignUps", + "index", + "messageBatchSize", + "batchEndIndex", + "coordPubKeyHash", + "msgRoot", + "currentSbCommitment", + "newSbCommitment", + "pollEndTimestamp", + "actualStateTreeDepth" + ] }, "ProcessMessagesNonQv_10-2-1-2_test": { "file": "./core/non-qv/processMessages", "template": "ProcessMessagesNonQv", "params": [10, 2, 1, 2], - "pubs": ["inputHash"] + "pubs": [ + "maxVoteOptions", + "numSignUps", + "index", + "messageBatchSize", + "batchEndIndex", + "coordPubKeyHash", + "msgRoot", + "currentSbCommitment", + "newSbCommitment", + "pollEndTimestamp", + "actualStateTreeDepth" + ] }, "TallyVotes_10-1-2_test": { "file": "./core/qv/tallyVotes", "template": "TallyVotes", "params": [10, 1, 2], - "pubs": ["inputHash"] + "pubs": ["index", "batchSize", "numSignUps", "sbCommitment", "currentTallyCommitment", "newTallyCommitment"] }, "TallyVotesNonQv_10-1-2_test": { "file": "./core/non-qv/tallyVotes", "template": "TallyVotesNonQv", "params": [10, 1, 2], - "pubs": ["inputHash"] + "pubs": ["index", "batchSize", "numSignUps", "sbCommitment", "currentTallyCommitment", "newTallyCommitment"] } } diff --git a/circuits/circom/core/non-qv/processMessages.circom b/circuits/circom/core/non-qv/processMessages.circom index 22a5fa643..d05a77c60 100644 --- a/circuits/circom/core/non-qv/processMessages.circom +++ b/circuits/circom/core/non-qv/processMessages.circom @@ -8,7 +8,6 @@ include "./safe-comparators.circom"; include "../../utils/hashers.circom"; include "../../utils/messageToCommand.circom"; include "../../utils/privToPubKey.circom"; -include "../../utils/processMessagesInputHasher.circom"; include "../../utils/non-qv/stateLeafAndBallotTransformer.circom"; include "../../trees/incrementalMerkleTree.circom"; include "../../trees/incrementalQuinaryTree.circom"; @@ -48,17 +47,12 @@ include "../../trees/incrementalQuinaryTree.circom"; var STATE_LEAF_TIMESTAMP_IDX = 3; var msgTreeZeroValue = 8370432830353022751713833565135785980866757267633941821328460903436894336785; - // nb. The usage of SHA-256 hash is necessary to save some gas costs at verification time - // at the cost of more constraints for the prover. - // Basically, some values from the contract are passed as private inputs and the hash as a public input. - - // The SHA-256 hash of values provided by the contract. - signal input inputHash; - signal input packedVals; // Number of users that have completed the sign up. - signal numSignUps; + signal input numSignUps; + // Mesage batch size + signal input messageBatchSize; // Number of options for this poll. - signal maxVoteOptions; + signal input maxVoteOptions; // Time when the poll ends. signal input pollEndTimestamp; // The existing message tree root. @@ -79,6 +73,12 @@ include "../../trees/incrementalQuinaryTree.circom"; // @note it is a public input to ensure fair processing from // the coordinator (no censoring) signal input actualStateTreeDepth; + // The last batch index + signal input batchEndIndex; + // The batch index of current message batch + signal input index; + // Coordinator public key hash + signal input coordPubKeyHash; // The state leaves upon which messages are applied. // transform(currentStateLeaf[4], message5) => newStateLeaf4 @@ -110,14 +110,6 @@ include "../../trees/incrementalQuinaryTree.circom"; // Therefore, the index of the first message to process does not match the index of the // first message (e.g., [msg1, msg2, msg3] => first message to process has index 3). - // The index of the first message leaf in the batch, inclusive. - signal batchStartIndex; - - // The index of the last message leaf in the batch to process, exclusive. - // This value may be less than batchStartIndex + batchSize if this batch is - // the last batch and the total number of messages is not a multiple of the batch size. - signal batchEndIndex; - // The history of state and ballot roots and temporary intermediate // signals (for processing purposes). signal stateRoots[batchSize + 1]; @@ -131,31 +123,6 @@ include "../../trees/incrementalQuinaryTree.circom"; var computedCurrentSbCommitment = PoseidonHasher(3)([currentStateRoot, currentBallotRoot, currentSbSalt]); computedCurrentSbCommitment === currentSbCommitment; - // Verify public inputs and assign unpacked values. - var ( - computedMaxVoteOptions, - computedNumSignUps, - computedBatchStartIndex, - computedBatchEndIndex, - computedHash - ) = ProcessMessagesInputHasher()( - packedVals, - coordPubKey, - msgRoot, - currentSbCommitment, - newSbCommitment, - pollEndTimestamp, - actualStateTreeDepth - ); - - // The unpacked values from packedVals. - computedMaxVoteOptions ==> maxVoteOptions; - computedNumSignUps ==> numSignUps; - computedBatchStartIndex ==> batchStartIndex; - computedBatchEndIndex ==> batchEndIndex; - // Matching constraints. - computedHash === inputHash; - // ----------------------------------------------------------------------- // 0. Ensure that the maximum vote options signal is valid and if // the maximum users signal is valid. @@ -167,13 +134,15 @@ include "../../trees/incrementalQuinaryTree.circom"; var numSignUpsValid = LessEqThan(32)([numSignUps, STATE_TREE_ARITY ** stateTreeDepth]); numSignUpsValid === 1; + messageBatchSize === batchSize; + // Hash each Message to check their existence in the Message tree. var computedMessageHashers[batchSize]; for (var i = 0; i < batchSize; i++) { computedMessageHashers[i] = MessageHasher()(msgs[i], encPubKeys[i]); } - // If batchEndIndex - batchStartIndex < batchSize, the remaining + // If endIndex - startIndex < batchSize, the remaining // message hashes should be the zero value. // e.g. [m, z, z, z, z] if there is only 1 real message in the batch // This makes possible to have a batch of messages which is only partially full. @@ -182,7 +151,7 @@ include "../../trees/incrementalQuinaryTree.circom"; var computedPathIndex[msgTreeDepth - msgBatchDepth]; for (var i = 0; i < batchSize; i++) { - var batchStartIndexValid = SafeLessThan(32)([batchStartIndex + i, batchEndIndex]); + var batchStartIndexValid = SafeLessThan(32)([index + i, batchEndIndex]); computedLeaves[i] = Mux1()([msgTreeZeroValue, computedMessageHashers[i]], batchStartIndexValid); } @@ -195,8 +164,8 @@ include "../../trees/incrementalQuinaryTree.circom"; // Computing the path_index values. Since msgBatchLeavesExists tests // the existence of a subroot, the length of the proof correspond to the last // n elements of a proof from the root to a leaf, where n = msgTreeDepth - msgBatchDepth. - // e.g. if batchStartIndex = 25, msgTreeDepth = 4, msgBatchDepth = 2, then path_index = [1, 0]. - var computedMsgBatchPathIndices[msgTreeDepth] = QuinGeneratePathIndices(msgTreeDepth)(batchStartIndex); + // e.g. if startIndex = 25, msgTreeDepth = 4, msgBatchDepth = 2, then path_index = [1, 0]. + var computedMsgBatchPathIndices[msgTreeDepth] = QuinGeneratePathIndices(msgTreeDepth)(index); for (var i = msgBatchDepth; i < msgTreeDepth; i++) { computedPathIndex[i - msgBatchDepth] = computedMsgBatchPathIndices[i]; diff --git a/circuits/circom/core/non-qv/tallyVotes.circom b/circuits/circom/core/non-qv/tallyVotes.circom index 2feb0232f..812596e8a 100644 --- a/circuits/circom/core/non-qv/tallyVotes.circom +++ b/circuits/circom/core/non-qv/tallyVotes.circom @@ -9,7 +9,6 @@ include "../../trees/incrementalMerkleTree.circom"; include "../../trees/incrementalQuinaryTree.circom"; include "../../utils/calculateTotal.circom"; include "../../utils/hashers.circom"; -include "../../utils/tallyVotesInputHasher.circom"; /** * Processes batches of votes and verifies their validity in a Merkle tree structure. @@ -32,7 +31,7 @@ template TallyVotesNonQv( var BALLOT_TREE_ARITY = 2; // The number of ballots processed at once, determined by the depth of the intermediate state tree. - var batchSize = BALLOT_TREE_ARITY ** intStateTreeDepth; + var ballotsBatchSize = BALLOT_TREE_ARITY ** intStateTreeDepth; // Number of voting options available, determined by the depth of the vote option tree. var numVoteOptions = TREE_ARITY ** voteOptionTreeDepth; @@ -51,81 +50,54 @@ template TallyVotesNonQv( signal input ballotRoot; // Salt used in commitment to secure the ballot data. signal input sbSalt; - - // Inputs combined into a hash to verify the integrity and authenticity of the data. - signal input packedVals; // Commitment to the state and ballots. signal input sbCommitment; // Commitment to the current tally before this batch. signal input currentTallyCommitment; // Commitment to the new tally after processing this batch. signal input newTallyCommitment; - - // A tally commitment is the hash of the following salted values: - // - the vote results, - // - the number of voice credits spent per vote option, - // - the total number of spent voice credits. - - // Hash of all inputs to ensure they are unchanged and authentic. - signal input inputHash; - + // Start index of given batch + signal input index; + // Size of batch + signal input batchSize; + // Number of users that signup + signal input numSignUps; // Ballots and their corresponding path elements for verification in the tree. - signal input ballots[batchSize][BALLOT_LENGTH]; + signal input ballots[ballotsBatchSize][BALLOT_LENGTH]; signal input ballotPathElements[k][BALLOT_TREE_ARITY - 1]; - signal input votes[batchSize][numVoteOptions]; - + signal input votes[ballotsBatchSize][numVoteOptions]; // Current results for each vote option. signal input currentResults[numVoteOptions]; // Salt for the root of the current results. signal input currentResultsRootSalt; - // Total voice credits spent so far. signal input currentSpentVoiceCreditSubtotal; // Salt for the total spent voice credits. signal input currentSpentVoiceCreditSubtotalSalt; - // Salt for the root of the new results. signal input newResultsRootSalt; // Salt for the new total spent voice credits root. signal input newSpentVoiceCreditSubtotalSalt; - // The number of total registrations, used to validate the batch index. - signal numSignUps; - // Index of the first ballot in this batch. - signal batchStartIndex; // Verify sbCommitment. var computedSbCommitment = PoseidonHasher(3)([stateRoot, ballotRoot, sbSalt]); computedSbCommitment === sbCommitment; - // Verify inputHash. - var ( - computedNumSignUps, - computedBatchNum, - computedHash - ) = TallyVotesInputHasher()( - sbCommitment, - currentTallyCommitment, - newTallyCommitment, - packedVals - ); - - inputHash === computedHash; - numSignUps <== computedNumSignUps; - batchStartIndex <== computedBatchNum * batchSize; + ballotsBatchSize === batchSize; - // Validates that the batchStartIndex is within the valid range of sign-ups. - var numSignUpsValid = LessEqThan(50)([batchStartIndex, numSignUps]); + // Validates that the index is within the valid range of sign-ups. + var numSignUpsValid = LessEqThan(50)([index * ballotsBatchSize, numSignUps]); numSignUpsValid === 1; // Hashes each ballot for subroot generation, and checks the existence of the leaf in the Merkle tree. - var computedBallotHashers[batchSize]; + var computedBallotHashers[ballotsBatchSize]; - for (var i = 0; i < batchSize; i++) { + for (var i = 0; i < ballotsBatchSize; i++) { computedBallotHashers[i] = PoseidonHasher(2)([ballots[i][BALLOT_NONCE_IDX], ballots[i][BALLOT_VO_ROOT_IDX]]); } var computedBallotSubroot = CheckRoot(intStateTreeDepth)(computedBallotHashers); - var computedBallotPathIndices[k] = MerkleGeneratePathIndices(k)(computedBatchNum); + var computedBallotPathIndices[k] = MerkleGeneratePathIndices(k)(index); // Verifies each ballot's existence within the ballot tree. LeafExists(k)( @@ -136,38 +108,38 @@ template TallyVotesNonQv( ); // Processes vote options, verifying each against its declared root. - var computedVoteTree[batchSize]; - for (var i = 0; i < batchSize; i++) { + var computedVoteTree[ballotsBatchSize]; + for (var i = 0; i < ballotsBatchSize; i++) { computedVoteTree[i] = QuinCheckRoot(voteOptionTreeDepth)(votes[i]); computedVoteTree[i] === ballots[i][BALLOT_VO_ROOT_IDX]; } // Calculates new results and spent voice credits based on the current and incoming votes. - var computedIsFirstBatch = IsZero()(batchStartIndex); + var computedIsFirstBatch = IsZero()(index * ballotsBatchSize); var computedIsZero = IsZero()(computedIsFirstBatch); // Tally the new results. var computedCalculateTotalResult[numVoteOptions]; for (var i = 0; i < numVoteOptions; i++) { - var computedNumsRC[batchSize + 1]; - computedNumsRC[batchSize] = currentResults[i] * computedIsZero; - for (var j = 0; j < batchSize; j++) { + var computedNumsRC[ballotsBatchSize + 1]; + computedNumsRC[ballotsBatchSize] = currentResults[i] * computedIsZero; + for (var j = 0; j < ballotsBatchSize; j++) { computedNumsRC[j] = votes[j][i]; } - computedCalculateTotalResult[i] = CalculateTotal(batchSize + 1)(computedNumsRC); + computedCalculateTotalResult[i] = CalculateTotal(ballotsBatchSize + 1)(computedNumsRC); } // Tally the new spent voice credit total. - var computedNumsSVC[batchSize * numVoteOptions + 1]; - computedNumsSVC[batchSize * numVoteOptions] = currentSpentVoiceCreditSubtotal * computedIsZero; - for (var i = 0; i < batchSize; i++) { + var computedNumsSVC[ballotsBatchSize * numVoteOptions + 1]; + computedNumsSVC[ballotsBatchSize * numVoteOptions] = currentSpentVoiceCreditSubtotal * computedIsZero; + for (var i = 0; i < ballotsBatchSize; i++) { for (var j = 0; j < numVoteOptions; j++) { computedNumsSVC[i * numVoteOptions + j] = votes[i][j]; } } - var computedNewSpentVoiceCreditSubtotal = CalculateTotal(batchSize * numVoteOptions + 1)(computedNumsSVC); + var computedNewSpentVoiceCreditSubtotal = CalculateTotal(ballotsBatchSize * numVoteOptions + 1)(computedNumsSVC); // Verifies the updated results and spent credits, ensuring consistency and correctness of tally updates. ResultCommitmentVerifierNonQv(voteOptionTreeDepth)( diff --git a/circuits/circom/core/qv/processMessages.circom b/circuits/circom/core/qv/processMessages.circom index 27df92cfb..8b7020558 100644 --- a/circuits/circom/core/qv/processMessages.circom +++ b/circuits/circom/core/qv/processMessages.circom @@ -8,7 +8,6 @@ include "./safe-comparators.circom"; include "../../utils/hashers.circom"; include "../../utils/messageToCommand.circom"; include "../../utils/privToPubKey.circom"; -include "../../utils/processMessagesInputHasher.circom"; include "../../utils/qv/stateLeafAndBallotTransformer.circom"; include "../../trees/incrementalQuinaryTree.circom"; include "../../trees/incrementalMerkleTree.circom"; @@ -47,18 +46,13 @@ template ProcessMessages( var STATE_LEAF_VOICE_CREDIT_BALANCE_IDX = 2; var STATE_LEAF_TIMESTAMP_IDX = 3; var msgTreeZeroValue = 8370432830353022751713833565135785980866757267633941821328460903436894336785; - - // nb. The usage of SHA-256 hash is necessary to save some gas costs at verification time - // at the cost of more constraints for the prover. - // Basically, some values from the contract are passed as private inputs and the hash as a public input. - - // The SHA-256 hash of values provided by the contract. - signal input inputHash; - signal input packedVals; - // Number of users that have completed the sign up. - signal numSignUps; + + // Number of users that signup + signal input numSignUps; + // Message batch size + signal input messageBatchSize; // Number of options for this poll. - signal maxVoteOptions; + signal input maxVoteOptions; // Time when the poll ends. signal input pollEndTimestamp; // The existing message tree root. @@ -79,6 +73,12 @@ template ProcessMessages( // @note it is a public input to ensure fair processing from // the coordinator (no censoring) signal input actualStateTreeDepth; + // The last batch index + signal input batchEndIndex; + // The batch index of current message batch + signal input index; + // Coordinator public key hash + signal input coordPubKeyHash; // The state leaves upon which messages are applied. // transform(currentStateLeaf[4], message5) => newStateLeaf4 @@ -110,14 +110,6 @@ template ProcessMessages( // Therefore, the index of the first message to process does not match the index of the // first message (e.g., [msg1, msg2, msg3] => first message to process has index 3). - // The index of the first message leaf in the batch, inclusive. - signal batchStartIndex; - - // The index of the last message leaf in the batch to process, exclusive. - // This value may be less than batchStartIndex + batchSize if this batch is - // the last batch and the total number of messages is not a multiple of the batch size. - signal batchEndIndex; - // The history of state and ballot roots and temporary intermediate // signals (for processing purposes). signal stateRoots[batchSize + 1]; @@ -127,31 +119,6 @@ template ProcessMessages( var computedCurrentSbCommitment = PoseidonHasher(3)([currentStateRoot, currentBallotRoot, currentSbSalt]); computedCurrentSbCommitment === currentSbCommitment; - // Verify public inputs and assign unpacked values. - var ( - computedMaxVoteOptions, - computedNumSignUps, - computedBatchStartIndex, - computedBatchEndIndex, - computedHash - ) = ProcessMessagesInputHasher()( - packedVals, - coordPubKey, - msgRoot, - currentSbCommitment, - newSbCommitment, - pollEndTimestamp, - actualStateTreeDepth - ); - - // The unpacked values from packedVals. - computedMaxVoteOptions ==> maxVoteOptions; - computedNumSignUps ==> numSignUps; - computedBatchStartIndex ==> batchStartIndex; - computedBatchEndIndex ==> batchEndIndex; - // Matching constraints. - computedHash === inputHash; - // 0. Ensure that the maximum vote options signal is valid and if // the maximum users signal is valid. var maxVoValid = LessEqThan(32)([maxVoteOptions, MESSAGE_TREE_ARITY ** voteOptionTreeDepth]); @@ -162,13 +129,15 @@ template ProcessMessages( var numSignUpsValid = LessEqThan(32)([numSignUps, STATE_TREE_ARITY ** stateTreeDepth]); numSignUpsValid === 1; + messageBatchSize === batchSize; + // Hash each Message to check their existence in the Message tree. var computedMessageHashers[batchSize]; for (var i = 0; i < batchSize; i++) { computedMessageHashers[i] = MessageHasher()(msgs[i], encPubKeys[i]); } - // If batchEndIndex - batchStartIndex < batchSize, the remaining + // If endIndex - startIndex < batchSize, the remaining // message hashes should be the zero value. // e.g. [m, z, z, z, z] if there is only 1 real message in the batch // This makes possible to have a batch of messages which is only partially full. @@ -177,7 +146,7 @@ template ProcessMessages( var computedPathIndex[msgTreeDepth - msgBatchDepth]; for (var i = 0; i < batchSize; i++) { - var batchStartIndexValid = SafeLessThan(32)([batchStartIndex + i, batchEndIndex]); + var batchStartIndexValid = SafeLessThan(32)([index + i, batchEndIndex]); computedLeaves[i] = Mux1()([msgTreeZeroValue, computedMessageHashers[i]], batchStartIndexValid); } @@ -190,8 +159,8 @@ template ProcessMessages( // Computing the path_index values. Since msgBatchLeavesExists tests // the existence of a subroot, the length of the proof correspond to the last // n elements of a proof from the root to a leaf, where n = msgTreeDepth - msgBatchDepth. - // e.g. if batchStartIndex = 25, msgTreeDepth = 4, msgBatchDepth = 2, then path_index = [1, 0]. - var computedMsgBatchPathIndices[msgTreeDepth] = QuinGeneratePathIndices(msgTreeDepth)(batchStartIndex); + // e.g. if startIndex = 25, msgTreeDepth = 4, msgBatchDepth = 2, then path_index = [1, 0]. + var computedMsgBatchPathIndices[msgTreeDepth] = QuinGeneratePathIndices(msgTreeDepth)(index); for (var i = msgBatchDepth; i < msgTreeDepth; i++) { computedPathIndex[i - msgBatchDepth] = computedMsgBatchPathIndices[i]; @@ -341,11 +310,12 @@ template ProcessOne(stateTreeDepth, voteOptionTreeDepth) { // Timestamp. var STATE_LEAF_TIMESTAMP_IDX = 3; - // Inputs representing the message and the current state. + // Number of users that have completed the sign up. signal input numSignUps; + // Number of options for this poll. signal input maxVoteOptions; + // Time when the poll ends. signal input pollEndTimestamp; - // The current value of the state tree root. signal input currentStateRoot; // The current value of the ballot tree root. diff --git a/circuits/circom/core/qv/tallyVotes.circom b/circuits/circom/core/qv/tallyVotes.circom index 24d663c50..462591443 100644 --- a/circuits/circom/core/qv/tallyVotes.circom +++ b/circuits/circom/core/qv/tallyVotes.circom @@ -9,7 +9,6 @@ include "../../trees/incrementalMerkleTree.circom"; include "../../trees/incrementalQuinaryTree.circom"; include "../../utils/calculateTotal.circom"; include "../../utils/hashers.circom"; -include "../../utils/tallyVotesInputHasher.circom"; /** * Processes batches of votes and verifies their validity in a Merkle tree structure. @@ -32,7 +31,7 @@ template TallyVotes( var BALLOT_TREE_ARITY = 2; // The number of ballots processed at once, determined by the depth of the intermediate state tree. - var batchSize = BALLOT_TREE_ARITY ** intStateTreeDepth; + var ballotsBatchSize = BALLOT_TREE_ARITY ** intStateTreeDepth; // Number of voting options available, determined by the depth of the vote option tree. var numVoteOptions = TREE_ARITY ** voteOptionTreeDepth; @@ -51,88 +50,58 @@ template TallyVotes( signal input ballotRoot; // Salt used in commitment to secure the ballot data. signal input sbSalt; - - // Inputs combined into a hash to verify the integrity and authenticity of the data. - signal input packedVals; // Commitment to the state and ballots. signal input sbCommitment; // Commitment to the current tally before this batch. signal input currentTallyCommitment; // Commitment to the new tally after processing this batch. signal input newTallyCommitment; - - // A tally commitment is the hash of the following salted values: - // - the vote results, - // - the number of voice credits spent per vote option, - // - the total number of spent voice credits. - - // Hash of all inputs to ensure they are unchanged and authentic. - signal input inputHash; - + // Start index of given batch + signal input index; + // Size of batch + signal input batchSize; + // Number of users that signup + signal input numSignUps; // Ballots and their corresponding path elements for verification in the tree. - signal input ballots[batchSize][BALLOT_LENGTH]; + signal input ballots[ballotsBatchSize][BALLOT_LENGTH]; signal input ballotPathElements[k][BALLOT_TREE_ARITY - 1]; - signal input votes[batchSize][numVoteOptions]; - + signal input votes[ballotsBatchSize][numVoteOptions]; // Current results for each vote option. signal input currentResults[numVoteOptions]; // Salt for the root of the current results. signal input currentResultsRootSalt; - // Total voice credits spent so far. signal input currentSpentVoiceCreditSubtotal; // Salt for the total spent voice credits. signal input currentSpentVoiceCreditSubtotalSalt; - // Spent voice credits per vote option. signal input currentPerVOSpentVoiceCredits[numVoteOptions]; // Salt for the root of spent credits per option. signal input currentPerVOSpentVoiceCreditsRootSalt; - // Salt for the root of the new results. signal input newResultsRootSalt; // Salt for the new spent credits per vote option root. signal input newPerVOSpentVoiceCreditsRootSalt; // Salt for the new total spent voice credits root. signal input newSpentVoiceCreditSubtotalSalt; - // The number of total registrations, used to validate the batch index. - signal numSignUps; - // Index of the first ballot in this batch. - signal batchStartIndex; // Verify sbCommitment. var computedSbCommitment = PoseidonHasher(3)([stateRoot, ballotRoot, sbSalt]); computedSbCommitment === sbCommitment; - // Verify inputHash. - var ( - computedNumSignUps, - computedBatchNum, - computedHash - ) = TallyVotesInputHasher()( - sbCommitment, - currentTallyCommitment, - newTallyCommitment, - packedVals - ); - - inputHash === computedHash; - numSignUps <== computedNumSignUps; - batchStartIndex <== computedBatchNum * batchSize; - - // Validates that the batchStartIndex is within the valid range of sign-ups. - var numSignUpsValid = LessEqThan(50)([batchStartIndex, numSignUps]); + // Validates that the index is within the valid range of sign-ups. + var numSignUpsValid = LessEqThan(50)([index * ballotsBatchSize, numSignUps]); numSignUpsValid === 1; // Hashes each ballot for subroot generation, and checks the existence of the leaf in the Merkle tree. - var computedBallotHashers[batchSize]; + var computedBallotHashers[ballotsBatchSize]; - for (var i = 0; i < batchSize; i++) { + for (var i = 0; i < ballotsBatchSize; i++) { computedBallotHashers[i] = PoseidonHasher(2)([ballots[i][BALLOT_NONCE_IDX], ballots[i][BALLOT_VO_ROOT_IDX]]); } var computedBallotSubroot = CheckRoot(intStateTreeDepth)(computedBallotHashers); - var computedBallotPathIndices[k] = MerkleGeneratePathIndices(k)(computedBatchNum); + var computedBallotPathIndices[k] = MerkleGeneratePathIndices(k)(index); // Verifies each ballot's existence within the ballot tree. LeafExists(k)( @@ -143,50 +112,50 @@ template TallyVotes( ); // Processes vote options, verifying each against its declared root. - var computedVoteTree[batchSize]; - for (var i = 0; i < batchSize; i++) { + var computedVoteTree[ballotsBatchSize]; + for (var i = 0; i < ballotsBatchSize; i++) { computedVoteTree[i] = QuinCheckRoot(voteOptionTreeDepth)(votes[i]); computedVoteTree[i] === ballots[i][BALLOT_VO_ROOT_IDX]; } // Calculates new results and spent voice credits based on the current and incoming votes. - var computedIsFirstBatch = IsZero()(batchStartIndex); + var computedIsFirstBatch = IsZero()(index * ballotsBatchSize); var computedIsZero = IsZero()(computedIsFirstBatch); // Tally the new results. var computedCalculateTotalResult[numVoteOptions]; for (var i = 0; i < numVoteOptions; i++) { - var numsRC[batchSize + 1]; - numsRC[batchSize] = currentResults[i] * computedIsZero; - for (var j = 0; j < batchSize; j++) { + var numsRC[ballotsBatchSize + 1]; + numsRC[ballotsBatchSize] = currentResults[i] * computedIsZero; + for (var j = 0; j < ballotsBatchSize; j++) { numsRC[j] = votes[j][i]; } - computedCalculateTotalResult[i] = CalculateTotal(batchSize + 1)(numsRC); + computedCalculateTotalResult[i] = CalculateTotal(ballotsBatchSize + 1)(numsRC); } // Tally the new spent voice credit total. - var numsSVC[batchSize * numVoteOptions + 1]; - numsSVC[batchSize * numVoteOptions] = currentSpentVoiceCreditSubtotal * computedIsZero; - for (var i = 0; i < batchSize; i++) { + var numsSVC[ballotsBatchSize * numVoteOptions + 1]; + numsSVC[ballotsBatchSize * numVoteOptions] = currentSpentVoiceCreditSubtotal * computedIsZero; + for (var i = 0; i < ballotsBatchSize; i++) { for (var j = 0; j < numVoteOptions; j++) { numsSVC[i * numVoteOptions + j] = votes[i][j] * votes[i][j]; } } - var computedNewSpentVoiceCreditSubtotal = CalculateTotal(batchSize * numVoteOptions + 1)(numsSVC); + var computedNewSpentVoiceCreditSubtotal = CalculateTotal(ballotsBatchSize * numVoteOptions + 1)(numsSVC); // Tally the spent voice credits per vote option. var computedNewPerVOSpentVoiceCredits[numVoteOptions]; for (var i = 0; i < numVoteOptions; i++) { - var computedNumsSVC[batchSize + 1]; - computedNumsSVC[batchSize] = currentPerVOSpentVoiceCredits[i] * computedIsZero; - for (var j = 0; j < batchSize; j++) { + var computedNumsSVC[ballotsBatchSize + 1]; + computedNumsSVC[ballotsBatchSize] = currentPerVOSpentVoiceCredits[i] * computedIsZero; + for (var j = 0; j < ballotsBatchSize; j++) { computedNumsSVC[j] = votes[j][i] * votes[j][i]; } - computedNewPerVOSpentVoiceCredits[i] = CalculateTotal(batchSize + 1)(computedNumsSVC); + computedNewPerVOSpentVoiceCredits[i] = CalculateTotal(ballotsBatchSize + 1)(computedNumsSVC); } // Verifies the updated results and spent credits, ensuring consistency and correctness of tally updates. diff --git a/circuits/circom/test/ProcessMessages_10-2-1-2_test.circom b/circuits/circom/test/ProcessMessages_10-2-1-2_test.circom index 843111f37..9f2b53c01 100644 --- a/circuits/circom/test/ProcessMessages_10-2-1-2_test.circom +++ b/circuits/circom/test/ProcessMessages_10-2-1-2_test.circom @@ -3,4 +3,4 @@ pragma circom 2.0.0; include ".././core/qv/processMessages.circom"; -component main {public[inputHash]} = ProcessMessages(10, 2, 1, 2); +component main {public[maxVoteOptions, numSignUps, index, messageBatchSize, batchEndIndex, coordPubKeyHash, msgRoot, currentSbCommitment, newSbCommitment, pollEndTimestamp, actualStateTreeDepth]} = ProcessMessages(10, 2, 1, 2); diff --git a/circuits/circom/test/TallyVotes_10-1-2_test.circom b/circuits/circom/test/TallyVotes_10-1-2_test.circom index 3c7028c05..0e806f55d 100644 --- a/circuits/circom/test/TallyVotes_10-1-2_test.circom +++ b/circuits/circom/test/TallyVotes_10-1-2_test.circom @@ -3,4 +3,4 @@ pragma circom 2.0.0; include ".././core/qv/tallyVotes.circom"; -component main {public[inputHash]} = TallyVotes(10, 1, 2); +component main {public[index, batchSize, numSignUps, sbCommitment, currentTallyCommitment, newTallyCommitment]} = TallyVotes(10, 1, 2); diff --git a/circuits/circom/utils/hashers.circom b/circuits/circom/utils/hashers.circom index 4be653196..35b65054b 100644 --- a/circuits/circom/utils/hashers.circom +++ b/circuits/circom/utils/hashers.circom @@ -1,48 +1,8 @@ pragma circom 2.0.0; -// circomlib import -include "./sha256/sha256.circom"; -include "./bitify.circom"; // zk-kit imports include "./poseidon-cipher.circom"; -/** - * Computes the SHA-256 hash of an array of input signals. Each input is first - * converted to a 256-bit representation, then these are concatenated and passed - * to the SHA-256 hash function. The output is the 256 hash value of the inputs bits - * converted back to numbers. - */ -template Sha256Hasher(length) { - var SHA_LENGTH = 256; - var inBits = SHA_LENGTH * length; - - signal input in[length]; - signal output hash; - - // Array to store all bits of inputs for SHA-256 input. - var computedBits[inBits]; - - // Convert each input into bits and store them in the `bits` array. - for (var i = 0; i < length; i++) { - var computedBitsInput[SHA_LENGTH] = Num2Bits(SHA_LENGTH)(in[i]); - for (var j = 0; j < SHA_LENGTH; j++) { - computedBits[(i * SHA_LENGTH) + (SHA_LENGTH - 1) - j] = computedBitsInput[j]; - } - } - - // SHA-256 hash computation. - var computedSha256Bits[SHA_LENGTH] = Sha256(inBits)(computedBits); - - // Convert SHA-256 output back to number. - var computedBitsToNumInput[SHA_LENGTH]; - for (var i = 0; i < SHA_LENGTH; i++) { - computedBitsToNumInput[i] = computedSha256Bits[(SHA_LENGTH - 1) - i]; - } - var computedSha256Number = Bits2Num(256)(computedBitsToNumInput); - - hash <== computedSha256Number; -} - /** * Computes the Poseidon hash for an array of n inputs, including a default initial state * of zero not counted in n. First, extends the inputs by prepending a zero, creating an array [0, inputs]. diff --git a/circuits/circom/utils/processMessagesInputHasher.circom b/circuits/circom/utils/processMessagesInputHasher.circom deleted file mode 100644 index 33b7b97e9..000000000 --- a/circuits/circom/utils/processMessagesInputHasher.circom +++ /dev/null @@ -1,61 +0,0 @@ -pragma circom 2.0.0; - -// zk-kit imports -include "./unpack-element.circom"; -// local imports -include "hashers.circom"; - -/** - * Processes various inputs, including packed values and public keys, to produce a SHA256 hash. - * It unpacks consolidated inputs like vote options and sign-up counts, validates them, - * hashes the coordinator's public key, and finally hashes all inputs together. - */ -template ProcessMessagesInputHasher() { - // Combine the following into 1 input element: - // - maxVoteOptions (50 bits) - // - numSignUps (50 bits) - // - batchStartIndex (50 bits) - // - batchEndIndex (50 bits) - // Hash coordPubKey: - // - coordPubKeyHash - // Other inputs that can't be compressed or packed: - // - msgRoot, currentSbCommitment, newSbCommitment - var UNPACK_ELEM_LENGTH = 4; - - signal input packedVals; - signal input coordPubKey[2]; - signal input msgRoot; - // The current state and ballot root commitment (hash(stateRoot, ballotRoot, salt)). - signal input currentSbCommitment; - signal input newSbCommitment; - signal input pollEndTimestamp; - signal input actualStateTreeDepth; - - signal output maxVoteOptions; - signal output numSignUps; - signal output batchStartIndex; - signal output batchEndIndex; - signal output hash; - - // 1. Unpack packedVals and ensure that it is valid. - var computedUnpackElement[UNPACK_ELEM_LENGTH] = UnpackElement(UNPACK_ELEM_LENGTH)(packedVals); - - maxVoteOptions <== computedUnpackElement[3]; - numSignUps <== computedUnpackElement[2]; - batchStartIndex <== computedUnpackElement[1]; - batchEndIndex <== computedUnpackElement[0]; - - // 2. Hash coordPubKey. - var computedPubKey = PoseidonHasher(2)(coordPubKey); - - // 3. Hash the 7 inputs with SHA256. - hash <== Sha256Hasher(7)([ - packedVals, - computedPubKey, - msgRoot, - currentSbCommitment, - newSbCommitment, - pollEndTimestamp, - actualStateTreeDepth - ]); -} \ No newline at end of file diff --git a/circuits/circom/utils/tallyVotesInputHasher.circom b/circuits/circom/utils/tallyVotesInputHasher.circom deleted file mode 100644 index 30f22b482..000000000 --- a/circuits/circom/utils/tallyVotesInputHasher.circom +++ /dev/null @@ -1,36 +0,0 @@ -pragma circom 2.0.0; - -// zk-kit imports -include "./safe-comparators.circom"; -// local imports -include "./hashers.circom"; - -/** - * Generates a sha256 hash of the provided tally inputs. - */ -template TallyVotesInputHasher() { - // Commitment to the state and ballots. - signal input sbCommitment; - // Commitment to the current tally before this batch. - signal input currentTallyCommitment; - // Commitment to the new tally after processing this batch. - signal input newTallyCommitment; - // Packed values. - signal input packedVals; - - signal output numSignUps; - signal output batchNum; - signal output hash; - - // Unpack the elements. - var computedUnpackedElement[2] = UnpackElement(2)(packedVals); - batchNum <== computedUnpackedElement[1]; - numSignUps <== computedUnpackedElement[0]; - - hash <== Sha256Hasher(4)([ - packedVals, - sbCommitment, - currentTallyCommitment, - newTallyCommitment - ]); -} \ No newline at end of file diff --git a/circuits/ts/__tests__/CeremonyParams.test.ts b/circuits/ts/__tests__/CeremonyParams.test.ts index ebc48b08f..d9cd1f8e3 100644 --- a/circuits/ts/__tests__/CeremonyParams.test.ts +++ b/circuits/ts/__tests__/CeremonyParams.test.ts @@ -1,12 +1,12 @@ import { expect } from "chai"; import { type WitnessTester } from "circomkit"; -import { MaciState, Poll, packProcessMessageSmallVals, STATE_TREE_ARITY, MESSAGE_TREE_ARITY } from "maci-core"; +import { MaciState, Poll, STATE_TREE_ARITY, MESSAGE_TREE_ARITY } from "maci-core"; import { hash5, IncrementalQuinTree } from "maci-crypto"; import { PrivKey, Keypair, PCommand, Message, Ballot } from "maci-domainobjs"; import { IProcessMessagesInputs, ITallyVotesInputs } from "../types"; -import { generateRandomIndex, getSignal, circomkitInstance } from "./utils/utils"; +import { generateRandomIndex, circomkitInstance } from "./utils/utils"; describe("Ceremony param tests", () => { const params = { @@ -47,8 +47,10 @@ describe("Ceremony param tests", () => { let circuit: WitnessTester< [ - "inputHash", - "packedVals", + "numSignUps", + "batchEndIndex", + "index", + "maxVoteOptions", "pollEndTimestamp", "msgRoot", "msgs", @@ -71,22 +73,12 @@ describe("Ceremony param tests", () => { ] >; - let hasherCircuit: WitnessTester< - ["packedVals", "coordPubKey", "msgRoot", "currentSbCommitment", "newSbCommitment", "pollEndTimestamp"], - ["maxVoteOptions", "numSignUps", "batchStartIndex", "batchEndIndex", "hash"] - >; - before(async () => { circuit = await circomkitInstance.WitnessTester("processMessages", { file: "./core/qv/processMessages", template: "ProcessMessages", params: [6, 9, 2, 3], }); - - hasherCircuit = await circomkitInstance.WitnessTester("processMessageInputHasher", { - file: "./utils/processMessagesInputHasher", - template: "ProcessMessagesInputHasher", - }); }); describe("1 user, 2 messages", () => { @@ -185,29 +177,6 @@ describe("Ceremony param tests", () => { expect(newStateRoot?.toString()).not.to.be.eq(currentStateRoot?.toString()); expect(newBallotRoot?.toString()).not.to.be.eq(currentBallotRoot.toString()); - - const packedVals = packProcessMessageSmallVals( - BigInt(maxValues.maxVoteOptions), - BigInt(poll.maciStateRef.numSignUps), - 0, - 2, - ); - - // Test the ProcessMessagesInputHasher circuit - const hasherCircuitInputs = { - packedVals, - coordPubKey: inputs.coordPubKey, - msgRoot: inputs.msgRoot, - currentSbCommitment: inputs.currentSbCommitment, - newSbCommitment: inputs.newSbCommitment, - pollEndTimestamp: inputs.pollEndTimestamp, - actualStateTreeDepth: inputs.actualStateTreeDepth, - }; - - const hasherWitness = await hasherCircuit.calculateWitness(hasherCircuitInputs); - await hasherCircuit.expectConstraintPass(hasherWitness); - const hash = await getSignal(hasherCircuit, hasherWitness, "hash"); - expect(hash.toString()).to.be.eq(inputs.inputHash.toString()); }); }); @@ -219,11 +188,12 @@ describe("Ceremony param tests", () => { "stateRoot", "ballotRoot", "sbSalt", - "packedVals", "sbCommitment", + "index", + "batchSize", + "numSignUps", "currentTallyCommitment", "newTallyCommitment", - "inputHash", "ballots", "ballotPathElements", "votes", diff --git a/circuits/ts/__tests__/Hasher.test.ts b/circuits/ts/__tests__/Hasher.test.ts index ff0d2f6f1..03d49e5e8 100644 --- a/circuits/ts/__tests__/Hasher.test.ts +++ b/circuits/ts/__tests__/Hasher.test.ts @@ -2,7 +2,7 @@ import { r } from "@zk-kit/baby-jubjub"; import { expect } from "chai"; import { type WitnessTester } from "circomkit"; import fc from "fast-check"; -import { genRandomSalt, sha256Hash, hash5, hash4, hash3, hash2 } from "maci-crypto"; +import { genRandomSalt, hash5, hash4, hash3, hash2 } from "maci-crypto"; import { PCommand, Keypair } from "maci-domainobjs"; import { getSignal, circomkitInstance } from "./utils/utils"; @@ -10,168 +10,6 @@ import { getSignal, circomkitInstance } from "./utils/utils"; describe("Poseidon hash circuits", function test() { this.timeout(900000); - describe("SHA256", () => { - describe("Sha256Hasher", () => { - let circuit: WitnessTester<["in"], ["hash"]>; - - it("correctly hashes 2 random values in order", async () => { - const n = 2; - - circuit = await circomkitInstance.WitnessTester("sha256hasher", { - file: "./utils/hashers", - template: "Sha256Hasher", - params: [n], - }); - - await fc.assert( - fc.asyncProperty( - fc.array(fc.bigInt({ min: 0n, max: r - 1n }), { minLength: n, maxLength: n }), - async (preImages: bigint[]) => { - const witness = await circuit.calculateWitness({ - in: preImages, - }); - await circuit.expectConstraintPass(witness); - const output = await getSignal(circuit, witness, "hash"); - const outputJS = sha256Hash(preImages); - - return output === outputJS; - }, - ), - ); - }); - - it("correctly hashes 3 random values", async () => { - const n = 3; - - circuit = await circomkitInstance.WitnessTester("sha256hasher", { - file: "./utils/hashers", - template: "Sha256Hasher", - params: [n], - }); - - await fc.assert( - fc.asyncProperty( - fc.array(fc.bigInt({ min: 0n, max: r - 1n }), { minLength: n, maxLength: n }), - async (preImages: bigint[]) => { - const witness = await circuit.calculateWitness({ - in: preImages, - }); - await circuit.expectConstraintPass(witness); - const output = await getSignal(circuit, witness, "hash"); - const outputJS = sha256Hash(preImages); - - return output === outputJS; - }, - ), - ); - }); - - it("correctly hashes 4 random values", async () => { - const n = 4; - - circuit = await circomkitInstance.WitnessTester("sha256hasher", { - file: "./utils/hashers", - template: "Sha256Hasher", - params: [n], - }); - - await fc.assert( - fc.asyncProperty( - fc.array(fc.bigInt({ min: 0n, max: r - 1n }), { minLength: n, maxLength: n }), - async (preImages: bigint[]) => { - const witness = await circuit.calculateWitness({ - in: preImages, - }); - await circuit.expectConstraintPass(witness); - const output = await getSignal(circuit, witness, "hash"); - const outputJS = sha256Hash(preImages); - - return output === outputJS; - }, - ), - ); - }); - - it("correctly hashes 5 random values", async () => { - const n = 5; - - circuit = await circomkitInstance.WitnessTester("sha256hasher", { - file: "./utils/hashers", - template: "Sha256Hasher", - params: [n], - }); - - await fc.assert( - fc.asyncProperty( - fc.array(fc.bigInt({ min: 0n, max: r - 1n }), { minLength: n, maxLength: n }), - async (preImages: bigint[]) => { - const witness = await circuit.calculateWitness({ - in: preImages, - }); - await circuit.expectConstraintPass(witness); - const output = await getSignal(circuit, witness, "hash"); - const outputJS = sha256Hash(preImages); - - return output === outputJS; - }, - ), - ); - }); - - it("correctly hashes 6 random values", async () => { - const n = 6; - - circuit = await circomkitInstance.WitnessTester("sha256hasher", { - file: "./utils/hashers", - template: "Sha256Hasher", - params: [n], - }); - - await fc.assert( - fc.asyncProperty( - fc.array(fc.bigInt({ min: 0n, max: r - 1n }), { minLength: n, maxLength: n }), - async (preImages: bigint[]) => { - const witness = await circuit.calculateWitness({ - in: preImages, - }); - await circuit.expectConstraintPass(witness); - const output = await getSignal(circuit, witness, "hash"); - const outputJS = sha256Hash(preImages); - - return output === outputJS; - }, - ), - ); - }); - - it("correctly hashes 10 random values", async () => { - const n = 10; - - circuit = await circomkitInstance.WitnessTester("sha256hasher", { - file: "./utils/hashers", - template: "Sha256Hasher", - params: [n], - }); - - await fc.assert( - fc.asyncProperty( - fc.array(fc.bigInt({ min: 0n, max: r - 1n }), { minLength: n, maxLength: n }), - async (preImages: bigint[]) => { - const witness = await circuit.calculateWitness({ - in: preImages, - }); - await circuit.expectConstraintPass(witness); - const output = await getSignal(circuit, witness, "hash"); - const outputJS = sha256Hash(preImages); - - return output === outputJS; - }, - ), - ); - }); - }); - }); - describe("Poseidon", () => { describe("PoseidonHasher", () => { let circuit: WitnessTester<["inputs"], ["out"]>; diff --git a/circuits/ts/__tests__/ProcessMessages.test.ts b/circuits/ts/__tests__/ProcessMessages.test.ts index 6562ec650..3d039a7f9 100644 --- a/circuits/ts/__tests__/ProcessMessages.test.ts +++ b/circuits/ts/__tests__/ProcessMessages.test.ts @@ -1,6 +1,6 @@ import { expect } from "chai"; import { type WitnessTester } from "circomkit"; -import { MaciState, Poll, packProcessMessageSmallVals, STATE_TREE_ARITY } from "maci-core"; +import { MaciState, Poll, STATE_TREE_ARITY } from "maci-core"; import { IncrementalQuinTree, hash2 } from "maci-crypto"; import { PrivKey, Keypair, PCommand, Message, Ballot, PubKey } from "maci-domainobjs"; @@ -14,7 +14,7 @@ import { treeDepths, voiceCreditBalance, } from "./utils/constants"; -import { getSignal, circomkitInstance } from "./utils/utils"; +import { circomkitInstance } from "./utils/utils"; describe("ProcessMessage circuit", function test() { this.timeout(900000); @@ -22,8 +22,10 @@ describe("ProcessMessage circuit", function test() { const coordinatorKeypair = new Keypair(); type ProcessMessageCircuitInputs = [ - "inputHash", - "packedVals", + "numSignUps", + "batchEndIndex", + "index", + "maxVoteOptions", "pollEndTimestamp", "msgRoot", "msgs", @@ -49,11 +51,6 @@ describe("ProcessMessage circuit", function test() { let circuitNonQv: WitnessTester; - let hasherCircuit: WitnessTester< - ["packedVals", "coordPubKey", "msgRoot", "currentSbCommitment", "newSbCommitment", "pollEndTimestamp"], - ["maxVoteOptions", "numSignUps", "batchStartIndex", "batchEndIndex", "hash"] - >; - before(async () => { circuit = await circomkitInstance.WitnessTester("processMessages", { file: "./core/qv/processMessages", @@ -66,11 +63,6 @@ describe("ProcessMessage circuit", function test() { template: "ProcessMessagesNonQv", params: [10, 2, 1, 2], }); - - hasherCircuit = await circomkitInstance.WitnessTester("ProcessMessagesInputHasher", { - file: "./utils/processMessagesInputHasher", - template: "ProcessMessagesInputHasher", - }); }); describe("5 users, 1 messages", () => { @@ -246,29 +238,6 @@ describe("ProcessMessage circuit", function test() { expect(newStateRoot?.toString()).not.to.be.eq(currentStateRoot?.toString()); expect(newBallotRoot?.toString()).not.to.be.eq(currentBallotRoot.toString()); - - const packedVals = packProcessMessageSmallVals( - BigInt(maxValues.maxVoteOptions), - BigInt(poll.maciStateRef.numSignUps), - 0, - 2, - ); - - // Test the ProcessMessagesInputHasher circuit - const hasherCircuitInputs = { - packedVals, - coordPubKey: inputs.coordPubKey, - msgRoot: inputs.msgRoot, - currentSbCommitment: inputs.currentSbCommitment, - newSbCommitment: inputs.newSbCommitment, - pollEndTimestamp: inputs.pollEndTimestamp, - actualStateTreeDepth: inputs.actualStateTreeDepth, - }; - - const hasherWitness = await hasherCircuit.calculateWitness(hasherCircuitInputs); - await hasherCircuit.expectConstraintPass(hasherWitness); - const hash = await getSignal(hasherCircuit, hasherWitness, "hash"); - expect(hash.toString()).to.be.eq(inputs.inputHash.toString()); }); }); diff --git a/circuits/ts/__tests__/TallyVotes.test.ts b/circuits/ts/__tests__/TallyVotes.test.ts index e907652e0..d645721eb 100644 --- a/circuits/ts/__tests__/TallyVotes.test.ts +++ b/circuits/ts/__tests__/TallyVotes.test.ts @@ -23,11 +23,12 @@ describe("TallyVotes circuit", function test() { "stateRoot", "ballotRoot", "sbSalt", - "packedVals", + "index", + "batchSize", + "numSignUps", "sbCommitment", "currentTallyCommitment", "newTallyCommitment", - "inputHash", "ballots", "ballotPathElements", "votes", diff --git a/circuits/ts/types.ts b/circuits/ts/types.ts index 20ed2cd39..5ed5395d0 100644 --- a/circuits/ts/types.ts +++ b/circuits/ts/types.ts @@ -45,9 +45,11 @@ export interface IGenProofOptions { */ export interface IProcessMessagesInputs { actualStateTreeDepth: bigint; - inputHash: bigint; - packedVals: bigint; pollEndTimestamp: bigint; + numSignUps: bigint; + batchEndIndex: bigint; + index: bigint; + maxVoteOptions: bigint; msgRoot: bigint; msgs: bigint[]; msgSubrootPathElements: bigint[][]; @@ -75,11 +77,12 @@ export interface ITallyVotesInputs { stateRoot: bigint; ballotRoot: bigint; sbSalt: bigint; - packedVals: bigint; + index: bigint; + batchSize: bigint; + numSignUps: bigint; sbCommitment: bigint; currentTallyCommitment: bigint; newTallyCommitment: bigint; - inputHash: bigint; ballots: bigint[]; ballotPathElements: bigint[]; votes: bigint[][]; diff --git a/cli/ts/commands/proveOnChain.ts b/cli/ts/commands/proveOnChain.ts index a7e425674..52b8f6ecc 100644 --- a/cli/ts/commands/proveOnChain.ts +++ b/cli/ts/commands/proveOnChain.ts @@ -1,6 +1,11 @@ /* eslint-disable no-await-in-loop */ import { type BigNumberish } from "ethers"; -import { type IVerifyingKeyStruct, formatProofForVerifierContract } from "maci-contracts"; +import { + type IVerifyingKeyStruct, + formatProofForVerifierContract, + generateProcessMessagesPublicInputs, + generateTallyVotesPublicInputs, +} from "maci-contracts"; import { MACI__factory as MACIFactory, AccQueue__factory as AccQueueFactory, @@ -18,7 +23,6 @@ import fs from "fs"; import path from "path"; import { - asHex, banner, contractExists, error, @@ -243,36 +247,21 @@ export const proveOnChain = async ({ logError("coordPubKey mismatch."); } - const packedValsOnChain = BigInt( - await mpContract.genProcessMessagesPackedVals( - currentMessageBatchIndex, - numSignUps, - numMessages, - treeDepths.messageTreeSubDepth, - treeDepths.voteOptionTreeDepth, - ), - ).toString(); - - if (circuitInputs.packedVals !== packedValsOnChain) { - logError("packedVals mismatch."); - } - const formattedProof = formatProofForVerifierContract(proof); - const publicInputHashOnChain = BigInt( - await mpContract.genProcessMessagesPublicInputHash( - currentMessageBatchIndex, - messageRootOnChain.toString(), - numSignUps, - numMessages, - circuitInputs.currentSbCommitment as BigNumberish, - circuitInputs.newSbCommitment as BigNumberish, - treeDepths.messageTreeSubDepth, - treeDepths.voteOptionTreeDepth, - ), - ); + const publicInputsOnChain = await generateProcessMessagesPublicInputs({ + pollContract, + messageAqContract, + currentMessageBatchIndex, + currentSbCommitment: BigInt(circuitInputs.currentSbCommitment as BigNumberish), + newSbCommitment: BigInt(circuitInputs.newSbCommitment as BigNumberish), + }); - if (publicInputHashOnChain.toString() !== publicInputs[0].toString()) { + if ( + Object.values(publicInputsOnChain).every( + (value: BigNumberish, index) => value.toString() === publicInputs[index].toString(), + ) + ) { logError("Public input mismatch."); } @@ -289,7 +278,7 @@ export const proveOnChain = async ({ const isValidOnChain = await verifierContract.verify( formattedProof, vk.asContractParam() as IVerifyingKeyStruct, - publicInputHashOnChain.toString(), + Object.values(publicInputsOnChain), ); if (!isValidOnChain) { @@ -298,7 +287,7 @@ export const proveOnChain = async ({ try { // validate process messaging proof and store the new state and ballot root commitment - const tx = await mpContract.processMessages(asHex(circuitInputs.newSbCommitment as BigNumberish), formattedProof); + const tx = await mpContract.processMessages(publicInputsOnChain, formattedProof); const receipt = await tx.wait(); if (receipt?.status !== 1) { @@ -345,30 +334,24 @@ export const proveOnChain = async ({ logError("currentTallyCommitment mismatch."); } - const packedValsOnChain = BigInt( - await tallyContract.genTallyVotesPackedVals(numSignUps, batchStartIndex, tallyBatchSize), - ); - - if (circuitInputs.packedVals !== packedValsOnChain.toString()) { - logError("packedVals mismatch."); - } - const currentSbCommitmentOnChain = await mpContract.sbCommitment(); if (currentSbCommitmentOnChain.toString() !== circuitInputs.sbCommitment) { logError("currentSbCommitment mismatch."); } - const publicInputHashOnChain = await tallyContract.genTallyVotesPublicInputHash( - numSignUps, + const publicInputsOnChain = await generateTallyVotesPublicInputs({ + pollContract, + tallyContract, batchStartIndex, - tallyBatchSize, - circuitInputs.newTallyCommitment as BigNumberish, - ); + newTallyCommitment: BigInt(circuitInputs.newTallyCommitment as BigNumberish), + }); - if (publicInputHashOnChain.toString() !== publicInputs[0]) { + const publicInputsArray = Object.values(publicInputsOnChain); + + if (publicInputsArray.every((value: BigNumberish, index) => value.toString() === publicInputs[index].toString())) { logError( - `public input mismatch. tallyBatchNum=${i}, onchain=${publicInputHashOnChain.toString()}, offchain=${publicInputs[0].toString()}`, + `public input mismatch. tallyBatchNum=${i}, onchain=${publicInputsArray.toString()}, offchain=${publicInputs.toString()}`, ); } @@ -377,10 +360,7 @@ export const proveOnChain = async ({ try { // verify the proof on chain - const tx = await tallyContract.tallyVotes( - asHex(circuitInputs.newTallyCommitment as BigNumberish), - formattedProof, - ); + const tx = await tallyContract.tallyVotes(publicInputsOnChain, formattedProof); const receipt = await tx.wait(); if (receipt?.status !== 1) { diff --git a/contracts/contracts/MessageProcessor.sol b/contracts/contracts/MessageProcessor.sol index 1161139c8..3a4d463bc 100644 --- a/contracts/contracts/MessageProcessor.sol +++ b/contracts/contracts/MessageProcessor.sol @@ -28,6 +28,21 @@ contract MessageProcessor is Ownable, SnarkCommon, Hasher, CommonUtilities, IMes error CurrentMessageBatchIndexTooLarge(); error BatchEndIndexTooLarge(); + /// @notice Circuit public inputs + struct CircuitPublicInputs { + uint256 maxVoteOptions; + uint256 numSignUps; + uint256 index; + uint256 messageBatchSize; + uint256 batchEndIndex; + uint256 coordinatorPubKeyHash; + uint256 messageRoot; + uint256 currentSbCommitment; + uint256 newSbCommitment; + uint256 pollEndTimestamp; + uint256 actualStateTreeDepth; + } + // the number of children per node in the merkle trees uint256 internal constant TREE_ARITY = 5; @@ -70,10 +85,12 @@ contract MessageProcessor is Ownable, SnarkCommon, Hasher, CommonUtilities, IMes } /// @notice Update the Poll's currentSbCommitment if the proof is valid. - /// @param _newSbCommitment The new state root and ballot root commitment - /// after all messages are processed + /// @param _publicCircuitInputs The public circuit inputs /// @param _proof The zk-SNARK proof - function processMessages(uint256 _newSbCommitment, uint256[8] memory _proof) external onlyOwner { + function processMessages( + CircuitPublicInputs memory _publicCircuitInputs, + uint256[8] memory _proof + ) external onlyOwner { // ensure the voting period is over _votingPeriodOver(poll); @@ -88,7 +105,7 @@ contract MessageProcessor is Ownable, SnarkCommon, Hasher, CommonUtilities, IMes } // Retrieve stored vals - (, uint8 messageTreeSubDepth, uint8 messageTreeDepth, uint8 voteOptionTreeDepth) = poll.treeDepths(); + (, uint8 messageTreeSubDepth, uint8 messageTreeDepth, ) = poll.treeDepths(); // calculate the message batch size from the message tree subdepth uint256 messageBatchSize = TREE_ARITY ** messageTreeSubDepth; @@ -119,18 +136,7 @@ contract MessageProcessor is Ownable, SnarkCommon, Hasher, CommonUtilities, IMes } } - if ( - !verifyProcessProof( - currentMessageBatchIndex, - messageRoot, - sbCommitment, - _newSbCommitment, - messageTreeSubDepth, - messageTreeDepth, - voteOptionTreeDepth, - _proof - ) - ) { + if (!verifyProcessProof(_publicCircuitInputs, _proof)) { revert InvalidProcessMessageProof(); } @@ -143,7 +149,7 @@ contract MessageProcessor is Ownable, SnarkCommon, Hasher, CommonUtilities, IMes } updateMessageProcessingData( - _newSbCommitment, + _publicCircuitInputs.newSbCommitment, currentMessageBatchIndex, numMessages <= messageBatchSize * (numBatchesProcessed + 1) ); @@ -152,139 +158,41 @@ contract MessageProcessor is Ownable, SnarkCommon, Hasher, CommonUtilities, IMes /// @notice Verify the proof for processMessage /// @dev used to update the sbCommitment - /// @param _currentMessageBatchIndex The batch index of current message batch - /// @param _messageRoot The message tree root - /// @param _currentSbCommitment The current sbCommitment (state and ballot) - /// @param _newSbCommitment The new sbCommitment after we update this message batch - /// @param _messageTreeSubDepth The message tree subdepth - /// @param _messageTreeDepth The message tree depth - /// @param _voteOptionTreeDepth The vote option tree depth + /// @param circuitPublicInputs The circuit public inputs /// @param _proof The zk-SNARK proof /// @return isValid Whether the proof is valid function verifyProcessProof( - uint256 _currentMessageBatchIndex, - uint256 _messageRoot, - uint256 _currentSbCommitment, - uint256 _newSbCommitment, - uint8 _messageTreeSubDepth, - uint8 _messageTreeDepth, - uint8 _voteOptionTreeDepth, + CircuitPublicInputs memory circuitPublicInputs, uint256[8] memory _proof ) internal view returns (bool isValid) { // get the tree depths - // get the message batch size from the message tree subdepth - // get the number of signups - (uint256 numSignUps, uint256 numMessages) = poll.numSignUpsAndMessages(); + (, uint8 messageTreeSubDepth, uint8 messageTreeDepth, uint8 voteOptionTreeDepth) = poll.treeDepths(); (IMACI maci, ) = poll.extContracts(); // Calculate the public input hash (a SHA256 hash of several values) - uint256 publicInputHash = genProcessMessagesPublicInputHash( - _currentMessageBatchIndex, - _messageRoot, - numSignUps, - numMessages, - _currentSbCommitment, - _newSbCommitment, - _messageTreeSubDepth, - _voteOptionTreeDepth - ); + uint256[] memory publicInputs = new uint256[](11); + publicInputs[0] = circuitPublicInputs.maxVoteOptions; + publicInputs[1] = circuitPublicInputs.numSignUps; + publicInputs[2] = circuitPublicInputs.index; + publicInputs[3] = circuitPublicInputs.messageBatchSize; + publicInputs[4] = circuitPublicInputs.batchEndIndex; + publicInputs[5] = circuitPublicInputs.coordinatorPubKeyHash; + publicInputs[6] = circuitPublicInputs.messageRoot; + publicInputs[7] = circuitPublicInputs.currentSbCommitment; + publicInputs[8] = circuitPublicInputs.newSbCommitment; + publicInputs[9] = circuitPublicInputs.pollEndTimestamp; + publicInputs[10] = circuitPublicInputs.actualStateTreeDepth; // Get the verifying key from the VkRegistry VerifyingKey memory vk = vkRegistry.getProcessVk( maci.stateTreeDepth(), - _messageTreeDepth, - _voteOptionTreeDepth, - TREE_ARITY ** _messageTreeSubDepth, + messageTreeDepth, + voteOptionTreeDepth, + TREE_ARITY ** messageTreeSubDepth, mode ); - isValid = verifier.verify(_proof, vk, publicInputHash); - } - - /// @notice Returns the SHA256 hash of the packed values (see - /// genProcessMessagesPackedVals), the hash of the coordinator's public key, - /// the message root, and the commitment to the current state root and - /// ballot root. By passing the SHA256 hash of these values to the circuit - /// as a single public input and the preimage as private inputs, we reduce - /// its verification gas cost though the number of constraints will be - /// higher and proving time will be longer. - /// @param _currentMessageBatchIndex The batch index of current message batch - /// @param _numSignUps The number of users that signup - /// @param _numMessages The number of messages - /// @param _currentSbCommitment The current sbCommitment (state and ballot root) - /// @param _newSbCommitment The new sbCommitment after we update this message batch - /// @param _messageTreeSubDepth The message tree subdepth - /// @return inputHash Returns the SHA256 hash of the packed values - function genProcessMessagesPublicInputHash( - uint256 _currentMessageBatchIndex, - uint256 _messageRoot, - uint256 _numSignUps, - uint256 _numMessages, - uint256 _currentSbCommitment, - uint256 _newSbCommitment, - uint8 _messageTreeSubDepth, - uint8 _voteOptionTreeDepth - ) public view returns (uint256 inputHash) { - uint256 coordinatorPubKeyHash = poll.coordinatorPubKeyHash(); - - uint8 actualStateTreeDepth = poll.actualStateTreeDepth(); - - // pack the values - uint256 packedVals = genProcessMessagesPackedVals( - _currentMessageBatchIndex, - _numSignUps, - _numMessages, - _messageTreeSubDepth, - _voteOptionTreeDepth - ); - - (uint256 deployTime, uint256 duration) = poll.getDeployTimeAndDuration(); - - // generate the circuit only public input - uint256[] memory input = new uint256[](7); - input[0] = packedVals; - input[1] = coordinatorPubKeyHash; - input[2] = _messageRoot; - input[3] = _currentSbCommitment; - input[4] = _newSbCommitment; - input[5] = deployTime + duration; - input[6] = actualStateTreeDepth; - inputHash = sha256Hash(input); - } - - /// @notice One of the inputs to the ProcessMessages circuit is a 250-bit - /// representation of four 50-bit values. This function generates this - /// 250-bit value, which consists of the maximum number of vote options, the - /// number of signups, the current message batch index, and the end index of - /// the current batch. - /// @param _currentMessageBatchIndex batch index of current message batch - /// @param _numSignUps number of users that signup - /// @param _numMessages number of messages - /// @param _messageTreeSubDepth message tree subdepth - /// @param _voteOptionTreeDepth vote option tree depth - /// @return result The packed value - function genProcessMessagesPackedVals( - uint256 _currentMessageBatchIndex, - uint256 _numSignUps, - uint256 _numMessages, - uint8 _messageTreeSubDepth, - uint8 _voteOptionTreeDepth - ) public pure returns (uint256 result) { - uint256 maxVoteOptions = TREE_ARITY ** _voteOptionTreeDepth; - - // calculate the message batch size from the message tree subdepth - uint256 messageBatchSize = TREE_ARITY ** _messageTreeSubDepth; - uint256 batchEndIndex = _currentMessageBatchIndex + messageBatchSize; - if (batchEndIndex > _numMessages) { - batchEndIndex = _numMessages; - } - - if (maxVoteOptions >= 2 ** 50) revert MaxVoteOptionsTooLarge(); - if (_numSignUps >= 2 ** 50) revert NumSignUpsTooLarge(); - if (_currentMessageBatchIndex >= 2 ** 50) revert CurrentMessageBatchIndexTooLarge(); - if (batchEndIndex >= 2 ** 50) revert BatchEndIndexTooLarge(); - - result = maxVoteOptions + (_numSignUps << 50) + (_currentMessageBatchIndex << 100) + (batchEndIndex << 150); + isValid = verifier.verify(_proof, vk, publicInputs); } /// @notice update message processing state variables diff --git a/contracts/contracts/Tally.sol b/contracts/contracts/Tally.sol index c6e4d12b7..161a2427c 100644 --- a/contracts/contracts/Tally.sol +++ b/contracts/contracts/Tally.sol @@ -60,6 +60,16 @@ contract Tally is Ownable, SnarkCommon, CommonUtilities, Hasher, DomainObjs { error TallyBatchSizeTooLarge(); error NotSupported(); + /// @notice Circuit public inputs + struct CircuitPublicInputs { + uint256 index; + uint256 batchSize; + uint256 numSignUps; + uint256 sbCommitment; + uint256 currentTallyCommitment; + uint256 newTallyCommitment; + } + /// @notice Create a new Tally contract /// @param _verifier The Verifier contract /// @param _vkRegistry The VkRegistry contract @@ -82,23 +92,6 @@ contract Tally is Ownable, SnarkCommon, CommonUtilities, Hasher, DomainObjs { mode = _mode; } - /// @notice Pack the batch start index and number of signups into a 100-bit value. - /// @param _numSignUps: number of signups - /// @param _batchStartIndex: the start index of given batch - /// @param _tallyBatchSize: size of batch - /// @return result an uint256 representing the 3 inputs packed together - function genTallyVotesPackedVals( - uint256 _numSignUps, - uint256 _batchStartIndex, - uint256 _tallyBatchSize - ) public pure returns (uint256 result) { - if (_numSignUps >= 2 ** 50) revert NumSignUpsTooLarge(); - if (_batchStartIndex >= 2 ** 50) revert BatchStartIndexTooLarge(); - if (_tallyBatchSize >= 2 ** 50) revert TallyBatchSizeTooLarge(); - - result = (_batchStartIndex / _tallyBatchSize) + (_numSignUps << uint256(50)); - } - /// @notice Check if all ballots are tallied /// @return tallied whether all ballots are tallied function isTallied() public view returns (bool tallied) { @@ -109,27 +102,6 @@ contract Tally is Ownable, SnarkCommon, CommonUtilities, Hasher, DomainObjs { tallied = tallyBatchNum * (TREE_ARITY ** intStateTreeDepth) >= numSignUps; } - /// @notice generate hash of public inputs for tally circuit - /// @param _numSignUps: number of signups - /// @param _batchStartIndex: the start index of given batch - /// @param _tallyBatchSize: size of batch - /// @param _newTallyCommitment: the new tally commitment to be updated - /// @return inputHash hash of public inputs - function genTallyVotesPublicInputHash( - uint256 _numSignUps, - uint256 _batchStartIndex, - uint256 _tallyBatchSize, - uint256 _newTallyCommitment - ) public view returns (uint256 inputHash) { - uint256 packedVals = genTallyVotesPackedVals(_numSignUps, _batchStartIndex, _tallyBatchSize); - uint256[] memory input = new uint256[](4); - input[0] = packedVals; - input[1] = sbCommitment; - input[2] = tallyCommitment; - input[3] = _newTallyCommitment; - inputHash = sha256Hash(input); - } - /// @notice Update the state and ballot root commitment function updateSbCommitment() public onlyOwner { // Require that all messages have been processed @@ -143,9 +115,9 @@ contract Tally is Ownable, SnarkCommon, CommonUtilities, Hasher, DomainObjs { } /// @notice Verify the result of a tally batch - /// @param _newTallyCommitment the new tally commitment to be verified + /// @param _circuitPublicInputs circuit public inputs /// @param _proof the proof generated after tallying this batch - function tallyVotes(uint256 _newTallyCommitment, uint256[8] calldata _proof) public onlyOwner { + function tallyVotes(CircuitPublicInputs memory _circuitPublicInputs, uint256[8] calldata _proof) public onlyOwner { _votingPeriodOver(poll); updateSbCommitment(); @@ -166,29 +138,23 @@ contract Tally is Ownable, SnarkCommon, CommonUtilities, Hasher, DomainObjs { revert AllBallotsTallied(); } - bool isValid = verifyTallyProof(_proof, numSignUps, batchStartIndex, tallyBatchSize, _newTallyCommitment); + bool isValid = verifyTallyProof(_circuitPublicInputs, _proof); if (!isValid) { revert InvalidTallyVotesProof(); } // Update the tally commitment and the tally batch num - tallyCommitment = _newTallyCommitment; + tallyCommitment = _circuitPublicInputs.newTallyCommitment; } /// @notice Verify the tally proof using the verifying key + /// @param _circuitPublicInputs circuit public inputs /// @param _proof the proof generated after processing all messages - /// @param _numSignUps number of signups for a given poll - /// @param _batchStartIndex the number of batches multiplied by the size of the batch - /// @param _tallyBatchSize batch size for the tally - /// @param _newTallyCommitment the tally commitment to be verified at a given batch index /// @return isValid whether the proof is valid function verifyTallyProof( - uint256[8] calldata _proof, - uint256 _numSignUps, - uint256 _batchStartIndex, - uint256 _tallyBatchSize, - uint256 _newTallyCommitment + CircuitPublicInputs memory _circuitPublicInputs, + uint256[8] calldata _proof ) public view returns (bool isValid) { (uint8 intStateTreeDepth, , , uint8 voteOptionTreeDepth) = poll.treeDepths(); @@ -198,15 +164,16 @@ contract Tally is Ownable, SnarkCommon, CommonUtilities, Hasher, DomainObjs { VerifyingKey memory vk = vkRegistry.getTallyVk(maci.stateTreeDepth(), intStateTreeDepth, voteOptionTreeDepth, mode); // Get the public inputs - uint256 publicInputHash = genTallyVotesPublicInputHash( - _numSignUps, - _batchStartIndex, - _tallyBatchSize, - _newTallyCommitment - ); + uint256[] memory publicInputs = new uint256[](6); + publicInputs[0] = _circuitPublicInputs.index; + publicInputs[1] = _circuitPublicInputs.batchSize; + publicInputs[2] = _circuitPublicInputs.numSignUps; + publicInputs[3] = _circuitPublicInputs.sbCommitment; + publicInputs[4] = _circuitPublicInputs.currentTallyCommitment; + publicInputs[5] = _circuitPublicInputs.newTallyCommitment; // Verify the proof - isValid = verifier.verify(_proof, vk, publicInputHash); + isValid = verifier.verify(_proof, vk, publicInputs); } /// @notice Compute the merkle root from the path elements diff --git a/contracts/contracts/crypto/MockVerifier.sol b/contracts/contracts/crypto/MockVerifier.sol index 2b8507a02..004739b33 100644 --- a/contracts/contracts/crypto/MockVerifier.sol +++ b/contracts/contracts/crypto/MockVerifier.sol @@ -10,7 +10,7 @@ import { IVerifier } from "../interfaces/IVerifier.sol"; contract MockVerifier is IVerifier, SnarkConstants, SnarkCommon { /// @notice Verify a zk-SNARK proof (test only return always true) /// @return result Whether the proof is valid given the verifying key and public - function verify(uint256[8] memory, VerifyingKey memory, uint256) public pure override returns (bool result) { + function verify(uint256[8] memory, VerifyingKey memory, uint256[] memory) public pure override returns (bool result) { result = true; } } diff --git a/contracts/contracts/crypto/Verifier.sol b/contracts/contracts/crypto/Verifier.sol index 68363e7e2..289e09e2b 100644 --- a/contracts/contracts/crypto/Verifier.sol +++ b/contracts/contracts/crypto/Verifier.sol @@ -26,7 +26,7 @@ contract Verifier is IVerifier, SnarkConstants, SnarkCommon { /// @notice Verify a zk-SNARK proof /// @param _proof The proof /// @param vk The verifying key - /// @param input The public inputs to the circuit + /// @param inputs The public inputs to the circuit /// @return isValid Whether the proof is valid given the verifying key and public /// input. Note that this function only supports one public input. /// Refer to the Semaphore source code for a verifier that supports @@ -34,7 +34,7 @@ contract Verifier is IVerifier, SnarkConstants, SnarkCommon { function verify( uint256[8] memory _proof, VerifyingKey memory vk, - uint256 input + uint256[] memory inputs ) public view override returns (bool isValid) { Proof memory proof; proof.a = Pairing.G1Point(_proof[0], _proof[1]); @@ -51,15 +51,25 @@ contract Verifier is IVerifier, SnarkConstants, SnarkCommon { checkPoint(proof.c.x); checkPoint(proof.c.y); - // Make sure that the input is less than the snark scalar field - if (input >= SNARK_SCALAR_FIELD) { - revert InvalidInputVal(); - } - // Compute the linear combination vk_x Pairing.G1Point memory vkX = Pairing.G1Point(0, 0); - vkX = Pairing.plus(vkX, Pairing.scalarMul(vk.ic[1], input)); + uint256 inputsLength = inputs.length; + + for (uint256 i = 0; i < inputsLength; ) { + uint256 input = inputs[i]; + + // Make sure that the input is less than the snark scalar field + if (input >= SNARK_SCALAR_FIELD) { + revert InvalidInputVal(); + } + + vkX = Pairing.plus(vkX, Pairing.scalarMul(vk.ic[1], input)); + + unchecked { + i++; + } + } vkX = Pairing.plus(vkX, vk.ic[0]); diff --git a/contracts/contracts/interfaces/IVerifier.sol b/contracts/contracts/interfaces/IVerifier.sol index aaf5891b3..3f3970e7a 100644 --- a/contracts/contracts/interfaces/IVerifier.sol +++ b/contracts/contracts/interfaces/IVerifier.sol @@ -9,7 +9,7 @@ interface IVerifier { /// @notice Verify a zk-SNARK proof /// @param _proof The proof /// @param vk The verifying key - /// @param input The public inputs to the circuit + /// @param inputs The public inputs to the circuit /// @return Whether the proof is valid given the verifying key and public /// input. Note that this function only supports one public input. /// Refer to the Semaphore source code for a verifier that supports @@ -17,6 +17,6 @@ interface IVerifier { function verify( uint256[8] memory _proof, SnarkCommon.VerifyingKey memory vk, - uint256 input + uint256[] memory inputs ) external view returns (bool); } diff --git a/contracts/contracts/trees/LazyIMT.sol b/contracts/contracts/trees/LazyIMT.sol index bd2df1a93..92af140fb 100644 --- a/contracts/contracts/trees/LazyIMT.sol +++ b/contracts/contracts/trees/LazyIMT.sol @@ -119,7 +119,7 @@ library InternalLazyIMT { /// @return The index for the element function _indexForElement(uint8 level, uint40 index) internal pure returns (uint40) { // store the elements sparsely - return MAX_INDEX * level + index; + return (uint40(level) << 32) - level + index; } /// @notice Inserts a leaf into the LazyIMT @@ -153,7 +153,7 @@ library InternalLazyIMT { uint40 numberOfLeaves = self.numberOfLeaves; // dynamically determine a depth uint8 depth = 1; - while (uint40(2) ** uint40(depth) < numberOfLeaves) { + while (uint40(1 << depth) < numberOfLeaves) { depth++; } return _root(self, numberOfLeaves, depth); diff --git a/contracts/tasks/helpers/Prover.ts b/contracts/tasks/helpers/Prover.ts index 8e2ffb494..1c5801b1d 100644 --- a/contracts/tasks/helpers/Prover.ts +++ b/contracts/tasks/helpers/Prover.ts @@ -6,7 +6,8 @@ import type { IVerifyingKeyStruct, Proof } from "../../ts/types"; import type { AccQueue, MACI, MessageProcessor, Poll, Tally, Verifier, VkRegistry } from "../../typechain-types"; import type { BigNumberish } from "ethers"; -import { formatProofForVerifierContract, asHex } from "../../ts/utils"; +import { generateProcessMessagesPublicInputs, generateTallyVotesPublicInputs } from "../../ts/circuitInputs"; +import { formatProofForVerifierContract } from "../../ts/utils"; import { STATE_TREE_ARITY } from "./constants"; import { IProverParams } from "./types"; @@ -91,7 +92,6 @@ export class Prover { this.mpContract.mode(), ]); - const numSignUps = Number(numSignUpsAndMessages[0]); const numMessages = Number(numSignUpsAndMessages[1]); const messageBatchSize = STATE_TREE_ARITY ** Number(treeDepths[1]); let totalMessageBatches = numMessages <= messageBatchSize ? 1 : Math.floor(numMessages / messageBatchSize); @@ -167,32 +167,16 @@ export class Prover { throw new Error("coordPubKey mismatch."); } - const packedValsOnChain = BigInt( - await this.mpContract.genProcessMessagesPackedVals( - currentMessageBatchIndex, - numSignUps, - numMessages, - treeDepths.messageTreeSubDepth, - treeDepths.voteOptionTreeDepth, - ), - ).toString(); - this.validatePackedValues(circuitInputs.packedVals as BigNumberish, packedValsOnChain); - const formattedProof = formatProofForVerifierContract(proof); - const publicInputHashOnChain = BigInt( - await this.mpContract.genProcessMessagesPublicInputHash( - currentMessageBatchIndex, - messageRootOnChain.toString(), - numSignUps, - numMessages, - circuitInputs.currentSbCommitment as BigNumberish, - circuitInputs.newSbCommitment as BigNumberish, - treeDepths.messageTreeSubDepth, - treeDepths.voteOptionTreeDepth, - ), - ); - this.validatePublicInput(publicInputs[0] as BigNumberish, publicInputHashOnChain); + const publicInputsOnChain = await generateProcessMessagesPublicInputs({ + pollContract: this.pollContract, + messageAqContract: this.messageAqContract, + currentSbCommitment: BigInt(circuitInputs.currentSbCommitment as BigNumberish), + newSbCommitment: BigInt(circuitInputs.newSbCommitment as BigNumberish), + currentMessageBatchIndex, + }); + this.validatePublicInput(publicInputs, Object.values(publicInputsOnChain) as BigNumberish[]); const vk = new VerifyingKey( new G1Point(onChainProcessVk.alpha1[0], onChainProcessVk.alpha1[1]), @@ -207,7 +191,7 @@ export class Prover { const isValidOnChain = await this.verifierContract.verify( formattedProof, vk.asContractParam() as IVerifyingKeyStruct, - publicInputHashOnChain.toString(), + Object.values(publicInputsOnChain), ); if (!isValidOnChain) { @@ -216,9 +200,8 @@ export class Prover { try { // validate process messaging proof and store the new state and ballot root commitment - const receipt = await this.mpContract - .processMessages(asHex(circuitInputs.newSbCommitment as BigNumberish), formattedProof) + .processMessages(publicInputsOnChain, formattedProof) .then((tx) => tx.wait()); if (receipt?.status !== 1) { @@ -280,30 +263,23 @@ export class Prover { const currentTallyCommitmentOnChain = await this.tallyContract.tallyCommitment(); this.validateCommitment(circuitInputs.currentTallyCommitment as BigNumberish, currentTallyCommitmentOnChain); - const packedValsOnChain = BigInt( - await this.tallyContract.genTallyVotesPackedVals(numSignUps, batchStartIndex, tallyBatchSize), - ); - this.validatePackedValues(circuitInputs.packedVals as BigNumberish, packedValsOnChain); - const currentSbCommitmentOnChain = await this.mpContract.sbCommitment(); console.log(currentSbCommitmentOnChain, circuitInputs); this.validateCommitment(circuitInputs.sbCommitment as BigNumberish, currentSbCommitmentOnChain); - const publicInputHashOnChain = await this.tallyContract.genTallyVotesPublicInputHash( - numSignUps, + const publicInputsOnChain = await generateTallyVotesPublicInputs({ + pollContract: this.pollContract, + tallyContract: this.tallyContract, batchStartIndex, - tallyBatchSize, - circuitInputs.newTallyCommitment as BigNumberish, - ); - this.validatePublicInput(publicInputs[0] as BigNumberish, publicInputHashOnChain); + newTallyCommitment: BigInt(circuitInputs.newTallyCommitment as BigNumberish), + }); + this.validatePublicInput(publicInputs, Object.values(publicInputsOnChain) as BigNumberish[]); // format the tally proof so it can be verified on chain const formattedProof = formatProofForVerifierContract(proof); // verify the proof on chain - const receipt = await this.tallyContract - .tallyVotes(asHex(circuitInputs.newTallyCommitment as BigNumberish), formattedProof) - .then((tx) => tx.wait()); + const receipt = await this.tallyContract.tallyVotes(publicInputsOnChain, formattedProof).then((tx) => tx.wait()); if (receipt?.status !== 1) { throw new Error("tallyVotes() failed"); @@ -359,28 +335,15 @@ export class Prover { } } - /** - * Validate packed values - * - * @param packedVals - off-chain packed values - * @param packedValsOnChain - on-chain packed values - * @throws error if packed values don't match - */ - private validatePackedValues(packedVals: BigNumberish, packedValsOnChain: BigNumberish) { - if (packedVals.toString() !== packedValsOnChain.toString()) { - throw new Error("packedVals mismatch."); - } - } - /** * Validate public input hash * - * @param publicInputHash - off-chain public input hash - * @param publicInputHashOnChain - on-chain public input hash + * @param publicInputs - off-chain public input hash + * @param publicInputsOnChain - on-chain public input hash * @throws error if public input hashes don't match */ - private validatePublicInput(publicInputHash: BigNumberish, publicInputHashOnChain: BigNumberish) { - if (publicInputHashOnChain.toString() !== publicInputHash.toString()) { + private validatePublicInput(publicInputs: BigNumberish[], publicInputsOnChain: BigNumberish[]) { + if (publicInputsOnChain.every((value, index) => value.toString() === publicInputs[index].toString())) { throw new Error("public input mismatch."); } } diff --git a/contracts/tests/MessageProcessor.test.ts b/contracts/tests/MessageProcessor.test.ts index c3e93155c..e63f9fe89 100644 --- a/contracts/tests/MessageProcessor.test.ts +++ b/contracts/tests/MessageProcessor.test.ts @@ -2,14 +2,17 @@ import { expect } from "chai"; import { Signer } from "ethers"; import { EthereumProvider } from "hardhat/types"; -import { MaciState, Poll, packProcessMessageSmallVals, IProcessMessagesCircuitInputs } from "maci-core"; +import { MaciState, Poll, IProcessMessagesCircuitInputs } from "maci-core"; import { NOTHING_UP_MY_SLEEVE } from "maci-crypto"; import { Keypair, Message, PubKey } from "maci-domainobjs"; +import { generateProcessMessagesPublicInputs } from "../ts/circuitInputs"; import { EMode } from "../ts/constants"; import { IVerifyingKeyStruct } from "../ts/types"; import { getDefaultSigner } from "../ts/utils"; import { + AccQueue, + AccQueue__factory as AccQueueFactory, MACI, MessageProcessor, MessageProcessor__factory as MessageProcessorFactory, @@ -35,6 +38,7 @@ describe("MessageProcessor", () => { // contracts let maciContract: MACI; let pollContract: PollContract; + let messageAqContract: AccQueue; let verifierContract: Verifier; let vkRegistryContract: VkRegistry; let mpContract: MessageProcessor; @@ -49,8 +53,6 @@ describe("MessageProcessor", () => { let generatedInputs: IProcessMessagesCircuitInputs; const coordinator = new Keypair(); - const users = [new Keypair(), new Keypair()]; - before(async () => { signer = await getDefaultSigner(); // deploy test contracts @@ -92,6 +94,10 @@ describe("MessageProcessor", () => { const pollContractAddress = await maciContract.getPoll(pollId); pollContract = PollFactory.connect(pollContractAddress, signer); + messageAqContract = AccQueueFactory.connect( + await pollContract.extContracts().then((value) => value.messageAq), + signer, + ); mpContract = MessageProcessorFactory.connect(event.args.pollAddr.messageProcessor, signer); @@ -144,7 +150,15 @@ describe("MessageProcessor", () => { }); it("processMessages() should fail if the state AQ has not been merged", async () => { - await expect(mpContract.processMessages(0, [0, 0, 0, 0, 0, 0, 0, 0])).to.be.revertedWithCustomError( + const inputs = await generateProcessMessagesPublicInputs({ + pollContract, + messageAqContract, + newSbCommitment: BigInt(generatedInputs.newSbCommitment), + currentSbCommitment: BigInt(generatedInputs.currentSbCommitment), + currentMessageBatchIndex: poll.currentMessageBatchIndex ?? 0, + }); + + await expect(mpContract.processMessages(inputs, [0, 0, 0, 0, 0, 0, 0, 0])).to.be.revertedWithCustomError( mpContract, "StateNotMerged", ); @@ -159,28 +173,17 @@ describe("MessageProcessor", () => { await pollContract.mergeMessageAq(); }); - it("genProcessMessagesPackedVals() should generate the correct value", async () => { - const packedVals = packProcessMessageSmallVals( - BigInt(maxValues.maxVoteOptions), - BigInt(users.length), - 0, - poll.messages.length, - ); - const onChainPackedVals = BigInt( - await mpContract.genProcessMessagesPackedVals( - 0, - users.length, - poll.messages.length, - treeDepths.messageTreeSubDepth, - treeDepths.voteOptionTreeDepth, - ), - ); - expect(packedVals.toString()).to.eq(onChainPackedVals.toString()); - }); - it("processMessages() should update the state and ballot root commitment", async () => { + const inputs = await generateProcessMessagesPublicInputs({ + pollContract, + messageAqContract, + newSbCommitment: BigInt(generatedInputs.newSbCommitment), + currentSbCommitment: BigInt(generatedInputs.currentSbCommitment), + currentMessageBatchIndex: poll.currentMessageBatchIndex ?? 0, + }); + // Submit the proof - const tx = await mpContract.processMessages(generatedInputs.newSbCommitment, [0, 0, 0, 0, 0, 0, 0, 0]); + const tx = await mpContract.processMessages(inputs, [0, 0, 0, 0, 0, 0, 0, 0]); const receipt = await tx.wait(); expect(receipt?.status).to.eq(1); diff --git a/contracts/tests/Tally.test.ts b/contracts/tests/Tally.test.ts index 37ce7d600..63185d118 100644 --- a/contracts/tests/Tally.test.ts +++ b/contracts/tests/Tally.test.ts @@ -2,16 +2,11 @@ import { expect } from "chai"; import { AbiCoder, Signer } from "ethers"; import { EthereumProvider } from "hardhat/types"; -import { - MaciState, - Poll, - packTallyVotesSmallVals, - IProcessMessagesCircuitInputs, - ITallyCircuitInputs, -} from "maci-core"; +import { MaciState, Poll, IProcessMessagesCircuitInputs, ITallyCircuitInputs } from "maci-core"; import { NOTHING_UP_MY_SLEEVE } from "maci-crypto"; import { Keypair, Message, PubKey } from "maci-domainobjs"; +import { generateProcessMessagesPublicInputs, generateTallyVotesPublicInputs } from "../ts/circuitInputs"; import { EMode } from "../ts/constants"; import { IVerifyingKeyStruct } from "../ts/types"; import { getDefaultSigner } from "../ts/utils"; @@ -25,6 +20,8 @@ import { MessageProcessor__factory as MessageProcessorFactory, Poll__factory as PollFactory, Tally__factory as TallyFactory, + AccQueue__factory as AccQueueFactory, + AccQueue, } from "../typechain-types"; import { @@ -33,7 +30,6 @@ import { initialVoiceCreditBalance, maxValues, messageBatchSize, - tallyBatchSize, testProcessVk, testTallyVk, treeDepths, @@ -45,6 +41,7 @@ describe("TallyVotes", () => { let maciContract: MACI; let pollContract: PollContract; let tallyContract: Tally; + let messageAqContract: AccQueue; let mpContract: MessageProcessor; let verifierContract: Verifier; let vkRegistryContract: VkRegistry; @@ -110,6 +107,10 @@ describe("TallyVotes", () => { pollContract = PollFactory.connect(pollContractAddress, signer); mpContract = MessageProcessorFactory.connect(event.args.pollAddr.messageProcessor, signer); tallyContract = TallyFactory.connect(event.args.pollAddr.tally, signer); + messageAqContract = AccQueueFactory.connect( + await pollContract.extContracts().then((value) => value.messageAq), + signer, + ); // deploy local poll const p = maciState.deployPoll(BigInt(deployTime + duration), maxValues, treeDepths, messageBatchSize, coordinator); @@ -151,18 +152,19 @@ describe("TallyVotes", () => { }); it("should not be possible to tally votes before the poll has ended", async () => { - await expect(tallyContract.tallyVotes(0, [0, 0, 0, 0, 0, 0, 0, 0])).to.be.revertedWithCustomError( + const inputs = await generateTallyVotesPublicInputs({ + pollContract, + tallyContract, + batchStartIndex: 0, + newTallyCommitment: 0n, + }); + + await expect(tallyContract.tallyVotes(inputs, [0, 0, 0, 0, 0, 0, 0, 0])).to.be.revertedWithCustomError( tallyContract, "VotingPeriodNotPassed", ); }); - it("genTallyVotesPackedVals() should generate the correct value", async () => { - const onChainPackedVals = BigInt(await tallyContract.genTallyVotesPackedVals(users.length, 0, tallyBatchSize)); - const packedVals = packTallyVotesSmallVals(0, tallyBatchSize, users.length); - expect(onChainPackedVals.toString()).to.eq(packedVals.toString()); - }); - it("updateSbCommitment() should revert when the messages have not been processed yet", async () => { // go forward in time await timeTravel(signer.provider! as unknown as EthereumProvider, duration + 1); @@ -174,7 +176,14 @@ describe("TallyVotes", () => { }); it("tallyVotes() should fail as the messages have not been processed yet", async () => { - await expect(tallyContract.tallyVotes(0, [0, 0, 0, 0, 0, 0, 0, 0])).to.be.revertedWithCustomError( + const inputs = await generateTallyVotesPublicInputs({ + pollContract, + tallyContract, + batchStartIndex: 0, + newTallyCommitment: 0n, + }); + + await expect(tallyContract.tallyVotes(inputs, [0, 0, 0, 0, 0, 0, 0, 0])).to.be.revertedWithCustomError( tallyContract, "ProcessingNotComplete", ); @@ -182,6 +191,7 @@ describe("TallyVotes", () => { describe("after merging acc queues", () => { let tallyGeneratedInputs: ITallyCircuitInputs; + before(async () => { await pollContract.mergeMaciState(); @@ -196,10 +206,26 @@ describe("TallyVotes", () => { }); it("tallyVotes() should update the tally commitment", async () => { + const [mpInputs, tallyInputs] = await Promise.all([ + generateProcessMessagesPublicInputs({ + pollContract, + messageAqContract, + newSbCommitment: BigInt(generatedInputs.newSbCommitment), + currentSbCommitment: BigInt(generatedInputs.currentSbCommitment), + currentMessageBatchIndex: poll.currentMessageBatchIndex ?? 0, + }), + generateTallyVotesPublicInputs({ + pollContract, + tallyContract, + batchStartIndex: 0, + newTallyCommitment: BigInt(tallyGeneratedInputs.newTallyCommitment), + }), + ]); + // do the processing on the message processor contract - await mpContract.processMessages(generatedInputs.newSbCommitment, [0, 0, 0, 0, 0, 0, 0, 0]); + await mpContract.processMessages(mpInputs, [0, 0, 0, 0, 0, 0, 0, 0]); - await tallyContract.tallyVotes(tallyGeneratedInputs.newTallyCommitment, [0, 0, 0, 0, 0, 0, 0, 0]); + await tallyContract.tallyVotes(tallyInputs, [0, 0, 0, 0, 0, 0, 0, 0]); const onChainNewTallyCommitment = await tallyContract.tallyCommitment(); expect(tallyGeneratedInputs.newTallyCommitment).to.eq(onChainNewTallyCommitment.toString()); @@ -211,9 +237,17 @@ describe("TallyVotes", () => { }); it("tallyVotes() should revert when votes have already been tallied", async () => { - await expect( - tallyContract.tallyVotes(tallyGeneratedInputs.newTallyCommitment, [0, 0, 0, 0, 0, 0, 0, 0]), - ).to.be.revertedWithCustomError(tallyContract, "AllBallotsTallied"); + const inputs = await generateTallyVotesPublicInputs({ + pollContract, + tallyContract, + batchStartIndex: 0, + newTallyCommitment: BigInt(tallyGeneratedInputs.newTallyCommitment), + }); + + await expect(tallyContract.tallyVotes(inputs, [0, 0, 0, 0, 0, 0, 0, 0])).to.be.revertedWithCustomError( + tallyContract, + "AllBallotsTallied", + ); }); }); @@ -291,6 +325,10 @@ describe("TallyVotes", () => { pollContract = PollFactory.connect(pollContractAddress, signer); mpContract = MessageProcessorFactory.connect(event.args.pollAddr.messageProcessor, signer); tallyContract = TallyFactory.connect(event.args.pollAddr.tally, signer); + messageAqContract = AccQueueFactory.connect( + await pollContract.extContracts().then((value) => value.messageAq), + signer, + ); // deploy local poll const p = maciState.deployPoll( @@ -343,21 +381,48 @@ describe("TallyVotes", () => { await pollContract.mergeMessageAq(); const processMessagesInputs = poll.processMessages(pollId); - await mpContract.processMessages(processMessagesInputs.newSbCommitment, [0, 0, 0, 0, 0, 0, 0, 0]); + + const inputs = await generateProcessMessagesPublicInputs({ + pollContract, + messageAqContract, + newSbCommitment: BigInt(processMessagesInputs.newSbCommitment), + currentSbCommitment: BigInt(generatedInputs.currentSbCommitment), + currentMessageBatchIndex: poll.currentMessageBatchIndex ?? 0, + }); + + await mpContract.processMessages(inputs, [0, 0, 0, 0, 0, 0, 0, 0]); }); it("should tally votes correctly", async () => { let tallyGeneratedInputs: ITallyCircuitInputs; + while (poll.hasUntalliedBallots()) { tallyGeneratedInputs = poll.tallyVotes(); // eslint-disable-next-line no-await-in-loop - await tallyContract.tallyVotes(tallyGeneratedInputs.newTallyCommitment, [0, 0, 0, 0, 0, 0, 0, 0]); + const inputs = await generateTallyVotesPublicInputs({ + pollContract, + tallyContract, + batchStartIndex: poll.currentMessageBatchIndex ?? 0, + newTallyCommitment: BigInt(tallyGeneratedInputs.newTallyCommitment), + }); + // eslint-disable-next-line no-await-in-loop + await tallyContract.tallyVotes(inputs, [0, 0, 0, 0, 0, 0, 0, 0]); } + const tallyInputs = await generateTallyVotesPublicInputs({ + pollContract, + tallyContract, + batchStartIndex: 0, + newTallyCommitment: BigInt(tallyGeneratedInputs!.newTallyCommitment), + }); + const onChainNewTallyCommitment = await tallyContract.tallyCommitment(); expect(tallyGeneratedInputs!.newTallyCommitment).to.eq(onChainNewTallyCommitment.toString()); await expect( - tallyContract.tallyVotes(tallyGeneratedInputs!.newTallyCommitment, [0, 0, 0, 0, 0, 0, 0, 0]), + tallyContract.tallyVotes( + { ...tallyInputs, newTallyCommitment: tallyGeneratedInputs!.newTallyCommitment }, + [0, 0, 0, 0, 0, 0, 0, 0], + ), ).to.be.revertedWithCustomError(tallyContract, "AllBallotsTallied"); }); }); @@ -436,6 +501,10 @@ describe("TallyVotes", () => { pollContract = PollFactory.connect(pollContractAddress, signer); mpContract = MessageProcessorFactory.connect(event.args.pollAddr.messageProcessor, signer); tallyContract = TallyFactory.connect(event.args.pollAddr.tally, signer); + messageAqContract = AccQueueFactory.connect( + await pollContract.extContracts().then((value) => value.messageAq), + signer, + ); // deploy local poll const p = maciState.deployPoll( @@ -488,14 +557,30 @@ describe("TallyVotes", () => { await pollContract.mergeMessageAq(); const processMessagesInputs = poll.processMessages(pollId); - await mpContract.processMessages(processMessagesInputs.newSbCommitment, [0, 0, 0, 0, 0, 0, 0, 0]); + + const inputs = await generateProcessMessagesPublicInputs({ + pollContract, + messageAqContract, + newSbCommitment: BigInt(processMessagesInputs.newSbCommitment), + currentSbCommitment: BigInt(generatedInputs.currentSbCommitment), + currentMessageBatchIndex: poll.currentMessageBatchIndex ?? 0, + }); + + await mpContract.processMessages(inputs, [0, 0, 0, 0, 0, 0, 0, 0]); }); it("should tally votes correctly", async () => { // tally first batch let tallyGeneratedInputs = poll.tallyVotes(); - await tallyContract.tallyVotes(tallyGeneratedInputs.newTallyCommitment, [0, 0, 0, 0, 0, 0, 0, 0]); + const inputs = await generateTallyVotesPublicInputs({ + pollContract, + tallyContract, + batchStartIndex: poll.currentMessageBatchIndex ?? 0, + newTallyCommitment: BigInt(tallyGeneratedInputs.newTallyCommitment), + }); + + await tallyContract.tallyVotes(inputs, [0, 0, 0, 0, 0, 0, 0, 0]); // check commitment const onChainNewTallyCommitment = await tallyContract.tallyCommitment(); @@ -504,18 +589,28 @@ describe("TallyVotes", () => { // tally second batch tallyGeneratedInputs = poll.tallyVotes(); - await tallyContract.tallyVotes(tallyGeneratedInputs.newTallyCommitment, [0, 0, 0, 0, 0, 0, 0, 0]); + await tallyContract.tallyVotes(inputs, [0, 0, 0, 0, 0, 0, 0, 0]); // tally everything else while (poll.hasUntalliedBallots()) { tallyGeneratedInputs = poll.tallyVotes(); // eslint-disable-next-line no-await-in-loop - await tallyContract.tallyVotes(tallyGeneratedInputs.newTallyCommitment, [0, 0, 0, 0, 0, 0, 0, 0]); + const tallyInputs = await generateTallyVotesPublicInputs({ + pollContract, + tallyContract, + batchStartIndex: poll.currentMessageBatchIndex ?? 0, + newTallyCommitment: BigInt(tallyGeneratedInputs.newTallyCommitment), + }); + // eslint-disable-next-line no-await-in-loop + await tallyContract.tallyVotes(tallyInputs, [0, 0, 0, 0, 0, 0, 0, 0]); } // check that it fails to tally again await expect( - tallyContract.tallyVotes(tallyGeneratedInputs.newTallyCommitment, [0, 0, 0, 0, 0, 0, 0, 0]), + tallyContract.tallyVotes( + { ...inputs, newTallyCommitment: BigInt(tallyGeneratedInputs.newTallyCommitment) }, + [0, 0, 0, 0, 0, 0, 0, 0], + ), ).to.be.revertedWithCustomError(tallyContract, "AllBallotsTallied"); }); }); diff --git a/contracts/tests/TallyNonQv.test.ts b/contracts/tests/TallyNonQv.test.ts index 8a9e6d2a4..6f3a41d35 100644 --- a/contracts/tests/TallyNonQv.test.ts +++ b/contracts/tests/TallyNonQv.test.ts @@ -2,16 +2,11 @@ import { expect } from "chai"; import { Signer } from "ethers"; import { EthereumProvider } from "hardhat/types"; -import { - MaciState, - Poll, - packTallyVotesSmallVals, - IProcessMessagesCircuitInputs, - ITallyCircuitInputs, -} from "maci-core"; +import { MaciState, Poll, IProcessMessagesCircuitInputs, ITallyCircuitInputs } from "maci-core"; import { NOTHING_UP_MY_SLEEVE } from "maci-crypto"; import { Keypair, Message, PubKey } from "maci-domainobjs"; +import { generateProcessMessagesPublicInputs, generateTallyVotesPublicInputs } from "../ts/circuitInputs"; import { EMode } from "../ts/constants"; import { IVerifyingKeyStruct } from "../ts/types"; import { getDefaultSigner } from "../ts/utils"; @@ -25,6 +20,8 @@ import { MessageProcessor__factory as MessageProcessorFactory, Poll__factory as PollFactory, Tally__factory as TallyFactory, + AccQueue, + AccQueue__factory as AccQueueFactory, } from "../typechain-types"; import { @@ -32,7 +29,6 @@ import { duration, maxValues, messageBatchSize, - tallyBatchSize, testProcessVk, testTallyVk, treeDepths, @@ -43,13 +39,13 @@ describe("TallyVotesNonQv", () => { let signer: Signer; let maciContract: MACI; let pollContract: PollContract; + let messageAqContract: AccQueue; let tallyContract: Tally; let mpContract: MessageProcessor; let verifierContract: Verifier; let vkRegistryContract: VkRegistry; const coordinator = new Keypair(); - let users: Keypair[]; let maciState: MaciState; let pollId: bigint; @@ -60,8 +56,6 @@ describe("TallyVotesNonQv", () => { before(async () => { maciState = new MaciState(STATE_TREE_DEPTH); - users = [new Keypair(), new Keypair()]; - signer = await getDefaultSigner(); const r = await deployTestContracts(100, STATE_TREE_DEPTH, signer, true); @@ -109,6 +103,10 @@ describe("TallyVotesNonQv", () => { pollContract = PollFactory.connect(pollContractAddress, signer); mpContract = MessageProcessorFactory.connect(event.args.pollAddr.messageProcessor, signer); tallyContract = TallyFactory.connect(event.args.pollAddr.tally, signer); + messageAqContract = AccQueueFactory.connect( + await pollContract.extContracts().then((value) => value.messageAq), + signer, + ); // deploy local poll const p = maciState.deployPoll(BigInt(deployTime + duration), maxValues, treeDepths, messageBatchSize, coordinator); @@ -150,18 +148,19 @@ describe("TallyVotesNonQv", () => { }); it("should not be possible to tally votes before the poll has ended", async () => { - await expect(tallyContract.tallyVotes(0, [0, 0, 0, 0, 0, 0, 0, 0])).to.be.revertedWithCustomError( + const inputs = await generateTallyVotesPublicInputs({ + pollContract, + tallyContract, + batchStartIndex: 0, + newTallyCommitment: 0n, + }); + + await expect(tallyContract.tallyVotes(inputs, [0, 0, 0, 0, 0, 0, 0, 0])).to.be.revertedWithCustomError( tallyContract, "VotingPeriodNotPassed", ); }); - it("genTallyVotesPackedVals() should generate the correct value", async () => { - const onChainPackedVals = BigInt(await tallyContract.genTallyVotesPackedVals(users.length, 0, tallyBatchSize)); - const packedVals = packTallyVotesSmallVals(0, tallyBatchSize, users.length); - expect(onChainPackedVals.toString()).to.eq(packedVals.toString()); - }); - it("updateSbCommitment() should revert when the messages have not been processed yet", async () => { // go forward in time await timeTravel(signer.provider! as unknown as EthereumProvider, duration + 1); @@ -173,7 +172,14 @@ describe("TallyVotesNonQv", () => { }); it("tallyVotes() should fail as the messages have not been processed yet", async () => { - await expect(tallyContract.tallyVotes(0, [0, 0, 0, 0, 0, 0, 0, 0])).to.be.revertedWithCustomError( + const inputs = await generateTallyVotesPublicInputs({ + pollContract, + tallyContract, + batchStartIndex: 0, + newTallyCommitment: 0n, + }); + + await expect(tallyContract.tallyVotes(inputs, [0, 0, 0, 0, 0, 0, 0, 0])).to.be.revertedWithCustomError( tallyContract, "ProcessingNotComplete", ); @@ -196,9 +202,25 @@ describe("TallyVotesNonQv", () => { it("tallyVotes() should update the tally commitment", async () => { // do the processing on the message processor contract - await mpContract.processMessages(generatedInputs.newSbCommitment, [0, 0, 0, 0, 0, 0, 0, 0]); - - await tallyContract.tallyVotes(tallyGeneratedInputs.newTallyCommitment, [0, 0, 0, 0, 0, 0, 0, 0]); + const [mpInputs, tallyInputs] = await Promise.all([ + generateProcessMessagesPublicInputs({ + pollContract, + messageAqContract, + newSbCommitment: BigInt(generatedInputs.newSbCommitment), + currentSbCommitment: BigInt(generatedInputs.currentSbCommitment), + currentMessageBatchIndex: poll.currentMessageBatchIndex ?? 0, + }), + generateTallyVotesPublicInputs({ + pollContract, + tallyContract, + batchStartIndex: 0, + newTallyCommitment: BigInt(tallyGeneratedInputs.newTallyCommitment), + }), + ]); + + await mpContract.processMessages(mpInputs, [0, 0, 0, 0, 0, 0, 0, 0]); + + await tallyContract.tallyVotes(tallyInputs, [0, 0, 0, 0, 0, 0, 0, 0]); const onChainNewTallyCommitment = await tallyContract.tallyCommitment(); expect(tallyGeneratedInputs.newTallyCommitment).to.eq(onChainNewTallyCommitment.toString()); @@ -217,9 +239,17 @@ describe("TallyVotesNonQv", () => { }); it("tallyVotes() should revert when votes have already been tallied", async () => { - await expect( - tallyContract.tallyVotes(tallyGeneratedInputs.newTallyCommitment, [0, 0, 0, 0, 0, 0, 0, 0]), - ).to.be.revertedWithCustomError(tallyContract, "AllBallotsTallied"); + const inputs = await generateTallyVotesPublicInputs({ + pollContract, + tallyContract, + batchStartIndex: 0, + newTallyCommitment: BigInt(tallyGeneratedInputs.newTallyCommitment), + }); + + await expect(tallyContract.tallyVotes(inputs, [0, 0, 0, 0, 0, 0, 0, 0])).to.be.revertedWithCustomError( + tallyContract, + "AllBallotsTallied", + ); }); }); }); diff --git a/contracts/tests/Verifier.test.ts b/contracts/tests/Verifier.test.ts index d48d737c0..78ab52e6a 100644 --- a/contracts/tests/Verifier.test.ts +++ b/contracts/tests/Verifier.test.ts @@ -83,20 +83,18 @@ describe("DomainObjs", () => { }); it("should correctly verify a proof", async () => { - const isValid = await verifierContract.verify( - proof, - vk.asContractParam() as IVerifyingKeyStruct, - publicInputs[0], - { gasLimit: 1000000 }, - ); + const isValid = await verifierContract.verify(proof, vk.asContractParam() as IVerifyingKeyStruct, publicInputs, { + gasLimit: 1000000, + }); expect(isValid).to.eq(true); }); + it("should return false for a proof that is not valid", async () => { const isValid = await verifierContract.verify( proof, vk.asContractParam() as IVerifyingKeyStruct, - BigInt(publicInputs[0]) + BigInt(1), + publicInputs.concat(BigInt(1)), { gasLimit: 1000000 }, ); diff --git a/contracts/ts/circuitInputs.ts b/contracts/ts/circuitInputs.ts new file mode 100644 index 000000000..aa9c74224 --- /dev/null +++ b/contracts/ts/circuitInputs.ts @@ -0,0 +1,81 @@ +import { MESSAGE_TREE_ARITY, STATE_TREE_ARITY } from "maci-core"; + +import type { + IMpCircuitPublicArgs, + IMpCircuitPublicInputs, + ITallyCircuitPublicArgs, + ITallyCircuitPublicInputs, +} from "./types"; + +/** + * Generate circuit public inputs for MessageProcessor + * + * @param params - arguments for circuit public inputs generator + * @returns circuit public inputs + */ +export async function generateProcessMessagesPublicInputs({ + pollContract, + messageAqContract, + currentMessageBatchIndex, + currentSbCommitment, + newSbCommitment, +}: IMpCircuitPublicArgs): Promise { + const [dd, [numSignUps], coordinatorPubKeyHash, actualStateTreeDepth, treeDepths] = await Promise.all([ + pollContract.getDeployTimeAndDuration(), + pollContract.numSignUpsAndMessages(), + pollContract.coordinatorPubKeyHash(), + pollContract.actualStateTreeDepth(), + pollContract.treeDepths(), + ]); + const pollEndTimestamp = dd[0] + dd[1]; + + const maxVoteOptions = BigInt(MESSAGE_TREE_ARITY) ** treeDepths.voteOptionTreeDepth; + const messageBatchSize = BigInt(MESSAGE_TREE_ARITY) ** treeDepths.messageTreeSubDepth; + const batchEndIndex = BigInt(currentMessageBatchIndex) + messageBatchSize; + + const messageRoot = await messageAqContract.getMainRoot(treeDepths.messageTreeDepth); + + return { + maxVoteOptions, + numSignUps, + index: BigInt(currentMessageBatchIndex), + messageBatchSize, + batchEndIndex, + coordinatorPubKeyHash, + messageRoot, + currentSbCommitment, + newSbCommitment, + pollEndTimestamp, + actualStateTreeDepth, + }; +} + +/** + * Generate circuit public inputs for Tally + * + * @param params - arguments for circuit public inputs generator + * @returns circuit public inputs + */ +export async function generateTallyVotesPublicInputs({ + pollContract, + tallyContract, + batchStartIndex, + newTallyCommitment, +}: ITallyCircuitPublicArgs): Promise { + const [[numSignUps], treeDepths, sbCommitment, currentTallyCommitment] = await Promise.all([ + pollContract.numSignUpsAndMessages(), + pollContract.treeDepths(), + tallyContract.sbCommitment(), + tallyContract.tallyCommitment(), + ]); + const batchSize = STATE_TREE_ARITY ** Number(treeDepths.intStateTreeDepth); + + return { + index: batchStartIndex, + batchSize, + numSignUps, + sbCommitment, + currentTallyCommitment, + newTallyCommitment, + }; +} diff --git a/contracts/ts/index.ts b/contracts/ts/index.ts index aa7c15bbb..4236b52c7 100644 --- a/contracts/ts/index.ts +++ b/contracts/ts/index.ts @@ -30,6 +30,16 @@ export { type TallyData, } from "../tasks/helpers/types"; export { linkPoseidonLibraries } from "../tasks/helpers/abi"; +export { generateProcessMessagesPublicInputs, generateTallyVotesPublicInputs } from "./circuitInputs"; -export type { IVerifyingKeyStruct, SnarkProof, Groth16Proof, Proof } from "./types"; +export type { + IVerifyingKeyStruct, + SnarkProof, + Groth16Proof, + Proof, + IMpCircuitPublicArgs, + IMpCircuitPublicInputs, + ITallyCircuitPublicArgs, + ITallyCircuitPublicInputs, +} from "./types"; export * from "../typechain-types"; diff --git a/contracts/ts/types.ts b/contracts/ts/types.ts index 3aa2c72f4..8faa3b27e 100644 --- a/contracts/ts/types.ts +++ b/contracts/ts/types.ts @@ -3,11 +3,14 @@ import type { FreeForAllGatekeeper, MACI, MockVerifier, + Poll, PollFactory, PoseidonT3, PoseidonT4, PoseidonT5, PoseidonT6, + Tally, + AccQueue, VkRegistry, } from "../typechain-types"; import type { BigNumberish, Signer } from "ethers"; @@ -168,3 +171,149 @@ export interface IDeployedMaci { poseidonT6: string; }; } + +/** + * Interface that represents arguments for circuit public inputs generator for MessageProcessor + */ +export interface IMpCircuitPublicArgs { + /** + * Poll contract + */ + pollContract: Poll; + + /** + * MessageAq contract + */ + messageAqContract: AccQueue; + + /** + * Current sbCommitment (state and ballot root) + */ + currentSbCommitment: bigint; + + /** + * New sbCommitment after we update this message batch + */ + newSbCommitment: bigint; + + /** + * Batch index of current message batch + */ + currentMessageBatchIndex: number; +} + +/** + * Interface that represents public circuit inputs for MessageProcessor + */ +export interface IMpCircuitPublicInputs { + maxVoteOptions: bigint; + /** + * Number of users that signup + */ + numSignUps: bigint; + + /** + * Batch index of current message batch + */ + index: bigint; + + /** + * Message batch size + */ + messageBatchSize: bigint; + + /** + * Last batch index + */ + batchEndIndex: bigint; + + /** + * Coordinator public key hash + */ + coordinatorPubKeyHash: bigint; + + /** + * Root of the message tree + */ + messageRoot: bigint; + + /** + * Current sbCommitment (state and ballot root) + */ + currentSbCommitment: bigint; + + /** + * New sbCommitment after we update this message batch + */ + newSbCommitment: bigint; + + /** + * Poll end timestamp + */ + pollEndTimestamp: bigint; + + /** + * Dynamic depth of the state tree at the time of poll finalization (based on the number of leaves inserted) + */ + actualStateTreeDepth: bigint; +} + +/** + * Interface that represents arguments for circuit public inputs generator for Tally + */ +export interface ITallyCircuitPublicArgs { + /** + * Poll contract + */ + pollContract: Poll; + + /** + * Tally contract + */ + tallyContract: Tally; + + /** + * Start index of given batch + */ + batchStartIndex: number; + + /** + * New tally commitment to be updated + */ + newTallyCommitment: bigint; +} + +/** + * Interface that represents public circuit inputs for Tally + */ +export interface ITallyCircuitPublicInputs { + /** + * Start index of given batch + */ + index: number; + + /** + * Size of batch + */ + batchSize: number; + + /** + * Number of users that signup + */ + numSignUps: bigint; + + /** + * Current sbCommitment (state and ballot root) + */ + sbCommitment: bigint; + + /** + * Current tally commitment + */ + currentTallyCommitment: bigint; + + /** + * New tally commitment to be updated + */ + newTallyCommitment: bigint; +} diff --git a/core/ts/Poll.ts b/core/ts/Poll.ts index 330a24f44..d53878dc7 100644 --- a/core/ts/Poll.ts +++ b/core/ts/Poll.ts @@ -6,7 +6,6 @@ import { hashLeftRight, hash3, hash5, - sha256Hash, stringifyBigInts, genTreeCommitment, hash2, @@ -43,7 +42,6 @@ import type { PathElements } from "maci-crypto"; import { STATE_TREE_ARITY, MESSAGE_TREE_ARITY } from "./utils/constants"; import { ProcessMessageErrors, ProcessMessageError } from "./utils/errors"; -import { packTallyVotesSmallVals } from "./utils/utils"; /** * A representation of the Poll contract. @@ -649,17 +647,8 @@ export class Poll implements IPoll { // here is important that a user validates it matches the one in the // smart contract const coordPubKeyHash = this.coordinatorKeypair.pubKey.hash(); - // create the input hash which is the only public input to the - // process messages circuit - circuitInputs.inputHash = sha256Hash([ - circuitInputs.packedVals as bigint, - coordPubKeyHash, - circuitInputs.msgRoot as bigint, - circuitInputs.currentSbCommitment as bigint, - circuitInputs.newSbCommitment, - this.pollEndTimestamp, - BigInt(this.actualStateTreeDepth), - ]); + circuitInputs.coordPubKeyHash = coordPubKeyHash; + circuitInputs.pollEndTimestamp = this.pollEndTimestamp; // If this is the last batch, release the lock if (this.numBatchesProcessed * batchSize >= this.messages.length) { @@ -751,18 +740,12 @@ export class Poll implements IPoll { this.sbSalts[this.currentMessageBatchIndex!], ]); - // Generate a SHA256 hash of inputs which the contract provides - /* eslint-disable no-bitwise */ - const packedVals = - BigInt(this.maxValues.maxVoteOptions) + - (BigInt(this.numSignups) << 50n) + - (BigInt(index) << 100n) + - (BigInt(batchEndIndex) << 150n); - /* eslint-enable no-bitwise */ - return stringifyBigInts({ pollEndTimestamp: this.pollEndTimestamp, - packedVals, + maxVoteOptions: BigInt(this.maxValues.maxVoteOptions), + numSignUps: BigInt(this.numSignups), + batchEndIndex: BigInt(batchEndIndex), + index: BigInt(index), msgRoot, msgs, msgSubrootPathElements: messageSubrootPath.pathElements, @@ -944,9 +927,6 @@ export class Poll implements IPoll { const sbSalt = this.sbSalts[this.currentMessageBatchIndex!]; const sbCommitment = hash3([stateRoot, ballotRoot, sbSalt]); - const packedVals = packTallyVotesSmallVals(batchStartIndex, batchSize, Number(this.numSignups)); - const inputHash = sha256Hash([packedVals, sbCommitment, currentTallyCommitment, newTallyCommitment]); - const ballotSubrootProof = this.ballotTree?.genSubrootProof(batchStartIndex, batchStartIndex + batchSize); const votes = ballots.map((x) => x.votes); @@ -955,11 +935,12 @@ export class Poll implements IPoll { stateRoot, ballotRoot, sbSalt, - packedVals, // contains numSignUps and batchStartIndex + index: BigInt(batchStartIndex), + batchSize: BigInt(batchSize), + numSignUps: BigInt(this.numSignups), sbCommitment, currentTallyCommitment, newTallyCommitment, - inputHash, ballots: ballots.map((x) => x.asCircuitInputs()), ballotPathElements: ballotSubrootProof!.pathElements, votes, @@ -1084,9 +1065,6 @@ export class Poll implements IPoll { const sbSalt = this.sbSalts[this.currentMessageBatchIndex!]; const sbCommitment = hash3([stateRoot, ballotRoot, sbSalt]); - const packedVals = packTallyVotesSmallVals(batchStartIndex, batchSize, Number(this.numSignups)); - const inputHash = sha256Hash([packedVals, sbCommitment, currentTallyCommitment, newTallyCommitment]); - const ballotSubrootProof = this.ballotTree?.genSubrootProof(batchStartIndex, batchStartIndex + batchSize); const votes = ballots.map((x) => x.votes); @@ -1095,11 +1073,12 @@ export class Poll implements IPoll { stateRoot, ballotRoot, sbSalt, - packedVals, // contains numSignUps and batchStartIndex + index: BigInt(batchStartIndex), + batchSize: BigInt(batchSize), + numSignUps: BigInt(this.numSignups), sbCommitment, currentTallyCommitment, newTallyCommitment, - inputHash, ballots: ballots.map((x) => x.asCircuitInputs()), ballotPathElements: ballotSubrootProof!.pathElements, votes, diff --git a/core/ts/__tests__/e2e.test.ts b/core/ts/__tests__/e2e.test.ts index 45d526dc9..0c354409f 100644 --- a/core/ts/__tests__/e2e.test.ts +++ b/core/ts/__tests__/e2e.test.ts @@ -5,7 +5,6 @@ import { PCommand, Keypair, StateLeaf, blankStateLeafHash } from "maci-domainobj import { MaciState } from "../MaciState"; import { Poll } from "../Poll"; import { STATE_TREE_DEPTH, STATE_TREE_ARITY, MESSAGE_TREE_ARITY } from "../utils/constants"; -import { packProcessMessageSmallVals, unpackProcessMessageSmallVals } from "../utils/utils"; import { coordinatorKeypair, @@ -424,20 +423,6 @@ describe("MaciState/Poll e2e", function test() { expect(accumulatorQueue.getRoot(treeDepths.messageTreeDepth)?.toString()).to.eq(poll.messageTree.root.toString()); }); - it("packProcessMessageSmallVals and unpackProcessMessageSmallVals", () => { - const maxVoteOptions = 1n; - const numUsers = 2n; - const batchStartIndex = 5; - const batchEndIndex = 10; - const packedVals = packProcessMessageSmallVals(maxVoteOptions, numUsers, batchStartIndex, batchEndIndex); - - const unpacked = unpackProcessMessageSmallVals(packedVals); - expect(unpacked.maxVoteOptions.toString()).to.eq(maxVoteOptions.toString()); - expect(unpacked.numUsers.toString()).to.eq(numUsers.toString()); - expect(unpacked.batchStartIndex.toString()).to.eq(batchStartIndex.toString()); - expect(unpacked.batchEndIndex.toString()).to.eq(batchEndIndex.toString()); - }); - it("Process a batch of messages (though only 1 message is in the batch)", () => { poll.processMessages(pollId); diff --git a/core/ts/__tests__/utils.test.ts b/core/ts/__tests__/utils.test.ts index 964aa897e..34aef4647 100644 --- a/core/ts/__tests__/utils.test.ts +++ b/core/ts/__tests__/utils.test.ts @@ -1,13 +1,6 @@ import { expect } from "chai"; -import { - genProcessVkSig, - genTallyVkSig, - packProcessMessageSmallVals, - unpackProcessMessageSmallVals, - packTallyVotesSmallVals, - unpackTallyVotesSmallVals, -} from "../utils/utils"; +import { genProcessVkSig, genTallyVkSig } from "../utils/utils"; describe("Utils", () => { it("genProcessVkSig should work", () => { @@ -19,32 +12,4 @@ describe("Utils", () => { const result = genTallyVkSig(1, 2, 3); expect(result).to.equal(340282366920938463500268095579187314691n); }); - - it("packProcessMessageSmallVals should work", () => { - const result = packProcessMessageSmallVals(1n, 2n, 3, 4); - expect(result).to.equal(5708990770823843327184944562488436835454287873n); - }); - - it("unpackProcessMessageSmallVals should work", () => { - const result = unpackProcessMessageSmallVals(5708990770823843327184944562488436835454287873n); - expect(result).to.deep.equal({ - maxVoteOptions: 1n, - numUsers: 2n, - batchStartIndex: 3n, - batchEndIndex: 4n, - }); - }); - - it("packTallyVotesSmallVals should work", () => { - const result = packTallyVotesSmallVals(1, 2, 3); - expect(result).to.equal(3377699720527872n); - }); - - it("unpackTallyVotesSmallVals should work", () => { - const result = unpackTallyVotesSmallVals(3377699720527872n); - expect(result).to.deep.equal({ - numSignUps: 3n, - batchStartIndex: 0n, - }); - }); }); diff --git a/core/ts/index.ts b/core/ts/index.ts index 54fbe4716..521eb92f7 100644 --- a/core/ts/index.ts +++ b/core/ts/index.ts @@ -2,14 +2,7 @@ export { MaciState } from "./MaciState"; export { Poll } from "./Poll"; -export { - genProcessVkSig, - genTallyVkSig, - packProcessMessageSmallVals, - unpackProcessMessageSmallVals, - packTallyVotesSmallVals, - unpackTallyVotesSmallVals, -} from "./utils/utils"; +export { genProcessVkSig, genTallyVkSig } from "./utils/utils"; export type { ITallyCircuitInputs, diff --git a/core/ts/utils/types.ts b/core/ts/utils/types.ts index 93647507f..9eaad573f 100644 --- a/core/ts/utils/types.ts +++ b/core/ts/utils/types.ts @@ -145,7 +145,10 @@ export interface IProcessMessagesOutput { export interface IProcessMessagesCircuitInputs { actualStateTreeDepth: string; pollEndTimestamp: string; - packedVals: string; + numSignUps: string; + batchEndIndex: string; + index: string; + maxVoteOptions: string; msgRoot: string; msgs: string[]; msgSubrootPathElements: string[][]; @@ -162,7 +165,6 @@ export interface IProcessMessagesCircuitInputs { currentBallotsPathElements: string[][]; currentVoteWeights: string[]; currentVoteWeightsPathElements: string[][]; - inputHash: string; newSbSalt: string; newSbCommitment: string; } @@ -175,10 +177,11 @@ export interface ITallyCircuitInputs { ballotRoot: string; sbSalt: string; sbCommitment: string; + index: bigint; + batchSize: bigint; currentTallyCommitment: string; newTallyCommitment: string; - packedVals: string; - inputHash: string; + numSignUps: bigint; ballots: string[]; ballotPathElements: PathElements; votes: string[][]; diff --git a/core/ts/utils/utils.ts b/core/ts/utils/utils.ts index 14b1622eb..ef77e40b3 100644 --- a/core/ts/utils/utils.ts +++ b/core/ts/utils/utils.ts @@ -1,5 +1,4 @@ /* eslint-disable no-bitwise */ -import assert from "assert"; /** * This function generates the signature of a ProcessMessage Verifying Key(VK). @@ -37,89 +36,3 @@ export const genTallyVkSig = ( _intStateTreeDepth: number, _voteOptionTreeDepth: number, ): bigint => (BigInt(_stateTreeDepth) << 128n) + (BigInt(_intStateTreeDepth) << 64n) + BigInt(_voteOptionTreeDepth); - -/** - * This function packs it's parameters into a single bigint. - * @param maxVoteOptions - The maximum number of vote options. - * @param numUsers - The number of users. - * @param batchStartIndex - The start index of the batch. - * @param batchEndIndex - The end index of the batch. - * @returns Returns a single bigint that contains the packed values. - */ -export const packProcessMessageSmallVals = ( - maxVoteOptions: bigint, - numUsers: bigint, - batchStartIndex: number, - batchEndIndex: number, -): bigint => { - const packedVals = - // Note: the << operator has lower precedence than + - BigInt(`${maxVoteOptions}`) + - (BigInt(`${numUsers}`) << 50n) + - (BigInt(batchStartIndex) << 100n) + - (BigInt(batchEndIndex) << 150n); - - return packedVals; -}; - -/** - * This function unpacks partial values for the ProcessMessages circuit from a single bigint. - * @param packedVals - The single bigint that contains the packed values. - * @returns Returns an object that contains the unpacked values. - */ -export const unpackProcessMessageSmallVals = ( - packedVals: bigint, -): { - maxVoteOptions: bigint; - numUsers: bigint; - batchStartIndex: bigint; - batchEndIndex: bigint; -} => { - let asBin = packedVals.toString(2); - assert(asBin.length <= 200); - while (asBin.length < 200) { - asBin = `0${asBin}`; - } - const maxVoteOptions = BigInt(`0b${asBin.slice(150, 200)}`); - const numUsers = BigInt(`0b${asBin.slice(100, 150)}`); - const batchStartIndex = BigInt(`0b${asBin.slice(50, 100)}`); - const batchEndIndex = BigInt(`0b${asBin.slice(0, 50)}`); - - return { - maxVoteOptions, - numUsers, - batchStartIndex, - batchEndIndex, - }; -}; - -/** - * This function packs it's parameters into a single bigint. - * @param batchStartIndex - The start index of the batch. - * @param batchSize - The size of the batch. - * @param numSignUps - The number of signups. - * @returns Returns a single bigint that contains the packed values. - */ -export const packTallyVotesSmallVals = (batchStartIndex: number, batchSize: number, numSignUps: number): bigint => { - // Note: the << operator has lower precedence than + - const packedVals = BigInt(batchStartIndex) / BigInt(batchSize) + (BigInt(numSignUps) << 50n); - - return packedVals; -}; - -/** - * This function unpacks partial values for the TallyVotes circuit from a single bigint. - * @param packedVals - The single bigint that contains the packed values. - * @returns Returns an object that contains the unpacked values. - */ -export const unpackTallyVotesSmallVals = (packedVals: bigint): { numSignUps: bigint; batchStartIndex: bigint } => { - let asBin = packedVals.toString(2); - assert(asBin.length <= 100); - while (asBin.length < 100) { - asBin = `0${asBin}`; - } - const numSignUps = BigInt(`0b${asBin.slice(0, 50)}`); - const batchStartIndex = BigInt(`0b${asBin.slice(50, 100)}`); - - return { numSignUps, batchStartIndex }; -};