diff --git a/src/compiler/cse.rs b/src/compiler/cse.rs index f7f0d9a9..6701f22d 100644 --- a/src/compiler/cse.rs +++ b/src/compiler/cse.rs @@ -1,6 +1,3 @@ -// Pub cse function receives a circuit and returns a new circuit with common subexpressions replaced -// by signals. Pub cse gets the - use rand_chacha::{rand_core::SeedableRng, ChaCha20Rng}; use std::{ @@ -20,14 +17,14 @@ use crate::{ wit_gen::NullTraceGenerator, }; - /// Common Subexpression Elimination (CSE) optimization. /// This optimization replaces common subexpressions with new internal signals for the step type. /// This is done by each time finding the optimal subexpression to replace and creating a new signal /// for it and replacing it in all constraints. /// The process is repeated until no more common subexpressions are found. -/// Equivalent expressions are found by hashing the expressions with random assignments to the queriables. Using -/// the Schwartz-Zippel lemma, we can determine if two expressions are equivalent with high probability. +/// Equivalent expressions are found by hashing the expressions with random assignments to the +/// queriables. Using the Schwartz-Zippel lemma, we can determine if two expressions are equivalent +/// with high probability. pub(super) fn cse( mut circuit: SBPIR, ) -> SBPIR { @@ -106,6 +103,7 @@ struct SubexprInfo { degree: usize, } +/// Information about a subexpression to help find the optimal subexpression to replace. impl SubexprInfo { fn new(count: usize, degree: usize) -> Self { Self { count, degree } @@ -117,6 +115,7 @@ impl SubexprInfo { } } +/// Find the optimal subexpression to replace in a list of expressions. fn find_optimal_subexpression( exprs: &Vec, HashResult>>, replaced_hashes: &HashSet, @@ -149,6 +148,7 @@ fn find_optimal_subexpression( } } +/// Count the subexpressions in an expression and store them in a map. fn count_subexpressions( expr: &Expr, HashResult>, count_map: &mut HashMap, @@ -245,7 +245,7 @@ mod test { let expr3 = 4.expr() + a * b + c; let expr4 = e * f * d; let expr5 = expr1.clone() + expr4.clone(); - let exprs = vec![expr1, expr2, expr3, expr4, expr5]; + let exprs = vec![expr1, expr2, expr3, expr4.clone(), expr5]; let mut rng = ChaCha20Rng::seed_from_u64(0); let mut rand_assignments = VarAssignments::new(); @@ -259,7 +259,9 @@ mod test { hashed_exprs.push(hashed_expr); } - find_optimal_subexpression(&hashed_exprs, &HashSet::new()); + let best_expr = find_optimal_subexpression(&hashed_exprs, &HashSet::new()); + + assert_eq!(format!("{:?}", best_expr.unwrap()), format!("{:?}", expr4)); } #[test] @@ -297,17 +299,23 @@ mod test { step.add_constr("expr1".into(), expr1); step.add_constr("expr2".into(), expr2); step.add_constr("expr3".into(), expr3); - step.add_constr("expr4".into(), expr4); + step.add_constr("expr4".into(), expr4.clone()); step.add_constr("expr5".into(), expr5); let mut circuit: SBPIR = SBPIR::default(); - circuit.add_step_type_def(step); - - println!("Circuit before CSE: {:#?}", circuit); + let step_uuid = circuit.add_step_type_def(step); let circuit = cse(circuit); - println!("Circuit after CSE: {:#?}", circuit); + let common_ses_found_and_replaced = circuit + .step_types + .get(&step_uuid) + .unwrap() + .auto_signals + .values(); + + assert!(common_ses_found_and_replaced.clone().find(|expr| format!("{:?}", expr) == "(a * b)").is_some()); + assert!(common_ses_found_and_replaced.clone().find(|expr| format!("{:?}", expr) == "(e * f * d)").is_some()); } #[derive(Clone)] @@ -354,11 +362,14 @@ mod test { step.add_constr("expr5".into(), expr5); let step_with_meta = step.with_meta(|expr| TestStruct { - value: format!("{:?}", expr), + value: format!("Expr: {:?}", expr), }); for constraint in &step_with_meta.constraints { - println!("{:?}", constraint.expr.meta().value); + assert_eq!( + constraint.expr.meta().value, + format!("Expr: {:?}", constraint.expr) + ); } } } diff --git a/src/poly/cse.rs b/src/poly/cse.rs index 41f9908c..b1f95ab2 100644 --- a/src/poly/cse.rs +++ b/src/poly/cse.rs @@ -102,11 +102,11 @@ mod tests { let (common_se, decomp) = create_common_ses_signal(&common_expr.hash(&assignments), &mut signal_factory); - let (new_expr, decomp) = replace_expr( + let (new_expr, _) = replace_expr( &expr.hash(&assignments), &common_se, &mut signal_factory, - decomp, + decomp.clone(), ); assert!(decomp.auto_signals.len() == 1);