Skip to content

Commit

Permalink
Split assignments (#2463)
Browse files Browse the repository at this point in the history
Split the `Assignment` struct into two structs and eliminate the
`VariableOrValue` enum.

This simplifies some interfaces or at least makes them more orthogonal.

The QueueItem is turned into a public data structure in turn.
  • Loading branch information
chriseth authored Feb 11, 2025
1 parent fd973e9 commit 4683e58
Show file tree
Hide file tree
Showing 7 changed files with 225 additions and 256 deletions.
10 changes: 10 additions & 0 deletions executor/src/witgen/jit/affine_symbolic_expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,16 @@ pub struct AffineSymbolicExpression<T: FieldElement, V> {
range_constraints: BTreeMap<V, RangeConstraint<T>>,
}

impl<T: FieldElement, V> Default for AffineSymbolicExpression<T, V> {
fn default() -> Self {
Self {
coefficients: Default::default(),
offset: T::zero().into(),
range_constraints: Default::default(),
}
}
}

/// Display for affine symbolic expressions, for informational purposes only.
impl<T: FieldElement, V: Display> Display for AffineSymbolicExpression<T, V> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
Expand Down
40 changes: 20 additions & 20 deletions executor/src/witgen/jit/block_machine_processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ use powdr_number::FieldElement;

use crate::witgen::{
jit::{
processor::Processor, prover_function_heuristics::decode_prover_functions,
witgen_inference::Assignment,
identity_queue::QueueItem, processor::Processor,
prover_function_heuristics::decode_prover_functions,
},
machines::MachineParts,
FixedData,
Expand Down Expand Up @@ -69,42 +69,43 @@ impl<'a, T: FieldElement> BlockMachineProcessor<'a, T> {

let prover_functions = decode_prover_functions(&self.machine_parts, self.fixed_data)?;

let mut queue_items = vec![];

// In the latch row, set the RHS selector to 1.
let mut assignments = vec![];
let selector = &connection.right.selector;
assignments.push(Assignment::assign_constant(
queue_items.push(QueueItem::constant_assignment(
selector,
self.latch_row as i32,
T::one(),
self.latch_row as i32,
));

if let Some((index, value)) = known_concrete {
// Set the known argument to the concrete value.
assignments.push(Assignment::assign_constant(
queue_items.push(QueueItem::constant_assignment(
&connection.right.expressions[index],
self.latch_row as i32,
value,
self.latch_row as i32,
));
}

// Set all other selectors to 0 in the latch row.
for other_connection in self.machine_parts.connections.values() {
let other_selector = &other_connection.right.selector;
if other_selector != selector {
assignments.push(Assignment::assign_constant(
queue_items.push(QueueItem::constant_assignment(
other_selector,
self.latch_row as i32,
T::zero(),
self.latch_row as i32,
));
}
}

// For each argument, connect the expression on the RHS with the formal parameter.
for (index, expr) in connection.right.expressions.iter().enumerate() {
assignments.push(Assignment::assign_variable(
queue_items.push(QueueItem::variable_assignment(
expr,
self.latch_row as i32,
Variable::Param(index),
self.latch_row as i32,
));
}

Expand Down Expand Up @@ -142,6 +143,11 @@ impl<'a, T: FieldElement> BlockMachineProcessor<'a, T> {
.collect_vec()
});

// Add the prover functions
queue_items.extend(prover_functions.iter().flat_map(|f| {
(0..self.block_size).map(move |row| QueueItem::ProverFunction(f.clone(), row as i32))
}));

let requested_known = known_args
.iter()
.enumerate()
Expand All @@ -150,19 +156,13 @@ impl<'a, T: FieldElement> BlockMachineProcessor<'a, T> {
self.fixed_data,
self,
identities,
assignments,
queue_items,
requested_known,
BLOCK_MACHINE_MAX_BRANCH_DEPTH,
)
.with_block_shape_check()
.with_block_size(self.block_size)
.with_requested_range_constraints((0..known_args.len()).map(Variable::Param))
.with_prover_functions(
prover_functions
.iter()
.flat_map(|f| (0..self.block_size).map(move |row| (f.clone(), row as i32)))
.collect_vec()
)
.generate_code(can_process, witgen)
.map_err(|e| {
let err_str = e.to_string_with_variable_formatter(|var| match var {
Expand Down Expand Up @@ -340,10 +340,10 @@ params[2] = Add::c[0];"
assert_eq!(c_rc, &RangeConstraint::from_mask(0xffffffffu64));
assert_eq!(
format_code(&result.code),
"main_binary::operation_id[3] = params[0];
"main_binary::sel[0][3] = 1;
main_binary::operation_id[3] = params[0];
main_binary::A[3] = params[1];
main_binary::B[3] = params[2];
main_binary::sel[0][3] = 1;
main_binary::operation_id[2] = main_binary::operation_id[3];
main_binary::operation_id[1] = main_binary::operation_id[2];
main_binary::operation_id[0] = main_binary::operation_id[1];
Expand Down
39 changes: 32 additions & 7 deletions executor/src/witgen/jit/debug_formatter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,39 @@ impl<T: FieldElement, FixedEval: FixedEvaluator<T>> DebugFormatter<'_, T, FixedE
fn format_identities(&self) -> String {
self.identities
.iter()
.filter(|(id, row)| !self.witgen.is_complete(id, *row))
.sorted_by_key(|(id, row)| (row, id.id()))
.map(|(id, row)| {
format!(
"--------------[ identity {} on row {row}: ]--------------\n{}",
id.id(),
self.format_identity(id, *row)
)
.flat_map(|(id, row)| {
let (skip, conflicting) = match &id {
Identity::BusSend(..) => (self.witgen.is_complete_call(id, *row), false),
Identity::Polynomial(PolynomialIdentity { expression, .. }) => {
let value = self
.witgen
.evaluate(expression, *row)
.and_then(|v| v.try_to_known().cloned());
let conflict = value
.as_ref()
.and_then(|v| v.try_to_number().map(|n| n != 0.into()))
.unwrap_or(false);
// We can skip the identity if it does not have unknown variables
// but only if there is no conflict.
(value.is_some() && !conflict, conflict)
}
Identity::Connect(..) => (false, false),
};
if skip {
None
} else {
Some(format!(
"{}--------------[ identity {} on row {row}: ]--------------\n{}",
if conflicting {
"--------------[ !!! CONFLICT !!! ]--------------\n"
} else {
""
},
id.id(),
self.format_identity(id, *row)
))
}
})
.join("\n")
}
Expand Down
97 changes: 61 additions & 36 deletions executor/src/witgen/jit/identity_queue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,7 @@ use crate::witgen::{
data_structures::identity::Identity, jit::variable::MachineCallVariable, FixedData,
};

use super::{
prover_function_heuristics::ProverFunction,
variable::Variable,
witgen_inference::{Assignment, VariableOrValue},
};
use super::{prover_function_heuristics::ProverFunction, variable::Variable};

/// Keeps track of identities that still need to be processed and
/// updates this list based on the occurrence of updated variables
Expand All @@ -35,20 +31,9 @@ pub struct IdentityQueue<'a, T: FieldElement> {
impl<'a, T: FieldElement> IdentityQueue<'a, T> {
pub fn new(
fixed_data: &'a FixedData<'a, T>,
identities: &[(&'a Identity<T>, i32)],
assignments: &[Assignment<'a, T>],
prover_functions: &[(ProverFunction<'a, T>, i32)],
items: impl IntoIterator<Item = QueueItem<'a, T>>,
) -> Self {
let queue: BTreeSet<_> = identities
.iter()
.map(|(id, row)| QueueItem::Identity(id, *row))
.chain(assignments.iter().map(|a| QueueItem::Assignment(a.clone())))
.chain(
prover_functions
.iter()
.map(|(p, row)| QueueItem::ProverFunction(p.clone(), *row)),
)
.collect();
let queue: BTreeSet<_> = items.into_iter().collect();
let mut references = ReferencesComputer::new(fixed_data);
let occurrences = Rc::new(
queue
Expand Down Expand Up @@ -93,25 +78,50 @@ impl<'a, T: FieldElement> IdentityQueue<'a, T> {
#[derive(Clone)]
pub enum QueueItem<'a, T: FieldElement> {
Identity(&'a Identity<T>, i32),
Assignment(Assignment<'a, T>),
VariableAssignment(VariableAssignment<'a, T>),
ConstantAssignment(ConstantAssignment<'a, T>),
ProverFunction(ProverFunction<'a, T>, i32),
}

/// Sorts identities by row and then by ID, preceded by assignments.
impl<T: FieldElement> Ord for QueueItem<'_, T> {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
match (self, other) {
(QueueItem::Identity(id1, row1), QueueItem::Identity(id2, row2)) => {
(row1, id1.id()).cmp(&(row2, id2.id()))
}
(QueueItem::Assignment(a1), QueueItem::Assignment(a2)) => a1.cmp(a2),
(QueueItem::VariableAssignment(a1), QueueItem::VariableAssignment(a2)) => a1.cmp(a2),
(QueueItem::ConstantAssignment(a1), QueueItem::ConstantAssignment(a2)) => a1.cmp(a2),
(QueueItem::ProverFunction(p1, row1), QueueItem::ProverFunction(p2, row2)) => {
(row1, p1.index).cmp(&(row2, p2.index))
}
(QueueItem::Assignment(..), _) => std::cmp::Ordering::Less,
(QueueItem::Identity(..), QueueItem::Assignment(..)) => std::cmp::Ordering::Greater,
(QueueItem::Identity(..), QueueItem::ProverFunction(..)) => std::cmp::Ordering::Less,
(QueueItem::ProverFunction(..), _) => std::cmp::Ordering::Greater,
(a, b) => a.order().cmp(&b.order()),
}
}
}

impl<'a, T: FieldElement> QueueItem<'a, T> {
pub fn constant_assignment(lhs: &'a Expression<T>, rhs: T, row_offset: i32) -> Self {
QueueItem::ConstantAssignment(ConstantAssignment {
lhs,
row_offset,
rhs,
})
}

pub fn variable_assignment(lhs: &'a Expression<T>, rhs: Variable, row_offset: i32) -> Self {
QueueItem::VariableAssignment(VariableAssignment {
lhs,
row_offset,
rhs,
})
}

fn order(&self) -> u32 {
match self {
QueueItem::ConstantAssignment(..) => 0,
QueueItem::VariableAssignment(..) => 1,
QueueItem::Identity(..) => 2,
QueueItem::ProverFunction(..) => 3,
}
}
}
Expand All @@ -130,6 +140,24 @@ impl<T: FieldElement> PartialEq for QueueItem<'_, T> {

impl<T: FieldElement> Eq for QueueItem<'_, T> {}

/// An equality constraint between an algebraic expression evaluated
/// on a certain row offset and a variable.
#[derive(Clone, Ord, PartialOrd, Eq, PartialEq, Debug)]
pub struct VariableAssignment<'a, T: FieldElement> {
pub lhs: &'a Expression<T>,
pub row_offset: i32,
pub rhs: Variable,
}

/// An equality constraint between an algebraic expression evaluated
/// on a certain row offset and a constant.
#[derive(Clone, Ord, PartialOrd, Eq, PartialEq, Debug)]
pub struct ConstantAssignment<'a, T: FieldElement> {
pub lhs: &'a Expression<T>,
pub row_offset: i32,
pub rhs: T,
}

/// Utility to compute the variables that occur in a queue item.
/// Follows intermediate column references and employs caches.
struct ReferencesComputer<'a, T: FieldElement> {
Expand Down Expand Up @@ -171,17 +199,14 @@ impl<'a, T: FieldElement> ReferencesComputer<'a, T> {
),
Identity::Connect(..) => Box::new(std::iter::empty()),
},
QueueItem::Assignment(a) => {
let vars_in_rhs = match &a.rhs {
VariableOrValue::Variable(v) => Some(v.clone()),
VariableOrValue::Value(_) => None,
};
Box::new(
self.variables_in_expression(a.lhs, a.row_offset)
.into_iter()
.chain(vars_in_rhs),
)
}
QueueItem::ConstantAssignment(a) => Box::new(
self.variables_in_expression(a.lhs, a.row_offset)
.into_iter(),
),
QueueItem::VariableAssignment(a) => Box::new(
std::iter::once(a.rhs.clone())
.chain(self.variables_in_expression(a.lhs, a.row_offset)),
),
QueueItem::ProverFunction(p, row) => Box::new(
p.condition
.iter()
Expand Down
Loading

0 comments on commit 4683e58

Please sign in to comment.