Skip to content

Commit

Permalink
Try zero, but only for simple send params. (#2475)
Browse files Browse the repository at this point in the history
  • Loading branch information
chriseth authored Feb 13, 2025
1 parent 27ce67a commit fb5bbd2
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 16 deletions.
30 changes: 30 additions & 0 deletions executor/src/witgen/jit/block_machine_processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -464,4 +464,34 @@ machine_call(1, [Known(call_var(1, 1, 0))]);
SubM::a[1] = ((SubM::b[1] * 256) + SubM::c[1]);"
);
}

#[test]
fn unused_fixed_lookup() {
// Checks that irrelevant fixed lookups are still performed
// in the generated code.
let input = "
namespace Main(256);
col witness a, b, c;
[a, b, c] is [S.a, S.b, S.c];
namespace S(256);
col witness a, b, c, x, y;
let B: col = |i| i & 0xff;
a * (a - 1) = 0;
[ a * x ] in [ B ];
[ (a - 1) * y ] in [ B ];
a + b = c;
";
let code = format_code(&generate_for_block_machine(input, "S", 2, 1).unwrap().code);
assert_eq!(
code,
"S::a[0] = params[0];
S::b[0] = params[1];
S::c[0] = (S::a[0] + S::b[0]);
params[2] = S::c[0];
call_var(2, 0, 0) = 0;
call_var(3, 0, 0) = 0;
machine_call(2, [Known(call_var(2, 0, 0))]);
machine_call(3, [Known(call_var(3, 0, 0))]);"
);
}
}
106 changes: 90 additions & 16 deletions executor/src/witgen/jit/processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ use std::{
};

use itertools::Itertools;
use powdr_ast::analyzed::{PolyID, PolynomialIdentity, PolynomialType};
use powdr_ast::analyzed::{
AlgebraicExpression as Expression, PolyID, PolynomialIdentity, PolynomialType,
};
use powdr_number::FieldElement;

