From 9392183bf3df0c886a7d79ee54913092e8a42af4 Mon Sep 17 00:00:00 2001 From: Alex Kuzmin Date: Fri, 9 Aug 2024 19:23:05 +0800 Subject: [PATCH 1/9] Implement new SBPIR instantiation --- examples/blake2f.rs | 17 +- examples/factorial.rs | 7 +- examples/fibo_with_padding.rs | 7 +- examples/fibonacci.rs | 7 +- examples/keccak.rs | 17 +- examples/mimc7.rs | 9 +- examples/poseidon.rs | 11 +- src/compiler/compiler.rs | 206 +++--- src/compiler/compiler_legacy.rs | 719 +++++++++++++++++++++ src/compiler/mod.rs | 7 +- src/frontend/dsl/circuit_context_legacy.rs | 209 ++++++ src/frontend/dsl/mod.rs | 33 +- src/frontend/dsl/sc.rs | 16 +- src/interpreter/mod.rs | 2 +- src/poly/mod.rs | 4 +- src/sbpir/mod.rs | 36 +- src/sbpir/sbpir_machine.rs | 73 +-- 17 files changed, 1145 insertions(+), 235 deletions(-) create mode 100644 src/compiler/compiler_legacy.rs create mode 100644 src/frontend/dsl/circuit_context_legacy.rs diff --git a/examples/blake2f.rs b/examples/blake2f.rs index 558461ce..0e2fdb46 100644 --- a/examples/blake2f.rs +++ b/examples/blake2f.rs @@ -1,10 +1,11 @@ use chiquito::{ frontend::dsl::{ cb::{eq, select, table}, + circuit_context_legacy::CircuitContextLegacy, lb::LookupTable, super_circuit, trace::DSLTraceGenerator, - CircuitContext, StepTypeSetupContext, StepTypeWGHandler, + StepTypeSetupContext, StepTypeWGHandler, }, plonkish::{ backend::halo2_legacy::{chiquitoSuperCircuit2Halo2, ChiquitoHalo2SuperCircuit}, @@ -127,7 +128,10 @@ pub fn split_to_4bits_values(vec_values: &[u64]) -> Vec(ctx: &mut CircuitContext, _: usize) -> LookupTable { +fn blake2f_iv_table( + ctx: &mut CircuitContextLegacy, + _: usize, +) -> LookupTable { let lookup_iv_row: Queriable = ctx.fixed("iv row"); let lookup_iv_value: Queriable = ctx.fixed("iv value"); @@ -144,7 +148,10 @@ fn blake2f_iv_table(ctx: &mut CircuitContext, _: usize) } // For range checking -fn blake2f_4bits_table(ctx: &mut CircuitContext, _: usize) -> LookupTable { +fn blake2f_4bits_table( + ctx: &mut CircuitContextLegacy, + _: usize, +) -> LookupTable { let lookup_4bits_row: Queriable = ctx.fixed("4bits row"); let lookup_4bits_value: Queriable = ctx.fixed("4bits value"); @@ -160,7 +167,7 @@ fn blake2f_4bits_table(ctx: &mut CircuitContext, _: usi } fn blake2f_xor_4bits_table( - ctx: &mut CircuitContext, + ctx: &mut CircuitContextLegacy, _: usize, ) -> LookupTable { let lookup_xor_row: Queriable = ctx.fixed("xor row"); @@ -526,7 +533,7 @@ fn g_setup( } fn blake2f_circuit( - ctx: &mut CircuitContext>, + ctx: &mut CircuitContextLegacy>, params: CircuitParams, ) { let v_vec: Vec> = (0..V_LEN) diff --git a/examples/factorial.rs b/examples/factorial.rs index 9bef03aa..2e0e6831 100644 --- a/examples/factorial.rs +++ b/examples/factorial.rs @@ -2,8 +2,9 @@ use std::hash::Hash; use chiquito::{ field::Field, - frontend::dsl::{circuit, trace::DSLTraceGenerator}, /* main function for constructing an AST - * circuit */ + frontend::dsl::{circuit_context_legacy::circuit_legacy, trace::DSLTraceGenerator}, /* main function for constructing + * an AST + * circuit */ plonkish::{ backend::halo2_legacy::{chiquito2Halo2, ChiquitoHalo2Circuit}, compiler::{ @@ -42,7 +43,7 @@ fn generate + Hash>() -> PlonkishCompilationResult("factorial", |ctx| { + let factorial_circuit = circuit_legacy::("factorial", |ctx| { let i = ctx.shared("i"); let x = ctx.forward("x"); diff --git a/examples/fibo_with_padding.rs b/examples/fibo_with_padding.rs index d93c5bbd..9c383eca 100644 --- a/examples/fibo_with_padding.rs +++ b/examples/fibo_with_padding.rs @@ -2,8 +2,9 @@ use std::hash::Hash; use chiquito::{ field::Field, - frontend::dsl::{circuit, trace::DSLTraceGenerator}, /* main function for constructing an AST - * circuit */ + frontend::dsl::{circuit_context_legacy::circuit_legacy, trace::DSLTraceGenerator}, /* main function for constructing + * an AST + * circuit */ plonkish::{ backend::halo2_legacy::{chiquito2Halo2, ChiquitoHalo2Circuit}, compiler::{ @@ -37,7 +38,7 @@ fn fibo_circuit + Hash>( sbpir::ExposeOffset::*, // for exposing witnesses }; - let fibo = circuit::("fibonacci", |ctx| { + let fibo = circuit_legacy::("fibonacci", |ctx| { // Example table for 7 rounds: // | step_type | a | b | c | n | // --------------------------------------- diff --git a/examples/fibonacci.rs b/examples/fibonacci.rs index a388425f..7812cae7 100644 --- a/examples/fibonacci.rs +++ b/examples/fibonacci.rs @@ -2,8 +2,9 @@ use std::hash::Hash; use chiquito::{ field::Field, - frontend::dsl::{circuit, trace::DSLTraceGenerator}, /* main function for constructing an AST - * circuit */ + frontend::dsl::{circuit_context_legacy::circuit_legacy, trace::DSLTraceGenerator}, /* main function for constructing + * an AST + * circuit */ plonkish::{ backend::{ halo2_legacy::{chiquito2Halo2, ChiquitoHalo2Circuit}, @@ -50,7 +51,7 @@ fn fibo_circuit + Hash>() -> FiboReturn { use chiquito::frontend::dsl::cb::*; // functions for constraint building - let fibo = circuit::("fibonacci", |ctx| { + let fibo = circuit_legacy::("fibonacci", |ctx| { // the following objects (forward signals, steptypes) are defined on the circuit-level // forward signals can have constraints across different steps diff --git a/examples/keccak.rs b/examples/keccak.rs index 5a22f07e..fb6c4802 100644 --- a/examples/keccak.rs +++ b/examples/keccak.rs @@ -1,6 +1,7 @@ use chiquito::{ frontend::dsl::{ - lb::LookupTable, super_circuit, trace::DSLTraceGenerator, CircuitContext, StepTypeWGHandler, + circuit_context_legacy::CircuitContextLegacy, lb::LookupTable, super_circuit, + trace::DSLTraceGenerator, StepTypeWGHandler, }, plonkish::{ backend::halo2_legacy::{chiquitoSuperCircuit2Halo2, ChiquitoHalo2SuperCircuit}, @@ -231,7 +232,7 @@ fn eval_keccak_f_to_bit_vec4>(value1: F, value2: } fn keccak_xor_table_batch2( - ctx: &mut CircuitContext, + ctx: &mut CircuitContextLegacy, lens: usize, ) -> LookupTable { use chiquito::frontend::dsl::cb::*; @@ -254,7 +255,7 @@ fn keccak_xor_table_batch2( } fn keccak_xor_table_batch3( - ctx: &mut CircuitContext, + ctx: &mut CircuitContextLegacy, lens: usize, ) -> LookupTable { use chiquito::frontend::dsl::cb::*; @@ -280,7 +281,7 @@ fn keccak_xor_table_batch3( } fn keccak_xor_table_batch4( - ctx: &mut CircuitContext, + ctx: &mut CircuitContextLegacy, lens: usize, ) -> LookupTable { use chiquito::frontend::dsl::cb::*; @@ -306,7 +307,7 @@ fn keccak_xor_table_batch4( } fn keccak_chi_table( - ctx: &mut CircuitContext, + ctx: &mut CircuitContextLegacy, lens: usize, ) -> LookupTable { use chiquito::frontend::dsl::cb::*; @@ -332,7 +333,7 @@ fn keccak_chi_table( } fn keccak_pack_table( - ctx: &mut CircuitContext, + ctx: &mut CircuitContextLegacy, _: usize, ) -> LookupTable { use chiquito::frontend::dsl::cb::*; @@ -362,7 +363,7 @@ fn keccak_pack_table( } fn keccak_round_constants_table( - ctx: &mut CircuitContext, + ctx: &mut CircuitContextLegacy, lens: usize, ) -> LookupTable { use chiquito::frontend::dsl::cb::*; @@ -722,7 +723,7 @@ fn eval_keccak_f_one_round + Eq + Hash>( } fn keccak_circuit + Eq + Hash>( - ctx: &mut CircuitContext>, + ctx: &mut CircuitContextLegacy>, param: CircuitParams, ) { use chiquito::frontend::dsl::cb::*; diff --git a/examples/mimc7.rs b/examples/mimc7.rs index 468d0868..9fed5b5d 100644 --- a/examples/mimc7.rs +++ b/examples/mimc7.rs @@ -6,7 +6,10 @@ use halo2_proofs::{ }; use chiquito::{ - frontend::dsl::{lb::LookupTable, super_circuit, trace::DSLTraceGenerator, CircuitContext}, + frontend::dsl::{ + circuit_context_legacy::CircuitContextLegacy, lb::LookupTable, super_circuit, + trace::DSLTraceGenerator, + }, plonkish::{ backend::halo2_legacy::{chiquitoSuperCircuit2Halo2, ChiquitoHalo2SuperCircuit}, compiler::{ @@ -23,7 +26,7 @@ use mimc7_constants::ROUND_CONSTANTS; pub const ROUNDS: usize = 91; fn mimc7_constants( - ctx: &mut CircuitContext>, + ctx: &mut CircuitContextLegacy>, _: (), ) -> LookupTable { use chiquito::frontend::dsl::cb::*; @@ -49,7 +52,7 @@ fn mimc7_constants( } fn mimc7_circuit( - ctx: &mut CircuitContext>, + ctx: &mut CircuitContextLegacy>, constants: LookupTable, ) { use chiquito::frontend::dsl::cb::*; diff --git a/examples/poseidon.rs b/examples/poseidon.rs index 56b5f3e3..d686b049 100644 --- a/examples/poseidon.rs +++ b/examples/poseidon.rs @@ -1,5 +1,8 @@ use chiquito::{ - frontend::dsl::{lb::LookupTable, super_circuit, trace::DSLTraceGenerator, CircuitContext}, + frontend::dsl::{ + circuit_context_legacy::CircuitContextLegacy, lb::LookupTable, super_circuit, + trace::DSLTraceGenerator, + }, plonkish::{ backend::halo2_legacy::{chiquitoSuperCircuit2Halo2, ChiquitoHalo2SuperCircuit}, compiler::{ @@ -49,7 +52,7 @@ struct CircuitParams { } fn poseidon_constants_table( - ctx: &mut CircuitContext, + ctx: &mut CircuitContextLegacy, param_t: usize, ) -> LookupTable { use chiquito::frontend::dsl::cb::*; @@ -75,7 +78,7 @@ fn poseidon_constants_table( } fn poseidon_matrix_table( - ctx: &mut CircuitContext, + ctx: &mut CircuitContextLegacy, param_t: usize, ) -> LookupTable { use chiquito::frontend::dsl::cb::*; @@ -97,7 +100,7 @@ fn poseidon_matrix_table( } fn poseidon_circuit( - ctx: &mut CircuitContext>>, + ctx: &mut CircuitContextLegacy>>, param: CircuitParams, ) { use chiquito::frontend::dsl::cb::*; diff --git a/src/compiler/compiler.rs b/src/compiler/compiler.rs index 44bb6d50..5d6c04c6 100644 --- a/src/compiler/compiler.rs +++ b/src/compiler/compiler.rs @@ -19,9 +19,8 @@ use crate::{ }, lang::TLDeclsParser, }, - plonkish::{self, compiler::PlonkishCompilationResult}, poly::{self, mielim::mi_elimination, reduce::reduce_degree, Expr}, - sbpir::{query::Queriable, InternalSignal, SBPIRLegacy, SBPIR}, + sbpir::{query::Queriable, InternalSignal, SBPIR}, wit_gen::{NullTraceGenerator, SymbolSignalMapping, TraceGenerator}, }; @@ -31,31 +30,12 @@ use super::{ Config, Message, Messages, }; +#[derive(Debug)] pub struct CompilerResult { pub messages: Vec, pub circuit: SBPIR, } -/// Contains the result of a single machine compilation (legacy). -#[derive(Debug)] -pub struct CompilerResultLegacy { - pub messages: Vec, - pub circuit: SBPIRLegacy, -} - -impl CompilerResultLegacy { - /// Compiles to the Plonkish IR, that then can be compiled to plonkish backends. - pub fn plonkish< - CM: plonkish::compiler::cell_manager::CellManager, - SSB: plonkish::compiler::step_selector::StepSelectorBuilder, - >( - &self, - config: plonkish::compiler::CompilerConfig, - ) -> PlonkishCompilationResult { - plonkish::compiler::compile(config, &self.circuit) - } -} - /// This compiler compiles from chiquito source code to the SBPIR. #[derive(Default)] pub(super) struct Compiler { @@ -80,41 +60,6 @@ impl Compiler { } } - /// Compile the source code containing a single machine (legacy). - pub(super) fn compile_legacy( - mut self, - source: &str, - debug_sym_ref_factory: &DebugSymRefFactory, - ) -> Result, Vec> { - let ast = self - .parse(source, debug_sym_ref_factory) - .map_err(|_| self.messages.clone())?; - assert!(ast.len() == 1, "Use `compile` to compile multiple machines"); - let ast = self.add_virtual(ast); - let symbols = self.semantic(&ast).map_err(|_| self.messages.clone())?; - let setup = Self::interpret(&ast, &symbols); - let setup = Self::map_consts(setup); - let circuit = self.build(&setup, &symbols); - let circuit = Self::mi_elim(circuit); - let circuit = if let Some(degree) = self.config.max_degree { - Self::reduce(circuit, degree) - } else { - circuit - }; - - let circuit = circuit.with_trace(InterpreterTraceGenerator::new( - ast, - symbols, - self.mapping, - self.config.max_steps, - )); - - Ok(CompilerResultLegacy { - messages: self.messages, - circuit, - }) - } - /// Compile the source code. pub(super) fn compile( mut self, @@ -126,12 +71,10 @@ impl Compiler { .map_err(|_| self.messages.clone())?; let ast = self.add_virtual(ast); let symbols = self.semantic(&ast).map_err(|_| self.messages.clone())?; - let setup = Self::interpret(&ast, &symbols); - let setup = Self::map_consts(setup); + let machine_setups = Self::interpret(&ast, &symbols); + let machine_setups = Self::map_consts(machine_setups); - let machine_id = setup.iter().next().unwrap().0; - - let circuit = self.build(&setup, &symbols); + let circuit = self.build(&machine_setups, &symbols); let circuit = Self::mi_elim(circuit); let circuit = if let Some(degree) = self.config.max_degree { Self::reduce(circuit, degree) @@ -139,19 +82,16 @@ impl Compiler { circuit }; - let circuit = circuit.with_trace(InterpreterTraceGenerator::new( + let circuit = circuit.with_trace(&InterpreterTraceGenerator::new( ast, symbols, self.mapping, self.config.max_steps, )); - // TODO perform real compilation for multiple machines - let sbpir = SBPIR::from_legacy(circuit, machine_id.as_str()); - Ok(CompilerResult { messages: self.messages, - circuit: sbpir, + circuit, }) } @@ -332,13 +272,11 @@ impl Compiler { } } - fn build( - &mut self, - setup: &Setup, - symbols: &SymTable, - ) -> SBPIRLegacy { - circuit::("circuit", |ctx| { - for (machine_id, machine) in setup { + fn build(&mut self, setup: &Setup, symbols: &SymTable) -> SBPIR { + let mut sbpir = SBPIR::default(); + + for (machine_id, machine) in setup { + let sbpir_machine = circuit::("circuit", |ctx| { self.add_forwards(ctx, symbols, machine_id); self.add_step_type_handlers(ctx, symbols, machine_id); @@ -369,42 +307,49 @@ impl Compiler { }, ); } - } - ctx.trace(|_, _| {}); - }) - .without_trace() + ctx.trace(|_, _| {}); + }) + .without_trace(); + + sbpir.add_machine(machine_id, sbpir_machine); + } + + sbpir } - fn mi_elim( - mut circuit: SBPIRLegacy, - ) -> SBPIRLegacy { - for (_, step_type) in circuit.step_types.iter_mut() { - let mut signal_factory = SignalFactory::default(); + fn mi_elim(mut circuit: SBPIR) -> SBPIR { + for machine in circuit.machines.values_mut() { + for (_, step_type) in machine.step_types.iter_mut() { + let mut signal_factory = SignalFactory::default(); - step_type.decomp_constraints(|expr| mi_elimination(expr.clone(), &mut signal_factory)); + step_type + .decomp_constraints(|expr| mi_elimination(expr.clone(), &mut signal_factory)); + } } circuit } fn reduce( - mut circuit: SBPIRLegacy, + mut circuit: SBPIR, degree: usize, - ) -> SBPIRLegacy { - for (_, step_type) in circuit.step_types.iter_mut() { - let mut signal_factory = SignalFactory::default(); + ) -> SBPIR { + for machine in circuit.machines.values_mut() { + for (_, step_type) in machine.step_types.iter_mut() { + let mut signal_factory = SignalFactory::default(); - step_type.decomp_constraints(|expr| { - reduce_degree(expr.clone(), degree, &mut signal_factory) - }); + step_type.decomp_constraints(|expr| { + reduce_degree(expr.clone(), degree, &mut signal_factory) + }); + } } circuit } #[allow(dead_code)] - fn cse(mut _circuit: SBPIRLegacy) -> SBPIRLegacy { + fn cse(mut _circuit: SBPIR) -> SBPIR { todo!() } @@ -677,17 +622,55 @@ impl poly::SignalFactory> for SignalFactory { mod test { use halo2_proofs::halo2curves::bn256::Fr; - use crate::{ - compiler::{compile_file_legacy, compile_legacy}, - parser::ast::debug_sym_factory::DebugSymRefFactory, - }; + use crate::{compiler::compile, parser::ast::debug_sym_factory::DebugSymRefFactory}; use super::Config; + // TODO rewrite the test after machines are able to call other machines #[test] - fn test_compiler_fibo() { + fn test_compiler_fibo_multiple_machines() { + // Source code containing two machines let circuit = " - machine fibo(signal n) (signal b: field) { + machine fibo1 (signal n) (signal b: field) { + // n and be are created automatically as shared + // signals + signal a: field, i; + + // there is always a state called initial + // input signals get bound to the signal + // in the initial state (first instance) + state initial { + signal c; + + i, a, b, c <== 1, 1, 1, 2; + + -> middle { + a', b', n' <== b, c, n; + } + } + + state middle { + signal c; + + c <== a + b; + + if i + 1 == n { + -> final { + i', b', n' <== i + 1, c, n; + } + } else { + -> middle { + i', a', b', n' <== i + 1, b, c, n; + } + } + } + + // There is always a state called final. + // Output signals get automatically bound to the signals + // with the same name in the final step (last instance). + // This state can be implicit if there are no constraints in it. + } + machine fibo2 (signal n) (signal b: field) { // n and be are created automatically as shared // signals signal a: field, i; @@ -729,35 +712,18 @@ mod test { "; let debug_sym_ref_factory = DebugSymRefFactory::new("", circuit); - let result = compile_legacy::( + let result = compile::( circuit, Config::default().max_degree(2), &debug_sym_ref_factory, ); match result { - Ok(result) => println!("{:#?}", result), + Ok(result) => { + assert_eq!(result.circuit.machines.len(), 2); + println!("{:#?}", result) + } Err(messages) => println!("{:#?}", messages), } } - - #[test] - fn test_compiler_fibo_file() { - let path = "test/circuit.chiquito"; - let result = compile_file_legacy::(path, Config::default().max_degree(2)); - assert!(result.is_ok()); - } - - #[test] - fn test_compiler_fibo_file_err() { - let path = "test/circuit_error.chiquito"; - let result = compile_file_legacy::(path, Config::default().max_degree(2)); - - assert!(result.is_err()); - - assert_eq!( - format!("{:?}", result.unwrap_err()), - r#"[SemErr { msg: "use of undeclared variable c", dsym: test/circuit_error.chiquito:24:39 }, SemErr { msg: "use of undeclared variable c", dsym: test/circuit_error.chiquito:28:46 }]"# - ) - } } diff --git a/src/compiler/compiler_legacy.rs b/src/compiler/compiler_legacy.rs new file mode 100644 index 00000000..09ff201c --- /dev/null +++ b/src/compiler/compiler_legacy.rs @@ -0,0 +1,719 @@ +use std::{collections::HashMap, hash::Hash, marker::PhantomData}; + +use num_bigint::BigInt; + +use crate::{ + field::Field, + frontend::dsl::{ + cb::{Constraint, Typing}, + circuit_context_legacy::{circuit_legacy, CircuitContextLegacy}, + StepTypeContext, + }, + interpreter::InterpreterTraceGenerator, + parser::{ + ast::{ + debug_sym_factory::DebugSymRefFactory, + expression::Expression, + statement::{Statement, TypedIdDecl}, + tl::TLDecl, + DebugSymRef, Identifiable, Identifier, + }, + lang::TLDeclsParser, + }, + plonkish::{self, compiler::PlonkishCompilationResult}, + poly::{self, mielim::mi_elimination, reduce::reduce_degree, Expr}, + sbpir::{query::Queriable, InternalSignal, SBPIRLegacy}, + wit_gen::{NullTraceGenerator, SymbolSignalMapping, TraceGenerator}, +}; + +use super::{ + semantic::{SymTable, SymbolCategory}, + setup_inter::{interpret, MachineSetup, Setup}, + Config, Message, Messages, +}; + +/// Contains the result of a single machine compilation (legacy). +#[derive(Debug)] +pub struct CompilerResultLegacy { + pub messages: Vec, + pub circuit: SBPIRLegacy, +} + +impl CompilerResultLegacy { + /// Compiles to the Plonkish IR, that then can be compiled to plonkish backends. + pub fn plonkish< + CM: plonkish::compiler::cell_manager::CellManager, + SSB: plonkish::compiler::step_selector::StepSelectorBuilder, + >( + &self, + config: plonkish::compiler::CompilerConfig, + ) -> PlonkishCompilationResult { + plonkish::compiler::compile(config, &self.circuit) + } +} + +/// This compiler compiles from chiquito source code to the SBPIR. +#[derive(Default)] +pub(super) struct CompilerLegacy { + pub(super) config: Config, + + messages: Vec, + + mapping: SymbolSignalMapping, + + _p: PhantomData, +} + +impl CompilerLegacy { + /// Creates a configured compiler. + pub fn new(mut config: Config) -> Self { + if config.max_steps == 0 { + config.max_steps = 1000; // TODO: organise this better + } + CompilerLegacy { + config, + ..CompilerLegacy::default() + } + } + + /// Compile the source code containing a single machine. + pub(super) fn compile( + mut self, + source: &str, + debug_sym_ref_factory: &DebugSymRefFactory, + ) -> Result, Vec> { + let ast = self + .parse(source, debug_sym_ref_factory) + .map_err(|_| self.messages.clone())?; + assert!(ast.len() == 1, "Use `compile` to compile multiple machines"); + let ast = self.add_virtual(ast); + let symbols = self.semantic(&ast).map_err(|_| self.messages.clone())?; + let setup = Self::interpret(&ast, &symbols); + let setup = Self::map_consts(setup); + let circuit = self.build(&setup, &symbols); + let circuit = Self::mi_elim(circuit); + let circuit = if let Some(degree) = self.config.max_degree { + Self::reduce(circuit, degree) + } else { + circuit + }; + + let circuit = circuit.with_trace(InterpreterTraceGenerator::new( + ast, + symbols, + self.mapping, + self.config.max_steps, + )); + + Ok(CompilerResultLegacy { + messages: self.messages, + circuit, + }) + } + + fn parse( + &mut self, + source: &str, + debug_sym_ref_factory: &DebugSymRefFactory, + ) -> Result>, ()> { + let result = TLDeclsParser::new().parse(debug_sym_ref_factory, source); + + match result { + Ok(ast) => Ok(ast), + Err(error) => { + self.messages.push(Message::ParseErr { + msg: error.to_string(), + }); + Err(()) + } + } + } + + fn add_virtual( + &mut self, + mut ast: Vec>, + ) -> Vec> { + for tldc in ast.iter_mut() { + match tldc { + TLDecl::MachineDecl { + dsym, + id: _, + input_params: _, + output_params, + block, + } => self.add_virtual_to_machine(dsym, output_params, block), + } + } + + ast + } + + fn add_virtual_to_machine( + &mut self, + dsym: &DebugSymRef, + output_params: &Vec>, + block: &mut Statement, + ) { + let dsym = DebugSymRef::into_virtual(dsym); + let output_params = Self::get_decls(output_params); + + if let Statement::Block(_, stmts) = block { + let mut has_final = false; + + for stmt in stmts.iter() { + if let Statement::StateDecl(_, id, _) = stmt + && id.name() == "final" + { + has_final = true + } + } + if !has_final { + stmts.push(Statement::StateDecl( + dsym.clone(), + Identifier::new("final", dsym.clone()), + Box::new(Statement::Block(dsym.clone(), vec![])), + )); + } + + let final_state = Self::find_state_mut("final", stmts).unwrap(); + + let mut padding_transitions = output_params + .iter() + .map(|output_signal| { + Statement::SignalAssignmentAssert( + dsym.clone(), + vec![output_signal.id.next()], + vec![Expression::Query::( + dsym.clone(), + output_signal.id.clone(), + )], + ) + }) + .collect::>(); + + padding_transitions.push(Statement::Transition( + dsym.clone(), + Identifier::new("__padding", dsym.clone()), + Box::new(Statement::Block(dsym.clone(), vec![])), + )); + + Self::add_virtual_to_state(final_state, padding_transitions.clone()); + + stmts.push(Statement::StateDecl( + dsym.clone(), + Identifier::new("__padding", dsym.clone()), + Box::new(Statement::Block(dsym.clone(), padding_transitions)), + )); + } // Semantic analyser must show an error in the else case + } + + fn find_state_mut>( + state_id: S, + stmts: &mut [Statement], + ) -> Option<&mut Statement> { + let state_id = state_id.into(); + let mut final_state: Option<&mut Statement> = None; + + for stmt in stmts.iter_mut() { + if let Statement::StateDecl(_, id, _) = stmt + && id.name() == state_id + { + final_state = Some(stmt) + } + } + + final_state + } + + fn add_virtual_to_state( + state: &mut Statement, + add_statements: Vec>, + ) { + if let Statement::StateDecl(_, _, final_state_stmts) = state { + if let Statement::Block(_, stmts) = final_state_stmts.as_mut() { + stmts.extend(add_statements) + } + } + } + + fn semantic(&mut self, ast: &[TLDecl]) -> Result { + let result = super::semantic::analyser::analyse(ast); + let has_errors = result.messages.has_errors(); + + self.messages.extend(result.messages); + + if has_errors { + Err(()) + } else { + Ok(result.symbols) + } + } + + fn interpret(ast: &[TLDecl], symbols: &SymTable) -> Setup { + interpret(ast, symbols) + } + + fn map_consts(setup: Setup) -> Setup { + setup + .iter() + .map(|(machine_id, machine)| { + let poly_constraints: HashMap>> = machine + .iter_states_poly_constraints() + .map(|(step_id, step)| { + let new_step: Vec> = + step.iter().map(|pi| Self::map_pi_consts(pi)).collect(); + + (step_id.clone(), new_step) + }) + .collect(); + + let new_machine: MachineSetup = + machine.replace_poly_constraints(poly_constraints); + (machine_id.clone(), new_machine) + }) + .collect() + } + + fn map_pi_consts(expr: &Expr) -> Expr { + use Expr::*; + match expr { + Const(v, _) => Const(F::from_big_int(v), ()), + Sum(ses, _) => Sum(ses.iter().map(|se| Self::map_pi_consts(se)).collect(), ()), + Mul(ses, _) => Mul(ses.iter().map(|se| Self::map_pi_consts(se)).collect(), ()), + Neg(se, _) => Neg(Box::new(Self::map_pi_consts(se)), ()), + Pow(se, exp, _) => Pow(Box::new(Self::map_pi_consts(se)), *exp, ()), + Query(q, _) => Query(q.clone(), ()), + Halo2Expr(_, _) => todo!(), + MI(se, _) => MI(Box::new(Self::map_pi_consts(se)), ()), + } + } + + fn build( + &mut self, + setup: &Setup, + symbols: &SymTable, + ) -> SBPIRLegacy { + circuit_legacy::("circuit", |ctx| { + for (machine_id, machine) in setup { + self.add_forwards(ctx, symbols, machine_id); + self.add_step_type_handlers(ctx, symbols, machine_id); + + ctx.pragma_num_steps(self.config.max_steps); + ctx.pragma_first_step(self.mapping.get_step_type_handler(machine_id, "initial")); + ctx.pragma_last_step(self.mapping.get_step_type_handler(machine_id, "__padding")); + + for state_id in machine.states() { + ctx.step_type_def( + self.mapping.get_step_type_handler(machine_id, state_id), + |ctx| { + self.add_internals(ctx, symbols, machine_id, state_id); + + ctx.setup(|ctx| { + let poly_constraints = + self.translate_queries(symbols, setup, machine_id, state_id); + poly_constraints.iter().for_each(|poly| { + let constraint = Constraint { + annotation: format!("{:?}", poly), + expr: poly.clone(), + typing: Typing::AntiBooly, + }; + ctx.constr(constraint); + }); + }); + + ctx.wg(|_, _: ()| {}) + }, + ); + } + } + + ctx.trace(|_, _| {}); + }) + .without_trace() + } + + fn mi_elim( + mut circuit: SBPIRLegacy, + ) -> SBPIRLegacy { + for (_, step_type) in circuit.step_types.iter_mut() { + let mut signal_factory = SignalFactory::default(); + + step_type.decomp_constraints(|expr| mi_elimination(expr.clone(), &mut signal_factory)); + } + + circuit + } + + fn reduce( + mut circuit: SBPIRLegacy, + degree: usize, + ) -> SBPIRLegacy { + for (_, step_type) in circuit.step_types.iter_mut() { + let mut signal_factory = SignalFactory::default(); + + step_type.decomp_constraints(|expr| { + reduce_degree(expr.clone(), degree, &mut signal_factory) + }); + } + + circuit + } + + #[allow(dead_code)] + fn cse(mut _circuit: SBPIRLegacy) -> SBPIRLegacy { + todo!() + } + + fn translate_queries( + &mut self, + symbols: &SymTable, + setup: &Setup, + machine_id: &str, + state_id: &str, + ) -> Vec, ()>> { + let exprs = setup + .get(machine_id) + .unwrap() + .get_poly_constraints(state_id) + .unwrap(); + + exprs + .iter() + .map(|expr| self.translate_queries_expr(symbols, machine_id, state_id, expr)) + .collect() + } + + fn translate_queries_expr( + &mut self, + symbols: &SymTable, + machine_id: &str, + state_id: &str, + expr: &Expr, + ) -> Expr, ()> { + use Expr::*; + match expr { + Const(v, _) => Const(*v, ()), + Sum(ses, _) => Sum( + ses.iter() + .map(|se| self.translate_queries_expr(symbols, machine_id, state_id, se)) + .collect(), + (), + ), + Mul(ses, _) => Mul( + ses.iter() + .map(|se| self.translate_queries_expr(symbols, machine_id, state_id, se)) + .collect(), + (), + ), + Neg(se, _) => Neg( + Box::new(self.translate_queries_expr(symbols, machine_id, state_id, se.as_ref())), + (), + ), + Pow(se, exp, _) => Pow( + Box::new(self.translate_queries_expr(symbols, machine_id, state_id, se.as_ref())), + *exp, + (), + ), + MI(se, _) => MI( + Box::new(self.translate_queries_expr(symbols, machine_id, state_id, se.as_ref())), + (), + ), + Halo2Expr(se, _) => Halo2Expr(se.clone(), ()), + Query(id, _) => Query(self.translate_query(symbols, machine_id, state_id, id), ()), + } + } + + fn translate_query( + &mut self, + symbols: &SymTable, + machine_id: &str, + state_id: &str, + id: &Identifier, + ) -> Queriable { + use super::semantic::{ScopeCategory, SymbolCategory::*}; + + let symbol = symbols + .find_symbol( + &[ + "/".to_string(), + machine_id.to_string(), + state_id.to_string(), + ], + id.name(), + ) + .unwrap_or_else(|| panic!("semantic analyser fail: undeclared id {}", id.name())); + + match symbol.symbol.category { + InputSignal | OutputSignal | InoutSignal => { + self.translate_forward_queriable(machine_id, id) + } + Signal => match symbol.scope_cat { + ScopeCategory::Machine => self.translate_forward_queriable(machine_id, id), + ScopeCategory::State => { + if id.rotation() != 0 { + unreachable!("semantic analyser should prevent this"); + } + let signal = self + .mapping + .get_internal(&format!("//{}/{}", machine_id, state_id), &id.name()); + + Queriable::Internal(signal) + } + + ScopeCategory::Global => unreachable!("no global signals"), + }, + + State => { + Queriable::StepTypeNext(self.mapping.get_step_type_handler(machine_id, &id.name())) + } + + _ => unreachable!("semantic analysis should prevent this"), + } + } + + fn translate_forward_queriable(&mut self, machine_id: &str, id: &Identifier) -> Queriable { + let forward = self.mapping.get_forward(machine_id, &id.name()); + let rot = if id.rotation() == 0 { + false + } else if id.rotation() == 1 { + true + } else { + unreachable!("semantic analyser should prevent this") + }; + + Queriable::Forward(forward, rot) + } + + fn get_all_internals( + &mut self, + symbols: &SymTable, + machine_id: &str, + state_id: &str, + ) -> Vec { + let symbols = symbols + .get_scope(&[ + "/".to_string(), + machine_id.to_string(), + state_id.to_string(), + ]) + .expect("scope not found") + .get_symbols(); + + symbols + .iter() + .filter(|(_, entry)| entry.category == SymbolCategory::Signal) + .map(|(id, _)| id) + .cloned() + .collect() + } + + fn add_internals( + &mut self, + ctx: &mut StepTypeContext, + symbols: &SymTable, + machine_id: &str, + state_id: &str, + ) { + let internal_ids = self.get_all_internals(symbols, machine_id, state_id); + let scope_name = format!("//{}/{}", machine_id, state_id); + + for internal_id in internal_ids { + let name = format!("{}:{}", &scope_name, internal_id); + + let queriable = ctx.internal(name.as_str()); + if let Queriable::Internal(signal) = queriable { + self.mapping + .symbol_uuid + .insert((scope_name.clone(), internal_id), signal.uuid()); + self.mapping.internal_signals.insert(signal.uuid(), signal); + } else { + unreachable!("ctx.internal returns not internal signal"); + } + } + } + + fn add_step_type_handlers>( + &mut self, + ctx: &mut CircuitContextLegacy, + symbols: &SymTable, + machine_id: &str, + ) { + let symbols = symbols + .get_scope(&["/".to_string(), machine_id.to_string()]) + .expect("scope not found") + .get_symbols(); + + let state_ids: Vec<_> = symbols + .iter() + .filter(|(_, entry)| entry.category == SymbolCategory::State) + .map(|(id, _)| id) + .cloned() + .collect(); + + for state_id in state_ids { + let scope_name = format!("//{}", machine_id); + let name = format!("{}:{}", scope_name, state_id); + + let handler = ctx.step_type(&name); + self.mapping + .step_type_handler + .insert(handler.uuid(), handler); + self.mapping + .symbol_uuid + .insert((scope_name, state_id), handler.uuid()); + } + } + + fn add_forwards>( + &mut self, + ctx: &mut CircuitContextLegacy, + symbols: &SymTable, + machine_id: &str, + ) { + let symbols = symbols + .get_scope(&["/".to_string(), machine_id.to_string()]) + .expect("scope not found") + .get_symbols(); + + let forward_ids: Vec<_> = symbols + .iter() + .filter(|(_, entry)| entry.is_signal()) + .map(|(id, _)| id) + .cloned() + .collect(); + + for forward_id in forward_ids { + let scope_name = format!("//{}", machine_id); + let name = format!("{}:{}", scope_name, forward_id); + + let queriable = ctx.forward(name.as_str()); + if let Queriable::Forward(signal, _) = queriable { + self.mapping + .symbol_uuid + .insert((scope_name, forward_id), signal.uuid()); + self.mapping.forward_signals.insert(signal.uuid(), signal); + } else { + unreachable!("ctx.internal returns not internal signal"); + } + } + } + + fn get_decls(stmts: &Vec>) -> Vec> { + let mut result: Vec> = vec![]; + + for stmt in stmts { + if let Statement::SignalDecl(_, ids) = stmt { + result.extend(ids.clone()) + } + } + + result + } +} + +// Basic signal factory. +#[derive(Default)] +struct SignalFactory { + count: u64, + _p: PhantomData, +} + +impl poly::SignalFactory> for SignalFactory { + fn create>(&mut self, annotation: S) -> Queriable { + self.count += 1; + Queriable::Internal(InternalSignal::new(format!( + "{}-{}", + annotation.into(), + self.count + ))) + } +} + +#[cfg(test)] +mod test { + use halo2_proofs::halo2curves::bn256::Fr; + + use crate::{ + compiler::{compile_file_legacy, compile_legacy}, + parser::ast::debug_sym_factory::DebugSymRefFactory, + }; + + use super::Config; + + #[test] + fn test_compiler_fibo() { + let circuit = " + machine fibo(signal n) (signal b: field) { + // n and be are created automatically as shared + // signals + signal a: field, i; + + // there is always a state called initial + // input signals get bound to the signal + // in the initial state (first instance) + state initial { + signal c; + + i, a, b, c <== 1, 1, 1, 2; + + -> middle { + a', b', n' <== b, c, n; + } + } + + state middle { + signal c; + + c <== a + b; + + if i + 1 == n { + -> final { + i', b', n' <== i + 1, c, n; + } + } else { + -> middle { + i', a', b', n' <== i + 1, b, c, n; + } + } + } + + // There is always a state called final. + // Output signals get automatically bound to the signals + // with the same name in the final step (last instance). + // This state can be implicit if there are no constraints in it. + } + "; + + let debug_sym_ref_factory = DebugSymRefFactory::new("", circuit); + let result = compile_legacy::( + circuit, + Config::default().max_degree(2), + &debug_sym_ref_factory, + ); + + match result { + Ok(result) => println!("{:#?}", result), + Err(messages) => println!("{:#?}", messages), + } + } + + #[test] + fn test_compiler_fibo_file() { + let path = "test/circuit.chiquito"; + let result = compile_file_legacy::(path, Config::default().max_degree(2)); + assert!(result.is_ok()); + } + + #[test] + fn test_compiler_fibo_file_err() { + let path = "test/circuit_error.chiquito"; + let result = compile_file_legacy::(path, Config::default().max_degree(2)); + + assert!(result.is_err()); + + assert_eq!( + format!("{:?}", result.unwrap_err()), + r#"[SemErr { msg: "use of undeclared variable c", dsym: test/circuit_error.chiquito:24:39 }, SemErr { msg: "use of undeclared variable c", dsym: test/circuit_error.chiquito:28:46 }]"# + ) + } +} diff --git a/src/compiler/mod.rs b/src/compiler/mod.rs index e590396c..9a6bfacd 100644 --- a/src/compiler/mod.rs +++ b/src/compiler/mod.rs @@ -4,9 +4,9 @@ use std::{ io::{self, Read}, }; -use compiler::CompilerResult; +use compiler::{Compiler, CompilerResult}; +use compiler_legacy::{CompilerLegacy, CompilerResultLegacy}; -use self::compiler::{Compiler, CompilerResultLegacy}; use crate::{ field::Field, parser::ast::{debug_sym_factory::DebugSymRefFactory, DebugSymRef}, @@ -15,6 +15,7 @@ use crate::{ pub mod abepi; #[allow(clippy::module_inception)] pub mod compiler; +pub mod compiler_legacy; pub mod semantic; mod setup_inter; @@ -74,7 +75,7 @@ pub fn compile_legacy( config: Config, debug_sym_ref_factory: &DebugSymRefFactory, ) -> Result, Vec> { - Compiler::new(config).compile_legacy(source, debug_sym_ref_factory) + CompilerLegacy::new(config).compile(source, debug_sym_ref_factory) } /// Compiles chiquito source code file into a SBPIR for a single machine, also returns messages diff --git a/src/frontend/dsl/circuit_context_legacy.rs b/src/frontend/dsl/circuit_context_legacy.rs new file mode 100644 index 00000000..082387c9 --- /dev/null +++ b/src/frontend/dsl/circuit_context_legacy.rs @@ -0,0 +1,209 @@ +use crate::{ + field::Field, + sbpir::{query::Queriable, ExposeOffset, SBPIRLegacy}, + wit_gen::{FixedGenContext, StepInstance, TraceGenerator}, +}; + +use halo2_proofs::plonk::{Advice, Column as Halo2Column, Fixed}; + +use core::{fmt::Debug, hash::Hash}; + +use super::{ + lb::{LookupTable, LookupTableRegistry, LookupTableStore}, + trace::{DSLTraceGenerator, TraceContext}, + StepTypeContext, StepTypeDefInput, StepTypeHandler, StepTypeWGHandler, +}; + +#[derive(Debug, Default)] +/// A generic structure designed to handle the context of a circuit. +/// The struct contains a `Circuit` instance and implements methods to build the circuit, +/// add various components, and manipulate the circuit. +/// +/// ### Type parameters +/// `F` is the field of the circuit. +/// `TG` is the trace generator. +/// +/// (LEGACY) +pub struct CircuitContextLegacy = DSLTraceGenerator> { + pub(crate) circuit: SBPIRLegacy, + pub(crate) tables: LookupTableRegistry, +} + +impl> CircuitContextLegacy { + /// Adds a forward signal to the circuit with a name string and zero rotation and returns a + /// `Queriable` instance representing the added forward signal. + pub fn forward(&mut self, name: &str) -> Queriable { + Queriable::Forward(self.circuit.add_forward(name, 0), false) + } + + /// Adds a forward signal to the circuit with a name string and a specified phase and returns a + /// `Queriable` instance representing the added forward signal. + pub fn forward_with_phase(&mut self, name: &str, phase: usize) -> Queriable { + Queriable::Forward(self.circuit.add_forward(name, phase), false) + } + + /// Adds a shared signal to the circuit with a name string and zero rotation and returns a + /// `Queriable` instance representing the added shared signal. + pub fn shared(&mut self, name: &str) -> Queriable { + Queriable::Shared(self.circuit.add_shared(name, 0), 0) + } + + /// Adds a shared signal to the circuit with a name string and a specified phase and returns a + /// `Queriable` instance representing the added shared signal. + pub fn shared_with_phase(&mut self, name: &str, phase: usize) -> Queriable { + Queriable::Shared(self.circuit.add_shared(name, phase), 0) + } + + pub fn fixed(&mut self, name: &str) -> Queriable { + Queriable::Fixed(self.circuit.add_fixed(name), 0) + } + + /// Exposes the first step instance value of a forward signal as public. + pub fn expose(&mut self, queriable: Queriable, offset: ExposeOffset) { + self.circuit.expose(queriable, offset); + } + + /// Imports a halo2 advice column with a name string into the circuit and returns a + /// `Queriable` instance representing the imported column. + pub fn import_halo2_advice(&mut self, name: &str, column: Halo2Column) -> Queriable { + Queriable::Halo2AdviceQuery(self.circuit.add_halo2_advice(name, column), 0) + } + + /// Imports a halo2 fixed column with a name string into the circuit and returns a + /// `Queriable` instance representing the imported column. + pub fn import_halo2_fixed(&mut self, name: &str, column: Halo2Column) -> Queriable { + Queriable::Halo2FixedQuery(self.circuit.add_halo2_fixed(name, column), 0) + } + + /// Adds a new step type with the specified name to the circuit and returns a + /// `StepTypeHandler` instance. The `StepTypeHandler` instance can be used to define the + /// step type using the `step_type_def` function. + pub fn step_type(&mut self, name: &str) -> StepTypeHandler { + let handler = StepTypeHandler::new(name.to_string()); + + self.circuit.add_step_type(handler, name); + + handler + } + + /// Defines a step type using the provided `StepTypeHandler` and a function that takes a + /// mutable reference to a `StepTypeContext`. This function typically adds constraints to a + /// step type and defines witness generation. + pub fn step_type_def, R>( + &mut self, + step: S, + def: D, + ) -> StepTypeWGHandler + where + D: FnOnce(&mut StepTypeContext) -> StepTypeWGHandler, + R: Fn(&mut StepInstance, Args) + 'static, + { + let handler: StepTypeHandler = match step.into() { + StepTypeDefInput::Handler(h) => h, + StepTypeDefInput::String(name) => { + let handler = StepTypeHandler::new(name.to_string()); + + self.circuit.add_step_type(handler, name); + + handler + } + }; + + let mut context = StepTypeContext::::new( + handler.uuid(), + handler.annotation.to_string(), + self.tables.clone(), + ); + + let result = def(&mut context); + + self.circuit.add_step_type_def(context.step_type); + + result + } + + pub fn new_table(&self, table: LookupTableStore) -> LookupTable { + let uuid = table.uuid(); + self.tables.add(table); + + LookupTable { uuid } + } + + /// Enforce the type of the first step by adding a constraint to the circuit. Takes a + /// `StepTypeHandler` parameter that represents the step type. + pub fn pragma_first_step>(&mut self, step_type: STH) { + self.circuit.first_step = Some(step_type.into().uuid()); + } + + /// Enforce the type of the last step by adding a constraint to the circuit. Takes a + /// `StepTypeHandler` parameter that represents the step type. + pub fn pragma_last_step>(&mut self, step_type: STH) { + self.circuit.last_step = Some(step_type.into().uuid()); + } + + /// Enforce the number of step instances by adding a constraint to the circuit. Takes a `usize` + /// parameter that represents the total number of steps. + pub fn pragma_num_steps(&mut self, num_steps: usize) { + self.circuit.num_steps = num_steps; + } + + pub fn pragma_disable_q_enable(&mut self) { + self.circuit.q_enable = false; + } +} + +impl CircuitContextLegacy> { + /// Sets the trace function that builds the witness. The trace function is responsible for + /// adding step instances defined in `step_type_def`. The function is entirely left for + /// the user to implement and is Turing complete. Users typically use external parameters + /// of type `TraceArgs` to generate cell values for witness generation, and call the + /// `add` function to add step instances with witness values. + pub fn trace(&mut self, def: D) + where + D: Fn(&mut TraceContext, TraceArgs) + 'static, + { + self.circuit.set_trace(def); + } +} + +impl> CircuitContextLegacy { + /// Executes the fixed generation function provided by the user and sets the fixed assignments + /// for the circuit. The fixed generation function is responsible for assigning fixed values to + /// fixed columns. It is entirely left for the user to implement and is Turing complete. Users + /// typically generate cell values and call the `assign` function to fill the fixed columns. + pub fn fixed_gen(&mut self, def: D) + where + D: Fn(&mut FixedGenContext) + 'static, + { + if self.circuit.num_steps == 0 { + panic!("circuit must call pragma_num_steps before calling fixed_gen"); + } + let mut ctx = FixedGenContext::new(self.circuit.num_steps); + (def)(&mut ctx); + + let assignments = ctx.get_assignments(); + + self.circuit.set_fixed_assignments(assignments); + } +} + +/// Creates a `Circuit` instance by providing a name and a definition closure that is applied to a +/// mutable `CircuitContext`. The user customizes the definition closure by calling `CircuitContext` +/// functions. This is the main function that users call to define a Chiquito circuit (legacy). +pub fn circuit_legacy( + _name: &str, + mut def: D, +) -> SBPIRLegacy> +where + D: FnMut(&mut CircuitContextLegacy>), +{ + // TODO annotate circuit + let mut context = CircuitContextLegacy { + circuit: SBPIRLegacy::default(), + tables: LookupTableRegistry::default(), + }; + + def(&mut context); + + context.circuit +} diff --git a/src/frontend/dsl/mod.rs b/src/frontend/dsl/mod.rs index 6710b114..cc21d054 100644 --- a/src/frontend/dsl/mod.rs +++ b/src/frontend/dsl/mod.rs @@ -1,6 +1,8 @@ use crate::{ field::Field, - sbpir::{query::Queriable, ExposeOffset, SBPIRLegacy, StepType, StepTypeUUID, PIR}, + sbpir::{ + query::Queriable, sbpir_machine::SBPIRMachine, ExposeOffset, StepType, StepTypeUUID, PIR, + }, util::{uuid, UUID}, wit_gen::{FixedGenContext, StepInstance, TraceGenerator}, }; @@ -19,24 +21,17 @@ use self::{ pub use sc::*; pub mod cb; +pub mod circuit_context_legacy; pub mod lb; pub mod sc; pub mod trace; -#[derive(Debug, Default)] -/// A generic structure designed to handle the context of a circuit. -/// The struct contains a `Circuit` instance and implements methods to build the circuit, -/// add various components, and manipulate the circuit. -/// -/// ### Type parameters -/// `F` is the field of the circuit. -/// `TG` is the trace generator. -pub struct CircuitContext = DSLTraceGenerator> { - circuit: SBPIRLegacy, +pub struct CircuitContext = DSLTraceGenerator> { + circuit: SBPIRMachine, tables: LookupTableRegistry, } -impl> CircuitContext { +impl> CircuitContext { /// Adds a forward signal to the circuit with a name string and zero rotation and returns a /// `Queriable` instance representing the added forward signal. pub fn forward(&mut self, name: &str) -> Queriable { @@ -424,13 +419,13 @@ impl, Args) + 'static> StepTypeWGHandler( _name: &str, mut def: D, -) -> SBPIRLegacy> +) -> SBPIRMachine> where D: FnMut(&mut CircuitContext>), { // TODO annotate circuit let mut context = CircuitContext { - circuit: SBPIRLegacy::default(), + circuit: SBPIRMachine::default(), tables: LookupTableRegistry::default(), }; @@ -441,18 +436,22 @@ where #[cfg(test)] mod tests { + use circuit_context_legacy::CircuitContextLegacy; use halo2_proofs::halo2curves::bn256::Fr; - use crate::{sbpir::ForwardSignal, wit_gen::NullTraceGenerator}; + use crate::{ + sbpir::{ForwardSignal, SBPIRLegacy}, + wit_gen::NullTraceGenerator, + }; use super::*; - fn setup_circuit_context() -> CircuitContext + fn setup_circuit_context() -> CircuitContextLegacy where F: Default, TG: TraceGenerator, { - CircuitContext { + CircuitContextLegacy { circuit: SBPIRLegacy::default(), tables: Default::default(), } diff --git a/src/frontend/dsl/sc.rs b/src/frontend/dsl/sc.rs index 860010c7..b7cdf974 100644 --- a/src/frontend/dsl/sc.rs +++ b/src/frontend/dsl/sc.rs @@ -17,7 +17,9 @@ use crate::{ wit_gen::{NullTraceGenerator, TraceGenerator}, }; -use super::{lb::LookupTableRegistry, trace::DSLTraceGenerator, CircuitContext}; +use super::{ + circuit_context_legacy::CircuitContextLegacy, lb::LookupTableRegistry, trace::DSLTraceGenerator, +}; pub struct SuperCircuitContext { super_circuit: SuperCircuit, @@ -61,9 +63,9 @@ impl SuperCircuitContext { Exports, ) where - D: Fn(&mut CircuitContext>, Imports) -> Exports, + D: Fn(&mut CircuitContextLegacy>, Imports) -> Exports, { - let mut sub_circuit_context = CircuitContext { + let mut sub_circuit_context = CircuitContextLegacy { circuit: SBPIRLegacy::default(), tables: self.tables.clone(), }; @@ -148,6 +150,7 @@ mod tests { use halo2_proofs::halo2curves::{bn256::Fr, ff::PrimeField}; use crate::{ + frontend::dsl::circuit_context_legacy::circuit_legacy, plonkish::compiler::{ cell_manager::SingleRowCellManager, config, step_selector::SimpleStepSelectorBuilder, }, @@ -180,7 +183,7 @@ mod tests { let mut ctx = SuperCircuitContext::::default(); fn simple_circuit( - ctx: &mut CircuitContext>, + ctx: &mut CircuitContextLegacy>, _: (), ) { use crate::frontend::dsl::cb::*; @@ -239,7 +242,7 @@ mod tests { let mut ctx = SuperCircuitContext::::default(); fn simple_circuit( - ctx: &mut CircuitContext>, + ctx: &mut CircuitContextLegacy>, _: (), ) { use crate::frontend::dsl::cb::*; @@ -296,10 +299,9 @@ mod tests { #[test] fn test_super_circuit_sub_circuit_with_ast() { - use crate::frontend::dsl::circuit; let mut ctx = SuperCircuitContext::::default(); - let simple_circuit_with_ast = circuit("simple circuit", |ctx| { + let simple_circuit_with_ast = circuit_legacy("simple circuit", |ctx| { use crate::frontend::dsl::cb::*; let x = ctx.forward("x"); diff --git a/src/interpreter/mod.rs b/src/interpreter/mod.rs index e6177239..7cd05010 100644 --- a/src/interpreter/mod.rs +++ b/src/interpreter/mod.rs @@ -255,7 +255,7 @@ pub fn run( } /// A trace generator that interprets chiquito source -#[derive(Default, Clone)] +#[derive(Debug, Default, Clone)] pub struct InterpreterTraceGenerator { program: Vec>, symbols: SymTable, diff --git a/src/poly/mod.rs b/src/poly/mod.rs index 2e5942f9..24fedf35 100644 --- a/src/poly/mod.rs +++ b/src/poly/mod.rs @@ -31,8 +31,8 @@ pub enum Expr { Pow(Box>, u32, M), Query(V, M), Halo2Expr(Expression, M), - - MI(Box>, M), // Multiplicative inverse, but MI(0) = 0 + /// Multiplicative inverse, but MI(0) = 0 + MI(Box>, M), } impl Expr { diff --git a/src/sbpir/mod.rs b/src/sbpir/mod.rs index 14f46c34..7c5076f5 100644 --- a/src/sbpir/mod.rs +++ b/src/sbpir/mod.rs @@ -269,23 +269,41 @@ impl> SBPIRLegacy { } } -pub struct SBPIR = DSLTraceGenerator> { - pub machines: HashMap>, +#[derive(Debug)] +pub struct SBPIR = DSLTraceGenerator> { + pub machines: HashMap>, pub identifiers: HashMap, } -impl> SBPIR { - pub(crate) fn from_legacy(circuit: SBPIRLegacy, machine_id: &str) -> SBPIR { - let mut machines = HashMap::new(); - let circuit_id = circuit.id; - machines.insert(circuit_id, SBPIRMachine::from_legacy(circuit)); - let mut identifiers = HashMap::new(); - identifiers.insert(machine_id.to_string(), circuit_id); +impl> SBPIR { + pub(crate) fn default() -> SBPIR { + let machines = HashMap::new(); + let identifiers = HashMap::new(); SBPIR { machines, identifiers, } } + + pub(crate) fn with_trace + Clone>( + &self, + // TODO does it have to be the same trace across all the machines? + trace: &TG2, + ) -> SBPIR { + let mut machines_with_trace = HashMap::new(); + for (name, machine) in self.machines.iter() { + let machine_with_trace = machine.with_trace(trace.clone()); + machines_with_trace.insert(name.clone(), machine_with_trace); + } + SBPIR { + machines: machines_with_trace, + identifiers: self.identifiers.clone(), + } + } + + pub(crate) fn add_machine(&mut self, name: &str, without_trace: SBPIRMachine) { + self.machines.insert(name.to_string(), without_trace); + } } pub type FixedGen = dyn Fn(&mut FixedGenContext) + 'static; diff --git a/src/sbpir/sbpir_machine.rs b/src/sbpir/sbpir_machine.rs index b5489b4d..0b79300c 100644 --- a/src/sbpir/sbpir_machine.rs +++ b/src/sbpir/sbpir_machine.rs @@ -16,9 +16,9 @@ use super::{ ImportedHalo2Fixed, SharedSignal, StepType, StepTypeUUID, }; -/// Circuit (Step-Based Polynomial Identity Representation) +/// Step-Based Polynomial Identity Representation (SBPIR) of a single machine. #[derive(Clone)] -pub struct SBPIRMachine = DSLTraceGenerator> { +pub struct SBPIRMachine = DSLTraceGenerator> { pub step_types: HashMap>, pub forward_signals: Vec, @@ -41,7 +41,7 @@ pub struct SBPIRMachine = DSLTraceGenerator> { pub id: UUID, } -impl> Debug for SBPIRMachine { +impl> Debug for SBPIRMachine { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Circuit") .field("step_types", &self.step_types) @@ -61,7 +61,7 @@ impl> Debug for SBPIRMachine { } } -impl> Default for SBPIRMachine { +impl> Default for SBPIRMachine { fn default() -> Self { Self { step_types: Default::default(), @@ -88,7 +88,7 @@ impl> Default for SBPIRMachine { } } -impl> SBPIRMachine { +impl> SBPIRMachine { pub fn add_forward>(&mut self, name: N, phase: usize) -> ForwardSignal { let name = name.into(); let signal = ForwardSignal::new_with_phase(phase, name.clone()); @@ -187,18 +187,18 @@ impl> SBPIRMachine { } } - pub fn without_trace(self) -> SBPIRMachine { + pub fn without_trace(&self) -> SBPIRMachine { SBPIRMachine { - step_types: self.step_types, - forward_signals: self.forward_signals, - shared_signals: self.shared_signals, - fixed_signals: self.fixed_signals, - halo2_advice: self.halo2_advice, - halo2_fixed: self.halo2_fixed, - exposed: self.exposed, - annotations: self.annotations, + step_types: self.step_types.clone(), + forward_signals: self.forward_signals.clone(), + shared_signals: self.shared_signals.clone(), + fixed_signals: self.fixed_signals.clone(), + halo2_advice: self.halo2_advice.clone(), + halo2_fixed: self.halo2_fixed.clone(), + exposed: self.exposed.clone(), + annotations: self.annotations.clone(), trace_generator: None, // Remove the trace. - fixed_assignments: self.fixed_assignments, + fixed_assignments: self.fixed_assignments.clone(), first_step: self.first_step, last_step: self.last_step, num_steps: self.num_steps, @@ -207,19 +207,18 @@ impl> SBPIRMachine { } } - #[allow(dead_code)] // TODO: Copy of the legacy SBPIR code. Remove if not used in the new compilation - pub(crate) fn with_trace>(self, trace: TG2) -> SBPIRMachine { + pub(crate) fn with_trace>(&self, clone: TG2) -> SBPIRMachine { SBPIRMachine { - trace_generator: Some(trace), // Change trace - step_types: self.step_types, - forward_signals: self.forward_signals, - shared_signals: self.shared_signals, - fixed_signals: self.fixed_signals, - halo2_advice: self.halo2_advice, - halo2_fixed: self.halo2_fixed, - exposed: self.exposed, - annotations: self.annotations, - fixed_assignments: self.fixed_assignments, + trace_generator: Some(clone), // Set trace + step_types: self.step_types.clone(), + forward_signals: self.forward_signals.clone(), + shared_signals: self.shared_signals.clone(), + fixed_signals: self.fixed_signals.clone(), + halo2_advice: self.halo2_advice.clone(), + halo2_fixed: self.halo2_fixed.clone(), + exposed: self.exposed.clone(), + annotations: self.annotations.clone(), + fixed_assignments: self.fixed_assignments.clone(), first_step: self.first_step, last_step: self.last_step, num_steps: self.num_steps, @@ -227,26 +226,6 @@ impl> SBPIRMachine { id: self.id, } } - - pub(crate) fn from_legacy(circuit: super::SBPIRLegacy) -> SBPIRMachine { - SBPIRMachine { - step_types: circuit.step_types, - forward_signals: circuit.forward_signals, - shared_signals: circuit.shared_signals, - fixed_signals: circuit.fixed_signals, - halo2_advice: circuit.halo2_advice, - halo2_fixed: circuit.halo2_fixed, - exposed: circuit.exposed, - annotations: circuit.annotations, - trace_generator: circuit.trace_generator, - fixed_assignments: circuit.fixed_assignments, - first_step: circuit.first_step, - last_step: circuit.last_step, - num_steps: circuit.num_steps, - q_enable: circuit.q_enable, - id: circuit.id, - } - } } impl SBPIRMachine> { From 96e83dc744496d33299988cc4ecd255d87d53a45 Mon Sep 17 00:00:00 2001 From: Alex Kuzmin Date: Fri, 9 Aug 2024 19:33:05 +0800 Subject: [PATCH 2/9] Restore CircuitContext docs --- src/frontend/dsl/mod.rs | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/frontend/dsl/mod.rs b/src/frontend/dsl/mod.rs index cc21d054..3bda965f 100644 --- a/src/frontend/dsl/mod.rs +++ b/src/frontend/dsl/mod.rs @@ -26,6 +26,14 @@ pub mod lb; pub mod sc; pub mod trace; +#[derive(Debug, Default)] +/// A generic structure designed to handle the context of a circuit. +/// The struct contains a `Circuit` instance and implements methods to build the circuit, +/// add various components, and manipulate the circuit. +/// +/// ### Type parameters +/// `F` is the field of the circuit. +/// `TG` is the trace generator. pub struct CircuitContext = DSLTraceGenerator> { circuit: SBPIRMachine, tables: LookupTableRegistry, From 5f892a0d975c5ee9d5206241c9a8eaaca9c3508f Mon Sep 17 00:00:00 2001 From: Alex Kuzmin Date: Tue, 20 Aug 2024 17:03:16 +0800 Subject: [PATCH 3/9] Refactor compiler to not use the "dsl" API --- src/compiler/compiler.rs | 479 +++++++++++++++++++------------- src/compiler/compiler_legacy.rs | 5 +- src/compiler/setup_inter.rs | 47 +++- src/frontend/dsl/mod.rs | 213 +------------- src/poly/mielim.rs | 4 +- src/sbpir/mod.rs | 66 ++++- src/sbpir/query.rs | 17 +- src/sbpir/sbpir_machine.rs | 4 + 8 files changed, 432 insertions(+), 403 deletions(-) diff --git a/src/compiler/compiler.rs b/src/compiler/compiler.rs index 5d6c04c6..e261870a 100644 --- a/src/compiler/compiler.rs +++ b/src/compiler/compiler.rs @@ -4,10 +4,7 @@ use num_bigint::BigInt; use crate::{ field::Field, - frontend::dsl::{ - cb::{Constraint, Typing}, - circuit, CircuitContext, StepTypeContext, - }, + frontend::dsl::StepTypeHandler, interpreter::InterpreterTraceGenerator, parser::{ ast::{ @@ -19,14 +16,16 @@ use crate::{ }, lang::TLDeclsParser, }, - poly::{self, mielim::mi_elimination, reduce::reduce_degree, Expr}, - sbpir::{query::Queriable, InternalSignal, SBPIR}, - wit_gen::{NullTraceGenerator, SymbolSignalMapping, TraceGenerator}, + poly::Expr, + sbpir::{ + query::Queriable, sbpir_machine::SBPIRMachine, Constraint, InternalSignal, StepType, SBPIR, + }, + wit_gen::{NullTraceGenerator, SymbolSignalMapping}, }; use super::{ semantic::{SymTable, SymbolCategory}, - setup_inter::{interpret, MachineSetup, Setup}, + setup_inter::{interpret, Setup}, Config, Message, Messages, }; @@ -72,12 +71,15 @@ impl Compiler { let ast = self.add_virtual(ast); let symbols = self.semantic(&ast).map_err(|_| self.messages.clone())?; let machine_setups = Self::interpret(&ast, &symbols); - let machine_setups = Self::map_consts(machine_setups); + let machine_setups = machine_setups + .iter() + .map(|(k, v)| (k.clone(), v.map_consts())) + .collect(); let circuit = self.build(&machine_setups, &symbols); - let circuit = Self::mi_elim(circuit); + let circuit = circuit.eliminate_mul_inv(); let circuit = if let Some(degree) = self.config.max_degree { - Self::reduce(circuit, degree) + circuit.reduce(degree) } else { circuit }; @@ -113,6 +115,7 @@ impl Compiler { } } + /// Adds "virtual" states to the AST (necessary to handle padding) fn add_virtual( &mut self, mut ast: Vec>, @@ -220,6 +223,8 @@ impl Compiler { } } + /// Semantic analysis of the AST + /// Returns the symbol table if successful fn semantic(&mut self, ast: &[TLDecl]) -> Result { let result = super::semantic::analyser::analyse(ast); let has_errors = result.messages.has_errors(); @@ -237,115 +242,57 @@ impl Compiler { interpret(ast, symbols) } - fn map_consts(setup: Setup) -> Setup { - setup - .iter() - .map(|(machine_id, machine)| { - let poly_constraints: HashMap>> = machine - .iter_states_poly_constraints() - .map(|(step_id, step)| { - let new_step: Vec> = - step.iter().map(|pi| Self::map_pi_consts(pi)).collect(); - - (step_id.clone(), new_step) - }) - .collect(); - - let new_machine: MachineSetup = - machine.replace_poly_constraints(poly_constraints); - (machine_id.clone(), new_machine) - }) - .collect() - } - - fn map_pi_consts(expr: &Expr) -> Expr { - use Expr::*; - match expr { - Const(v, _) => Const(F::from_big_int(v), ()), - Sum(ses, _) => Sum(ses.iter().map(|se| Self::map_pi_consts(se)).collect(), ()), - Mul(ses, _) => Mul(ses.iter().map(|se| Self::map_pi_consts(se)).collect(), ()), - Neg(se, _) => Neg(Box::new(Self::map_pi_consts(se)), ()), - Pow(se, exp, _) => Pow(Box::new(Self::map_pi_consts(se)), *exp, ()), - Query(q, _) => Query(q.clone(), ()), - Halo2Expr(_, _) => todo!(), - MI(se, _) => MI(Box::new(Self::map_pi_consts(se)), ()), - } - } - fn build(&mut self, setup: &Setup, symbols: &SymTable) -> SBPIR { let mut sbpir = SBPIR::default(); - for (machine_id, machine) in setup { - let sbpir_machine = circuit::("circuit", |ctx| { - self.add_forwards(ctx, symbols, machine_id); - self.add_step_type_handlers(ctx, symbols, machine_id); - - ctx.pragma_num_steps(self.config.max_steps); - ctx.pragma_first_step(self.mapping.get_step_type_handler(machine_id, "initial")); - ctx.pragma_last_step(self.mapping.get_step_type_handler(machine_id, "__padding")); - - for state_id in machine.states() { - ctx.step_type_def( - self.mapping.get_step_type_handler(machine_id, state_id), - |ctx| { - self.add_internals(ctx, symbols, machine_id, state_id); - - ctx.setup(|ctx| { - let poly_constraints = - self.translate_queries(symbols, setup, machine_id, state_id); - poly_constraints.iter().for_each(|poly| { - let constraint = Constraint { - annotation: format!("{:?}", poly), - expr: poly.clone(), - typing: Typing::AntiBooly, - }; - ctx.constr(constraint); - }); - }); - - ctx.wg(|_, _: ()| {}) - }, - ); - } - - ctx.trace(|_, _| {}); - }) - .without_trace(); - - sbpir.add_machine(machine_id, sbpir_machine); - } - - sbpir - } + for (machine_name, machine_setup) in setup { + let mut sbpir_machine = SBPIRMachine::default(); + self.add_forward_signals(&mut sbpir_machine, symbols, machine_name); + self.add_step_type_handlers(&mut sbpir_machine, symbols, machine_name); - fn mi_elim(mut circuit: SBPIR) -> SBPIR { - for machine in circuit.machines.values_mut() { - for (_, step_type) in machine.step_types.iter_mut() { - let mut signal_factory = SignalFactory::default(); + sbpir_machine.num_steps = self.config.max_steps; + sbpir_machine.first_step = Some( + self.mapping + .get_step_type_handler(machine_name, "initial") + .uuid(), + ); + sbpir_machine.last_step = Some( + self.mapping + .get_step_type_handler(machine_name, "__padding") + .uuid(), + ); + + for state_id in machine_setup.states() { + let handler = self.mapping.get_step_type_handler(machine_name, state_id); + + let mut step_type = StepType::new(handler.uuid(), handler.annotation.to_string()); + + self.add_internal_signals( + symbols, + machine_name, + &mut sbpir_machine, + &mut step_type, + state_id, + ); + + let poly_constraints = + self.translate_queries(symbols, setup, machine_name, state_id); + poly_constraints.iter().for_each(|poly| { + let constraint = Constraint { + annotation: format!("{:?}", poly), + expr: poly.clone(), + }; + + step_type.constraints.push(constraint) + }); - step_type - .decomp_constraints(|expr| mi_elimination(expr.clone(), &mut signal_factory)); + sbpir_machine.add_step_type_def(step_type); } - } - - circuit - } - - fn reduce( - mut circuit: SBPIR, - degree: usize, - ) -> SBPIR { - for machine in circuit.machines.values_mut() { - for (_, step_type) in machine.step_types.iter_mut() { - let mut signal_factory = SignalFactory::default(); - step_type.decomp_constraints(|expr| { - reduce_degree(expr.clone(), degree, &mut signal_factory) - }); - } + sbpir.machines.insert(machine_name.clone(), sbpir_machine); } - circuit + sbpir.without_trace() } #[allow(dead_code)] @@ -357,25 +304,25 @@ impl Compiler { &mut self, symbols: &SymTable, setup: &Setup, - machine_id: &str, + machine_name: &str, state_id: &str, ) -> Vec, ()>> { let exprs = setup - .get(machine_id) + .get(machine_name) .unwrap() .get_poly_constraints(state_id) .unwrap(); exprs .iter() - .map(|expr| self.translate_queries_expr(symbols, machine_id, state_id, expr)) + .map(|expr| self.translate_queries_expr(symbols, machine_name, state_id, expr)) .collect() } fn translate_queries_expr( &mut self, symbols: &SymTable, - machine_id: &str, + machine_name: &str, state_id: &str, expr: &Expr, ) -> Expr, ()> { @@ -384,38 +331,41 @@ impl Compiler { Const(v, _) => Const(*v, ()), Sum(ses, _) => Sum( ses.iter() - .map(|se| self.translate_queries_expr(symbols, machine_id, state_id, se)) + .map(|se| self.translate_queries_expr(symbols, machine_name, state_id, se)) .collect(), (), ), Mul(ses, _) => Mul( ses.iter() - .map(|se| self.translate_queries_expr(symbols, machine_id, state_id, se)) + .map(|se| self.translate_queries_expr(symbols, machine_name, state_id, se)) .collect(), (), ), Neg(se, _) => Neg( - Box::new(self.translate_queries_expr(symbols, machine_id, state_id, se.as_ref())), + Box::new(self.translate_queries_expr(symbols, machine_name, state_id, se.as_ref())), (), ), Pow(se, exp, _) => Pow( - Box::new(self.translate_queries_expr(symbols, machine_id, state_id, se.as_ref())), + Box::new(self.translate_queries_expr(symbols, machine_name, state_id, se.as_ref())), *exp, (), ), MI(se, _) => MI( - Box::new(self.translate_queries_expr(symbols, machine_id, state_id, se.as_ref())), + Box::new(self.translate_queries_expr(symbols, machine_name, state_id, se.as_ref())), (), ), Halo2Expr(se, _) => Halo2Expr(se.clone(), ()), - Query(id, _) => Query(self.translate_query(symbols, machine_id, state_id, id), ()), + Query(id, _) => Query( + self.translate_query(symbols, machine_name, state_id, id), + (), + ), } } fn translate_query( &mut self, symbols: &SymTable, - machine_id: &str, + machine_name: &str, state_id: &str, id: &Identifier, ) -> Queriable { @@ -425,7 +375,7 @@ impl Compiler { .find_symbol( &[ "/".to_string(), - machine_id.to_string(), + machine_name.to_string(), state_id.to_string(), ], id.name(), @@ -434,17 +384,17 @@ impl Compiler { match symbol.symbol.category { InputSignal | OutputSignal | InoutSignal => { - self.translate_forward_queriable(machine_id, id) + self.translate_forward_queriable(machine_name, id) } Signal => match symbol.scope_cat { - ScopeCategory::Machine => self.translate_forward_queriable(machine_id, id), + ScopeCategory::Machine => self.translate_forward_queriable(machine_name, id), ScopeCategory::State => { if id.rotation() != 0 { unreachable!("semantic analyser should prevent this"); } let signal = self .mapping - .get_internal(&format!("//{}/{}", machine_id, state_id), &id.name()); + .get_internal(&format!("//{}/{}", machine_name, state_id), &id.name()); Queriable::Internal(signal) } @@ -452,16 +402,16 @@ impl Compiler { ScopeCategory::Global => unreachable!("no global signals"), }, - State => { - Queriable::StepTypeNext(self.mapping.get_step_type_handler(machine_id, &id.name())) - } + State => Queriable::StepTypeNext( + self.mapping.get_step_type_handler(machine_name, &id.name()), + ), _ => unreachable!("semantic analysis should prevent this"), } } - fn translate_forward_queriable(&mut self, machine_id: &str, id: &Identifier) -> Queriable { - let forward = self.mapping.get_forward(machine_id, &id.name()); + fn translate_forward_queriable(&mut self, machine_name: &str, id: &Identifier) -> Queriable { + let forward = self.mapping.get_forward(machine_name, &id.name()); let rot = if id.rotation() == 0 { false } else if id.rotation() == 1 { @@ -476,13 +426,13 @@ impl Compiler { fn get_all_internals( &mut self, symbols: &SymTable, - machine_id: &str, + machine_name: &str, state_id: &str, ) -> Vec { let symbols = symbols .get_scope(&[ "/".to_string(), - machine_id.to_string(), + machine_name.to_string(), state_id.to_string(), ]) .expect("scope not found") @@ -496,41 +446,41 @@ impl Compiler { .collect() } - fn add_internals( + fn add_internal_signals( &mut self, - ctx: &mut StepTypeContext, symbols: &SymTable, - machine_id: &str, + machine_name: &str, + sbpir_machine: &mut SBPIRMachine, + step_type: &mut StepType, state_id: &str, ) { - let internal_ids = self.get_all_internals(symbols, machine_id, state_id); - let scope_name = format!("//{}/{}", machine_id, state_id); + let internal_ids = self.get_all_internals(symbols, machine_name, state_id); + let scope_name = format!("//{}/{}", machine_name, state_id); for internal_id in internal_ids { let name = format!("{}:{}", &scope_name, internal_id); + let signal = InternalSignal::new(name); - let queriable = ctx.internal(name.as_str()); - if let Queriable::Internal(signal) = queriable { - self.mapping - .symbol_uuid - .insert((scope_name.clone(), internal_id), signal.uuid()); - self.mapping.internal_signals.insert(signal.uuid(), signal); - } else { - unreachable!("ctx.internal returns not internal signal"); - } + sbpir_machine + .annotations + .insert(signal.uuid(), signal.annotation().to_string()); + + step_type.signals.push(signal); + + self.mapping + .symbol_uuid + .insert((scope_name.clone(), internal_id), signal.uuid()); + self.mapping.internal_signals.insert(signal.uuid(), signal); } } - fn add_step_type_handlers>( + fn add_step_type_handlers( &mut self, - ctx: &mut CircuitContext, + machine: &mut SBPIRMachine, symbols: &SymTable, - machine_id: &str, + machine_name: &str, ) { - let symbols = symbols - .get_scope(&["/".to_string(), machine_id.to_string()]) - .expect("scope not found") - .get_symbols(); + let symbols = get_symbols(symbols, machine_name); let state_ids: Vec<_> = symbols .iter() @@ -540,10 +490,13 @@ impl Compiler { .collect(); for state_id in state_ids { - let scope_name = format!("//{}", machine_id); + let scope_name = format!("//{}", machine_name); let name = format!("{}:{}", scope_name, state_id); - let handler = ctx.step_type(&name); + let handler = StepTypeHandler::new(name.to_string()); + + machine.add_step_type(handler, name); + self.mapping .step_type_handler .insert(handler.uuid(), handler); @@ -553,16 +506,13 @@ impl Compiler { } } - fn add_forwards>( + fn add_forward_signals( &mut self, - ctx: &mut CircuitContext, + machine: &mut SBPIRMachine, symbols: &SymTable, - machine_id: &str, + machine_name: &str, ) { - let symbols = symbols - .get_scope(&["/".to_string(), machine_id.to_string()]) - .expect("scope not found") - .get_symbols(); + let symbols = get_symbols(symbols, machine_name); let forward_ids: Vec<_> = symbols .iter() @@ -572,17 +522,16 @@ impl Compiler { .collect(); for forward_id in forward_ids { - let scope_name = format!("//{}", machine_id); + let scope_name = format!("//{}", machine_name); let name = format!("{}:{}", scope_name, forward_id); - - let queriable = ctx.forward(name.as_str()); + let queriable = Queriable::::Forward(machine.add_forward(name.as_str(), 0), false); if let Queriable::Forward(signal, _) = queriable { self.mapping .symbol_uuid .insert((scope_name, forward_id), signal.uuid()); self.mapping.forward_signals.insert(signal.uuid(), signal); } else { - unreachable!("ctx.internal returns not internal signal"); + unreachable!("Forward queriable should return a forward signal"); } } } @@ -600,29 +549,29 @@ impl Compiler { } } -// Basic signal factory. -#[derive(Default)] -struct SignalFactory { - count: u64, - _p: PhantomData, -} - -impl poly::SignalFactory> for SignalFactory { - fn create>(&mut self, annotation: S) -> Queriable { - self.count += 1; - Queriable::Internal(InternalSignal::new(format!( - "{}-{}", - annotation.into(), - self.count - ))) - } +fn get_symbols<'a>( + symbols: &'a SymTable, + machine_name: &'a str, +) -> &'a HashMap { + let symbols = symbols + .get_scope(&["/".to_string(), machine_name.to_string()]) + .expect("scope not found") + .get_symbols(); + symbols } #[cfg(test)] mod test { + use std::collections::HashMap; + use halo2_proofs::halo2curves::bn256::Fr; + use itertools::Itertools; - use crate::{compiler::compile, parser::ast::debug_sym_factory::DebugSymRefFactory}; + use crate::{ + compiler::{compile, compile_legacy}, + parser::ast::debug_sym_factory::DebugSymRefFactory, + wit_gen::TraceGenerator, + }; use super::Config; @@ -645,7 +594,7 @@ mod test { i, a, b, c <== 1, 1, 1, 2; -> middle { - a', b', n' <== b, c, n; + i', a', b', n' <== i + 1, b, c, n; } } @@ -684,7 +633,7 @@ mod test { i, a, b, c <== 1, 1, 1, 2; -> middle { - a', b', n' <== b, c, n; + i', a', b', n' <== i + 1, b, c, n; } } @@ -726,4 +675,162 @@ mod test { Err(messages) => println!("{:#?}", messages), } } + + #[test] + fn test_is_new_compiler_identical_to_legacy() { + let circuit = " + machine fibo(signal n) (signal b: field) { + // n and be are created automatically as shared + // signals + signal a: field, i; + + // there is always a state called initial + // input signals get bound to the signal + // in the initial state (first instance) + state initial { + signal c; + + i, a, b, c <== 1, 1, 1, 2; + + -> middle { + i', a', b', n' <== i + 1, b, c, n; + } + } + + state middle { + signal c; + + c <== a + b; + + if i + 1 == n { + -> final { + i', b', n' <== i + 1, c, n; + } + } else { + -> middle { + i', a', b', n' <== i + 1, b, c, n; + } + } + } + + // There is always a state called final. + // Output signals get automatically bound to the signals + // with the same name in the final step (last instance). + // This state can be implicit if there are no constraints in it. + } + "; + + let debug_sym_ref_factory = DebugSymRefFactory::new("", circuit); + let result = compile::( + circuit, + Config::default().max_degree(2), + &debug_sym_ref_factory, + ) + .unwrap(); + + let result_legacy = compile_legacy::( + circuit, + Config::default().max_degree(2), + &debug_sym_ref_factory, + ) + .unwrap(); + + let result = result.circuit.machines.get("fibo").unwrap(); + let result_legacy = result_legacy.circuit; + let exposed = &result.exposed; + let exposed_legacy = result_legacy.exposed; + + for exposed in exposed.iter().zip(exposed_legacy.iter()) { + assert_eq!(exposed.0 .0, exposed.1 .0); + assert_eq!(exposed.0 .1, exposed.1 .1); + } + // TODO investigate why new compiler produces extra annotations + // assert_eq!(result.annotations.len(), result_legacy.annotations.len()); + for val in result_legacy.annotations.values() { + assert!(result.annotations.values().contains(val)); + } + + assert_eq!( + result.forward_signals.len(), + result_legacy.forward_signals.len() + ); + for val in result_legacy.forward_signals.iter() { + assert!(result + .forward_signals + .iter() + .find(|x| x.annotation() == val.annotation() && x.phase() == val.phase()) + .is_some()); + } + + assert_eq!(result.shared_signals, result_legacy.shared_signals); + assert_eq!(result.fixed_signals, result_legacy.fixed_signals); + assert_eq!(result.halo2_advice, result_legacy.halo2_advice); + assert_eq!(result.halo2_fixed, result_legacy.halo2_fixed); + assert_eq!(result.step_types.len(), result_legacy.step_types.len()); + for step in result_legacy.step_types.values() { + let name = step.name(); + let step_new = result + .step_types + .iter() + .find(|x| x.1.name() == name) + .unwrap() + .1; + assert_eq!(step_new.signals.len(), step.signals.len()); + for signal in step.signals.iter() { + assert!(step_new + .signals + .iter() + .any(|x| x.annotation() == signal.annotation())); + } + assert_eq!(step_new.constraints.len(), step.constraints.len()); + for constraint in step.constraints.iter() { + assert!(step_new + .constraints + .iter() + .any(|x| x.annotation == constraint.annotation)); + } + assert_eq!(step_new.lookups.len() == 0, step.lookups.len() == 0); + assert_eq!( + step_new.auto_signals.len() == 0, + step.auto_signals.len() == 0 + ); + assert_eq!( + step_new.transition_constraints.len() == 0, + step.transition_constraints.len() == 0 + ); + // TODO investigate why new compiler produces extra annotations + // assert_eq!(step_new.annotations.len(), step.annotations.len()); + } + + assert_eq!( + result.first_step.is_some(), + result_legacy.first_step.is_some() + ); + assert_eq!( + result.last_step.is_some(), + result_legacy.last_step.is_some() + ); + assert_eq!(result.num_steps, result_legacy.num_steps); + assert_eq!(result.q_enable, result_legacy.q_enable); + + let tg_new = result.trace_generator.as_ref().unwrap(); + let tg_legacy = result_legacy.trace_generator.unwrap(); + + // Check if the witness values of the new compiler are the same as the legacy compiler + let res = tg_new.generate(HashMap::from([("n".to_string(), Fr::from(12))])); + let res_legacy = tg_legacy.generate(HashMap::from([("n".to_string(), Fr::from(12))])); + assert_eq!(res.step_instances.len(), res_legacy.step_instances.len()); + for (step, step_legacy) in res.step_instances.iter().zip(res_legacy.step_instances) { + assert_eq!(step.assignments.len(), step_legacy.assignments.len()); + for assignment in step.assignments.iter() { + let assignment_legacy = step_legacy + .assignments + .iter() + .find(|x| x.0.annotation() == assignment.0.annotation()) + .unwrap(); + assert_eq!(assignment.0.annotation(), assignment_legacy.0.annotation()); + assert!(assignment.1.eq(&assignment_legacy.1)); + } + } + } } diff --git a/src/compiler/compiler_legacy.rs b/src/compiler/compiler_legacy.rs index 09ff201c..66bfc4d6 100644 --- a/src/compiler/compiler_legacy.rs +++ b/src/compiler/compiler_legacy.rs @@ -129,6 +129,7 @@ impl CompilerLegacy { } } + /// Adds "virtual" states to the AST (necessary to handle padding) fn add_virtual( &mut self, mut ast: Vec>, @@ -236,6 +237,8 @@ impl CompilerLegacy { } } + /// Semantic analysis of the AST + /// Returns the symbol table if successful fn semantic(&mut self, ast: &[TLDecl]) -> Result { let result = super::semantic::analyser::analyse(ast); let has_errors = result.messages.has_errors(); @@ -258,7 +261,7 @@ impl CompilerLegacy { .iter() .map(|(machine_id, machine)| { let poly_constraints: HashMap>> = machine - .iter_states_poly_constraints() + .poly_constraints_iter() .map(|(step_id, step)| { let new_step: Vec> = step.iter().map(|pi| Self::map_pi_consts(pi)).collect(); diff --git a/src/compiler/setup_inter.rs b/src/compiler/setup_inter.rs index 53af8721..9b2b3e24 100644 --- a/src/compiler/setup_inter.rs +++ b/src/compiler/setup_inter.rs @@ -4,6 +4,7 @@ use itertools::Itertools; use num_bigint::BigInt; use crate::{ + field::Field, parser::ast::{ statement::{Statement, TypedIdDecl}, tl::TLDecl, @@ -46,6 +47,50 @@ impl Default for MachineSetup { } } } +impl MachineSetup { + pub(crate) fn map_consts(&self) -> MachineSetup { + let poly_constraints: HashMap>> = self + .poly_constraints_iter() + .map(|(step_id, step)| { + let new_step: Vec> = step + .iter() + .map(|pi| Self::convert_const_to_field(pi)) + .collect(); + + (step_id.clone(), new_step) + }) + .collect(); + + let new_machine: MachineSetup = self.replace_poly_constraints(poly_constraints); + new_machine + } + + fn convert_const_to_field( + expr: &Expr, + ) -> Expr { + use Expr::*; + match expr { + Const(v, _) => Const(F::from_big_int(v), ()), + Sum(ses, _) => Sum( + ses.iter() + .map(|se| Self::convert_const_to_field(se)) + .collect(), + (), + ), + Mul(ses, _) => Mul( + ses.iter() + .map(|se| Self::convert_const_to_field(se)) + .collect(), + (), + ), + Neg(se, _) => Neg(Box::new(Self::convert_const_to_field(se)), ()), + Pow(se, exp, _) => Pow(Box::new(Self::convert_const_to_field(se)), *exp, ()), + Query(q, _) => Query(q.clone(), ()), + Halo2Expr(_, _) => todo!(), + MI(se, _) => MI(Box::new(Self::convert_const_to_field(se)), ()), + } + } +} impl MachineSetup { fn new( @@ -88,7 +133,7 @@ impl MachineSetup { .extend(poly_constraints); } - pub(super) fn iter_states_poly_constraints( + pub(super) fn poly_constraints_iter( &self, ) -> std::collections::hash_map::Iter>> { self.poly_constraints.iter() diff --git a/src/frontend/dsl/mod.rs b/src/frontend/dsl/mod.rs index 3bda965f..77f91933 100644 --- a/src/frontend/dsl/mod.rs +++ b/src/frontend/dsl/mod.rs @@ -1,21 +1,15 @@ use crate::{ - field::Field, - sbpir::{ - query::Queriable, sbpir_machine::SBPIRMachine, ExposeOffset, StepType, StepTypeUUID, PIR, - }, + sbpir::{query::Queriable, StepType, StepTypeUUID, PIR}, util::{uuid, UUID}, - wit_gen::{FixedGenContext, StepInstance, TraceGenerator}, + wit_gen::StepInstance, }; -use halo2_proofs::plonk::{Advice, Column as Halo2Column, Fixed}; -use trace::{DSLTraceGenerator, TraceContext}; - use core::{fmt::Debug, hash::Hash}; use std::marker::PhantomData; use self::{ cb::{eq, Constraint, Typing}, - lb::{LookupBuilder, LookupTable, LookupTableRegistry, LookupTableStore}, + lb::{LookupBuilder, LookupTableRegistry}, }; pub use sc::*; @@ -26,177 +20,6 @@ pub mod lb; pub mod sc; pub mod trace; -#[derive(Debug, Default)] -/// A generic structure designed to handle the context of a circuit. -/// The struct contains a `Circuit` instance and implements methods to build the circuit, -/// add various components, and manipulate the circuit. -/// -/// ### Type parameters -/// `F` is the field of the circuit. -/// `TG` is the trace generator. -pub struct CircuitContext = DSLTraceGenerator> { - circuit: SBPIRMachine, - tables: LookupTableRegistry, -} - -impl> CircuitContext { - /// Adds a forward signal to the circuit with a name string and zero rotation and returns a - /// `Queriable` instance representing the added forward signal. - pub fn forward(&mut self, name: &str) -> Queriable { - Queriable::Forward(self.circuit.add_forward(name, 0), false) - } - - /// Adds a forward signal to the circuit with a name string and a specified phase and returns a - /// `Queriable` instance representing the added forward signal. - pub fn forward_with_phase(&mut self, name: &str, phase: usize) -> Queriable { - Queriable::Forward(self.circuit.add_forward(name, phase), false) - } - - /// Adds a shared signal to the circuit with a name string and zero rotation and returns a - /// `Queriable` instance representing the added shared signal. - pub fn shared(&mut self, name: &str) -> Queriable { - Queriable::Shared(self.circuit.add_shared(name, 0), 0) - } - - /// Adds a shared signal to the circuit with a name string and a specified phase and returns a - /// `Queriable` instance representing the added shared signal. - pub fn shared_with_phase(&mut self, name: &str, phase: usize) -> Queriable { - Queriable::Shared(self.circuit.add_shared(name, phase), 0) - } - - pub fn fixed(&mut self, name: &str) -> Queriable { - Queriable::Fixed(self.circuit.add_fixed(name), 0) - } - - /// Exposes the first step instance value of a forward signal as public. - pub fn expose(&mut self, queriable: Queriable, offset: ExposeOffset) { - self.circuit.expose(queriable, offset); - } - - /// Imports a halo2 advice column with a name string into the circuit and returns a - /// `Queriable` instance representing the imported column. - pub fn import_halo2_advice(&mut self, name: &str, column: Halo2Column) -> Queriable { - Queriable::Halo2AdviceQuery(self.circuit.add_halo2_advice(name, column), 0) - } - - /// Imports a halo2 fixed column with a name string into the circuit and returns a - /// `Queriable` instance representing the imported column. - pub fn import_halo2_fixed(&mut self, name: &str, column: Halo2Column) -> Queriable { - Queriable::Halo2FixedQuery(self.circuit.add_halo2_fixed(name, column), 0) - } - - /// Adds a new step type with the specified name to the circuit and returns a - /// `StepTypeHandler` instance. The `StepTypeHandler` instance can be used to define the - /// step type using the `step_type_def` function. - pub fn step_type(&mut self, name: &str) -> StepTypeHandler { - let handler = StepTypeHandler::new(name.to_string()); - - self.circuit.add_step_type(handler, name); - - handler - } - - /// Defines a step type using the provided `StepTypeHandler` and a function that takes a - /// mutable reference to a `StepTypeContext`. This function typically adds constraints to a - /// step type and defines witness generation. - pub fn step_type_def, R>( - &mut self, - step: S, - def: D, - ) -> StepTypeWGHandler - where - D: FnOnce(&mut StepTypeContext) -> StepTypeWGHandler, - R: Fn(&mut StepInstance, Args) + 'static, - { - let handler: StepTypeHandler = match step.into() { - StepTypeDefInput::Handler(h) => h, - StepTypeDefInput::String(name) => { - let handler = StepTypeHandler::new(name.to_string()); - - self.circuit.add_step_type(handler, name); - - handler - } - }; - - let mut context = StepTypeContext::::new( - handler.uuid(), - handler.annotation.to_string(), - self.tables.clone(), - ); - - let result = def(&mut context); - - self.circuit.add_step_type_def(context.step_type); - - result - } - - pub fn new_table(&self, table: LookupTableStore) -> LookupTable { - let uuid = table.uuid(); - self.tables.add(table); - - LookupTable { uuid } - } - - /// Enforce the type of the first step by adding a constraint to the circuit. Takes a - /// `StepTypeHandler` parameter that represents the step type. - pub fn pragma_first_step>(&mut self, step_type: STH) { - self.circuit.first_step = Some(step_type.into().uuid()); - } - - /// Enforce the type of the last step by adding a constraint to the circuit. Takes a - /// `StepTypeHandler` parameter that represents the step type. - pub fn pragma_last_step>(&mut self, step_type: STH) { - self.circuit.last_step = Some(step_type.into().uuid()); - } - - /// Enforce the number of step instances by adding a constraint to the circuit. Takes a `usize` - /// parameter that represents the total number of steps. - pub fn pragma_num_steps(&mut self, num_steps: usize) { - self.circuit.num_steps = num_steps; - } - - pub fn pragma_disable_q_enable(&mut self) { - self.circuit.q_enable = false; - } -} - -impl CircuitContext> { - /// Sets the trace function that builds the witness. The trace function is responsible for - /// adding step instances defined in `step_type_def`. The function is entirely left for - /// the user to implement and is Turing complete. Users typically use external parameters - /// of type `TraceArgs` to generate cell values for witness generation, and call the - /// `add` function to add step instances with witness values. - pub fn trace(&mut self, def: D) - where - D: Fn(&mut TraceContext, TraceArgs) + 'static, - { - self.circuit.set_trace(def); - } -} - -impl> CircuitContext { - /// Executes the fixed generation function provided by the user and sets the fixed assignments - /// for the circuit. The fixed generation function is responsible for assigning fixed values to - /// fixed columns. It is entirely left for the user to implement and is Turing complete. Users - /// typically generate cell values and call the `assign` function to fill the fixed columns. - pub fn fixed_gen(&mut self, def: D) - where - D: Fn(&mut FixedGenContext) + 'static, - { - if self.circuit.num_steps == 0 { - panic!("circuit must call pragma_num_steps before calling fixed_gen"); - } - let mut ctx = FixedGenContext::new(self.circuit.num_steps); - (def)(&mut ctx); - - let assignments = ctx.get_assignments(); - - self.circuit.set_fixed_assignments(assignments); - } -} - pub enum StepTypeDefInput { Handler(StepTypeHandler), String(&'static str), @@ -358,7 +181,7 @@ pub struct StepTypeHandler { } impl StepTypeHandler { - fn new(annotation: String) -> Self { + pub(crate) fn new(annotation: String) -> Self { Self { id: uuid(), annotation: Box::leak(annotation.into_boxed_str()), @@ -419,37 +242,15 @@ impl, Args) + 'static> StepTypeWGHandler( - _name: &str, - mut def: D, -) -> SBPIRMachine> -where - D: FnMut(&mut CircuitContext>), -{ - // TODO annotate circuit - let mut context = CircuitContext { - circuit: SBPIRMachine::default(), - tables: LookupTableRegistry::default(), - }; - - def(&mut context); - - context.circuit -} - #[cfg(test)] mod tests { use circuit_context_legacy::CircuitContextLegacy; use halo2_proofs::halo2curves::bn256::Fr; + use trace::DSLTraceGenerator; use crate::{ - sbpir::{ForwardSignal, SBPIRLegacy}, - wit_gen::NullTraceGenerator, + sbpir::{ExposeOffset, ForwardSignal, SBPIRLegacy}, + wit_gen::{NullTraceGenerator, TraceGenerator}, }; use super::*; diff --git a/src/poly/mielim.rs b/src/poly/mielim.rs index 96e7cfeb..a9e1e888 100644 --- a/src/poly/mielim.rs +++ b/src/poly/mielim.rs @@ -3,8 +3,8 @@ use std::{fmt::Debug, hash::Hash}; use super::{ConstrDecomp, Expr, SignalFactory}; use crate::field::Field; -/// This function eliminates MI operators from the PI expression, by creating new signals that are -/// constraint to the MI sub-expressions. +/// This function eliminates MI operators from the PI expression by creating new signals that are +/// constrained to the MI sub-expressions. pub fn mi_elimination>( constr: Expr, signal_factory: &mut SF, diff --git a/src/sbpir/mod.rs b/src/sbpir/mod.rs index 7c5076f5..0335345f 100644 --- a/src/sbpir/mod.rs +++ b/src/sbpir/mod.rs @@ -1,7 +1,7 @@ pub mod query; pub mod sbpir_machine; -use std::{collections::HashMap, fmt::Debug, hash::Hash, rc::Rc}; +use std::{collections::HashMap, fmt::Debug, hash::Hash, marker::PhantomData, rc::Rc}; use crate::{ field::Field, @@ -9,7 +9,7 @@ use crate::{ trace::{DSLTraceGenerator, TraceContext}, StepTypeHandler, }, - poly::{ConstrDecomp, Expr}, + poly::{self, mielim::mi_elimination, reduce::reduce_degree, ConstrDecomp, Expr}, util::{uuid, UUID}, wit_gen::{FixedAssignment, FixedGenContext, NullTraceGenerator, TraceGenerator}, }; @@ -275,7 +275,7 @@ pub struct SBPIR = DSLTraceGenerator> { pub identifiers: HashMap, } -impl> SBPIR { +impl> SBPIR { pub(crate) fn default() -> SBPIR { let machines = HashMap::new(); let identifiers = HashMap::new(); @@ -301,8 +301,62 @@ impl> SBPIR { } } - pub(crate) fn add_machine(&mut self, name: &str, without_trace: SBPIRMachine) { - self.machines.insert(name.to_string(), without_trace); + pub(crate) fn without_trace(&self) -> SBPIR { + let mut machines_without_trace = HashMap::new(); + for (name, machine) in self.machines.iter() { + let machine_without_trace = machine.without_trace(); + machines_without_trace.insert(name.clone(), machine_without_trace); + } + SBPIR { + machines: machines_without_trace, + identifiers: self.identifiers.clone(), + } + } + + /// Eliminate multiplicative inverses + pub(crate) fn eliminate_mul_inv(mut self) -> SBPIR { + for machine in self.machines.values_mut() { + for (_, step_type) in machine.step_types.iter_mut() { + let mut signal_factory = SignalFactory::default(); + + step_type + .decomp_constraints(|expr| mi_elimination(expr.clone(), &mut signal_factory)); + } + } + + self + } + + pub(crate) fn reduce(mut self, degree: usize) -> SBPIR { + for machine in self.machines.values_mut() { + for (_, step_type) in machine.step_types.iter_mut() { + let mut signal_factory = SignalFactory::default(); + + step_type.decomp_constraints(|expr| { + reduce_degree(expr.clone(), degree, &mut signal_factory) + }); + } + } + + self + } +} + +// Basic signal factory. +#[derive(Default)] +struct SignalFactory { + count: u64, + _p: PhantomData, +} + +impl poly::SignalFactory> for SignalFactory { + fn create>(&mut self, annotation: S) -> Queriable { + self.count += 1; + Queriable::Internal(InternalSignal::new(format!( + "{}-{}", + annotation.into(), + self.count + ))) } } @@ -667,7 +721,7 @@ impl FixedSignal { } } -#[derive(Clone, Copy, Debug)] +#[derive(Clone, Copy, Debug, PartialEq)] pub enum ExposeOffset { First, Last, diff --git a/src/sbpir/query.rs b/src/sbpir/query.rs index 0f0b727d..ab6d7efb 100644 --- a/src/sbpir/query.rs +++ b/src/sbpir/query.rs @@ -21,11 +21,26 @@ use super::PIR; #[derive(Clone, Copy, PartialEq, Eq, Hash)] pub enum Queriable { Internal(InternalSignal), + /// Forward signal + /// - `ForwardSignal` is the signal to be queried + /// - `bool` is the rotation state of the signal (true if rotated) Forward(ForwardSignal, bool), + /// Shared signal + /// - `SharedSignal` is the signal to be queried + /// - `i32` is the rotation value Shared(SharedSignal, i32), + /// Fixed signal + /// - `FixedSignal` is the signal to be queried + /// - `i32` is the rotation value Fixed(FixedSignal, i32), StepTypeNext(StepTypeHandler), + /// Imported Halo2 advice query + /// - `ImportedHalo2Advice` is the signal to be queried + /// - `i32` is the rotation value Halo2AdviceQuery(ImportedHalo2Advice, i32), + /// Imported Halo2 fixed query + /// - `ImportedHalo2Fixed` is the signal to be queried + /// - `i32` is the rotation value Halo2FixedQuery(ImportedHalo2Fixed, i32), #[allow(non_camel_case_types)] _unaccessible(PhantomData), @@ -38,7 +53,7 @@ impl Debug for Queriable { } impl Queriable { - /// Call `next` function on a `Querible` forward signal to build constraints for forward + /// Call `next` function on a `Queriable` forward signal to build constraints for forward /// signal with rotation. Cannot be called on an internal signal and must be used within a /// `transition` constraint. Returns a new `Queriable` forward signal with rotation. pub fn next(&self) -> Queriable { diff --git a/src/sbpir/sbpir_machine.rs b/src/sbpir/sbpir_machine.rs index 0b79300c..3a4a6493 100644 --- a/src/sbpir/sbpir_machine.rs +++ b/src/sbpir/sbpir_machine.rs @@ -22,6 +22,7 @@ pub struct SBPIRMachine = DSLTraceGenerator> pub step_types: HashMap>, pub forward_signals: Vec, + // TODO currently not used pub shared_signals: Vec, pub fixed_signals: Vec, pub halo2_advice: Vec, @@ -31,6 +32,7 @@ pub struct SBPIRMachine = DSLTraceGenerator> pub annotations: HashMap, pub trace_generator: Option, + // TODO currently not used pub fixed_assignments: Option>, pub first_step: Option, @@ -109,6 +111,7 @@ impl> SBPIRMachine { signal } + // TODO currently not used pub fn add_fixed>(&mut self, name: N) -> FixedSignal { let name = name.into(); let signal = FixedSignal::new(name.clone()); @@ -119,6 +122,7 @@ impl> SBPIRMachine { signal } + // TODO currently not used pub fn expose(&mut self, signal: Queriable, offset: ExposeOffset) { match signal { Queriable::Forward(..) | Queriable::Shared(..) => { From ab26d1b60588a0825c87630c275bbd3fd0d164a4 Mon Sep 17 00:00:00 2001 From: Alex Kuzmin Date: Tue, 20 Aug 2024 17:06:39 +0800 Subject: [PATCH 4/9] Fix clippy lints --- src/compiler/compiler.rs | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/src/compiler/compiler.rs b/src/compiler/compiler.rs index e261870a..77f114b9 100644 --- a/src/compiler/compiler.rs +++ b/src/compiler/compiler.rs @@ -758,8 +758,7 @@ mod test { assert!(result .forward_signals .iter() - .find(|x| x.annotation() == val.annotation() && x.phase() == val.phase()) - .is_some()); + .any(|x| x.annotation() == val.annotation() && x.phase() == val.phase())); } assert_eq!(result.shared_signals, result_legacy.shared_signals); @@ -789,14 +788,14 @@ mod test { .iter() .any(|x| x.annotation == constraint.annotation)); } - assert_eq!(step_new.lookups.len() == 0, step.lookups.len() == 0); + assert_eq!(step_new.lookups.is_empty(), step.lookups.is_empty()); assert_eq!( - step_new.auto_signals.len() == 0, - step.auto_signals.len() == 0 + step_new.auto_signals.is_empty(), + step.auto_signals.is_empty() ); assert_eq!( - step_new.transition_constraints.len() == 0, - step.transition_constraints.len() == 0 + step_new.transition_constraints.is_empty(), + step.transition_constraints.is_empty() ); // TODO investigate why new compiler produces extra annotations // assert_eq!(step_new.annotations.len(), step.annotations.len()); @@ -829,7 +828,7 @@ mod test { .find(|x| x.0.annotation() == assignment.0.annotation()) .unwrap(); assert_eq!(assignment.0.annotation(), assignment_legacy.0.annotation()); - assert!(assignment.1.eq(&assignment_legacy.1)); + assert!(assignment.1.eq(assignment_legacy.1)); } } } From 56324e5f36b78d0d7766f5eff8878051431e8d74 Mon Sep 17 00:00:00 2001 From: Alex Kuzmin Date: Tue, 20 Aug 2024 17:50:43 +0800 Subject: [PATCH 5/9] Remove unnecessary annotation --- src/compiler/compiler.rs | 27 +++++---------------------- 1 file changed, 5 insertions(+), 22 deletions(-) diff --git a/src/compiler/compiler.rs b/src/compiler/compiler.rs index 77f114b9..394adbda 100644 --- a/src/compiler/compiler.rs +++ b/src/compiler/compiler.rs @@ -17,9 +17,7 @@ use crate::{ lang::TLDeclsParser, }, poly::Expr, - sbpir::{ - query::Queriable, sbpir_machine::SBPIRMachine, Constraint, InternalSignal, StepType, SBPIR, - }, + sbpir::{query::Queriable, sbpir_machine::SBPIRMachine, Constraint, StepType, SBPIR}, wit_gen::{NullTraceGenerator, SymbolSignalMapping}, }; @@ -267,13 +265,7 @@ impl Compiler { let mut step_type = StepType::new(handler.uuid(), handler.annotation.to_string()); - self.add_internal_signals( - symbols, - machine_name, - &mut sbpir_machine, - &mut step_type, - state_id, - ); + self.add_internal_signals(symbols, machine_name, &mut step_type, state_id); let poly_constraints = self.translate_queries(symbols, setup, machine_name, state_id); @@ -450,7 +442,6 @@ impl Compiler { &mut self, symbols: &SymTable, machine_name: &str, - sbpir_machine: &mut SBPIRMachine, step_type: &mut StepType, state_id: &str, ) { @@ -459,13 +450,7 @@ impl Compiler { for internal_id in internal_ids { let name = format!("{}:{}", &scope_name, internal_id); - let signal = InternalSignal::new(name); - - sbpir_machine - .annotations - .insert(signal.uuid(), signal.annotation().to_string()); - - step_type.signals.push(signal); + let signal = step_type.add_signal(name.as_str()); self.mapping .symbol_uuid @@ -744,8 +729,7 @@ mod test { assert_eq!(exposed.0 .0, exposed.1 .0); assert_eq!(exposed.0 .1, exposed.1 .1); } - // TODO investigate why new compiler produces extra annotations - // assert_eq!(result.annotations.len(), result_legacy.annotations.len()); + assert_eq!(result.annotations.len(), result_legacy.annotations.len()); for val in result_legacy.annotations.values() { assert!(result.annotations.values().contains(val)); } @@ -797,8 +781,7 @@ mod test { step_new.transition_constraints.is_empty(), step.transition_constraints.is_empty() ); - // TODO investigate why new compiler produces extra annotations - // assert_eq!(step_new.annotations.len(), step.annotations.len()); + assert_eq!(step_new.annotations.len(), step.annotations.len()); } assert_eq!( From 74541b95267f9dfe2a544232fd73f5363542e713 Mon Sep 17 00:00:00 2001 From: Alex Kuzmin Date: Tue, 20 Aug 2024 17:59:00 +0800 Subject: [PATCH 6/9] Restructure changes to reduce the diff --- examples/blake2f.rs | 3 +- examples/factorial.rs | 6 +- examples/fibo_with_padding.rs | 6 +- examples/fibonacci.rs | 6 +- examples/keccak.rs | 4 +- examples/mimc7.rs | 3 +- examples/poseidon.rs | 3 +- src/compiler/compiler_legacy.rs | 3 +- src/frontend/dsl/circuit_context_legacy.rs | 209 --------------------- src/frontend/dsl/mod.rs | 204 +++++++++++++++++++- src/frontend/dsl/sc.rs | 6 +- 11 files changed, 217 insertions(+), 236 deletions(-) delete mode 100644 src/frontend/dsl/circuit_context_legacy.rs diff --git a/examples/blake2f.rs b/examples/blake2f.rs index 0e2fdb46..187c3c45 100644 --- a/examples/blake2f.rs +++ b/examples/blake2f.rs @@ -1,11 +1,10 @@ use chiquito::{ frontend::dsl::{ cb::{eq, select, table}, - circuit_context_legacy::CircuitContextLegacy, lb::LookupTable, super_circuit, trace::DSLTraceGenerator, - StepTypeSetupContext, StepTypeWGHandler, + CircuitContextLegacy, StepTypeSetupContext, StepTypeWGHandler, }, plonkish::{ backend::halo2_legacy::{chiquitoSuperCircuit2Halo2, ChiquitoHalo2SuperCircuit}, diff --git a/examples/factorial.rs b/examples/factorial.rs index 2e0e6831..37196e11 100644 --- a/examples/factorial.rs +++ b/examples/factorial.rs @@ -2,9 +2,9 @@ use std::hash::Hash; use chiquito::{ field::Field, - frontend::dsl::{circuit_context_legacy::circuit_legacy, trace::DSLTraceGenerator}, /* main function for constructing - * an AST - * circuit */ + frontend::dsl::{circuit_legacy, trace::DSLTraceGenerator}, /* main function for constructing + * an AST + * circuit */ plonkish::{ backend::halo2_legacy::{chiquito2Halo2, ChiquitoHalo2Circuit}, compiler::{ diff --git a/examples/fibo_with_padding.rs b/examples/fibo_with_padding.rs index 9c383eca..2c6df8b3 100644 --- a/examples/fibo_with_padding.rs +++ b/examples/fibo_with_padding.rs @@ -2,9 +2,9 @@ use std::hash::Hash; use chiquito::{ field::Field, - frontend::dsl::{circuit_context_legacy::circuit_legacy, trace::DSLTraceGenerator}, /* main function for constructing - * an AST - * circuit */ + frontend::dsl::{circuit_legacy, trace::DSLTraceGenerator}, /* main function for constructing + * an AST + * circuit */ plonkish::{ backend::halo2_legacy::{chiquito2Halo2, ChiquitoHalo2Circuit}, compiler::{ diff --git a/examples/fibonacci.rs b/examples/fibonacci.rs index 4cf87607..e83c3fd4 100644 --- a/examples/fibonacci.rs +++ b/examples/fibonacci.rs @@ -2,9 +2,9 @@ use std::hash::Hash; use chiquito::{ field::Field, - frontend::dsl::{circuit_context_legacy::circuit_legacy, trace::DSLTraceGenerator}, /* main function for constructing - * an AST - * circuit */ + frontend::dsl::{circuit_legacy, trace::DSLTraceGenerator}, /* main function for constructing + * an AST + * circuit */ plonkish::{ backend::{ halo2_legacy::{chiquito2Halo2, ChiquitoHalo2Circuit}, diff --git a/examples/keccak.rs b/examples/keccak.rs index fb6c4802..10d66a37 100644 --- a/examples/keccak.rs +++ b/examples/keccak.rs @@ -1,7 +1,7 @@ use chiquito::{ frontend::dsl::{ - circuit_context_legacy::CircuitContextLegacy, lb::LookupTable, super_circuit, - trace::DSLTraceGenerator, StepTypeWGHandler, + lb::LookupTable, super_circuit, trace::DSLTraceGenerator, CircuitContextLegacy, + StepTypeWGHandler, }, plonkish::{ backend::halo2_legacy::{chiquitoSuperCircuit2Halo2, ChiquitoHalo2SuperCircuit}, diff --git a/examples/mimc7.rs b/examples/mimc7.rs index 9fed5b5d..2dab9976 100644 --- a/examples/mimc7.rs +++ b/examples/mimc7.rs @@ -7,8 +7,7 @@ use halo2_proofs::{ use chiquito::{ frontend::dsl::{ - circuit_context_legacy::CircuitContextLegacy, lb::LookupTable, super_circuit, - trace::DSLTraceGenerator, + lb::LookupTable, super_circuit, trace::DSLTraceGenerator, CircuitContextLegacy, }, plonkish::{ backend::halo2_legacy::{chiquitoSuperCircuit2Halo2, ChiquitoHalo2SuperCircuit}, diff --git a/examples/poseidon.rs b/examples/poseidon.rs index d686b049..43e7174d 100644 --- a/examples/poseidon.rs +++ b/examples/poseidon.rs @@ -1,7 +1,6 @@ use chiquito::{ frontend::dsl::{ - circuit_context_legacy::CircuitContextLegacy, lb::LookupTable, super_circuit, - trace::DSLTraceGenerator, + lb::LookupTable, super_circuit, trace::DSLTraceGenerator, CircuitContextLegacy, }, plonkish::{ backend::halo2_legacy::{chiquitoSuperCircuit2Halo2, ChiquitoHalo2SuperCircuit}, diff --git a/src/compiler/compiler_legacy.rs b/src/compiler/compiler_legacy.rs index 66bfc4d6..c7bed275 100644 --- a/src/compiler/compiler_legacy.rs +++ b/src/compiler/compiler_legacy.rs @@ -6,8 +6,7 @@ use crate::{ field::Field, frontend::dsl::{ cb::{Constraint, Typing}, - circuit_context_legacy::{circuit_legacy, CircuitContextLegacy}, - StepTypeContext, + circuit_legacy, CircuitContextLegacy, StepTypeContext, }, interpreter::InterpreterTraceGenerator, parser::{ diff --git a/src/frontend/dsl/circuit_context_legacy.rs b/src/frontend/dsl/circuit_context_legacy.rs deleted file mode 100644 index 082387c9..00000000 --- a/src/frontend/dsl/circuit_context_legacy.rs +++ /dev/null @@ -1,209 +0,0 @@ -use crate::{ - field::Field, - sbpir::{query::Queriable, ExposeOffset, SBPIRLegacy}, - wit_gen::{FixedGenContext, StepInstance, TraceGenerator}, -}; - -use halo2_proofs::plonk::{Advice, Column as Halo2Column, Fixed}; - -use core::{fmt::Debug, hash::Hash}; - -use super::{ - lb::{LookupTable, LookupTableRegistry, LookupTableStore}, - trace::{DSLTraceGenerator, TraceContext}, - StepTypeContext, StepTypeDefInput, StepTypeHandler, StepTypeWGHandler, -}; - -#[derive(Debug, Default)] -/// A generic structure designed to handle the context of a circuit. -/// The struct contains a `Circuit` instance and implements methods to build the circuit, -/// add various components, and manipulate the circuit. -/// -/// ### Type parameters -/// `F` is the field of the circuit. -/// `TG` is the trace generator. -/// -/// (LEGACY) -pub struct CircuitContextLegacy = DSLTraceGenerator> { - pub(crate) circuit: SBPIRLegacy, - pub(crate) tables: LookupTableRegistry, -} - -impl> CircuitContextLegacy { - /// Adds a forward signal to the circuit with a name string and zero rotation and returns a - /// `Queriable` instance representing the added forward signal. - pub fn forward(&mut self, name: &str) -> Queriable { - Queriable::Forward(self.circuit.add_forward(name, 0), false) - } - - /// Adds a forward signal to the circuit with a name string and a specified phase and returns a - /// `Queriable` instance representing the added forward signal. - pub fn forward_with_phase(&mut self, name: &str, phase: usize) -> Queriable { - Queriable::Forward(self.circuit.add_forward(name, phase), false) - } - - /// Adds a shared signal to the circuit with a name string and zero rotation and returns a - /// `Queriable` instance representing the added shared signal. - pub fn shared(&mut self, name: &str) -> Queriable { - Queriable::Shared(self.circuit.add_shared(name, 0), 0) - } - - /// Adds a shared signal to the circuit with a name string and a specified phase and returns a - /// `Queriable` instance representing the added shared signal. - pub fn shared_with_phase(&mut self, name: &str, phase: usize) -> Queriable { - Queriable::Shared(self.circuit.add_shared(name, phase), 0) - } - - pub fn fixed(&mut self, name: &str) -> Queriable { - Queriable::Fixed(self.circuit.add_fixed(name), 0) - } - - /// Exposes the first step instance value of a forward signal as public. - pub fn expose(&mut self, queriable: Queriable, offset: ExposeOffset) { - self.circuit.expose(queriable, offset); - } - - /// Imports a halo2 advice column with a name string into the circuit and returns a - /// `Queriable` instance representing the imported column. - pub fn import_halo2_advice(&mut self, name: &str, column: Halo2Column) -> Queriable { - Queriable::Halo2AdviceQuery(self.circuit.add_halo2_advice(name, column), 0) - } - - /// Imports a halo2 fixed column with a name string into the circuit and returns a - /// `Queriable` instance representing the imported column. - pub fn import_halo2_fixed(&mut self, name: &str, column: Halo2Column) -> Queriable { - Queriable::Halo2FixedQuery(self.circuit.add_halo2_fixed(name, column), 0) - } - - /// Adds a new step type with the specified name to the circuit and returns a - /// `StepTypeHandler` instance. The `StepTypeHandler` instance can be used to define the - /// step type using the `step_type_def` function. - pub fn step_type(&mut self, name: &str) -> StepTypeHandler { - let handler = StepTypeHandler::new(name.to_string()); - - self.circuit.add_step_type(handler, name); - - handler - } - - /// Defines a step type using the provided `StepTypeHandler` and a function that takes a - /// mutable reference to a `StepTypeContext`. This function typically adds constraints to a - /// step type and defines witness generation. - pub fn step_type_def, R>( - &mut self, - step: S, - def: D, - ) -> StepTypeWGHandler - where - D: FnOnce(&mut StepTypeContext) -> StepTypeWGHandler, - R: Fn(&mut StepInstance, Args) + 'static, - { - let handler: StepTypeHandler = match step.into() { - StepTypeDefInput::Handler(h) => h, - StepTypeDefInput::String(name) => { - let handler = StepTypeHandler::new(name.to_string()); - - self.circuit.add_step_type(handler, name); - - handler - } - }; - - let mut context = StepTypeContext::::new( - handler.uuid(), - handler.annotation.to_string(), - self.tables.clone(), - ); - - let result = def(&mut context); - - self.circuit.add_step_type_def(context.step_type); - - result - } - - pub fn new_table(&self, table: LookupTableStore) -> LookupTable { - let uuid = table.uuid(); - self.tables.add(table); - - LookupTable { uuid } - } - - /// Enforce the type of the first step by adding a constraint to the circuit. Takes a - /// `StepTypeHandler` parameter that represents the step type. - pub fn pragma_first_step>(&mut self, step_type: STH) { - self.circuit.first_step = Some(step_type.into().uuid()); - } - - /// Enforce the type of the last step by adding a constraint to the circuit. Takes a - /// `StepTypeHandler` parameter that represents the step type. - pub fn pragma_last_step>(&mut self, step_type: STH) { - self.circuit.last_step = Some(step_type.into().uuid()); - } - - /// Enforce the number of step instances by adding a constraint to the circuit. Takes a `usize` - /// parameter that represents the total number of steps. - pub fn pragma_num_steps(&mut self, num_steps: usize) { - self.circuit.num_steps = num_steps; - } - - pub fn pragma_disable_q_enable(&mut self) { - self.circuit.q_enable = false; - } -} - -impl CircuitContextLegacy> { - /// Sets the trace function that builds the witness. The trace function is responsible for - /// adding step instances defined in `step_type_def`. The function is entirely left for - /// the user to implement and is Turing complete. Users typically use external parameters - /// of type `TraceArgs` to generate cell values for witness generation, and call the - /// `add` function to add step instances with witness values. - pub fn trace(&mut self, def: D) - where - D: Fn(&mut TraceContext, TraceArgs) + 'static, - { - self.circuit.set_trace(def); - } -} - -impl> CircuitContextLegacy { - /// Executes the fixed generation function provided by the user and sets the fixed assignments - /// for the circuit. The fixed generation function is responsible for assigning fixed values to - /// fixed columns. It is entirely left for the user to implement and is Turing complete. Users - /// typically generate cell values and call the `assign` function to fill the fixed columns. - pub fn fixed_gen(&mut self, def: D) - where - D: Fn(&mut FixedGenContext) + 'static, - { - if self.circuit.num_steps == 0 { - panic!("circuit must call pragma_num_steps before calling fixed_gen"); - } - let mut ctx = FixedGenContext::new(self.circuit.num_steps); - (def)(&mut ctx); - - let assignments = ctx.get_assignments(); - - self.circuit.set_fixed_assignments(assignments); - } -} - -/// Creates a `Circuit` instance by providing a name and a definition closure that is applied to a -/// mutable `CircuitContext`. The user customizes the definition closure by calling `CircuitContext` -/// functions. This is the main function that users call to define a Chiquito circuit (legacy). -pub fn circuit_legacy( - _name: &str, - mut def: D, -) -> SBPIRLegacy> -where - D: FnMut(&mut CircuitContextLegacy>), -{ - // TODO annotate circuit - let mut context = CircuitContextLegacy { - circuit: SBPIRLegacy::default(), - tables: LookupTableRegistry::default(), - }; - - def(&mut context); - - context.circuit -} diff --git a/src/frontend/dsl/mod.rs b/src/frontend/dsl/mod.rs index 77f91933..b6ee4f96 100644 --- a/src/frontend/dsl/mod.rs +++ b/src/frontend/dsl/mod.rs @@ -1,7 +1,8 @@ use crate::{ - sbpir::{query::Queriable, StepType, StepTypeUUID, PIR}, + field::Field, + sbpir::{query::Queriable, ExposeOffset, SBPIRLegacy, StepType, StepTypeUUID, PIR}, util::{uuid, UUID}, - wit_gen::StepInstance, + wit_gen::{FixedGenContext, StepInstance, TraceGenerator}, }; use core::{fmt::Debug, hash::Hash}; @@ -12,14 +13,186 @@ use self::{ lb::{LookupBuilder, LookupTableRegistry}, }; +use halo2_proofs::plonk::{Advice, Column as Halo2Column, Fixed}; +use lb::{LookupTable, LookupTableStore}; pub use sc::*; +use trace::{DSLTraceGenerator, TraceContext}; pub mod cb; -pub mod circuit_context_legacy; pub mod lb; pub mod sc; pub mod trace; +#[derive(Debug, Default)] +/// A generic structure designed to handle the context of a circuit. +/// The struct contains a `Circuit` instance and implements methods to build the circuit, +/// add various components, and manipulate the circuit. +/// +/// ### Type parameters +/// `F` is the field of the circuit. +/// `TG` is the trace generator. +pub struct CircuitContextLegacy = DSLTraceGenerator> { + circuit: SBPIRLegacy, + tables: LookupTableRegistry, +} + +impl> CircuitContextLegacy { + /// Adds a forward signal to the circuit with a name string and zero rotation and returns a + /// `Queriable` instance representing the added forward signal. + pub fn forward(&mut self, name: &str) -> Queriable { + Queriable::Forward(self.circuit.add_forward(name, 0), false) + } + + /// Adds a forward signal to the circuit with a name string and a specified phase and returns a + /// `Queriable` instance representing the added forward signal. + pub fn forward_with_phase(&mut self, name: &str, phase: usize) -> Queriable { + Queriable::Forward(self.circuit.add_forward(name, phase), false) + } + + /// Adds a shared signal to the circuit with a name string and zero rotation and returns a + /// `Queriable` instance representing the added shared signal. + pub fn shared(&mut self, name: &str) -> Queriable { + Queriable::Shared(self.circuit.add_shared(name, 0), 0) + } + + /// Adds a shared signal to the circuit with a name string and a specified phase and returns a + /// `Queriable` instance representing the added shared signal. + pub fn shared_with_phase(&mut self, name: &str, phase: usize) -> Queriable { + Queriable::Shared(self.circuit.add_shared(name, phase), 0) + } + + pub fn fixed(&mut self, name: &str) -> Queriable { + Queriable::Fixed(self.circuit.add_fixed(name), 0) + } + + /// Exposes the first step instance value of a forward signal as public. + pub fn expose(&mut self, queriable: Queriable, offset: ExposeOffset) { + self.circuit.expose(queriable, offset); + } + + /// Imports a halo2 advice column with a name string into the circuit and returns a + /// `Queriable` instance representing the imported column. + pub fn import_halo2_advice(&mut self, name: &str, column: Halo2Column) -> Queriable { + Queriable::Halo2AdviceQuery(self.circuit.add_halo2_advice(name, column), 0) + } + + /// Imports a halo2 fixed column with a name string into the circuit and returns a + /// `Queriable` instance representing the imported column. + pub fn import_halo2_fixed(&mut self, name: &str, column: Halo2Column) -> Queriable { + Queriable::Halo2FixedQuery(self.circuit.add_halo2_fixed(name, column), 0) + } + + /// Adds a new step type with the specified name to the circuit and returns a + /// `StepTypeHandler` instance. The `StepTypeHandler` instance can be used to define the + /// step type using the `step_type_def` function. + pub fn step_type(&mut self, name: &str) -> StepTypeHandler { + let handler = StepTypeHandler::new(name.to_string()); + + self.circuit.add_step_type(handler, name); + + handler + } + + /// Defines a step type using the provided `StepTypeHandler` and a function that takes a + /// mutable reference to a `StepTypeContext`. This function typically adds constraints to a + /// step type and defines witness generation. + pub fn step_type_def, R>( + &mut self, + step: S, + def: D, + ) -> StepTypeWGHandler + where + D: FnOnce(&mut StepTypeContext) -> StepTypeWGHandler, + R: Fn(&mut StepInstance, Args) + 'static, + { + let handler: StepTypeHandler = match step.into() { + StepTypeDefInput::Handler(h) => h, + StepTypeDefInput::String(name) => { + let handler = StepTypeHandler::new(name.to_string()); + + self.circuit.add_step_type(handler, name); + + handler + } + }; + + let mut context = StepTypeContext::::new( + handler.uuid(), + handler.annotation.to_string(), + self.tables.clone(), + ); + + let result = def(&mut context); + + self.circuit.add_step_type_def(context.step_type); + + result + } + + pub fn new_table(&self, table: LookupTableStore) -> LookupTable { + let uuid = table.uuid(); + self.tables.add(table); + + LookupTable { uuid } + } + + /// Enforce the type of the first step by adding a constraint to the circuit. Takes a + /// `StepTypeHandler` parameter that represents the step type. + pub fn pragma_first_step>(&mut self, step_type: STH) { + self.circuit.first_step = Some(step_type.into().uuid()); + } + + /// Enforce the type of the last step by adding a constraint to the circuit. Takes a + /// `StepTypeHandler` parameter that represents the step type. + pub fn pragma_last_step>(&mut self, step_type: STH) { + self.circuit.last_step = Some(step_type.into().uuid()); + } + + /// Enforce the number of step instances by adding a constraint to the circuit. Takes a `usize` + /// parameter that represents the total number of steps. + pub fn pragma_num_steps(&mut self, num_steps: usize) { + self.circuit.num_steps = num_steps; + } + + pub fn pragma_disable_q_enable(&mut self) { + self.circuit.q_enable = false; + } +} + +impl CircuitContextLegacy> { + /// Sets the trace function that builds the witness. The trace function is responsible for + /// adding step instances defined in `step_type_def`. The function is entirely left for + /// the user to implement and is Turing complete. Users typically use external parameters + /// of type `TraceArgs` to generate cell values for witness generation, and call the + /// `add` function to add step instances with witness values. + pub fn trace(&mut self, def: D) + where + D: Fn(&mut TraceContext, TraceArgs) + 'static, + { + self.circuit.set_trace(def); + } +} + +impl> CircuitContextLegacy { + /// Executes the fixed generation function provided by the user and sets the fixed assignments + /// for the circuit. The fixed generation function is responsible for assigning fixed values to + /// fixed columns. It is entirely left for the user to implement and is Turing complete. Users + /// typically generate cell values and call the `assign` function to fill the fixed columns. + pub fn fixed_gen(&mut self, def: D) + where + D: Fn(&mut FixedGenContext) + 'static, + { + if self.circuit.num_steps == 0 { + panic!("circuit must call pragma_num_steps before calling fixed_gen"); + } + let mut ctx = FixedGenContext::new(self.circuit.num_steps); + (def)(&mut ctx); + + let assignments = ctx.get_assignments(); + + self.circuit.set_fixed_assignments(assignments); + } +} pub enum StepTypeDefInput { Handler(StepTypeHandler), String(&'static str), @@ -242,9 +415,32 @@ impl, Args) + 'static> StepTypeWGHandler( + _name: &str, + mut def: D, +) -> SBPIRLegacy> +where + D: FnMut(&mut CircuitContextLegacy>), +{ + // TODO annotate circuit + let mut context = CircuitContextLegacy { + circuit: SBPIRLegacy::default(), + tables: LookupTableRegistry::default(), + }; + + def(&mut context); + + context.circuit +} + #[cfg(test)] mod tests { - use circuit_context_legacy::CircuitContextLegacy; use halo2_proofs::halo2curves::bn256::Fr; use trace::DSLTraceGenerator; diff --git a/src/frontend/dsl/sc.rs b/src/frontend/dsl/sc.rs index b7cdf974..31275d45 100644 --- a/src/frontend/dsl/sc.rs +++ b/src/frontend/dsl/sc.rs @@ -17,9 +17,7 @@ use crate::{ wit_gen::{NullTraceGenerator, TraceGenerator}, }; -use super::{ - circuit_context_legacy::CircuitContextLegacy, lb::LookupTableRegistry, trace::DSLTraceGenerator, -}; +use super::{lb::LookupTableRegistry, trace::DSLTraceGenerator, CircuitContextLegacy}; pub struct SuperCircuitContext { super_circuit: SuperCircuit, @@ -150,7 +148,7 @@ mod tests { use halo2_proofs::halo2curves::{bn256::Fr, ff::PrimeField}; use crate::{ - frontend::dsl::circuit_context_legacy::circuit_legacy, + frontend::dsl::circuit_legacy, plonkish::compiler::{ cell_manager::SingleRowCellManager, config, step_selector::SimpleStepSelectorBuilder, }, From e2c8bfe4b06020a60aad841809a4f136f4720777 Mon Sep 17 00:00:00 2001 From: Alex Kuzmin Date: Tue, 20 Aug 2024 18:28:25 +0800 Subject: [PATCH 7/9] Remove StepTypeHandler import in compiler --- src/compiler/compiler.rs | 9 +++------ src/sbpir/sbpir_machine.rs | 7 +++++-- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/compiler/compiler.rs b/src/compiler/compiler.rs index 394adbda..9c7dd362 100644 --- a/src/compiler/compiler.rs +++ b/src/compiler/compiler.rs @@ -4,7 +4,6 @@ use num_bigint::BigInt; use crate::{ field::Field, - frontend::dsl::StepTypeHandler, interpreter::InterpreterTraceGenerator, parser::{ ast::{ @@ -38,9 +37,9 @@ pub struct CompilerResult { pub(super) struct Compiler { pub(super) config: Config, - messages: Vec, + pub(super) messages: Vec, - mapping: SymbolSignalMapping, + pub(super) mapping: SymbolSignalMapping, _p: PhantomData, } @@ -478,9 +477,7 @@ impl Compiler { let scope_name = format!("//{}", machine_name); let name = format!("{}:{}", scope_name, state_id); - let handler = StepTypeHandler::new(name.to_string()); - - machine.add_step_type(handler, name); + let handler = machine.add_step_type(name); self.mapping .step_type_handler diff --git a/src/sbpir/sbpir_machine.rs b/src/sbpir/sbpir_machine.rs index 3a4a6493..2b21104b 100644 --- a/src/sbpir/sbpir_machine.rs +++ b/src/sbpir/sbpir_machine.rs @@ -171,8 +171,11 @@ impl> SBPIRMachine { advice } - pub fn add_step_type>(&mut self, handler: StepTypeHandler, name: N) { - self.annotations.insert(handler.uuid(), name.into()); + pub fn add_step_type>(&mut self, name: N) -> StepTypeHandler { + let annotation = name.into(); + let handler = StepTypeHandler::new(annotation.clone()); + self.annotations.insert(handler.uuid(), annotation); + handler } pub fn add_step_type_def(&mut self, step: StepType) -> StepTypeUUID { From d02995fccebb0ea35a1bf3b61f0568da58f2d5ae Mon Sep 17 00:00:00 2001 From: Alex Kuzmin Date: Tue, 20 Aug 2024 18:55:00 +0800 Subject: [PATCH 8/9] Refactor `build` function --- src/compiler/compiler.rs | 64 ++++++++++++++++++++++++---------------- 1 file changed, 38 insertions(+), 26 deletions(-) diff --git a/src/compiler/compiler.rs b/src/compiler/compiler.rs index 9c7dd362..ab8751bf 100644 --- a/src/compiler/compiler.rs +++ b/src/compiler/compiler.rs @@ -22,7 +22,7 @@ use crate::{ use super::{ semantic::{SymTable, SymbolCategory}, - setup_inter::{interpret, Setup}, + setup_inter::{interpret, MachineSetup, Setup}, Config, Message, Messages, }; @@ -260,22 +260,8 @@ impl Compiler { ); for state_id in machine_setup.states() { - let handler = self.mapping.get_step_type_handler(machine_name, state_id); - - let mut step_type = StepType::new(handler.uuid(), handler.annotation.to_string()); - - self.add_internal_signals(symbols, machine_name, &mut step_type, state_id); - - let poly_constraints = - self.translate_queries(symbols, setup, machine_name, state_id); - poly_constraints.iter().for_each(|poly| { - let constraint = Constraint { - annotation: format!("{:?}", poly), - expr: poly.clone(), - }; - - step_type.constraints.push(constraint) - }); + let step_type = + self.create_step_type(symbols, machine_name, machine_setup, state_id); sbpir_machine.add_step_type_def(step_type); } @@ -291,22 +277,26 @@ impl Compiler { todo!() } - fn translate_queries( + /// Translate the queries to constraints + fn queries_into_constraints( &mut self, symbols: &SymTable, - setup: &Setup, + setup: &MachineSetup, machine_name: &str, state_id: &str, - ) -> Vec, ()>> { - let exprs = setup - .get(machine_name) - .unwrap() - .get_poly_constraints(state_id) - .unwrap(); + ) -> Vec> { + let exprs = setup.get_poly_constraints(state_id).unwrap(); exprs .iter() - .map(|expr| self.translate_queries_expr(symbols, machine_name, state_id, expr)) + .map(|expr| { + let translate_queries_expr = + self.translate_queries_expr(symbols, machine_name, state_id, expr); + Constraint { + annotation: format!("{:?}", translate_queries_expr), + expr: translate_queries_expr.clone(), + } + }) .collect() } @@ -437,6 +427,28 @@ impl Compiler { .collect() } + fn create_step_type( + &mut self, + symbols: &SymTable, + machine_name: &str, + machine_setup: &MachineSetup, + state_id: &str, + ) -> StepType { + let handler = self.mapping.get_step_type_handler(machine_name, state_id); + + let mut step_type: StepType = + StepType::new(handler.uuid(), handler.annotation.to_string()); + + self.add_internal_signals(symbols, machine_name, &mut step_type, state_id); + + let poly_constraints = + self.queries_into_constraints(symbols, machine_setup, machine_name, state_id); + + step_type.constraints = poly_constraints.clone(); + + step_type + } + fn add_internal_signals( &mut self, symbols: &SymTable, From 05a9046a7f10f8a6bc4cf5e53c01abdcc0d146cd Mon Sep 17 00:00:00 2001 From: Alex Kuzmin Date: Thu, 22 Aug 2024 19:35:21 +0800 Subject: [PATCH 9/9] Restore unit tests --- src/compiler/compiler.rs | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/src/compiler/compiler.rs b/src/compiler/compiler.rs index ab8751bf..f4d80d20 100644 --- a/src/compiler/compiler.rs +++ b/src/compiler/compiler.rs @@ -562,7 +562,7 @@ mod test { use itertools::Itertools; use crate::{ - compiler::{compile, compile_legacy}, + compiler::{compile, compile_file, compile_legacy}, parser::ast::debug_sym_factory::DebugSymRefFactory, wit_gen::TraceGenerator, }; @@ -824,4 +824,24 @@ mod test { } } } + + #[test] + fn test_compiler_fibo_file() { + let path = "test/circuit.chiquito"; + let result = compile_file::(path, Config::default().max_degree(2)); + assert!(result.is_ok()); + } + + #[test] + fn test_compiler_fibo_file_err() { + let path = "test/circuit_error.chiquito"; + let result = compile_file::(path, Config::default().max_degree(2)); + + assert!(result.is_err()); + + assert_eq!( + format!("{:?}", result.unwrap_err()), + r#"[SemErr { msg: "use of undeclared variable c", dsym: test/circuit_error.chiquito:24:39 }, SemErr { msg: "use of undeclared variable c", dsym: test/circuit_error.chiquito:28:46 }]"# + ) + } }