Skip to content

Commit

Permalink
Compile prover fuctions as well (#2432)
Browse files Browse the repository at this point in the history
  • Loading branch information
chriseth authored Feb 5, 2025
1 parent 718f802 commit a79ff6d
Show file tree
Hide file tree
Showing 6 changed files with 85 additions and 16 deletions.
12 changes: 8 additions & 4 deletions executor/src/witgen/jit/block_machine_processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use crate::witgen::{

use super::{
processor::ProcessorResult,
prover_function_heuristics::ProverFunction,
variable::{Cell, Variable},
witgen_inference::{CanProcessCall, FixedEvaluator, WitgenInference},
};
Expand Down Expand Up @@ -50,7 +51,7 @@ impl<'a, T: FieldElement> BlockMachineProcessor<'a, T> {
can_process: impl CanProcessCall<T>,
identity_id: u64,
known_args: &BitVec,
) -> Result<ProcessorResult<T>, String> {
) -> Result<(ProcessorResult<T>, Vec<ProverFunction<'a>>), String> {
let connection = self.machine_parts.connections[&identity_id];
assert_eq!(connection.right.expressions.len(), known_args.len());

Expand Down Expand Up @@ -119,7 +120,7 @@ impl<'a, T: FieldElement> BlockMachineProcessor<'a, T> {
.iter()
.enumerate()
.filter_map(|(i, is_input)| (!is_input).then_some(Variable::Param(i)));
Processor::new(
let result = Processor::new(
self.fixed_data,
self,
identities,
Expand Down Expand Up @@ -154,7 +155,8 @@ impl<'a, T: FieldElement> BlockMachineProcessor<'a, T> {
.take(10)
.format("\n ");
format!("Code generation failed: {shortened_error}\nRun with RUST_LOG=trace to see the code generated so far.")
})
})?;
Ok((result, prover_functions))
}
}

Expand Down Expand Up @@ -243,7 +245,9 @@ mod test {
.chain((0..num_outputs).map(|_| false)),
);

processor.generate_code(&mutable_state, connection_id, &known_values)
processor
.generate_code(&mutable_state, connection_id, &known_values)
.map(|(result, _)| result)
}

#[test]
Expand Down
52 changes: 50 additions & 2 deletions executor/src/witgen/jit/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@ use powdr_ast::{
analyzed::{PolyID, PolynomialType},
indent,
};
use powdr_jit_compiler::util_code::util_code;
use powdr_jit_compiler::{util_code::util_code, CodeGenerator, DefinitionFetcher};
use powdr_number::FieldElement;