use crate::witgen::{
Expand Down Expand Up @@ -108,18 +110,9 @@ impl<'a, T: FieldElement, FixedEval: FixedEvaluator<T>> Processor<'a, T, FixedEv
match &id {
Identity::BusSend(bus_send) => {
// Create variable assignments for the arguments of bus send identities.
let arguments = &bus_send.selected_payload.expressions;
arguments
.iter()
.enumerate()
.map(move |(index, arg)| {
let var = Variable::MachineCallParam(MachineCallVariable {
identity_id: bus_send.identity_id,
row_offset: *row_offset,
index,
});
QueueItem::variable_assignment(arg, var, *row_offset)
})
machine_call_params(bus_send, *row_offset)
.zip(&bus_send.selected_payload.expressions)
.map(|(var, arg)| QueueItem::variable_assignment(arg, var, *row_offset))
.chain(std::iter::once(QueueItem::Identity(id, *row_offset)))
.collect_vec()
}
Expand All @@ -141,7 +134,7 @@ impl<'a, T: FieldElement, FixedEval: FixedEvaluator<T>> Processor<'a, T, FixedEv
branch_depth: usize,
) -> Result<ProcessorResult<T>, Error<'a, T, FixedEval>> {
if self
.process_until_no_progress(can_process.clone(), &mut witgen, &mut identity_queue)
.process_until_no_progress(can_process.clone(), &mut witgen, identity_queue.clone())
.is_err()
{
return Err(Error::conflicting_constraints(
Expand Down Expand Up @@ -171,7 +164,14 @@ impl<'a, T: FieldElement, FixedEval: FixedEvaluator<T>> Processor<'a, T, FixedEv
.collect_vec();

let incomplete_machine_calls = self.incomplete_machine_calls(&witgen);
if missing_variables.is_empty() && incomplete_machine_calls.is_empty() {
if missing_variables.is_empty()
&& self.try_fix_simple_sends(
&incomplete_machine_calls,
can_process.clone(),
&mut witgen,
identity_queue.clone(),
)
{
let range_constraints = self
.requested_range_constraints
.iter()
Expand Down Expand Up @@ -199,6 +199,7 @@ impl<'a, T: FieldElement, FixedEval: FixedEvaluator<T>> Processor<'a, T, FixedEv
.as_ref()
.map(|(_, rc)| (rc.range_width() >> self.max_branch_depth) > 0.into())
.unwrap_or(true);

if branch_depth >= self.max_branch_depth || no_viable_branch_variable {
let reason = if no_viable_branch_variable {
ErrorReason::NoBranchVariable
Expand Down Expand Up @@ -289,7 +290,7 @@ impl<'a, T: FieldElement, FixedEval: FixedEvaluator<T>> Processor<'a, T, FixedEv
&self,
can_process: impl CanProcessCall<T>,
witgen: &mut WitgenInference<'a, T, FixedEval>,
identity_queue: &mut IdentityQueue<'a, T>,
mut identity_queue: IdentityQueue<'a, T>,
) -> Result<(), affine_symbolic_expression::Error> {
while let Some(item) = identity_queue.next() {
let updated_vars = match &item {
Expand Down Expand Up @@ -455,6 +456,66 @@ impl<'a, T: FieldElement, FixedEval: FixedEvaluator<T>> Processor<'a, T, FixedEv
}
}
}

/// If the only missing sends all only have a single argument, try to set those arguments
/// to zero.
fn try_fix_simple_sends(
&self,
incomplete_machine_calls: &[(&Identity<T>, i32)],
can_process: impl CanProcessCall<T>,
witgen: &mut WitgenInference<'a, T, FixedEval>,
mut identity_queue: IdentityQueue<'a, T>,
) -> bool {
let missing_sends_in_block = incomplete_machine_calls
.iter()
.filter(|(_, row)| 0 <= *row && *row < self.block_size as i32)
.map(|(id, row)| match id {
Identity::BusSend(bus_send) => (bus_send, *row),
_ => unreachable!(),
})
.collect_vec();
// If the send has more than one parameter, we do not want to touch it.
// Same if we do not know that the selector is 1.
if missing_sends_in_block.iter().any(|(bus_send, row)| {
bus_send.selected_payload.expressions.len() > 1
|| !witgen
.evaluate(&bus_send.selected_payload.selector, *row)
.and_then(|v| v.try_to_known().map(|v| v.is_known_one()))
.unwrap_or(false)
}) {
return false;
}
// Create a copy in case we fail.
let mut modified_witgen = witgen.clone();
// Now set all parameters to zero.
for (bus_send, row) in missing_sends_in_block {
let [param] = &machine_call_params(bus_send, row).collect_vec()[..] else {
unreachable!()
};
assert!(!witgen.is_known(param));
match modified_witgen.process_equation_on_row(
&Expression::Number(T::from(0)),
Some(param.clone()),
0.into(),
row,
) {
Err(_) => return false,
Ok(updated_vars) => {
identity_queue.variables_updated(updated_vars, None);
}
};
}
if self
.process_until_no_progress(can_process, &mut modified_witgen, identity_queue)
.is_ok()
&& self.incomplete_machine_calls(&modified_witgen).is_empty()
{
*witgen = modified_witgen;
true
} else {
false
}
}
}

fn is_machine_call<T>(identity: &Identity<T>) -> bool {
Expand All @@ -472,6 +533,19 @@ fn combine_range_constraints<T: FieldElement>(
.collect()
}

fn machine_call_params<T: FieldElement>(
bus_send: &BusSend<T>,
row_offset: i32,
) -> impl Iterator<Item = Variable> + '_ {
(0..bus_send.selected_payload.expressions.len()).map(move |index| {
Variable::MachineCallParam(MachineCallVariable {
identity_id: bus_send.identity_id,
row_offset,
index,
})
})
}

pub struct Error<'a, T: FieldElement, FixedEval: FixedEvaluator<T>> {
pub reason: ErrorReason,
pub witgen: WitgenInference<'a, T, FixedEval>,
Expand Down

0 comments on commit fb5bbd2

Please sign in to comment.