diff --git a/executor/src/witgen/jit/block_machine_processor.rs b/executor/src/witgen/jit/block_machine_processor.rs index ba307e93d2..0d447a1037 100644 --- a/executor/src/witgen/jit/block_machine_processor.rs +++ b/executor/src/witgen/jit/block_machine_processor.rs @@ -13,6 +13,7 @@ use crate::witgen::{ use super::{ processor::ProcessorResult, + prover_function_heuristics::ProverFunction, variable::{Cell, Variable}, witgen_inference::{CanProcessCall, FixedEvaluator, WitgenInference}, }; @@ -50,7 +51,7 @@ impl<'a, T: FieldElement> BlockMachineProcessor<'a, T> { can_process: impl CanProcessCall, identity_id: u64, known_args: &BitVec, - ) -> Result, String> { + ) -> Result<(ProcessorResult, Vec>), String> { let connection = self.machine_parts.connections[&identity_id]; assert_eq!(connection.right.expressions.len(), known_args.len()); @@ -119,7 +120,7 @@ impl<'a, T: FieldElement> BlockMachineProcessor<'a, T> { .iter() .enumerate() .filter_map(|(i, is_input)| (!is_input).then_some(Variable::Param(i))); - Processor::new( + let result = Processor::new( self.fixed_data, self, identities, @@ -154,7 +155,8 @@ impl<'a, T: FieldElement> BlockMachineProcessor<'a, T> { .take(10) .format("\n "); format!("Code generation failed: {shortened_error}\nRun with RUST_LOG=trace to see the code generated so far.") - }) + })?; + Ok((result, prover_functions)) } } @@ -243,7 +245,9 @@ mod test { .chain((0..num_outputs).map(|_| false)), ); - processor.generate_code(&mutable_state, connection_id, &known_values) + processor + .generate_code(&mutable_state, connection_id, &known_values) + .map(|(result, _)| result) } #[test] diff --git a/executor/src/witgen/jit/compiler.rs b/executor/src/witgen/jit/compiler.rs index ca90e7b23c..0e29b2a785 100644 --- a/executor/src/witgen/jit/compiler.rs +++ b/executor/src/witgen/jit/compiler.rs @@ -6,7 +6,7 @@ use powdr_ast::{ analyzed::{PolyID, PolynomialType}, indent, }; -use powdr_jit_compiler::util_code::util_code; +use powdr_jit_compiler::{util_code::util_code, CodeGenerator, DefinitionFetcher}; use powdr_number::FieldElement; use crate::witgen::{ @@ -14,6 +14,7 @@ use crate::witgen::{ finalizable_data::{ColumnLayout, CompactDataRef}, mutable_state::MutableState, }, + jit::prover_function_heuristics::ProverFunctionComputation, machines::{ profiling::{record_end, record_start}, LookupCell, @@ -23,6 +24,7 @@ use crate::witgen::{ use super::{ effect::{Assertion, BranchCondition, Effect, ProverFunctionCall}, + prover_function_heuristics::ProverFunction, symbolic_expression::{BinaryOperator, BitOperator, SymbolicExpression, UnaryOperator}, variable::Variable, }; @@ -87,17 +89,30 @@ extern "C" fn call_machine>( } /// Compile the given inferred effects into machine code and load it. -pub fn compile_effects( +pub fn compile_effects( + definitions: &D, column_layout: ColumnLayout, known_inputs: &[Variable], effects: &[Effect], + prover_functions: Vec>, ) -> Result, String> { let utils = util_code::()?; let interface = interface_code(column_layout); + let mut codegen = CodeGenerator::::new(definitions); + let prover_functions = prover_functions + .iter() + .map(|f| prover_function_code(f, &mut codegen)) + .collect::, _>>()? + .into_iter() + .format("\n"); + let prover_functions_dependents = codegen.generated_code(); let witgen_code = witgen_code(known_inputs, effects); let code = format!( "{utils}\n\ //-------------------------------\n\ + {prover_functions_dependents}\n\ + {prover_functions}\n\ + //-------------------------------\n\ {interface}\n\ //-------------------------------\n\ {witgen_code}" @@ -522,11 +537,35 @@ fn interface_code(column_layout: ColumnLayout) -> String { ) } +fn prover_function_code( + f: &ProverFunction<'_>, + codegen: &mut CodeGenerator<'_, T, D>, +) -> Result { + let code = match f.computation { + ProverFunctionComputation::ComputeFrom(code) => format!( + "({}).call(args.to_vec().into())", + codegen.generate_code_for_expression(code)? + ), + ProverFunctionComputation::ProvideIfUnknown(code) => { + format!("({}).call()", codegen.generate_code_for_expression(code)?) + } + }; + + let index = f.index; + Ok(format!( + "fn prover_function_{index}(i: u64, args: &[FieldElement]) -> FieldElement {{\n\ + let i: ibig::IBig = i.into();\n\ + {code} + }}" + )) +} + #[cfg(test)] mod tests { use std::ptr::null; + use powdr_ast::analyzed::FunctionValueDefinition; use pretty_assertions::assert_eq; use test_log::test; @@ -538,18 +577,27 @@ mod tests { use super::*; + struct NoDefinitions; + impl DefinitionFetcher for NoDefinitions { + fn get_definition(&self, _: &str) -> Option<&FunctionValueDefinition> { + None + } + } + fn compile_effects( column_count: usize, known_inputs: &[Variable], effects: &[Effect], ) -> Result, String> { super::compile_effects( + &NoDefinitions, ColumnLayout { column_count, first_column_id: 0, }, known_inputs, effects, + vec![], ) } diff --git a/executor/src/witgen/jit/function_cache.rs b/executor/src/witgen/jit/function_cache.rs index 460c9ca1ca..eefb1d71fd 100644 --- a/executor/src/witgen/jit/function_cache.rs +++ b/executor/src/witgen/jit/function_cache.rs @@ -117,10 +117,13 @@ impl<'a, T: FieldElement> FunctionCache<'a, T> { cache_key.known_args ); - let ProcessorResult { - code, - range_constraints, - } = self + let ( + ProcessorResult { + code, + range_constraints, + }, + prover_functions, + ) = self .processor .generate_code(can_process, cache_key.identity_id, &cache_key.known_args) .map_err(|e| { @@ -177,7 +180,14 @@ impl<'a, T: FieldElement> FunctionCache<'a, T> { .collect::>(); log::trace!("Compiling effects..."); - let function = compile_effects(self.column_layout.clone(), &known_inputs, &code).unwrap(); + let function = compile_effects( + self.fixed_data.analyzed, + self.column_layout.clone(), + &known_inputs, + &code, + prover_functions, + ) + .unwrap(); log::trace!("Compilation done."); Some(CacheEntry { diff --git a/executor/src/witgen/jit/interpreter.rs b/executor/src/witgen/jit/interpreter.rs index fd26936c9e..58a83f8b83 100644 --- a/executor/src/witgen/jit/interpreter.rs +++ b/executor/src/witgen/jit/interpreter.rs @@ -533,15 +533,15 @@ mod test { .chain((0..num_outputs).map(|_| false)), ); - let effects = processor + // TODO we cannot compile the prover functions here, but we can evaluate them. + let (result, _prover_functions) = processor .generate_code(&mutable_state, connection_id, &known_values) - .unwrap() - .code; + .unwrap(); let known_inputs = (0..12).map(Variable::Param).collect::>(); // generate interpreter - let interpreter = EffectsInterpreter::new(&known_inputs, &effects); + let interpreter = EffectsInterpreter::new(&known_inputs, &result.code); // call it let mut params = [GoldilocksField::default(); 16]; let mut param_lookups = params diff --git a/jit-compiler/src/codegen.rs b/jit-compiler/src/codegen.rs index 7253d98963..cdb6db84d3 100644 --- a/jit-compiler/src/codegen.rs +++ b/jit-compiler/src/codegen.rs @@ -80,6 +80,13 @@ impl<'a, T: FieldElement, Def: DefinitionFetcher> CodeGenerator<'a, T, Def> { Ok(self.symbol_reference(name, type_args)) } + /// Generates code for an isolated expression. This might request code generation + /// for referenced symbols, this the returned code is only valid code in connection with + /// the code returned by `self.generated_code`. + pub fn generate_code_for_expression(&mut self, e: &Expression) -> Result { + self.format_expr(e, 0) + } + /// Returns the concatenation of all successfully compiled symbols. pub fn generated_code(self) -> String { self.symbols diff --git a/jit-compiler/src/lib.rs b/jit-compiler/src/lib.rs index a97d778c84..2719356688 100644 --- a/jit-compiler/src/lib.rs +++ b/jit-compiler/src/lib.rs @@ -8,13 +8,13 @@ use std::{ sync::Arc, }; -use codegen::CodeGenerator; use compiler::{generate_glue_code, load_library}; use itertools::Itertools; use powdr_ast::analyzed::Analyzed; use powdr_number::FieldElement; +pub use codegen::{CodeGenerator, DefinitionFetcher}; pub use compiler::call_cargo; pub struct CompiledPIL {