use crate::witgen::{
data_structures::{
finalizable_data::{ColumnLayout, CompactDataRef},
mutable_state::MutableState,
},
jit::prover_function_heuristics::ProverFunctionComputation,
machines::{
profiling::{record_end, record_start},
LookupCell,
Expand All @@ -23,6 +24,7 @@ use crate::witgen::{

use super::{
effect::{Assertion, BranchCondition, Effect, ProverFunctionCall},
prover_function_heuristics::ProverFunction,
symbolic_expression::{BinaryOperator, BitOperator, SymbolicExpression, UnaryOperator},
variable::Variable,
};
Expand Down Expand Up @@ -87,17 +89,30 @@ extern "C" fn call_machine<T: FieldElement, Q: QueryCallback<T>>(
}

/// Compile the given inferred effects into machine code and load it.
pub fn compile_effects<T: FieldElement>(
pub fn compile_effects<T: FieldElement, D: DefinitionFetcher>(
definitions: &D,
column_layout: ColumnLayout,
known_inputs: &[Variable],
effects: &[Effect<T, Variable>],
prover_functions: Vec<ProverFunction<'_>>,
) -> Result<WitgenFunction<T>, String> {
let utils = util_code::<T>()?;
let interface = interface_code(column_layout);
let mut codegen = CodeGenerator::<T, _>::new(definitions);
let prover_functions = prover_functions
.iter()
.map(|f| prover_function_code(f, &mut codegen))
.collect::<Result<Vec<_>, _>>()?
.into_iter()
.format("\n");
let prover_functions_dependents = codegen.generated_code();
let witgen_code = witgen_code(known_inputs, effects);
let code = format!(
"{utils}\n\
//-------------------------------\n\
{prover_functions_dependents}\n\
{prover_functions}\n\
//-------------------------------\n\
{interface}\n\
//-------------------------------\n\
{witgen_code}"
Expand Down Expand Up @@ -522,11 +537,35 @@ fn interface_code(column_layout: ColumnLayout) -> String {
)
}

fn prover_function_code<T: FieldElement, D: DefinitionFetcher>(
f: &ProverFunction<'_>,
codegen: &mut CodeGenerator<'_, T, D>,
) -> Result<String, String> {
let code = match f.computation {
ProverFunctionComputation::ComputeFrom(code) => format!(
"({}).call(args.to_vec().into())",
codegen.generate_code_for_expression(code)?
),
ProverFunctionComputation::ProvideIfUnknown(code) => {
format!("({}).call()", codegen.generate_code_for_expression(code)?)
}
};

let index = f.index;
Ok(format!(
"fn prover_function_{index}(i: u64, args: &[FieldElement]) -> FieldElement {{\n\
let i: ibig::IBig = i.into();\n\
{code}
}}"
))
}

#[cfg(test)]
mod tests {

use std::ptr::null;

use powdr_ast::analyzed::FunctionValueDefinition;
use pretty_assertions::assert_eq;
use test_log::test;

Expand All @@ -538,18 +577,27 @@ mod tests {

use super::*;

struct NoDefinitions;
impl DefinitionFetcher for NoDefinitions {
fn get_definition(&self, _: &str) -> Option<&FunctionValueDefinition> {
None
}
}

fn compile_effects(
column_count: usize,
known_inputs: &[Variable],
effects: &[Effect<GoldilocksField, Variable>],
) -> Result<WitgenFunction<GoldilocksField>, String> {
super::compile_effects(
&NoDefinitions,
ColumnLayout {
column_count,
first_column_id: 0,
},
known_inputs,
effects,
vec![],
)
}

Expand Down
20 changes: 15 additions & 5 deletions executor/src/witgen/jit/function_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,10 +117,13 @@ impl<'a, T: FieldElement> FunctionCache<'a, T> {
cache_key.known_args
);

let ProcessorResult {
code,
range_constraints,
} = self
let (
ProcessorResult {
code,
range_constraints,
},
prover_functions,
) = self
.processor
.generate_code(can_process, cache_key.identity_id, &cache_key.known_args)
.map_err(|e| {
Expand Down Expand Up @@ -177,7 +180,14 @@ impl<'a, T: FieldElement> FunctionCache<'a, T> {
.collect::<Vec<_>>();

log::trace!("Compiling effects...");
let function = compile_effects(self.column_layout.clone(), &known_inputs, &code).unwrap();
let function = compile_effects(
self.fixed_data.analyzed,
self.column_layout.clone(),
&known_inputs,
&code,
prover_functions,
)
.unwrap();
log::trace!("Compilation done.");

Some(CacheEntry {
Expand Down
8 changes: 4 additions & 4 deletions executor/src/witgen/jit/interpreter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -533,15 +533,15 @@ mod test {
.chain((0..num_outputs).map(|_| false)),
);

let effects = processor
// TODO we cannot compile the prover functions here, but we can evaluate them.
let (result, _prover_functions) = processor
.generate_code(&mutable_state, connection_id, &known_values)
.unwrap()
.code;
.unwrap();

let known_inputs = (0..12).map(Variable::Param).collect::<Vec<_>>();

// generate interpreter
let interpreter = EffectsInterpreter::new(&known_inputs, &effects);
let interpreter = EffectsInterpreter::new(&known_inputs, &result.code);
// call it
let mut params = [GoldilocksField::default(); 16];
let mut param_lookups = params
Expand Down
7 changes: 7 additions & 0 deletions jit-compiler/src/codegen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,13 @@ impl<'a, T: FieldElement, Def: DefinitionFetcher> CodeGenerator<'a, T, Def> {
Ok(self.symbol_reference(name, type_args))
}

/// Generates code for an isolated expression. This might request code generation
/// for referenced symbols, this the returned code is only valid code in connection with
/// the code returned by `self.generated_code`.
pub fn generate_code_for_expression(&mut self, e: &Expression) -> Result<String, String> {
self.format_expr(e, 0)
}

/// Returns the concatenation of all successfully compiled symbols.
pub fn generated_code(self) -> String {
self.symbols
Expand Down
2 changes: 1 addition & 1 deletion jit-compiler/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@ use std::{
sync::Arc,
};

use codegen::CodeGenerator;
use compiler::{generate_glue_code, load_library};

use itertools::Itertools;
use powdr_ast::analyzed::Analyzed;
use powdr_number::FieldElement;

pub use codegen::{CodeGenerator, DefinitionFetcher};
pub use compiler::call_cargo;

pub struct CompiledPIL {
Expand Down

0 comments on commit a79ff6d

Please sign in to comment.