From 814028506df2384370b8b8762018e04b211bf7ef Mon Sep 17 00:00:00 2001 From: Georg Wiese Date: Thu, 23 Jan 2025 06:06:44 -0300 Subject: [PATCH] Process identities in order (if possible) (#2323) Needed for the automatic pre-compile experiments @leonardoalt is running. Changes the order in which **runtime** witgen for block machines processes identities. ### Motivation The problem this solves is that sometimes processing an identity has side effects, specifically links to memory: the order in which you execute them matters. In practice, this is usually not an issue, because: - In the main machine, rows are processed in order anyway, and rows correspond to time steps of the computation. - A machine like [`Keccakf16Memory`](https://github.com/powdr-labs/powdr/blob/main/std/machines/hash/keccakf16_memory.asm) only reads and writes each address once, and the value to write depends on the read value, so in order to do the write, the read needs to happen first. This is guaranteed by the current algorithm. But in the general case, there is no guarantee, because witgen doesn't know about the notion of time. For example, if the value written is a constant, it can be executed by witgen as soon as the address is known, even if there is also a read of a smaller time step. The workaround that this PR implements is that identities are processed *in the order they appear in PIL*, as much as possible. In particular, an identity will only be processed, if no previous identity (or the outer query or any prover function) led to a progress. With this, the user can get witgen to execute memory operations in the correct order. This issue came up by some pre-compile experiments of Leo. ### Implementation The previous algorithm processed a given row like this: ```python progress = True while progress: progress = False for identity in identities: progress |= process(identity) progress |= process(prover_queries) progress |= process(outer_query) ``` With this PR, we do this: ```python while round(): pass def round(): if process(outer_query): return True if process(prover_queries): return True for identity in identities: if process(identity): return True return False ``` What this gets us is that an identity is only executed *if no previous identity identity has made progress*. The order of identities is the order in which it appears in the PIL (I think!). This allows us to control the order in which stateful machine calls, like memory accesses, are executed. I guess this can make the solving of the first block significantly (afterwards, the sequence cache should remove any identities that did not lead to progress). As an upper bound, witgen for `test_data/std/keccakf16_memory_test.asm` takes 347s now (instead of 309 on `main`) :) --- executor/src/witgen/jit/processor.rs | 2 +- executor/src/witgen/sequence_iterator.rs | 50 +++++++++++++----------- 2 files changed, 28 insertions(+), 24 deletions(-) diff --git a/executor/src/witgen/jit/processor.rs b/executor/src/witgen/jit/processor.rs index a31477af1d..3727f27fa1 100644 --- a/executor/src/witgen/jit/processor.rs +++ b/executor/src/witgen/jit/processor.rs @@ -338,7 +338,7 @@ impl<'a, T: FieldElement, FixedEval: FixedEvaluator> Processor<'a, T, FixedEv } } -/// Computes a map from each variable to the identitie-row-offset pairs it occurs in. +/// Computes a map from each variable to the identity-row-offset pairs it occurs in. fn compute_occurrences_map<'a, T: FieldElement>( fixed_data: &'a FixedData<'a, T>, identities: &[(&'a Identity, i32)], diff --git a/executor/src/witgen/sequence_iterator.rs b/executor/src/witgen/sequence_iterator.rs index cd562f9da1..af5f0e882f 100644 --- a/executor/src/witgen/sequence_iterator.rs +++ b/executor/src/witgen/sequence_iterator.rs @@ -26,16 +26,11 @@ pub struct DefaultSequenceIterator { /// [process identity 1, ..., process identity , process queries, process outer query (if on outer_query_row)] /// Can be -1 to indicate that the round has just started. cur_action_index: i32, - /// The number of rounds for the current row delta. - /// If this number gets too large, we will assume that we're in an infinite loop and exit. - current_round_count: usize, /// The steps on which we made progress. progress_steps: Vec, } -const MAX_ROUNDS_PER_ROW_DELTA: usize = 100; - impl DefaultSequenceIterator { pub fn new(block_size: usize, identities_count: usize, outer_query_row: Option) -> Self { let max_row = block_size as i64 - 1; @@ -50,7 +45,6 @@ impl DefaultSequenceIterator { progress_in_current_round: false, cur_row_delta_index: 0, cur_action_index: -1, - current_round_count: 0, progress_steps: vec![], } } @@ -59,7 +53,9 @@ impl DefaultSequenceIterator { /// If we're not at the last identity in the current row, just moves to the next. /// Otherwise, starts with identity 0 and moves to the next row if no progress was made. fn update_state(&mut self) { - while !self.is_done() && !self.has_more_actions() { + if !self.is_done() && (!self.has_more_actions() || self.progress_in_current_round) { + // Starting a new round if we made any progress ensures that identities are + // processed in source order if possible. self.start_next_round(); } @@ -86,18 +82,9 @@ impl DefaultSequenceIterator { } fn start_next_round(&mut self) { - if self.current_round_count > MAX_ROUNDS_PER_ROW_DELTA { - panic!("In witness generation for block machine, we have been stuck in the same row for {MAX_ROUNDS_PER_ROW_DELTA} rounds. \ - This is a bug in the witness generation algorithm."); - } - if !self.progress_in_current_round { // Move to next row delta self.cur_row_delta_index += 1; - self.current_round_count = 0; - } else { - // Stay and current row delta - self.current_round_count += 1; } // Reset action index and progress flag self.cur_action_index = -1; @@ -126,16 +113,33 @@ impl DefaultSequenceIterator { Some(self.current_step()) } + /// Computes the current step from the current action index and row delta. + /// The actions are: + /// - The outer query (if on the outer query row) + /// - Processing the prover queries + /// - Processing the internal identities, in the order there are given + /// (which should typically correspond to source order). fn current_step(&self) -> SequenceStep { assert!(self.cur_action_index != -1); + + let row_delta = self.row_deltas[self.cur_row_delta_index]; + let is_on_row_with_outer_query = self.outer_query_row == Some(row_delta); + + let cur_action_index = if is_on_row_with_outer_query { + self.cur_action_index as usize + } else { + // Skip the outer query action + self.cur_action_index as usize + 1 + }; + SequenceStep { row_delta: self.row_deltas[self.cur_row_delta_index], - action: match self.cur_action_index.cmp(&(self.identities_count as i32)) { - std::cmp::Ordering::Less => { - Action::InternalIdentity(self.cur_action_index as usize) - } - std::cmp::Ordering::Equal => Action::ProverQueries, - std::cmp::Ordering::Greater => Action::OuterQuery, + action: if cur_action_index == 0 { + Action::OuterQuery + } else if cur_action_index == 1 { + Action::ProverQueries + } else { + Action::InternalIdentity(cur_action_index - 2) }, } } @@ -143,9 +147,9 @@ impl DefaultSequenceIterator { #[derive(Clone, Copy, Debug)] pub enum Action { - InternalIdentity(usize), OuterQuery, ProverQueries, + InternalIdentity(usize), } #[derive(PartialOrd, Ord, PartialEq, Eq, Debug)]