diff --git a/executor/src/witgen/bus_accumulator/mod.rs b/executor/src/witgen/bus_accumulator/mod.rs index c04359a826..db3edbc3ee 100644 --- a/executor/src/witgen/bus_accumulator/mod.rs +++ b/executor/src/witgen/bus_accumulator/mod.rs @@ -1,5 +1,5 @@ use std::{ - collections::{BTreeMap, BTreeSet}, + collections::{BTreeMap, BTreeSet, HashSet}, iter::once, }; @@ -15,6 +15,8 @@ use powdr_executor_utils::{ use powdr_number::{DegreeType, FieldElement, KnownField}; use rayon::iter::{IntoParallelRefIterator, ParallelIterator}; +use super::polynomial_references; + mod extension_field; mod fp2; mod fp4; @@ -120,20 +122,71 @@ impl<'a, T: FieldElement, Ext: ExtensionField + Sync> BusAccumulatorGenerator } pub fn generate(&self) -> Vec<(String, Vec)> { + let second_stage_cols = self + .pil + .committed_polys_in_source_order() + .filter(|(symbol, _)| symbol.stage == Some(1)) + .flat_map(|(symbol, _)| symbol.array_elements()) + .map(|(name, poly_id)| (poly_id, name)) + .collect::>(); + let intermediate_definitions = self.pil.intermediate_definitions(); + let accumulators = self .bus_interactions .par_iter() .flat_map(|bus_interaction| { + // Collect the PolyIDs of the accumulator columns. + let acc_cols = bus_interaction + .accumulator_columns + .iter() + .map(|reference| { + assert!(!reference.next); + reference.poly_id + }) + .collect::>(); + // Find the folded columns if they exist, by finding all second-stage columns that + // are referenced together with the accumulator columns. + let folded_cols = self + .pil + .identities + .iter() + .flat_map(|identity| match identity { + Identity::Polynomial(_) => { + let references = + polynomial_references(identity, &intermediate_definitions) + .into_iter() + .filter(|col| second_stage_cols.contains_key(col)) + .collect::>(); + if acc_cols.iter().any(|col| references.contains(col)) { + references + .into_iter() + .filter(|col| !acc_cols.contains(col)) + .collect() + } else { + vec![] + } + } + _ => vec![], + }) + .collect::>(); + assert!(folded_cols.is_empty() || folded_cols.len() == acc_cols.len()); + let (folded, acc) = self.interaction_columns(bus_interaction); - folded.into_iter().chain(acc).collect::>() + + assert!(folded_cols.is_empty() || folded_cols.len() == folded.len()); + let acc = acc_cols.into_iter().zip_eq(acc); + let folded = folded_cols.into_iter().zip(folded); + + acc.chain(folded).collect::>() }) .collect::>(); - self.pil - .committed_polys_in_source_order() - .filter(|(symbol, _)| symbol.stage == Some(1)) - .flat_map(|(symbol, _)| symbol.array_elements().map(|(name, _)| name)) - .zip_eq(accumulators) + accumulators + .into_iter() + .map(|(poly_id, acc)| { + let name = second_stage_cols.get(&poly_id).unwrap().clone(); + (name, acc) + }) .collect() }