Skip to content

Commit

Permalink
Process prover functions (#2422)
Browse files Browse the repository at this point in the history
Depends on #2417 and #2423
  • Loading branch information
chriseth authored Feb 4, 2025
1 parent 98ffa98 commit a8091a5
Show file tree
Hide file tree
Showing 6 changed files with 126 additions and 10 deletions.
29 changes: 28 additions & 1 deletion executor/src/witgen/jit/block_machine_processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,12 @@ impl<'a, T: FieldElement> BlockMachineProcessor<'a, T> {
.with_block_shape_check()
.with_block_size(self.block_size)
.with_requested_range_constraints((0..known_args.len()).map(Variable::Param))
.with_prover_functions(prover_functions)
.with_prover_functions(
prover_functions
.iter()
.flat_map(|f| (0..self.block_size).map(move |row| (f.clone(), row as i32)))
.collect_vec()
)
.generate_code(can_process, witgen)
.map_err(|e| {
let err_str = e.to_string_with_variable_formatter(|var| match var {
Expand Down Expand Up @@ -368,4 +373,26 @@ params[3] = main_binary::C[3];"
let input = read_to_string("../test_data/pil/poseidon_gl.pil").unwrap();
generate_for_block_machine(&input, "main_poseidon", 12, 4).unwrap();
}

#[test]
fn simple_prover_function() {
let input = "
namespace std::prover;
let compute_from: expr, int, expr[], (fe[] -> fe) -> () = query |dest_col, row, input_cols, f| {};
namespace Main(256);
col witness a, b;
[a, b] is [Sub.a, Sub.b];
namespace Sub(256);
col witness a, b;
(a - 20) * (b + 3) = 1;
query |i| std::prover::compute_from(b, i, [a], |values| 20);
";
let code = generate_for_block_machine(input, "Sub", 1, 1).unwrap().code;
assert_eq!(
format_code(&code),
"Sub::a[0] = params[0];
Sub::b[0] = prover_function_0(0, [Sub::a[0]]);
params[1] = Sub::b[0];"
);
}
}
21 changes: 20 additions & 1 deletion executor/src/witgen/jit/function_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@ use powdr_number::{FieldElement, KnownField};

use crate::witgen::{
data_structures::finalizable_data::{ColumnLayout, CompactDataRef},
jit::{effect::format_code, processor::ProcessorResult},
jit::{
effect::{format_code, Effect},
processor::ProcessorResult,
},
machines::{
profiling::{record_end, record_start},
LookupCell, MachineParts,
Expand Down Expand Up @@ -149,6 +152,22 @@ impl<'a, T: FieldElement> FunctionCache<'a, T> {
);
}

// TODO remove this once code generation for prover functions is working.
if code
.iter()
.flat_map(|e| -> Box<dyn Iterator<Item = &Effect<_, _>>> {
if let Effect::Branch(_, first, second) = e {
Box::new(first.iter().chain(second))
} else {
Box::new(std::iter::once(e))
}
})
.any(|e| matches!(e, Effect::ProverFunctionCall { .. }))
{
log::debug!("Inferred code contains call to prover function, which is not yet implemented. Using runtime solving instead.");
return None;
}

log::trace!("Generated code ({} steps)", code.len());
let known_inputs = cache_key
.known_args
Expand Down
42 changes: 35 additions & 7 deletions executor/src/witgen/jit/processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ pub struct Processor<'a, T: FieldElement, FixedEval> {
identities: Vec<(&'a Identity<T>, i32)>,
/// The prover functions, i.e. helpers to compute certain values that
/// we cannot easily determine.
prover_functions: Vec<ProverFunction<'a>>,
prover_functions: Vec<(ProverFunction<'a>, i32)>,
/// The size of a block.
block_size: usize,
/// If the processor should check for correctly stackable block shapes.
Expand Down Expand Up @@ -97,7 +97,10 @@ impl<'a, T: FieldElement, FixedEval: FixedEvaluator<T>> Processor<'a, T, FixedEv
self
}

pub fn with_prover_functions(mut self, prover_functions: Vec<ProverFunction<'a>>) -> Self {
pub fn with_prover_functions(
mut self,
prover_functions: Vec<(ProverFunction<'a>, i32)>,
) -> Self {
assert!(self.prover_functions.is_empty());
self.prover_functions = prover_functions;
self
Expand Down Expand Up @@ -277,12 +280,37 @@ impl<'a, T: FieldElement, FixedEval: FixedEvaluator<T>> Processor<'a, T, FixedEv
witgen: &mut WitgenInference<'a, T, FixedEval>,
identity_queue: &mut IdentityQueue<'a, T>,
) -> Result<(), affine_symbolic_expression::Error> {
while let Some((identity, row_offset)) = identity_queue.next() {
let updated_vars =
witgen.process_identity(can_process.clone(), identity, row_offset)?;
identity_queue.variables_updated(updated_vars, Some((identity, row_offset)));
loop {
let identity = identity_queue.next();
let updated_vars = match identity {
Some((identity, row_offset)) => {
witgen.process_identity(can_process.clone(), identity, row_offset)
}
None => self.process_prover_functions(witgen),
}?;
if updated_vars.is_empty() && identity.is_none() {
// No identities to process and prover functions did not make any progress,
// we are done.
return Ok(());
}
identity_queue.variables_updated(updated_vars, identity);
}
Ok(())
}

/// Tries to process all prover functions until the first one is able to update a variable.
/// Returns the updated variables.
fn process_prover_functions(
&self,
witgen: &mut WitgenInference<'a, T, FixedEval>,
) -> Result<Vec<Variable>, affine_symbolic_expression::Error> {
for (prover_function, row_offset) in &self.prover_functions {
let updated_vars = witgen.process_prover_function(prover_function, *row_offset)?;
if !updated_vars.is_empty() {
return Ok(updated_vars);
}
}

Ok(vec![])
}

/// If any machine call could not be completed, that's bad because machine calls typically have side effects.
Expand Down
2 changes: 2 additions & 0 deletions executor/src/witgen/jit/prover_function_heuristics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ pub trait TryColumnByName: Copy {
}

#[allow(unused)]
#[derive(Clone)]
pub struct ProverFunction<'a> {
pub index: usize,
pub target_column: AlgebraicReference,
Expand All @@ -22,6 +23,7 @@ pub struct ProverFunction<'a> {
}

#[allow(unused)]
#[derive(Clone)]
pub enum ProverFunctionComputation<'a> {
/// The expression `f` in `query |i| std::prover::provide_if_unknown(Y, i, f)`,
/// where f: (-> fe)
Expand Down
6 changes: 5 additions & 1 deletion executor/src/witgen/jit/single_step_processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,11 @@ impl<'a, T: FieldElement> SingleStepProcessor<'a, T> {
let witgen =
WitgenInference::new(self.fixed_data, self, known_variables, complete_identities);

let prover_functions = decode_prover_functions(&self.machine_parts, self.fixed_data)?;
let prover_functions = decode_prover_functions(&self.machine_parts, self.fixed_data)?
.into_iter()
// Process prover functions only on the next row.
.map(|f| (f, 1))
.collect_vec();

Processor::new(
self.fixed_data,
Expand Down
36 changes: 36 additions & 0 deletions executor/src/witgen/jit/witgen_inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ use crate::witgen::{
use super::{
affine_symbolic_expression::{AffineSymbolicExpression, Error, ProcessResult},
effect::{BranchCondition, Effect, ProverFunctionCall},
prover_function_heuristics::ProverFunction,
symbolic_expression::SymbolicExpression,
variable::{Cell, MachineCallVariable, Variable},
};
Expand Down Expand Up @@ -196,6 +197,41 @@ impl<'a, T: FieldElement, FixedEval: FixedEvaluator<T>> WitgenInference<'a, T, F
self.ingest_effects(result, Some((id.id(), row_offset)))
}

/// Process a prover function on a row, i.e. determine if we can execute it and if it will
/// help us to compute the value of a previously unknown variable.
/// Returns the list of updated variables.
pub fn process_prover_function(
&mut self,
prover_function: &ProverFunction<'a>,
row_offset: i32,
) -> Result<Vec<Variable>, Error> {
let target = Variable::from_reference(&prover_function.target_column, row_offset);
if !self.is_known(&target) {
let inputs = prover_function
.input_columns
.iter()
.map(|c| Variable::from_reference(c, row_offset))
.collect::<Vec<_>>();
if inputs.iter().all(|v| self.is_known(v)) {
let effect = Effect::ProverFunctionCall(ProverFunctionCall {
target,
function_index: prover_function.index,
row_offset,
inputs,
});
return self.ingest_effects(
ProcessResult {
effects: vec![effect],
complete: true,
},
None,
);
}
}

Ok(vec![])
}

/// Process the constraint that the expression evaluated at the given offset equals the given value.
/// This does not have to be solvable right away, but is always processed as soon as we have progress.
/// Note that all variables in the expression can be unknown and their status can also change over time.
Expand Down

0 comments on commit a8091a5

Please sign in to comment.