From 813bfc7b25e403902e6779ef4304ee48bb5640e9 Mon Sep 17 00:00:00 2001 From: Anton <14254374+0xmad@users.noreply.github.com> Date: Tue, 28 Jan 2025 09:27:28 -0600 Subject: [PATCH] feat(relayer): add auth guard for message publishing - [x] Add authorization guard - [x] Add more e2e tests with real scenarios - [x] Update cli join poll - [x] Update public circuit inputs for poll contract - [x] Update poll joined offchain circuit inputs - [x] Use actual state tree depth for poll joined circuit - [x] Save poll state tree roots for public inputs onchain --- .github/workflows/relayer-build.yml | 52 ++++++ apps/relayer/.env.example | 2 + apps/relayer/.gitignore | 3 + apps/relayer/package.json | 2 + apps/relayer/tests/constants.ts | 38 +++++ apps/relayer/tests/messages.test.ts | 159 +++++++++++++++++- .../message/__tests__/message.guard.test.ts | 129 ++++++++++++++ apps/relayer/ts/message/__tests__/utils.ts | 2 +- apps/relayer/ts/message/message.controller.ts | 9 +- .../ts/message/{dto.ts => message.dto.ts} | 30 +++- apps/relayer/ts/message/message.guard.ts | 109 ++++++++++++ apps/relayer/ts/message/message.repository.ts | 2 +- apps/relayer/ts/message/message.service.ts | 2 +- .../__tests__/messageBatch.service.test.ts | 2 +- .../ts/messageBatch/__tests__/utils.ts | 2 +- .../{dto.ts => messageBatch.dto.ts} | 0 .../messageBatch/messageBatch.repository.ts | 2 +- .../ts/messageBatch/messageBatch.service.ts | 2 +- package.json | 2 + packages/circuits/circom/anon/poll.circom | 5 +- .../circuits/ts/__tests__/PollJoined.test.ts | 4 +- packages/circuits/ts/types.ts | 1 + packages/cli/ts/commands/index.ts | 2 +- packages/cli/ts/commands/joinPoll.ts | 11 +- packages/cli/ts/index.ts | 1 + packages/cli/ts/utils/interfaces.ts | 15 ++ packages/contracts/contracts/Poll.sol | 32 ++-- .../contracts/contracts/trees/LazyIMT.sol | 4 +- packages/core/ts/Poll.ts | 22 ++- packages/core/ts/utils/types.ts | 2 + pnpm-lock.yaml | 6 + 31 files changed, 607 insertions(+), 47 deletions(-) create mode 100644 apps/relayer/tests/constants.ts create mode 100644 apps/relayer/ts/message/__tests__/message.guard.test.ts rename apps/relayer/ts/message/{dto.ts => message.dto.ts} (75%) create mode 100644 apps/relayer/ts/message/message.guard.ts rename apps/relayer/ts/messageBatch/{dto.ts => messageBatch.dto.ts} (100%) diff --git a/.github/workflows/relayer-build.yml b/.github/workflows/relayer-build.yml index 9997ea183..8eb950f9d 100644 --- a/.github/workflows/relayer-build.yml +++ b/.github/workflows/relayer-build.yml @@ -10,6 +10,7 @@ env: TTL: ${{ vars.RELAYER_TTL }} LIMIT: ${{ vars.RELAYER_LIMIT }} ALLOWED_ORIGINS: ${{ vars.ALLOWED_ORIGINS }} + MAX_MESSAGES: ${{ vars.RELAYER_MAX_MESSAGES || 20 }} MONGO_DB_URI: ${{ secrets.RELAYER_MONGO_DB_URI }} MONGODB_USER: ${{ secrets.MONGODB_USER }} MONGODB_PASSWORD: ${{ secrets.MONGODB_PASSWORD }} @@ -36,6 +37,34 @@ jobs: node-version: 20 cache: "pnpm" + - name: Get changed files + id: get-changed-files + uses: jitterbit/get-changed-files@v1 + with: + format: "csv" + + - name: Check for changes in 'circuit' folder + id: check_changes + run: | + CHANGED_FILES=${{ steps.get-changed-files.outputs.all }} + if echo "$CHANGED_FILES" | grep -q "\.circom"; then + echo "CHANGED=true" >> $GITHUB_ENV + echo "Circuits have changes." + else + echo "CHANGED=false" >> $GITHUB_ENV + echo "No changes on circuits." + fi + + - name: Install dependencies + run: | + sudo apt-get update + sudo apt-get install --yes \ + build-essential \ + libgmp-dev \ + libsodium-dev \ + nasm \ + nlohmann-json3-dev + - name: Install run: | pnpm install --frozen-lockfile --prefer-offline @@ -44,6 +73,29 @@ jobs: run: | pnpm run build + - name: Download rapidsnark (1c137) + run: | + mkdir -p ~/rapidsnark/build + wget -qO ~/rapidsnark/build/prover https://maci-devops-zkeys.s3.ap-northeast-2.amazonaws.com/rapidsnark-linux-amd64-1c137 + chmod +x ~/rapidsnark/build/prover + + - name: Download circom Binary v2.1.6 + run: | + wget -qO ${{ github.workspace }}/circom https://github.com/iden3/circom/releases/download/v2.1.6/circom-linux-amd64 + chmod +x ${{ github.workspace }}/circom + sudo mv ${{ github.workspace }}/circom /bin/circom + + - name: Compile Circuits And Generate zkeys + if: ${{ env.CHANGED == 'true' }} + run: | + pnpm build:circuits-c -- --outPath ../../apps/relayer/zkeys + pnpm setup:zkeys -- --outPath ../../apps/relayer/zkeys + + - name: Download zkeys + if: ${{ env.CHANGED == 'false' }} + run: | + pnpm download-zkeys:test:relayer + - name: Run hardhat run: | pnpm run hardhat & diff --git a/apps/relayer/.env.example b/apps/relayer/.env.example index b0628a7ce..3905512b0 100644 --- a/apps/relayer/.env.example +++ b/apps/relayer/.env.example @@ -19,3 +19,5 @@ PORT= # Mnemonic phrase MNEMONIC="" + +MAX_MESSAGES=20 diff --git a/apps/relayer/.gitignore b/apps/relayer/.gitignore index 801ee0564..122d244bb 100644 --- a/apps/relayer/.gitignore +++ b/apps/relayer/.gitignore @@ -1,3 +1,6 @@ build/ coverage/ .env +zkeys +deploy-config.json +deployed-contracts.json diff --git a/apps/relayer/package.json b/apps/relayer/package.json index a5da2735b..b8b1dde01 100644 --- a/apps/relayer/package.json +++ b/apps/relayer/package.json @@ -45,6 +45,8 @@ "helia": "^5.2.0", "helmet": "^8.0.0", "lodash": "^4.17.21", + "maci-circuits": "workspace:^2.5.0", + "maci-cli": "workspace:^2.5.0", "maci-contracts": "workspace:^2.5.0", "maci-domainobjs": "workspace:^2.5.0", "mongoose": "^8.9.5", diff --git a/apps/relayer/tests/constants.ts b/apps/relayer/tests/constants.ts new file mode 100644 index 000000000..b6f5386a0 --- /dev/null +++ b/apps/relayer/tests/constants.ts @@ -0,0 +1,38 @@ +import { homedir } from "os"; +import path from "path"; +import url from "url"; + +export const STATE_TREE_DEPTH = 10; +export const INT_STATE_TREE_DEPTH = 1; +export const VOTE_OPTION_TREE_DEPTH = 2; +export const MESSAGE_BATCH_SIZE = 20; + +export const dirname = url.fileURLToPath(new URL(".", import.meta.url)); + +export const pollJoiningZkey = path.resolve(dirname, "../zkeys/PollJoining_10_test/PollJoining_10_test.0.zkey"); +export const pollJoinedZkey = path.resolve(dirname, "../zkeys/PollJoined_10_test/PollJoined_10_test.0.zkey"); +export const pollWasm = path.resolve( + dirname, + "../zkeys/PollJoining_10_test/PollJoining_10_test_js/PollJoining_10_test.wasm", +); +export const pollWitgen = path.resolve( + dirname, + "../zkeys/PollJoining_10_test/PollJoining_10_test_cpp/PollJoining_10_test", +); +export const pollJoinedWasm = path.resolve( + dirname, + "../zkeys/PollJoined_10_test/PollJoined_10_test_js/PollJoined_10_test.wasm", +); +export const pollJoinedWitgen = path.resolve( + dirname, + "../zkeys/PollJoined_10_test/PollJoined_10_test_cpp/PollJoined_10_test", +); +export const rapidsnark = `${homedir()}/rapidsnark/build/prover`; +export const processMessagesZkeyPathNonQv = path.resolve( + dirname, + "../zkeys/ProcessMessagesNonQv_10-20-2_test/ProcessMessagesNonQv_10-20-2_test.0.zkey", +); +export const tallyVotesZkeyPathNonQv = path.resolve( + dirname, + "../zkeys/TallyVotesNonQv_10-1-2_test/TallyVotesNonQv_10-1-2_test.0.zkey", +); diff --git a/apps/relayer/tests/messages.test.ts b/apps/relayer/tests/messages.test.ts index 401d3971a..6ea12737d 100644 --- a/apps/relayer/tests/messages.test.ts +++ b/apps/relayer/tests/messages.test.ts @@ -1,6 +1,10 @@ +import { jest } from "@jest/globals"; import { HttpStatus, ValidationPipe, type INestApplication } from "@nestjs/common"; import { Test } from "@nestjs/testing"; -import { ZeroAddress } from "ethers"; +import hardhat from "hardhat"; +import { genProof } from "maci-circuits"; +import { deploy, deployPoll, deployVkRegistryContract, joinPoll, setVerifyingKeys, signup } from "maci-cli"; +import { formatProofForVerifierContract, genMaciStateFromContract } from "maci-contracts"; import { Keypair } from "maci-domainobjs"; import request from "supertest"; @@ -8,10 +12,102 @@ import type { App } from "supertest/types"; import { AppModule } from "../ts/app.module"; +import { + INT_STATE_TREE_DEPTH, + MESSAGE_BATCH_SIZE, + STATE_TREE_DEPTH, + VOTE_OPTION_TREE_DEPTH, + pollJoinedZkey, + pollJoiningZkey, + processMessagesZkeyPathNonQv, + tallyVotesZkeyPathNonQv, + pollWasm, + pollWitgen, + rapidsnark, + pollJoinedWitgen, + pollJoinedWasm, +} from "./constants"; + +jest.unmock("maci-contracts/typechain-types"); + describe("Integration messages", () => { let app: INestApplication; + let circuitInputs: Record; + let stateLeafIndex: number; + let maciContractAddress: string; + + const coordinatorKeypair = new Keypair(); + const user = new Keypair(); beforeAll(async () => { + const [signer] = await hardhat.ethers.getSigners(); + + const vkRegistry = await deployVkRegistryContract({ signer }); + await setVerifyingKeys({ + quiet: true, + vkRegistry, + stateTreeDepth: STATE_TREE_DEPTH, + intStateTreeDepth: INT_STATE_TREE_DEPTH, + voteOptionTreeDepth: VOTE_OPTION_TREE_DEPTH, + messageBatchSize: MESSAGE_BATCH_SIZE, + processMessagesZkeyPathNonQv, + tallyVotesZkeyPathNonQv, + pollJoiningZkeyPath: pollJoiningZkey, + pollJoinedZkeyPath: pollJoinedZkey, + useQuadraticVoting: false, + signer, + }); + + const maciAddresses = await deploy({ stateTreeDepth: 10, signer }); + + maciContractAddress = maciAddresses.maciAddress; + + await deployPoll({ + pollDuration: 30, + intStateTreeDepth: INT_STATE_TREE_DEPTH, + messageBatchSize: MESSAGE_BATCH_SIZE, + voteOptionTreeDepth: VOTE_OPTION_TREE_DEPTH, + coordinatorPubkey: coordinatorKeypair.pubKey.serialize(), + useQuadraticVoting: false, + signer, + }); + + await signup({ maciAddress: maciAddresses.maciAddress, maciPubKey: user.pubKey.serialize(), signer }); + + const { pollStateIndex, timestamp, voiceCredits } = await joinPoll({ + maciAddress: maciAddresses.maciAddress, + pollId: 0n, + privateKey: user.privKey.serialize(), + stateIndex: 1n, + pollJoiningZkey, + pollWasm, + pollWitgen, + rapidsnark, + signer, + useWasm: true, + quiet: true, + }); + + const maciState = await genMaciStateFromContract( + signer.provider, + maciAddresses.maciAddress, + coordinatorKeypair, + 0n, + ); + + const poll = maciState.polls.get(0n); + + poll!.updatePoll(BigInt(maciState.pubKeys.length)); + + stateLeafIndex = Number(pollStateIndex); + + circuitInputs = poll!.joinedCircuitInputs({ + maciPrivKey: user.privKey, + stateLeafIndex: BigInt(pollStateIndex), + voiceCreditsBalance: BigInt(voiceCredits), + joinTimestamp: BigInt(timestamp), + }) as unknown as typeof circuitInputs; + const moduleFixture = await Test.createTestingModule({ imports: [AppModule], }).compile(); @@ -29,8 +125,9 @@ describe("Integration messages", () => { const keypair = new Keypair(); const defaultSaveMessagesArgs = { - maciContractAddress: ZeroAddress, + maciContractAddress, poll: 0, + stateLeafIndex, messages: [ { data: ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"], @@ -39,10 +136,40 @@ describe("Integration messages", () => { ], }; + test("should throw an error if there is no valid proof", async () => { + const result = await request(app.getHttpServer() as App) + .post("/v1/messages/publish") + .send({ + ...defaultSaveMessagesArgs, + maciContractAddress, + stateLeafIndex, + poll: 0, + proof: ["0", "0", "0", "0", "0", "0", "0", "0"], + }) + .expect(HttpStatus.FORBIDDEN); + + expect(result.body).toStrictEqual({ + error: "Forbidden", + statusCode: HttpStatus.FORBIDDEN, + message: "Forbidden resource", + }); + }); + test("should throw an error if dto is invalid", async () => { + const { proof } = await genProof({ + inputs: circuitInputs, + zkeyPath: pollJoinedZkey, + useWasm: true, + rapidsnarkExePath: rapidsnark, + witnessExePath: pollJoinedWitgen, + wasmPath: pollJoinedWasm, + }); + const result = await request(app.getHttpServer() as App) .post("/v1/messages/publish") .send({ + stateLeafIndex, + proof: formatProofForVerifierContract(proof), maciContractAddress: "invalid", poll: "-1", messages: [], @@ -62,10 +189,22 @@ describe("Integration messages", () => { }); test("should throw an error if messages dto is invalid", async () => { + const { proof } = await genProof({ + inputs: circuitInputs, + zkeyPath: pollJoinedZkey, + useWasm: true, + rapidsnarkExePath: rapidsnark, + witnessExePath: pollJoinedWitgen, + wasmPath: pollJoinedWasm, + }); + const result = await request(app.getHttpServer() as App) .post("/v1/messages/publish") .send({ ...defaultSaveMessagesArgs, + maciContractAddress, + stateLeafIndex, + proof: formatProofForVerifierContract(proof), messages: [{ data: [], publicKey: "invalid" }], }) .expect(HttpStatus.BAD_REQUEST); @@ -78,9 +217,23 @@ describe("Integration messages", () => { }); test("should publish user messages properly", async () => { + const { proof } = await genProof({ + inputs: circuitInputs, + zkeyPath: pollJoinedZkey, + useWasm: true, + rapidsnarkExePath: rapidsnark, + witnessExePath: pollJoinedWitgen, + wasmPath: pollJoinedWasm, + }); + const result = await request(app.getHttpServer() as App) .post("/v1/messages/publish") - .send(defaultSaveMessagesArgs) + .send({ + ...defaultSaveMessagesArgs, + maciContractAddress, + stateLeafIndex, + proof: formatProofForVerifierContract(proof), + }) .expect(HttpStatus.CREATED); expect(result.status).toBe(HttpStatus.CREATED); diff --git a/apps/relayer/ts/message/__tests__/message.guard.test.ts b/apps/relayer/ts/message/__tests__/message.guard.test.ts new file mode 100644 index 000000000..bdcf983f7 --- /dev/null +++ b/apps/relayer/ts/message/__tests__/message.guard.test.ts @@ -0,0 +1,129 @@ +import { jest } from "@jest/globals"; +import { HttpException, type ExecutionContext } from "@nestjs/common"; +import { Reflector } from "@nestjs/core"; +import dotenv from "dotenv"; +import { ZeroAddress } from "ethers"; +import { MACI__factory as MACIFactory, Poll__factory as PollFactory } from "maci-contracts/typechain-types"; +import { Keypair } from "maci-domainobjs"; + +import { MessageGuard, PUBLIC_METADATA_KEY, Public } from "../message.guard"; + +dotenv.config(); + +jest.mock("maci-contracts/typechain-types", (): unknown => ({ + MACI__factory: { + connect: jest.fn(), + }, + Poll__factory: { + connect: jest.fn(), + }, +})); + +describe("MessageGuard", () => { + const mockRequest = { + body: { + stateLeafIndex: 0, + proof: ["0", "0", "0", "0", "0", "0", "0", "0"], + poll: 0, + maciContractAddress: ZeroAddress, + messages: [ + { + data: ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"], + publicKey: new Keypair().pubKey.serialize(), + }, + ], + }, + }; + + const mockContext = { + getHandler: jest.fn(), + switchToHttp: jest.fn().mockReturnValue({ + getRequest: jest.fn(() => mockRequest), + }), + } as unknown as ExecutionContext; + + const reflector = { + get: jest.fn(), + } as Reflector & { get: jest.Mock }; + + const mockMaciContract = { + polls: jest.fn().mockImplementation(() => Promise.resolve({ poll: ZeroAddress })), + }; + + const mockPollContract = { + verifyJoinedPollProof: jest.fn().mockImplementation(() => Promise.resolve(true)), + }; + + beforeEach(() => { + reflector.get.mockReturnValue(false); + + mockMaciContract.polls = jest.fn().mockImplementation(() => Promise.resolve({ poll: ZeroAddress })); + mockPollContract.verifyJoinedPollProof = jest.fn().mockImplementation(() => Promise.resolve(true)); + + MACIFactory.connect = jest.fn().mockImplementation(() => mockMaciContract) as typeof MACIFactory.connect; + PollFactory.connect = jest.fn().mockImplementation(() => mockPollContract) as typeof PollFactory.connect; + }); + + afterEach(() => { + jest.clearAllMocks(); + }); + + test("should create public decorator properly", () => { + const decorator = Public(); + + expect(decorator.KEY).toBe(PUBLIC_METADATA_KEY); + }); + + test("should throw an error if there is empty body", async () => { + const ctx = { + getHandler: jest.fn(), + switchToHttp: jest.fn().mockReturnValue({ + getRequest: jest.fn(() => ({ + body: {}, + })), + }), + } as unknown as ExecutionContext; + + const guard = new MessageGuard(reflector); + + await expect(guard.canActivate(ctx)).rejects.toThrow(HttpException); + }); + + test("should return false if proof is invalid", async () => { + mockPollContract.verifyJoinedPollProof = jest.fn().mockImplementation(() => Promise.resolve(false)); + + const guard = new MessageGuard(reflector); + + const result = await guard.canActivate(mockContext); + + expect(result).toBe(false); + }); + + test("should return false if validation is failed", async () => { + const error = new Error("error"); + mockPollContract.verifyJoinedPollProof = jest.fn().mockImplementation(() => Promise.reject(error)); + + const guard = new MessageGuard(reflector); + + const result = await guard.canActivate(mockContext); + + expect(result).toBe(false); + }); + + test("should return true if proof is valid", async () => { + const guard = new MessageGuard(reflector); + + const result = await guard.canActivate(mockContext); + + expect(result).toBe(true); + }); + + test("should return true if can skip authorization", async () => { + reflector.get.mockReturnValue(true); + const guard = new MessageGuard(reflector); + + const result = await guard.canActivate(mockContext); + + expect(result).toBe(true); + }); +}); diff --git a/apps/relayer/ts/message/__tests__/utils.ts b/apps/relayer/ts/message/__tests__/utils.ts index 2aeac0278..f10351ff5 100644 --- a/apps/relayer/ts/message/__tests__/utils.ts +++ b/apps/relayer/ts/message/__tests__/utils.ts @@ -2,7 +2,7 @@ import { ZeroAddress } from "ethers"; import { Keypair } from "maci-domainobjs"; import { defaultMessageBatches } from "../../messageBatch/__tests__/utils"; -import { PublishMessagesDto } from "../dto"; +import { PublishMessagesDto } from "../message.dto"; const keypair = new Keypair(); diff --git a/apps/relayer/ts/message/message.controller.ts b/apps/relayer/ts/message/message.controller.ts index 5bf1a8a51..4b49a33a8 100644 --- a/apps/relayer/ts/message/message.controller.ts +++ b/apps/relayer/ts/message/message.controller.ts @@ -1,14 +1,15 @@ /* eslint-disable @typescript-eslint/no-shadow */ -import { Body, Controller, HttpException, HttpStatus, Logger, Post } from "@nestjs/common"; -import { ApiBearerAuth, ApiBody, ApiResponse, ApiTags } from "@nestjs/swagger"; +import { Body, Controller, HttpException, HttpStatus, Logger, Post, UseGuards } from "@nestjs/common"; +import { ApiBody, ApiResponse, ApiTags } from "@nestjs/swagger"; -import { PublishMessagesDto } from "./dto"; +import { PublishMessagesDto } from "./message.dto"; +import { MessageGuard } from "./message.guard"; import { Message } from "./message.schema"; import { MessageService } from "./message.service"; @ApiTags("v1/messages") -@ApiBearerAuth() @Controller("v1/messages") +@UseGuards(MessageGuard) export class MessageController { /** * Logger diff --git a/apps/relayer/ts/message/dto.ts b/apps/relayer/ts/message/message.dto.ts similarity index 75% rename from apps/relayer/ts/message/dto.ts rename to apps/relayer/ts/message/message.dto.ts index a5e25e701..af5b4fd14 100644 --- a/apps/relayer/ts/message/dto.ts +++ b/apps/relayer/ts/message/message.dto.ts @@ -9,6 +9,8 @@ import { ArrayMinSize, ArrayMaxSize, ValidateNested, + ArrayNotEmpty, + IsString, } from "class-validator"; import { Message } from "maci-domainobjs"; @@ -17,7 +19,7 @@ import { PublicKeyValidator } from "./validation"; /** * Max messages per batch */ -const MAX_MESSAGES = 20; +export const MAX_MESSAGES = Number(process.env.MAX_MESSAGES); /** * Data transfer object for user message @@ -50,6 +52,32 @@ export class MessageContractParamsDto { * Data transfer object for publish messages */ export class PublishMessagesDto { + /** + * State leaf index + */ + @ApiProperty({ + description: "State leaf index", + minimum: 0, + type: Number, + }) + @IsInt() + @Min(0) + stateLeafIndex!: number; + + /** + * User poll joined proof + */ + @ApiProperty({ + description: "User poll joined proof", + type: [String], + }) + @IsArray() + @ArrayNotEmpty() + @ArrayMinSize(8) + @ArrayMaxSize(8) + @IsString({ each: true }) + proof!: string[]; + /** * Poll id */ diff --git a/apps/relayer/ts/message/message.guard.ts b/apps/relayer/ts/message/message.guard.ts new file mode 100644 index 000000000..8dbbcec9b --- /dev/null +++ b/apps/relayer/ts/message/message.guard.ts @@ -0,0 +1,109 @@ +import { + Logger, + type CanActivate, + Injectable, + SetMetadata, + type ExecutionContext, + type CustomDecorator, + HttpException, + HttpStatus, +} from "@nestjs/common"; +import { Reflector } from "@nestjs/core"; +import { validate } from "class-validator"; +import hardhat from "hardhat"; +import flatMap from "lodash/flatMap"; +import flatten from "lodash/flatten"; +import map from "lodash/map"; +import values from "lodash/values"; +import { MACI__factory as MACIFactory, Poll__factory as PollFactory } from "maci-contracts/typechain-types"; + +import type { Request as Req } from "express"; + +import { MAX_MESSAGES, MessageContractParamsDto, PublishMessagesDto } from "./message.dto"; + +/** + * Public metadata key + */ +export const PUBLIC_METADATA_KEY = "isPublic"; + +/** + * Public decorator to by-pass auth checks + * + * @returns public decorator + */ +export const Public = (): CustomDecorator => SetMetadata(PUBLIC_METADATA_KEY, true); + +@Injectable() +export class MessageGuard implements CanActivate { + /** + * Logger + */ + private readonly logger: Logger; + + /** + * Initialized MessageGuard + * + * @param reflector Reflector + */ + constructor(private readonly reflector: Reflector) { + this.logger = new Logger(MessageGuard.name); + } + + /** + * This function should return a boolean, indicating whether the request is allowed or not based on proof. + * + * @param ctx - execution context + * @returns whether the request is allowed or not + */ + async canActivate(ctx: ExecutionContext): Promise { + const isPublic = this.reflector.get(PUBLIC_METADATA_KEY, ctx.getHandler()); + + if (isPublic) { + return true; + } + + const request = ctx.switchToHttp().getRequest>>(); + + const messages = + Array.isArray(request.body?.messages) && request.body.messages.length <= MAX_MESSAGES + ? request.body.messages + : []; + + const messageErrors = await Promise.all( + map(messages, (message) => validate(Object.assign(new MessageContractParamsDto(), message))), + ).then((errors) => flatten(errors)); + + const dto = Object.assign(new PublishMessagesDto(), request.body); + const dtoErrors = await validate(dto); + + if (dtoErrors.length > 0 || messageErrors.length > 0) { + this.logger.warn("Invalid body: ", dtoErrors); + throw new HttpException( + { + error: "Bad Request", + message: flatMap(dtoErrors, (error) => values(error.constraints)).concat( + flatMap(messageErrors, ({ constraints }) => + values(constraints).map((value, index) => `messages.${index}.${value}`), + ), + ), + statusCode: HttpStatus.BAD_REQUEST, + }, + HttpStatus.BAD_REQUEST, + ); + } + + try { + const [signer] = await hardhat.ethers.getSigners(); + const maciContract = MACIFactory.connect(dto.maciContractAddress, signer); + const pollAddresses = await maciContract.polls(dto.poll); + const pollContract = PollFactory.connect(pollAddresses.poll, signer); + + const isValid = await pollContract.verifyJoinedPollProof(dto.stateLeafIndex, dto.proof); + + return isValid; + } catch (error) { + this.logger.error("Activate error: ", error); + return false; + } + } +} diff --git a/apps/relayer/ts/message/message.repository.ts b/apps/relayer/ts/message/message.repository.ts index 5b65ac7e2..c983ee072 100644 --- a/apps/relayer/ts/message/message.repository.ts +++ b/apps/relayer/ts/message/message.repository.ts @@ -2,7 +2,7 @@ import { Injectable, Logger } from "@nestjs/common"; import { InjectModel } from "@nestjs/mongoose"; import { Model, RootFilterQuery } from "mongoose"; -import { PublishMessagesDto } from "./dto"; +import { PublishMessagesDto } from "./message.dto"; import { Message, MESSAGES_LIMIT } from "./message.schema"; /** diff --git a/apps/relayer/ts/message/message.service.ts b/apps/relayer/ts/message/message.service.ts index 231ce3d98..db034f836 100644 --- a/apps/relayer/ts/message/message.service.ts +++ b/apps/relayer/ts/message/message.service.ts @@ -1,7 +1,7 @@ import { Injectable, Logger } from "@nestjs/common"; import { Cron, CronExpression } from "@nestjs/schedule"; -import type { PublishMessagesDto } from "./dto"; +import type { PublishMessagesDto } from "./message.dto"; import { MessageBatchService } from "../messageBatch/messageBatch.service"; diff --git a/apps/relayer/ts/messageBatch/__tests__/messageBatch.service.test.ts b/apps/relayer/ts/messageBatch/__tests__/messageBatch.service.test.ts index 2b9ebe2ab..d4958796d 100644 --- a/apps/relayer/ts/messageBatch/__tests__/messageBatch.service.test.ts +++ b/apps/relayer/ts/messageBatch/__tests__/messageBatch.service.test.ts @@ -3,7 +3,7 @@ import { ZeroAddress } from "ethers"; import { MACI__factory as MACIFactory, Poll__factory as PollFactory } from "maci-contracts"; import { IpfsService } from "../../ipfs/ipfs.service"; -import { MessageBatchDto } from "../dto"; +import { MessageBatchDto } from "../messageBatch.dto"; import { MessageBatchRepository } from "../messageBatch.repository"; import { MessageBatchService } from "../messageBatch.service"; diff --git a/apps/relayer/ts/messageBatch/__tests__/utils.ts b/apps/relayer/ts/messageBatch/__tests__/utils.ts index c717c8b12..76b77c40e 100644 --- a/apps/relayer/ts/messageBatch/__tests__/utils.ts +++ b/apps/relayer/ts/messageBatch/__tests__/utils.ts @@ -1,7 +1,7 @@ import { ZeroAddress } from "ethers"; import { Keypair } from "maci-domainobjs"; -import { MessageBatchDto } from "../dto"; +import { MessageBatchDto } from "../messageBatch.dto"; const keypair = new Keypair(); diff --git a/apps/relayer/ts/messageBatch/dto.ts b/apps/relayer/ts/messageBatch/messageBatch.dto.ts similarity index 100% rename from apps/relayer/ts/messageBatch/dto.ts rename to apps/relayer/ts/messageBatch/messageBatch.dto.ts diff --git a/apps/relayer/ts/messageBatch/messageBatch.repository.ts b/apps/relayer/ts/messageBatch/messageBatch.repository.ts index bac2e6efc..b2c2ed389 100644 --- a/apps/relayer/ts/messageBatch/messageBatch.repository.ts +++ b/apps/relayer/ts/messageBatch/messageBatch.repository.ts @@ -2,7 +2,7 @@ import { Injectable, Logger } from "@nestjs/common"; import { InjectModel } from "@nestjs/mongoose"; import { Model, RootFilterQuery } from "mongoose"; -import { MessageBatchDto } from "./dto"; +import { MessageBatchDto } from "./messageBatch.dto"; import { MESSAGE_BATCHES_LIMIT, MessageBatch } from "./messageBatch.schema"; /** diff --git a/apps/relayer/ts/messageBatch/messageBatch.service.ts b/apps/relayer/ts/messageBatch/messageBatch.service.ts index 64aa635ee..0f008725c 100644 --- a/apps/relayer/ts/messageBatch/messageBatch.service.ts +++ b/apps/relayer/ts/messageBatch/messageBatch.service.ts @@ -5,7 +5,7 @@ import uniqBy from "lodash/uniqBy"; import { MACI__factory as MACIFactory, Poll__factory as PollFactory } from "maci-contracts"; import { PubKey } from "maci-domainobjs"; -import type { MessageBatchDto } from "./dto"; +import type { MessageBatchDto } from "./messageBatch.dto"; import { IpfsService } from "../ipfs/ipfs.service"; diff --git a/package.json b/package.json index e18ab4dcd..60def844e 100644 --- a/package.json +++ b/package.json @@ -13,7 +13,9 @@ "clean": "lerna exec -- rm -rf node_modules build && rm -rf node_modules", "commit": "git cz", "download-zkeys:test": "ts-node ./.github/scripts/downloadZkeys.ts test ./packages/cli/zkeys", + "download-zkeys:test:relayer": "ts-node ./.github/scripts/downloadZkeys.ts test ./apps/relayer/zkeys", "download-zkeys:ceremony": "ts-node ./.github/scripts/downloadZkeys.ts prod ./packages/cli/zkeys", + "download-zkeys:ceremony:relayer": "ts-node ./.github/scripts/downloadZkeys.ts prod ./apps/relayer/zkeys", "prettier": "prettier -c .", "prettier:fix": "prettier -w .", "lint:ts": "eslint './**/**/*.ts' './**/**/*.tsx'", diff --git a/packages/circuits/circom/anon/poll.circom b/packages/circuits/circom/anon/poll.circom index 8f3e71e93..52530814e 100644 --- a/packages/circuits/circom/anon/poll.circom +++ b/packages/circuits/circom/anon/poll.circom @@ -66,6 +66,8 @@ template PollJoined(stateTreeDepth) { signal input pathIndices[stateTreeDepth]; // Poll State tree root which proves the user is joined signal input stateRoot; + // The actual tree depth (might be <= stateTreeDepth) Used in BinaryMerkleRoot + signal input actualStateTreeDepth; // User private to public key var derivedPubKey[2] = PrivToPubKey()(privKey); @@ -73,8 +75,9 @@ template PollJoined(stateTreeDepth) { var stateLeaf = PoseidonHasher(4)([derivedPubKey[0], derivedPubKey[1], voiceCreditsBalance, joinTimestamp]); // Inclusion proof - var stateLeafQip = MerkleTreeInclusionProof(stateTreeDepth)( + var stateLeafQip = BinaryMerkleRoot(stateTreeDepth)( stateLeaf, + actualStateTreeDepth, pathIndices, pathElements ); diff --git a/packages/circuits/ts/__tests__/PollJoined.test.ts b/packages/circuits/ts/__tests__/PollJoined.test.ts index ff658299e..2749f17d8 100644 --- a/packages/circuits/ts/__tests__/PollJoined.test.ts +++ b/packages/circuits/ts/__tests__/PollJoined.test.ts @@ -22,6 +22,7 @@ describe("Poll Joined circuit", function test() { "pathElements", "pathIndices", "stateRoot", + "actualStateTreeDepth", ]; let circuit: WitnessTester; @@ -55,7 +56,6 @@ describe("Poll Joined circuit", function test() { pollId = maciState.deployPoll(timestamp + BigInt(duration), treeDepths, messageBatchSize, coordinatorKeypair); poll = maciState.polls.get(pollId)!; - poll.updatePoll(BigInt(maciState.pubKeys.length)); // Join the poll const { privKey, pubKey: pollPubKey } = users[0]; @@ -84,6 +84,8 @@ describe("Poll Joined circuit", function test() { poll.publishMessage(message, ecdhKeypair.pubKey); + poll.updatePoll(BigInt(maciState.pubKeys.length)); + // Process messages poll.processMessages(pollId); }); diff --git a/packages/circuits/ts/types.ts b/packages/circuits/ts/types.ts index a3a41a82a..859f82d32 100644 --- a/packages/circuits/ts/types.ts +++ b/packages/circuits/ts/types.ts @@ -68,6 +68,7 @@ export interface IPollJoinedInputs { pathIndices: bigint[]; credits: bigint; stateRoot: bigint; + actualStateTreeDepth: bigint; } /** diff --git a/packages/cli/ts/commands/index.ts b/packages/cli/ts/commands/index.ts index d1ac5d93d..7a49291b0 100644 --- a/packages/cli/ts/commands/index.ts +++ b/packages/cli/ts/commands/index.ts @@ -1,7 +1,7 @@ export { deploy } from "./deploy"; export { deployPoll } from "./deployPoll"; export { getPoll } from "./poll"; -export { joinPoll, isJoinedUser } from "./joinPoll"; +export { joinPoll, isJoinedUser, generateAndVerifyProof } from "./joinPoll"; export { deployVkRegistryContract } from "./deployVkRegistry"; export { genKeyPair } from "./genKeyPair"; export { genMaciPubKey } from "./genPubKey"; diff --git a/packages/cli/ts/commands/joinPoll.ts b/packages/cli/ts/commands/joinPoll.ts index a6adb2e53..e10597b59 100644 --- a/packages/cli/ts/commands/joinPoll.ts +++ b/packages/cli/ts/commands/joinPoll.ts @@ -291,6 +291,9 @@ export const joinPoll = async ({ const pollJoiningVk = await extractVk(pollJoiningZkey); let pollStateIndex = ""; + let voiceCredits = ""; + let timestamp = ""; + let privateKeyNullifier = ""; let receipt: ContractTransactionReceipt | null = null; const sgData = sgDataArg || DEFAULT_SG_DATA; @@ -330,7 +333,7 @@ export const joinPoll = async ({ if (receipt?.logs) { const [log] = receipt.logs; const { args } = iface.parseLog(log as unknown as { topics: string[]; data: string }) || { args: [] }; - [, , , , , pollStateIndex] = args; + [, , voiceCredits, timestamp, privateKeyNullifier, pollStateIndex] = args; logGreen(quiet, success(`State index: ${pollStateIndex.toString()}`)); } else { logError("Unable to retrieve the transaction receipt"); @@ -341,6 +344,9 @@ export const joinPoll = async ({ return { pollStateIndex: pollStateIndex ? pollStateIndex.toString() : "", + voiceCredits: voiceCredits ? voiceCredits.toString() : "", + timestamp: timestamp ? timestamp.toString() : "", + nullifier: privateKeyNullifier ? privateKeyNullifier.toString() : "", hash: receipt!.hash, }; }; @@ -356,6 +362,7 @@ const parsePollJoinEvents = async ({ }: IParsePollJoinEventsArgs): Promise<{ pollStateIndex?: string; voiceCredits?: string; + timestamp?: string; }> => { // 1000 blocks at a time for (let block = startBlock; block <= currentBlock; block += BLOCKS_STEP) { @@ -374,6 +381,7 @@ const parsePollJoinEvents = async ({ return { pollStateIndex: event.args[5].toString(), voiceCredits: event.args[2].toString(), + timestamp: event.args[3].toString(), }; } } @@ -381,6 +389,7 @@ const parsePollJoinEvents = async ({ return { pollStateIndex: undefined, voiceCredits: undefined, + timestamp: undefined, }; }; diff --git a/packages/cli/ts/index.ts b/packages/cli/ts/index.ts index d4e3fb315..2f8ea3aed 100644 --- a/packages/cli/ts/index.ts +++ b/packages/cli/ts/index.ts @@ -832,6 +832,7 @@ export { verify, joinPoll, isJoinedUser, + generateAndVerifyProof, } from "./commands"; export type { diff --git a/packages/cli/ts/utils/interfaces.ts b/packages/cli/ts/utils/interfaces.ts index 50bfc4bce..a5ecaa852 100644 --- a/packages/cli/ts/utils/interfaces.ts +++ b/packages/cli/ts/utils/interfaces.ts @@ -484,6 +484,21 @@ export interface IJoinPollData { */ pollStateIndex: string; + /** + * Voice credits balance + */ + voiceCredits: string; + + /** + * Joining poll timestamp + */ + timestamp: string; + + /** + * Private key nullifier + */ + nullifier: string; + /** * The join poll transaction hash */ diff --git a/packages/contracts/contracts/Poll.sol b/packages/contracts/contracts/Poll.sol index e42e01e4a..620ea82f3 100644 --- a/packages/contracts/contracts/Poll.sol +++ b/packages/contracts/contracts/Poll.sol @@ -99,6 +99,10 @@ contract Poll is Params, Utilities, SnarkCommon, IPoll { /// @notice The Id of this poll uint256 public immutable pollId; + /// @notice The array of the poll state tree roots for each poll join + /// For the N'th poll join, the poll state tree root will be stored at the index N + uint256[] public pollStateRootsOnJoin; + error VotingPeriodOver(); error VotingPeriodNotOver(); error PollAlreadyInit(); @@ -193,6 +197,7 @@ contract Poll is Params, Utilities, SnarkCommon, IPoll { InternalLazyIMT._init(pollStateTree, extContracts.maci.stateTreeDepth()); InternalLazyIMT._insert(pollStateTree, BLANK_STATE_LEAF_HASH); + pollStateRootsOnJoin.push(BLANK_STATE_LEAF_HASH); emit PublishMessage(_message, _padKey); } @@ -353,7 +358,10 @@ contract Poll is Params, Utilities, SnarkCommon, IPoll { // Store user in the pollStateTree uint256 stateLeaf = hashStateLeaf(StateLeaf(_pubKey, voiceCreditBalance, block.timestamp)); - InternalLazyIMT._insert(pollStateTree, stateLeaf); + uint256 stateRoot = InternalLazyIMT._insert(pollStateTree, stateLeaf); + + // Store the current state tree root in the array + pollStateRootsOnJoin.push(stateRoot); uint256 pollStateIndex = pollStateTree.numberOfLeaves - 1; emit PollJoined(_pubKey.x, _pubKey.y, voiceCreditBalance, block.timestamp, _nullifier, pollStateIndex); @@ -385,14 +393,9 @@ contract Poll is Params, Utilities, SnarkCommon, IPoll { /// @notice Verify the proof for joined Poll /// @param _index Index of the MACI's stateRootOnSignUp when the user signed up - /// @param _pubKey Poll user's public key /// @param _proof The zk-SNARK proof /// @return isValid Whether the proof is valid - function verifyJoinedPollProof( - uint256 _index, - PubKey calldata _pubKey, - uint256[8] memory _proof - ) public view returns (bool isValid) { + function verifyJoinedPollProof(uint256 _index, uint256[8] memory _proof) public view returns (bool isValid) { // Get the verifying key from the VkRegistry VerifyingKey memory vk = extContracts.vkRegistry.getPollJoinedVk( extContracts.maci.stateTreeDepth(), @@ -400,7 +403,7 @@ contract Poll is Params, Utilities, SnarkCommon, IPoll { ); // Generate the circuit public input - uint256[] memory circuitPublicInputs = getPublicJoinedCircuitInputs(_index, _pubKey); + uint256[] memory circuitPublicInputs = getPublicJoinedCircuitInputs(_index); isValid = extContracts.verifier.verify(_proof, vk, circuitPublicInputs); } @@ -426,18 +429,11 @@ contract Poll is Params, Utilities, SnarkCommon, IPoll { /// @notice Get public circuit inputs for poll joined circuit /// @param _index Index of the MACI's stateRootOnSignUp when the user signed up - /// @param _pubKey Poll user's public key /// @return publicInputs Public circuit inputs - function getPublicJoinedCircuitInputs( - uint256 _index, - PubKey calldata _pubKey - ) public view returns (uint256[] memory publicInputs) { - publicInputs = new uint256[](4); + function getPublicJoinedCircuitInputs(uint256 _index) public view returns (uint256[] memory publicInputs) { + publicInputs = new uint256[](1); - publicInputs[0] = _pubKey.x; - publicInputs[1] = _pubKey.y; - publicInputs[2] = extContracts.maci.getStateRootOnIndexedSignUp(_index); - publicInputs[3] = pollId; + publicInputs[0] = pollStateRootsOnJoin[_index]; } /// @inheritdoc IPoll diff --git a/packages/contracts/contracts/trees/LazyIMT.sol b/packages/contracts/contracts/trees/LazyIMT.sol index 091f5b584..f8f3140ba 100644 --- a/packages/contracts/contracts/trees/LazyIMT.sol +++ b/packages/contracts/contracts/trees/LazyIMT.sol @@ -125,7 +125,7 @@ library InternalLazyIMT { /// @notice Inserts a leaf into the LazyIMT /// @param self The LazyIMTData /// @param leaf The leaf to insert - function _insert(LazyIMTData storage self, uint256 leaf) internal { + function _insert(LazyIMTData storage self, uint256 leaf) internal returns (uint256) { uint40 index = self.numberOfLeaves; self.numberOfLeaves = index + 1; @@ -143,6 +143,8 @@ library InternalLazyIMT { i++; } } + + return hash; } /// @notice Returns the root of the LazyIMT diff --git a/packages/core/ts/Poll.ts b/packages/core/ts/Poll.ts index 01f2d8b2b..340fdc0bd 100644 --- a/packages/core/ts/Poll.ts +++ b/packages/core/ts/Poll.ts @@ -496,18 +496,19 @@ export class Poll implements IPoll { voiceCreditsBalance, joinTimestamp, }: IJoinedCircuitArgs): IPollJoinedCircuitInputs => { - // copy a poll state tree - const pollStateTree = new IncrementalQuinTree(this.stateTreeDepth, blankStateLeafHash, STATE_TREE_ARITY, hash2); - - this.pollStateLeaves.forEach((stateLeaf) => { - pollStateTree.insert(stateLeaf.hash()); - }); - // calculate the path elements for the state tree given the original state tree - const { pathElements, pathIndices } = pollStateTree.genProof(Number(stateLeafIndex)); + const { pathElements, pathIndices } = this.pollStateTree!.genProof(Number(stateLeafIndex)); // Get poll state tree's root - const stateRoot = pollStateTree.root; + const stateRoot = this.pollStateTree!.root; + const elementsLength = pathIndices.length; + + for (let i = 0; i < this.stateTreeDepth; i += 1) { + if (i >= elementsLength) { + pathElements[i] = [0n]; + pathIndices[i] = 0; + } + } const circuitInputs = { privKey: maciPrivKey.asCircuitInputs(), @@ -515,6 +516,7 @@ export class Poll implements IPoll { voiceCreditsBalance: voiceCreditsBalance.toString(), joinTimestamp: joinTimestamp.toString(), pathIndices: pathIndices.map((item) => item.toString()), + actualStateTreeDepth: BigInt(this.actualStateTreeDepth), stateRoot, }; @@ -1432,6 +1434,7 @@ export class Poll implements IPoll { numBatchesProcessed: this.numBatchesProcessed, numSignups: this.numSignups.toString(), chainHash: this.chainHash.toString(), + pollNullifiers: [...this.pollNullifiers.keys()].map((nullifier) => nullifier.toString()), batchHashes: this.batchHashes.map((batchHash) => batchHash.toString()), }; } @@ -1456,6 +1459,7 @@ export class Poll implements IPoll { poll.numBatchesProcessed = json.numBatchesProcessed; poll.chainHash = BigInt(json.chainHash); poll.batchHashes = json.batchHashes.map((batchHash: string) => BigInt(batchHash)); + poll.pollNullifiers = new Map(json.pollNullifiers.map((nullifier) => [BigInt(nullifier), true])); // copy maci state poll.updatePoll(BigInt(json.numSignups)); diff --git a/packages/core/ts/utils/types.ts b/packages/core/ts/utils/types.ts index edb1c223d..8c1f93ea5 100644 --- a/packages/core/ts/utils/types.ts +++ b/packages/core/ts/utils/types.ts @@ -102,6 +102,7 @@ export interface IJsonPoll { numBatchesProcessed: number; numSignups: string; chainHash: string; + pollNullifiers: string[]; batchHashes: string[]; } @@ -178,6 +179,7 @@ export interface IPollJoinedCircuitInputs { pathElements: string[][]; pathIndices: string[]; stateRoot: string; + actualStateTreeDepth: string; } /** diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index a5f4df047..3ca5d8506 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -161,6 +161,12 @@ importers: lodash: specifier: ^4.17.21 version: 4.17.21 + maci-circuits: + specifier: workspace:^2.5.0 + version: link:../../packages/circuits + maci-cli: + specifier: workspace:^2.5.0 + version: link:../../packages/cli maci-contracts: specifier: workspace:^2.5.0 version: link:../../packages/contracts