diff --git a/src/compiler/cse.rs b/src/compiler/cse.rs index 0fd557ae..3cca6824 100644 --- a/src/compiler/cse.rs +++ b/src/compiler/cse.rs @@ -3,11 +3,19 @@ use rand_chacha::{rand_core::SeedableRng, ChaCha20Rng}; -use std::{collections::{HashMap, HashSet}, hash::Hash, marker::PhantomData}; +use std::{ + collections::{HashMap, HashSet}, + hash::Hash, + marker::PhantomData, +}; use crate::{ field::Field, - poly::{self, cse::{create_common_ses_signal, replace_expr}, Expr, HashResult, VarAssignments}, + poly::{ + self, + cse::{create_common_ses_signal, replace_expr}, + Expr, HashResult, VarAssignments, + }, sbpir::{query::Queriable, InternalSignal, SBPIR}, wit_gen::NullTraceGenerator, }; @@ -52,19 +60,20 @@ pub(super) fn cse( } // Find the optimal subexpression to replace - println!("Step type before CSE: {:#?}", step_type_with_hash); if let Some(common_expr) = find_optimal_subexpression(&exprs, &replaced_hashes) { - println!("Common expression found: {:?}", common_expr); - // Add the hash of the replaced expression to the set - replaced_hashes.insert(common_expr.meta().hash); + // Add the hash of the replaced expression to the set + replaced_hashes.insert(common_expr.meta().hash); // Create a new signal for the common subexpression - let (common_se, decomp) = create_common_ses_signal(&common_expr, &mut signal_factory); + let (common_se, decomp) = + create_common_ses_signal(&common_expr, &mut signal_factory); decomp.auto_signals.iter().for_each(|(q, expr)| { if let Queriable::Internal(signal) = q { step_type_with_hash.add_internal(signal.clone()); } - step_type_with_hash.auto_signals.insert(q.clone(), expr.clone()); + step_type_with_hash + .auto_signals + .insert(q.clone(), expr.clone()); step_type_with_hash.add_constr(format!("{:?}", q), expr.clone()); }); @@ -114,22 +123,18 @@ fn find_optimal_subexpression( // Find the best common subexpression to replace let common_ses = count_map .into_iter() - .filter(|&(hash, info)| info.count > 1 && info.degree > 1 && !replaced_hashes.contains(&hash)) + .filter(|&(hash, info)| { + info.count > 1 && info.degree > 1 && !replaced_hashes.contains(&hash) + }) .collect::>(); - println!("Common subexpressions: {:#?}", common_ses); - let best_subexpr = common_ses .iter() .max_by_key(|&(_, info)| (info.degree, info.count)) .map(|(&hash, info)| (hash, info.count, info.degree)); - println!("Best subexpression: {:#?}", best_subexpr); - if let Some((hash, _count, _degree)) = best_subexpr { - let best_subexpr = hash_to_expr.get(&hash).cloned(); - println!("Best subexpression found: {:#?}", best_subexpr); - best_subexpr + hash_to_expr.get(&hash).cloned() } else { None } diff --git a/src/poly/cse.rs b/src/poly/cse.rs index 0d8b6052..3daf4489 100644 --- a/src/poly/cse.rs +++ b/src/poly/cse.rs @@ -7,7 +7,7 @@ pub fn replace_expr, common_se: &Expr, signal_factory: &mut SF, - decomp: ConstrDecomp + decomp: ConstrDecomp, ) -> (Expr, ConstrDecomp) { let mut decomp = decomp; let new_expr = replace_subexpr(expr, common_se, signal_factory, &mut decomp); @@ -15,7 +15,11 @@ pub fn replace_expr>( +pub fn create_common_ses_signal< + F: Field, + V: Clone + PartialEq + Eq + Hash + Debug, + SF: SignalFactory, +>( common_se: &Expr, signal_factory: &mut SF, ) -> (Expr, ConstrDecomp) { @@ -58,7 +62,10 @@ mod tests { use halo2_proofs::halo2curves::bn256::Fr; use crate::{ - poly::{cse::{create_common_ses_signal, replace_expr}, SignalFactory, ToExpr, VarAssignments}, + poly::{ + cse::{create_common_ses_signal, replace_expr}, + SignalFactory, ToExpr, VarAssignments, + }, sbpir::{query::Queriable, InternalSignal}, }; @@ -91,7 +98,8 @@ mod tests { let assignments: VarAssignments> = vars.iter().cloned().map(|q| (q, Fr::from(2))).collect(); - let (common_se, decomp) = create_common_ses_signal(&common_expr.hash(&assignments), &mut signal_factory); + let (common_se, decomp) = + create_common_ses_signal(&common_expr.hash(&assignments), &mut signal_factory); let (new_expr, decomp) = replace_expr( &expr.hash(&assignments), diff --git a/src/sbpir/mod.rs b/src/sbpir/mod.rs index 4a7ed59a..c9f63e1e 100644 --- a/src/sbpir/mod.rs +++ b/src/sbpir/mod.rs @@ -574,7 +574,10 @@ pub struct TransitionConstraint { } impl TransitionConstraint { - pub fn with_meta(&self, apply_meta: ApplyMetaFn) -> TransitionConstraint + pub fn with_meta( + &self, + apply_meta: ApplyMetaFn, + ) -> TransitionConstraint where ApplyMetaFn: Fn(&Expr, ()>) -> N + Clone, {