From 77db782ef9a6a677c3aaeff53788a943cd7574e6 Mon Sep 17 00:00:00 2001 From: Rute Figueiredo Date: Wed, 31 Jul 2024 18:59:32 +0100 Subject: [PATCH] implemented cse loop --- src/compiler/cse.rs | 80 ++++++++++++++++++++++++++++++--------------- src/poly/cse.rs | 62 +++++++++++++++++++++-------------- 2 files changed, 92 insertions(+), 50 deletions(-) diff --git a/src/compiler/cse.rs b/src/compiler/cse.rs index d49f07df..0fd557ae 100644 --- a/src/compiler/cse.rs +++ b/src/compiler/cse.rs @@ -3,11 +3,11 @@ use rand_chacha::{rand_core::SeedableRng, ChaCha20Rng}; -use std::{collections::HashMap, hash::Hash, marker::PhantomData}; +use std::{collections::{HashMap, HashSet}, hash::Hash, marker::PhantomData}; use crate::{ field::Field, - poly::{self, cse::replace_expr, Expr, HashResult, VarAssignments}, + poly::{self, cse::{create_common_ses_signal, replace_expr}, Expr, HashResult, VarAssignments}, sbpir::{query::Queriable, InternalSignal, SBPIR}, wit_gen::NullTraceGenerator, }; @@ -17,17 +17,16 @@ pub(super) fn cse( ) -> SBPIR { for (_, step_type) in circuit.step_types.iter_mut() { let mut signal_factory = SignalFactory::default(); + let mut replaced_hashes = HashSet::new(); loop { let mut queriables = Vec::>::new(); circuit.forward_signals.iter().for_each(|signal| { - println!("Forward signal: {:?}", signal); queriables.push(Queriable::Forward(signal.clone(), false)); queriables.push(Queriable::Forward(signal.clone(), true)); }); step_type.signals.iter().for_each(|signal| { - println!("Signal: {:?}", signal); queriables.push(Queriable::Internal(signal.clone())); }); @@ -53,11 +52,25 @@ pub(super) fn cse( } // Find the optimal subexpression to replace - if let Some(common_expr) = find_optimal_subexpression(&exprs) { + println!("Step type before CSE: {:#?}", step_type_with_hash); + if let Some(common_expr) = find_optimal_subexpression(&exprs, &replaced_hashes) { println!("Common expression found: {:?}", common_expr); + // Add the hash of the replaced expression to the set + replaced_hashes.insert(common_expr.meta().hash); + // Create a new signal for the common subexpression + let (common_se, decomp) = create_common_ses_signal(&common_expr, &mut signal_factory); + + decomp.auto_signals.iter().for_each(|(q, expr)| { + if let Queriable::Internal(signal) = q { + step_type_with_hash.add_internal(signal.clone()); + } + step_type_with_hash.auto_signals.insert(q.clone(), expr.clone()); + step_type_with_hash.add_constr(format!("{:?}", q), expr.clone()); + }); + // Replace the common subexpression in all constraints step_type_with_hash.decomp_constraints(|expr| { - replace_expr(expr, &common_expr, &mut signal_factory) + replace_expr(expr, &common_se, &mut signal_factory, decomp.clone()) }); } else { // No more common subexpressions found, exit the loop @@ -88,29 +101,35 @@ impl SubexprInfo { fn find_optimal_subexpression( exprs: &Vec, HashResult>>, + replaced_hashes: &HashSet, ) -> Option, HashResult>> { - // Extract all the subexpressions that appear more than once and sort them by degree - // and number of times they appear let mut count_map = HashMap::::new(); + let mut hash_to_expr = HashMap::, HashResult>>::new(); + + // Extract all subexpressions and count them for expr in exprs.iter() { - count_subexpressions(expr, &mut count_map); + count_subexpressions(expr, &mut count_map, &mut hash_to_expr, replaced_hashes); } - // Find the best common subexpression to replace - the one with the highest degree and - // the highest number of appearances + // Find the best common subexpression to replace let common_ses = count_map .into_iter() - .filter(|&(_, info)| info.count > 1 && info.degree > 1) + .filter(|&(hash, info)| info.count > 1 && info.degree > 1 && !replaced_hashes.contains(&hash)) .collect::>(); + println!("Common subexpressions: {:#?}", common_ses); + let best_subexpr = common_ses .iter() .max_by_key(|&(_, info)| (info.degree, info.count)) .map(|(&hash, info)| (hash, info.count, info.degree)); + println!("Best subexpression: {:#?}", best_subexpr); + if let Some((hash, _count, _degree)) = best_subexpr { - let best_subexpr = exprs.iter().find(|expr| expr.meta().hash == hash); - best_subexpr.cloned() + let best_subexpr = hash_to_expr.get(&hash).cloned(); + println!("Best subexpression found: {:#?}", best_subexpr); + best_subexpr } else { None } @@ -119,30 +138,39 @@ fn find_optimal_subexpression( fn count_subexpressions( expr: &Expr, HashResult>, count_map: &mut HashMap, + hash_to_expr: &mut HashMap, HashResult>>, + replaced_hashes: &HashSet, ) { let degree = expr.degree(); + let hash_result = expr.meta().hash; + + // Only count and store if not already replaced + if !replaced_hashes.contains(&hash_result) { + // Store the expression with its hash + hash_to_expr.insert(hash_result, expr.clone()); + count_map + .entry(hash_result) + .and_modify(|info| info.update(degree)) + .or_insert(SubexprInfo::new(1, degree)); + } + + // Recurse into subexpressions match expr { Expr::Const(_, _) | Expr::Query(_, _) => {} Expr::Sum(exprs, _) | Expr::Mul(exprs, _) => { for subexpr in exprs { - count_subexpressions(subexpr, count_map); + count_subexpressions(subexpr, count_map, hash_to_expr, replaced_hashes); } } Expr::Neg(subexpr, _) | Expr::MI(subexpr, _) => { - count_subexpressions(subexpr, count_map); + count_subexpressions(subexpr, count_map, hash_to_expr, replaced_hashes); } Expr::Pow(subexpr, _, _) => { - count_subexpressions(subexpr, count_map); + count_subexpressions(subexpr, count_map, hash_to_expr, replaced_hashes); } _ => {} } - - let hash_result = expr.meta().hash; - count_map - .entry(hash_result) - .and_modify(|info| info.update(degree)) - .or_insert(SubexprInfo::new(1, degree)); } // Basic signal factory. @@ -175,6 +203,8 @@ impl poly::SignalFactory> for SignalFactory { #[cfg(test)] mod test { + use std::collections::HashSet; + use halo2_proofs::halo2curves::bn256::Fr; use rand_chacha::{rand_core::SeedableRng, ChaCha20Rng}; @@ -225,7 +255,7 @@ mod test { hashed_exprs.push(hashed_expr); } - find_optimal_subexpression(&hashed_exprs); + find_optimal_subexpression(&hashed_exprs, &HashSet::new()); } #[test] @@ -266,8 +296,6 @@ mod test { step.add_constr("expr4".into(), expr4); step.add_constr("expr5".into(), expr5); - println!("Step before CSE: {:#?}", step); - let mut circuit: SBPIR = SBPIR::default(); circuit.add_step_type_def(step); diff --git a/src/poly/cse.rs b/src/poly/cse.rs index 817e6997..0d8b6052 100644 --- a/src/poly/cse.rs +++ b/src/poly/cse.rs @@ -7,11 +7,23 @@ pub fn replace_expr, common_se: &Expr, signal_factory: &mut SF, + decomp: ConstrDecomp +) -> (Expr, ConstrDecomp) { + let mut decomp = decomp; + let new_expr = replace_subexpr(expr, common_se, signal_factory, &mut decomp); + + (new_expr, ConstrDecomp::default()) +} + +pub fn create_common_ses_signal>( + common_se: &Expr, + signal_factory: &mut SF, ) -> (Expr, ConstrDecomp) { let mut decomp = ConstrDecomp::default(); - let new_expr = replace_subexpr(expr, common_se, signal_factory, &mut decomp); - (new_expr, decomp) + let signal = signal_factory.create("cse"); + decomp.auto_eq(signal.clone(), common_se.clone()); + (Expr::Query(signal, common_se.meta().clone()), decomp) } fn replace_subexpr>( @@ -21,28 +33,20 @@ fn replace_subexpr, ) -> Expr { let common_expr_hash = common_se.meta().hash; - println!("Common expr hash: {:#?}", common_expr_hash); - + + if expr.meta().degree < common_se.meta().degree { + // If the current expression's degree is less than the common subexpression's degree, + // it can't contain the common subexpression, so we return it as is + return expr.clone(); + } + // If the expression is the same as the common subexpression, create a new signal and return it if expr.meta().hash == common_expr_hash { // Find the signal or create a new signal for the expression - let signal = decomp.find_auto_signal_by_hash(common_expr_hash); - println!("signal: {:#?}", signal); - println!("decomp auto signals: {:#?}", decomp.auto_signals); - - if let Some((s, _)) = signal { - Expr::Query(s.clone(), common_se.meta().clone()) - } else { - let signal = signal_factory.create("cse"); - decomp.auto_eq(signal.clone(), common_se.clone()); - Expr::Query(signal, common_se.meta().clone()) - } - } else if expr.meta().degree < common_se.meta().degree { - // If the current expression's degree is less than the common subexpression's degree, - // it can't contain the common subexpression, so we return it as is - expr.clone() + return common_se.clone(); } else { - // Only recurse if we haven't found a match and the expression could potentially contain the common subexpression + // Only recurse if we haven't found a match and the expression could potentially contain the + // common subexpression expr.apply_subexpressions(|se| replace_subexpr(se, common_se, signal_factory, decomp)) } } @@ -53,7 +57,10 @@ mod tests { use halo2_proofs::halo2curves::bn256::Fr; - use crate::{poly::{cse::replace_expr, SignalFactory, ToExpr, VarAssignments}, sbpir::{query::Queriable, InternalSignal}}; + use crate::{ + poly::{cse::{create_common_ses_signal, replace_expr}, SignalFactory, ToExpr, VarAssignments}, + sbpir::{query::Queriable, InternalSignal}, + }; #[derive(Default)] struct TestSignalFactory { @@ -68,7 +75,6 @@ mod tests { } } - #[test] fn test_replace_expr() { let a = Queriable::Internal(InternalSignal::new("a")); @@ -82,9 +88,17 @@ mod tests { let mut signal_factory = TestSignalFactory::default(); - let assignments: VarAssignments> = vars.iter().cloned().map(|q| (q, Fr::from(2))).collect(); + let assignments: VarAssignments> = + vars.iter().cloned().map(|q| (q, Fr::from(2))).collect(); + + let (common_se, decomp) = create_common_ses_signal(&common_expr.hash(&assignments), &mut signal_factory); - let (new_expr, decomp) = replace_expr(&expr.hash(&assignments), &common_expr.hash(&assignments), &mut signal_factory); + let (new_expr, decomp) = replace_expr( + &expr.hash(&assignments), + &common_se, + &mut signal_factory, + decomp, + ); assert!(decomp.auto_signals.len() == 1); assert_eq!(format!("{:#?}", new_expr), "((-0x1) + cse-1 + (-c))");