diff --git a/examples/factorial.py b/examples/factorial.py index 6963a3f9..8648a87b 100644 --- a/examples/factorial.py +++ b/examples/factorial.py @@ -106,7 +106,7 @@ def test_zero(self): ) assert last_assignments[0] == 0 # i assert last_assignments[1] == 1 # x - factorial.halo2_mock_prover(factorial_witness) + factorial.halo2_mock_prover(factorial_witness, "examples/ptau/hermez-raw-11") def test_basic(self): factorial = Factorial() @@ -116,7 +116,7 @@ def test_basic(self): ) assert last_assignments[0] == 7 # i assert last_assignments[1] == 5040 # x - factorial.halo2_mock_prover(factorial_witness) + factorial.halo2_mock_prover(factorial_witness, "examples/ptau/hermez-raw-11") if __name__ == "__main__": diff --git a/examples/fibonacci.py b/examples/fibonacci.py index be0047b0..371e7414 100644 --- a/examples/fibonacci.py +++ b/examples/fibonacci.py @@ -80,9 +80,9 @@ def trace(self, n): fibo = Fibonacci() fibo_witness = fibo.gen_witness(7) fibo.halo2_mock_prover( - fibo_witness + fibo_witness, "examples/ptau/hermez-raw-11" ) # 2^k specifies the number of PLONKish table rows in Halo2 another_fibo_witness = fibo.gen_witness(4) -fibo.halo2_mock_prover(another_fibo_witness) +fibo.halo2_mock_prover(another_fibo_witness, "examples/ptau/hermez-raw-11") fibo.to_pil(fibo_witness, "FiboCircuit") diff --git a/examples/mimc7.py b/examples/mimc7.py index a23e78c1..68c70c9c 100644 --- a/examples/mimc7.py +++ b/examples/mimc7.py @@ -167,4 +167,4 @@ def mapping(self, x_in_value, k_value): mimc7_super_witness = mimc7.gen_witness(F(1), F(2)) # for key, value in mimc7_super_witness.items(): # print(f"{key}: {str(value)}") -mimc7.halo2_mock_prover(mimc7_super_witness) +mimc7.halo2_mock_prover(mimc7_super_witness, "examples/ptau/hermez-raw-11") diff --git a/examples/poseidon.py b/examples/poseidon.py index 1d81233b..5f8cd5df 100644 --- a/examples/poseidon.py +++ b/examples/poseidon.py @@ -18,14 +18,10 @@ def setup(self): self.pragma_num_steps(self.lens) - self.table = self.new_table( - table() - .add(self.row) - .add(self.value) - ) + self.table = self.new_table(table().add(self.row).add(self.value)) def fixed_gen(self): - for i, round_key in enumerate(self.constants[0:self.lens]): + for i, round_key in enumerate(self.constants[0 : self.lens]): self.assign(i, self.row, F(i)) self.assign(i, self.value, F(round_key)) @@ -41,14 +37,10 @@ def setup(self): self.pragma_num_steps(self.lens) - self.table = self.new_table( - table() - .add(self.row) - .add(self.value) - ) + self.table = self.new_table(table().add(self.row).add(self.value)) def fixed_gen(self): - for i, round_key in enumerate(self.matrix[0:self.lens]): + for i, round_key in enumerate(self.matrix[0 : self.lens]): self.assign(i, self.row, F(i)) self.assign(i, self.value, F(round_key)) @@ -70,19 +62,11 @@ def setup(self): else: self.constr(eq(inputs[i - 1] + constants[i], x_vec[i])) - self.add_lookup( - self.circuit.constants_table - .apply(i) - .apply(constants[i]) - ) + self.add_lookup(self.circuit.constants_table.apply(i).apply(constants[i])) self.constr( eq( - x_vec[i] - * x_vec[i] - * x_vec[i] - * x_vec[i] - * x_vec[i], + x_vec[i] * x_vec[i] * x_vec[i] * x_vec[i] * x_vec[i], sboxs[i], ) ) @@ -91,9 +75,9 @@ def setup(self): m_offset = i * param_t for j in range(0, param_t): self.add_lookup( - self.circuit.matrix_table - .apply(m_offset + j) - .apply(matrix[m_offset + j]) + self.circuit.matrix_table.apply(m_offset + j).apply( + matrix[m_offset + j] + ) ) lc = sboxs[0] * matrix[m_offset] @@ -115,7 +99,9 @@ def wg(self, round_values): self.assign(signal, F(value)) for i in range(0, self.circuit.param_t): - self.assign(self.circuit.constants[i], F(round_values["constant_values"][i])) + self.assign( + self.circuit.constants[i], F(round_values["constant_values"][i]) + ) if i < len(round_values["input_values"]): self.assign(self.circuit.inputs[i], F(round_values["input_values"][i])) else: @@ -139,27 +125,25 @@ def setup(self): for i in range(0, param_t): self.constr(eq(inputs[i] + constants[i], x_vec[i])) - self.constr(eq( - x_vec[i] - * x_vec[i] - * x_vec[i] - * x_vec[i] - * x_vec[i], - sboxs[i], - )) + self.constr( + eq( + x_vec[i] * x_vec[i] * x_vec[i] * x_vec[i] * x_vec[i], + sboxs[i], + ) + ) self.add_lookup( - self.circuit.constants_table - .apply(self.circuit.round * param_t + i) - .apply(constants[i]), + self.circuit.constants_table.apply( + self.circuit.round * param_t + i + ).apply(constants[i]), ) for i in range(0, param_t): m_offset = i * param_t for j in range(0, param_t): self.add_lookup( - self.circuit.matrix_table - .apply(m_offset + j) - .apply(matrix[m_offset + j]), + self.circuit.matrix_table.apply(m_offset + j).apply( + matrix[m_offset + j] + ), ) lc = sboxs[0] * matrix[m_offset] @@ -169,19 +153,16 @@ def setup(self): self.constr(eq(lc, outs[i])) self.transition(eq(outs[i], inputs[i].next())) - self.transition( - eq(self.circuit.round + 1, self.circuit.round.next()) - ) + self.transition(eq(self.circuit.round + 1, self.circuit.round.next())) def wg(self, round_values): - for signal, value in zip( - self.circuit.matrix, - round_values["matrix_values"] - ): + for signal, value in zip(self.circuit.matrix, round_values["matrix_values"]): self.assign(signal, F(value)) for i in range(0, self.circuit.param_t): - self.assign(self.circuit.constants[i], F(round_values["constant_values"][i])) + self.assign( + self.circuit.constants[i], F(round_values["constant_values"][i]) + ) self.assign(self.circuit.inputs[i], F(round_values["input_values"][i])) self.assign(self.circuit.x_vec[i], F(round_values["x_values"][i])) self.assign(self.circuit.sboxs[i], F(round_values["sbox_values"][i])) @@ -205,19 +186,15 @@ def setup(self): for i in range(0, param_t): self.constr(eq(inputs[i] + constants[i], x_vec[i])) self.add_lookup( - self.circuit.constants_table - .apply(self.circuit.round * param_t + i) - .apply(constants[i]), + self.circuit.constants_table.apply( + self.circuit.round * param_t + i + ).apply(constants[i]), ) for i in range(0, param_c): self.constr( eq( - x_vec[i] - * x_vec[i] - * x_vec[i] - * x_vec[i] - * x_vec[i], + x_vec[i] * x_vec[i] * x_vec[i] * x_vec[i] * x_vec[i], sboxs[i], ) ) @@ -226,9 +203,9 @@ def setup(self): m_offset = i * param_t for j in range(0, param_t): self.add_lookup( - self.circuit.matrix_table - .apply(m_offset + j) - .apply(matrix[m_offset + j]), + self.circuit.matrix_table.apply(m_offset + j).apply( + matrix[m_offset + j] + ), ) lc = sboxs[0] * matrix[m_offset] @@ -240,24 +217,21 @@ def setup(self): self.constr(eq(lc, outs[i])) self.transition(eq(outs[i], inputs[i].next())) - self.transition( - eq(self.circuit.round + 1, self.circuit.round.next()) - ) + self.transition(eq(self.circuit.round + 1, self.circuit.round.next())) def wg(self, round_values): - for signal, value in zip( - self.circuit.matrix, - round_values["matrix_values"] - ): + for signal, value in zip(self.circuit.matrix, round_values["matrix_values"]): self.assign(signal, F(value)) for i in range(0, self.circuit.param_t): - self.assign(self.circuit.constants[i], F(round_values["constant_values"][i])) + self.assign( + self.circuit.constants[i], F(round_values["constant_values"][i]) + ) self.assign(self.circuit.inputs[i], F(round_values["input_values"][i])) self.assign(self.circuit.outs[i], F(round_values["out_values"][i])) self.assign(self.circuit.x_vec[i], F(round_values["x_values"][i])) - for i, sbox in enumerate(self.circuit.sboxs[0:self.circuit.param_c]): + for i, sbox in enumerate(self.circuit.sboxs[0 : self.circuit.param_c]): self.assign(sbox, F(round_values["sbox_values"][i])) self.assign(self.circuit.round, F(round_values["round"])) @@ -276,30 +250,28 @@ def setup(self): for i in range(0, param_t): self.constr(eq(inputs[i] + constants[i], x_vec[i])) - self.constr(eq( - x_vec[i] - * x_vec[i] - * x_vec[i] - * x_vec[i] - * x_vec[i], - sboxs[i], - )) + self.constr( + eq( + x_vec[i] * x_vec[i] * x_vec[i] * x_vec[i] * x_vec[i], + sboxs[i], + ) + ) self.add_lookup( - self.circuit.constants_table - .apply(self.circuit.round * param_t + i) - .apply(constants[i]), + self.circuit.constants_table.apply( + self.circuit.round * param_t + i + ).apply(constants[i]), ) for i in range(0, param_t): m_offset = i * param_t for j in range(0, param_t): self.add_lookup( - self.circuit.matrix_table - .apply(m_offset + j) - .apply(matrix[m_offset + j]) + self.circuit.matrix_table.apply(m_offset + j).apply( + matrix[m_offset + j] + ) ) - for i, out in enumerate(outs[0:self.circuit.lens["n_outputs"]]): + for i, out in enumerate(outs[0 : self.circuit.lens["n_outputs"]]): m_offset = i * param_t lc = sboxs[0] * matrix[m_offset] for s in range(1, param_t): @@ -307,22 +279,18 @@ def setup(self): self.constr(eq(lc, out)) def wg(self, round_values): - for signal, value in zip( - self.circuit.matrix, - round_values["matrix_values"] - ): + for signal, value in zip(self.circuit.matrix, round_values["matrix_values"]): self.assign(signal, F(value)) for i in range(0, self.circuit.param_t): - self.assign(self.circuit.constants[i], F(round_values["constant_values"][i])) + self.assign( + self.circuit.constants[i], F(round_values["constant_values"][i]) + ) self.assign(self.circuit.inputs[i], F(round_values["input_values"][i])) self.assign(self.circuit.x_vec[i], F(round_values["x_values"][i])) self.assign(self.circuit.sboxs[i], F(round_values["sbox_values"][i])) - for signal, value in zip( - self.circuit.outs, - round_values["out_values"] - ): + for signal, value in zip(self.circuit.outs, round_values["out_values"]): self.assign(signal, F(value)) self.assign(self.circuit.round, F(round_values["round"])) @@ -340,8 +308,13 @@ def setup(self): assert self.lens["n_inputs"] < self.param_t assert self.lens["n_outputs"] < self.param_t - self.matrix = [self.forward("matrix_" + str(i)) for i in range(0, self.param_t * self.param_t)] - self.constants = [self.forward("constant_" + str(i)) for i in range(0, self.param_t)] + self.matrix = [ + self.forward("matrix_" + str(i)) + for i in range(0, self.param_t * self.param_t) + ] + self.constants = [ + self.forward("constant_" + str(i)) for i in range(0, self.param_t) + ] self.inputs = [self.forward("input_" + str(i)) for i in range(0, self.param_t)] self.outs = [self.forward("output_" + str(i)) for i in range(0, self.param_t)] self.sboxs = [self.forward("sbox_" + str(i)) for i in range(0, self.param_t)] @@ -349,10 +322,18 @@ def setup(self): self.round = self.forward("round") - self.step_first_round = self.step_type(PoseidonStepFirstRound(self, "step_first_round")) - self.step_full_round = self.step_type(PoseidonStepFullRound(self, "step_full_round")) - self.step_partial_round = self.step_type(PoseidonStepPartialRound(self, "step_partial_round")) - self.step_last_round = self.step_type(PoseidonStepLastRound(self, "step_last_round")) + self.step_first_round = self.step_type( + PoseidonStepFirstRound(self, "step_first_round") + ) + self.step_full_round = self.step_type( + PoseidonStepFullRound(self, "step_full_round") + ) + self.step_partial_round = self.step_type( + PoseidonStepPartialRound(self, "step_partial_round") + ) + self.step_last_round = self.step_type( + PoseidonStepLastRound(self, "step_last_round") + ) self.pragma_first_step(self.step_first_round) self.pragma_last_step(self.step_last_round) @@ -382,7 +363,7 @@ def trace(self, values): sbox_values = [] for x_value in x_values: - sbox_values.append(x_value ** 5) + sbox_values.append(x_value**5) outputs = [] for i in range(0, param_t): @@ -408,10 +389,9 @@ def trace(self, values): for i in range(1, int(param_t / 2) + 1): x_values = [ - inputs[j] + constant_values[i * param_t + j] - for j in range(0, param_t) + inputs[j] + constant_values[i * param_t + j] for j in range(0, param_t) ] - sbox_values = [x_value ** 5 for x_value in x_values] + sbox_values = [x_value**5 for x_value in x_values] def method(j): m_offset = j * param_t @@ -420,12 +400,10 @@ def method(j): out_value += sbox_values[s] * matrix_values[m_offset + s] return out_value - outputs = [ - method(j) for j in range(0, param_t) - ] + outputs = [method(j) for j in range(0, param_t)] round_values = { "input_values": inputs, - "constant_values": constant_values[i * param_t:(i + 1) * param_t], + "constant_values": constant_values[i * param_t : (i + 1) * param_t], "matrix_values": matrix_values, "x_values": x_values, "sbox_values": sbox_values, @@ -440,7 +418,7 @@ def method(j): inputs[j] + constant_values[j + int(i + param_f / 2) * param_t] for j in range(0, param_t) ] - sbox_values = [x_value ** 5 for x_value in x_values] + sbox_values = [x_value**5 for x_value in x_values] def method(t): m_offset = t * param_t @@ -451,13 +429,13 @@ def method(t): out_value += x_values[k] * matrix_values[m_offset + k] return out_value - outputs = [ - method(j) for j in range(0, param_t) - ] + outputs = [method(j) for j in range(0, param_t)] round_values = { "input_values": inputs, - "constant_values": constant_values[int(i + param_f / 2) * param_t:int(i + param_f / 2 + 1) * param_t], + "constant_values": constant_values[ + int(i + param_f / 2) * param_t : int(i + param_f / 2 + 1) * param_t + ], "matrix_values": matrix_values, "x_values": x_values, "sbox_values": sbox_values, @@ -469,10 +447,11 @@ def method(t): for i in range(0, int(param_f / 2) - 1): x_values = [ - inputs[j] + constant_values[(i + int(param_f / 2) + param_p) * param_t + j] + inputs[j] + + constant_values[(i + int(param_f / 2) + param_p) * param_t + j] for j in range(0, param_t) ] - sbox_values = [x_value ** 5 for x_value in x_values] + sbox_values = [x_value**5 for x_value in x_values] def method(j): m_offset = j * param_t @@ -481,12 +460,14 @@ def method(j): out_value += sbox_values[s] * matrix_values[m_offset + s] return out_value - outputs = [ - method(j) for j in range(0, param_t) - ] + outputs = [method(j) for j in range(0, param_t)] round_values = { "input_values": inputs, - "constant_values": constant_values[(i + int(param_f / 2) + param_p) * param_t:(i + int(param_f / 2) + param_p + 1) * param_t], + "constant_values": constant_values[ + (i + int(param_f / 2) + param_p) + * param_t : (i + int(param_f / 2) + param_p + 1) + * param_t + ], "matrix_values": matrix_values, "x_values": x_values, "sbox_values": sbox_values, @@ -500,7 +481,7 @@ def method(j): inputs[i] + constant_values[i + (param_p + param_f - 1) * param_t] for i in range(0, param_t) ] - sbox_values = [x_value ** 5 for x_value in x_values] + sbox_values = [x_value**5 for x_value in x_values] def method(i): m_offset = i * param_t @@ -509,15 +490,14 @@ def method(i): out_value += sbox_values[s] * matrix_values[m_offset + s] return out_value - outputs = [ - method(i) - for i in range(0, values["n_outputs"]) - ] + outputs = [method(i) for i in range(0, values["n_outputs"])] print("[poseidon hash] outputs = ", outputs) round_values = { "input_values": inputs, - "constant_values": constant_values[(param_p + param_f - 1) * param_t:(param_p + param_f) * param_t], + "constant_values": constant_values[ + (param_p + param_f - 1) * param_t : (param_p + param_f) * param_t + ], "matrix_values": matrix_values, "x_values": x_values, "sbox_values": sbox_values, @@ -533,9 +513,7 @@ def setup(self): self.constants_circuit = self.sub_circuit( PoseidonConstants(self, n_inputs=n_inputs) ) - self.matrix_circuit = self.sub_circuit( - PoseidonMatrix(self, n_inputs=n_inputs) - ) + self.matrix_circuit = self.sub_circuit(PoseidonMatrix(self, n_inputs=n_inputs)) self.poseidon_circuit = self.sub_circuit( PoseidonCircuit( self, @@ -553,9 +531,7 @@ class Examples: def test_basic(self): # Arrange values = { - "inputs": [ - 1, 1, 1, 1, 1, 1 - ], + "inputs": [1, 1, 1, 1, 1, 1], "n_outputs": 1, } lens = { @@ -567,7 +543,7 @@ def test_basic(self): # Act poseidon = PoseidonSuperCircuit(lens=lens) witness = poseidon.gen_witness(values) - poseidon.halo2_mock_prover(witness) + poseidon.halo2_mock_prover(witness, "examples/ptau/hermez-raw-11") # Assert circuit_trace = list(witness.values())[0] @@ -583,7 +559,8 @@ def test_basic(self): if __name__ == "__main__": x = Examples() for method in [ - method for method in dir(x) + method + for method in dir(x) if callable(getattr(x, method)) if not method.startswith("_") ]: diff --git a/src/frontend/python/chiquito/dsl.py b/src/frontend/python/chiquito/dsl.py index a8b25146..9c8e71ab 100644 --- a/src/frontend/python/chiquito/dsl.py +++ b/src/frontend/python/chiquito/dsl.py @@ -75,7 +75,9 @@ def gen_witness(self: SuperCircuit, *args: Any) -> Dict[int, TraceWitness]: ) # so that we can generate different witness mapping in the next gen_witness() call return super_witness - def halo2_mock_prover(self: SuperCircuit, super_witness: Dict[int, TraceWitness]): + def halo2_mock_prover( + self: SuperCircuit, super_witness: Dict[int, TraceWitness], params_path: str + ): witness_json = {} for rust_id, witness in super_witness.items(): if rust_id not in self.ast.sub_circuits: @@ -84,7 +86,9 @@ def halo2_mock_prover(self: SuperCircuit, super_witness: Dict[int, TraceWitness] ) witness_json[rust_id] = witness.get_witness_json() rust_chiquito.super_circuit_halo2_mock_prover( - list(self.ast.sub_circuits.keys()), witness_json + list(self.ast.sub_circuits.keys()), + witness_json, + params_path, ) @@ -216,19 +220,19 @@ def gen_witness(self: Circuit, *args) -> TraceWitness: def get_ast_json(self: Circuit) -> str: return json.dumps(self.ast, cls=CustomEncoder, indent=4) - def halo2_mock_prover(self: Circuit, witness: TraceWitness): + def halo2_mock_prover(self: Circuit, witness: TraceWitness, params_path: str): if self.rust_id == 0: ast_json: str = self.get_ast_json() - self.rust_id: int = rust_chiquito.ast_to_halo2(ast_json) + self.rust_id: int = rust_chiquito.ast_to_plonkish(ast_json) witness_json: str = witness.get_witness_json() - rust_chiquito.halo2_mock_prover(witness_json, self.rust_id) + rust_chiquito.halo2_mock_prover(witness_json, self.rust_id, params_path) def to_pil( self: Circuit, witness: TraceWitness, circuit_name: str = "Circuit" ) -> str: if self.rust_id == 0: ast_json: str = self.get_ast_json() - self.rust_id: int = rust_chiquito.ast_to_halo2(ast_json) + self.rust_id: int = rust_chiquito.ast_to_plonkish(ast_json) witness_json: str = witness.get_witness_json() rust_chiquito.to_pil(witness_json, self.rust_id, circuit_name) diff --git a/src/frontend/python/mod.rs b/src/frontend/python/mod.rs index 5323a838..c08c6ef5 100644 --- a/src/frontend/python/mod.rs +++ b/src/frontend/python/mod.rs @@ -7,7 +7,7 @@ use crate::{ frontend::dsl::{StepTypeHandler, SuperCircuitContext}, pil::backend::powdr_pil::chiquito2Pil, plonkish::{ - backend::halo2::{chiquito2Halo2, halo2_verify, ChiquitoHalo2, Halo2Provable}, + backend::halo2::{halo2_verify, Halo2Provable}, compiler::{ cell_manager::SingleRowCellManager, compile, config, step_selector::SimpleStepSelectorBuilder, PlonkishCompilationResult, @@ -30,7 +30,7 @@ use std::{cell::RefCell, collections::HashMap, fmt}; type CircuitMapStore = ( SBPIR, - ChiquitoHalo2, + Option>, Option>, ); type CircuitMap = RefCell>; @@ -51,22 +51,25 @@ impl TraceGenerator for PythonTraceGenerator { } /// Parses JSON into `ast::Circuit` and compile. Generates a Rust UUID. Inserts tuple of -/// (`ast::Circuit`, `ChiquitoHalo2`, `AssignmentGenerator`, _) to `CIRCUIT_MAP` with the Rust UUID -/// as the key. Return the Rust UUID to Python. The last field of the tuple, `TraceWitness`, is left -/// as None, for `chiquito_add_witness_to_rust_id` to insert. -pub fn chiquito_ast_to_halo2(ast_json: &str) -> UUID { +/// (`ast::Circuit`, `PlonkishCompilationResult`, `AssignmentGenerator`, _) to `CIRCUIT_MAP` with +/// the Rust UUID as the key. Return the Rust UUID to Python. The last field of the tuple, +/// `TraceWitness`, is left as None, for `chiquito_add_witness_to_rust_id` to insert. +pub fn chiquito_ast_to_plonkish(ast_json: &str) -> UUID { let circuit: SBPIR = serde_json::from_str(ast_json).expect("Json deserialization to Circuit failed."); let config = config(SingleRowCellManager {}, SimpleStepSelectorBuilder {}); let plonkish = compile(config, &circuit); - let chiquito_halo2 = chiquito2Halo2(plonkish.circuit); let uuid = uuid(); CIRCUIT_MAP.with(|circuit_map| { circuit_map.borrow_mut().insert( uuid, - (circuit, chiquito_halo2, plonkish.assignment_generator), + ( + circuit, + Some(plonkish.clone()), + plonkish.assignment_generator, + ), ); }); @@ -83,9 +86,7 @@ pub fn chiquito_ast_map_store(ast_json: &str) -> UUID { let uuid = uuid(); CIRCUIT_MAP.with(|circuit_map| { - circuit_map - .borrow_mut() - .insert(uuid, (circuit, ChiquitoHalo2::default(), None)); + circuit_map.borrow_mut().insert(uuid, (circuit, None, None)); }); uuid @@ -173,27 +174,18 @@ fn rust_id_to_halo2(uuid: UUID) -> CircuitMapStore { }) } -/// Runs `MockProver` for a single circuit given JSON of `TraceWitness` and `rust_id` of the +/// Runs the Halo2 prover for a single circuit given JSON of `TraceWitness` and `rust_id` of the /// circuit. -pub fn chiquito_halo2_mock_prover(witness_json: &str, rust_id: UUID) { +pub fn chiquito_halo2_prover(witness_json: &str, rust_id: UUID, params_path: &str) { let trace_witness: TraceWitness = serde_json::from_str(witness_json).expect("Json deserialization to TraceWitness failed."); - let (_, compiled, assignment_generator) = rust_id_to_halo2(rust_id); - - let mut plonkish = PlonkishCompilationResult { - circuit: compiled.plonkish_ir, - assignment_generator, - }; - - let params_path = "examples/ptau/hermez-raw-11"; - let halo2_prover = plonkish.create_halo2_prover(params_path); + let (_, plonkish, assignment_generator) = rust_id_to_halo2(rust_id); - let (proof, instance) = halo2_prover.generate_proof( - plonkish - .assignment_generator - .unwrap() - .generate(trace_witness), - ); + let halo2_prover = plonkish + .expect("Plonkish compilation is missing") + .create_halo2_prover(params_path); + let (proof, instance) = + halo2_prover.generate_proof(assignment_generator.unwrap().generate(trace_witness)); let result = halo2_verify( proof, @@ -1876,8 +1868,8 @@ fn convert_and_print_trace_witness(json: &PyString) { } #[pyfunction] -fn ast_to_halo2(json: &PyString) -> u128 { - let uuid = chiquito_ast_to_halo2(json.to_str().expect("PyString conversion failed.")); +fn ast_to_plonkish(json: &PyString) -> u128 { + let uuid = chiquito_ast_to_plonkish(json.to_str().expect("PyString conversion failed.")); uuid } @@ -1885,9 +1877,9 @@ fn ast_to_halo2(json: &PyString) -> u128 { #[pyfunction] fn to_pil(witness_json: &PyString, rust_id: &PyLong, circuit_name: &PyString) -> String { let pil = chiquito_ast_to_pil( - witness_json.to_str().expect("PyString convertion failed."), - rust_id.extract().expect("PyLong convertion failed."), - circuit_name.to_str().expect("PyString convertion failed."), + witness_json.to_str().expect("PyString conversion failed."), + rust_id.extract().expect("PyLong conversion failed."), + circuit_name.to_str().expect("PyString conversion failed."), ); println!("{}", pil); @@ -1902,10 +1894,11 @@ fn ast_map_store(json: &PyString) -> u128 { } #[pyfunction] -fn halo2_mock_prover(witness_json: &PyString, rust_id: &PyLong) { - chiquito_halo2_mock_prover( +fn halo2_mock_prover(witness_json: &PyString, rust_id: &PyLong, params_path: &str) { + chiquito_halo2_prover( witness_json.to_str().expect("PyString conversion failed."), rust_id.extract().expect("PyLong conversion failed."), + params_path, ); } @@ -1946,7 +1939,7 @@ fn super_circuit_halo2_mock_prover(rust_ids: &PyList, super_witness: &PyDict) { fn rust_chiquito(_py: Python, m: &PyModule) -> PyResult<()> { m.add_function(wrap_pyfunction!(convert_and_print_ast, m)?)?; m.add_function(wrap_pyfunction!(convert_and_print_trace_witness, m)?)?; - m.add_function(wrap_pyfunction!(ast_to_halo2, m)?)?; + m.add_function(wrap_pyfunction!(ast_to_plonkish, m)?)?; m.add_function(wrap_pyfunction!(to_pil, m)?)?; m.add_function(wrap_pyfunction!(ast_map_store, m)?)?; m.add_function(wrap_pyfunction!(halo2_mock_prover, m)?)?; diff --git a/src/plonkish/backend/halo2.rs b/src/plonkish/backend/halo2.rs index 773b4090..aef4a75c 100644 --- a/src/plonkish/backend/halo2.rs +++ b/src/plonkish/backend/halo2.rs @@ -113,7 +113,7 @@ pub trait Halo2WitnessGenerator { ) -> Vec>>; } -#[derive(Clone, Debug, Default)] +#[derive(Clone, Debug)] pub struct ChiquitoHalo2> { pub debug: bool, @@ -721,7 +721,7 @@ trait Halo2Compilable> { fn halo2_compile(&mut self) -> (WG, CompiledCircuit, u32); } -impl + Default> Halo2Compilable, ChiquitoHalo2> +impl + Clone + Default> Halo2Compilable, ChiquitoHalo2> for PlonkishCompilationResult { fn halo2_compile(&mut self) -> (ChiquitoHalo2, CompiledCircuit, u32) { @@ -737,7 +737,7 @@ impl Halo2Compilable, ChiquitoHalo2SuperCircui let compiled = self .get_sub_circuits() .iter() - .map(|c| chiquito2Halo2((*c).clone())) + .map(|c| ChiquitoHalo2::new(c.clone())) .collect(); let mut circuit = ChiquitoHalo2SuperCircuit::new(compiled); @@ -796,11 +796,3 @@ fn to_halo2_advice( _ => panic!("jarll wrong phase"), } } - -/// LEGACY -#[allow(non_snake_case)] -pub(crate) fn chiquito2Halo2 + Hash>( - circuit: Circuit, -) -> ChiquitoHalo2 { - ChiquitoHalo2::new(circuit) -} diff --git a/src/plonkish/compiler/mod.rs b/src/plonkish/compiler/mod.rs index 50a14d05..7d656f9d 100644 --- a/src/plonkish/compiler/mod.rs +++ b/src/plonkish/compiler/mod.rs @@ -54,7 +54,8 @@ pub fn compile< } } -pub struct PlonkishCompilationResult { +#[derive(Clone)] +pub struct PlonkishCompilationResult> { pub circuit: Circuit, pub assignment_generator: Option>, } diff --git a/src/plonkish/ir/assignments.rs b/src/plonkish/ir/assignments.rs index afc1cae7..711f29f2 100644 --- a/src/plonkish/ir/assignments.rs +++ b/src/plonkish/ir/assignments.rs @@ -94,7 +94,7 @@ pub struct AssignmentGenerator> { ir_id: UUID, } -impl Clone for AssignmentGenerator +impl Clone for AssignmentGenerator where TG: TraceGenerator + Clone, {