Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
implemented cse loop
Browse files Browse the repository at this point in the history
  • Loading branch information
rutefig committed Jul 31, 2024
1 parent a319b7b commit 77db782
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 50 deletions.
80 changes: 54 additions & 26 deletions src/compiler/cse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@

use rand_chacha::{rand_core::SeedableRng, ChaCha20Rng};

use std::{collections::HashMap, hash::Hash, marker::PhantomData};
use std::{collections::{HashMap, HashSet}, hash::Hash, marker::PhantomData};

use crate::{
field::Field,
poly::{self, cse::replace_expr, Expr, HashResult, VarAssignments},
poly::{self, cse::{create_common_ses_signal, replace_expr}, Expr, HashResult, VarAssignments},
sbpir::{query::Queriable, InternalSignal, SBPIR},
wit_gen::NullTraceGenerator,
};
Expand All @@ -17,17 +17,16 @@ pub(super) fn cse<F: Field + Hash>(
) -> SBPIR<F, NullTraceGenerator> {
for (_, step_type) in circuit.step_types.iter_mut() {
let mut signal_factory = SignalFactory::default();
let mut replaced_hashes = HashSet::new();

loop {
let mut queriables = Vec::<Queriable<F>>::new();

circuit.forward_signals.iter().for_each(|signal| {
println!("Forward signal: {:?}", signal);
queriables.push(Queriable::Forward(signal.clone(), false));
queriables.push(Queriable::Forward(signal.clone(), true));
});
step_type.signals.iter().for_each(|signal| {
println!("Signal: {:?}", signal);
queriables.push(Queriable::Internal(signal.clone()));
});

Expand All @@ -53,11 +52,25 @@ pub(super) fn cse<F: Field + Hash>(
}

// Find the optimal subexpression to replace
if let Some(common_expr) = find_optimal_subexpression(&exprs) {
println!("Step type before CSE: {:#?}", step_type_with_hash);
if let Some(common_expr) = find_optimal_subexpression(&exprs, &replaced_hashes) {
println!("Common expression found: {:?}", common_expr);
// Add the hash of the replaced expression to the set
replaced_hashes.insert(common_expr.meta().hash);
// Create a new signal for the common subexpression
let (common_se, decomp) = create_common_ses_signal(&common_expr, &mut signal_factory);

decomp.auto_signals.iter().for_each(|(q, expr)| {
if let Queriable::Internal(signal) = q {
step_type_with_hash.add_internal(signal.clone());
}
step_type_with_hash.auto_signals.insert(q.clone(), expr.clone());
step_type_with_hash.add_constr(format!("{:?}", q), expr.clone());
});

// Replace the common subexpression in all constraints
step_type_with_hash.decomp_constraints(|expr| {
replace_expr(expr, &common_expr, &mut signal_factory)
replace_expr(expr, &common_se, &mut signal_factory, decomp.clone())
});
} else {
// No more common subexpressions found, exit the loop
Expand Down Expand Up @@ -88,29 +101,35 @@ impl SubexprInfo {

fn find_optimal_subexpression<F: Field + Hash>(
exprs: &Vec<Expr<F, Queriable<F>, HashResult>>,
replaced_hashes: &HashSet<u64>,
) -> Option<Expr<F, Queriable<F>, HashResult>> {
// Extract all the subexpressions that appear more than once and sort them by degree
// and number of times they appear
let mut count_map = HashMap::<u64, SubexprInfo>::new();
let mut hash_to_expr = HashMap::<u64, Expr<F, Queriable<F>, HashResult>>::new();

// Extract all subexpressions and count them
for expr in exprs.iter() {
count_subexpressions(expr, &mut count_map);
count_subexpressions(expr, &mut count_map, &mut hash_to_expr, replaced_hashes);
}

// Find the best common subexpression to replace - the one with the highest degree and
// the highest number of appearances
// Find the best common subexpression to replace
let common_ses = count_map
.into_iter()
.filter(|&(_, info)| info.count > 1 && info.degree > 1)
.filter(|&(hash, info)| info.count > 1 && info.degree > 1 && !replaced_hashes.contains(&hash))
.collect::<HashMap<_, _>>();

println!("Common subexpressions: {:#?}", common_ses);

let best_subexpr = common_ses
.iter()
.max_by_key(|&(_, info)| (info.degree, info.count))
.map(|(&hash, info)| (hash, info.count, info.degree));

println!("Best subexpression: {:#?}", best_subexpr);

if let Some((hash, _count, _degree)) = best_subexpr {
let best_subexpr = exprs.iter().find(|expr| expr.meta().hash == hash);
best_subexpr.cloned()
let best_subexpr = hash_to_expr.get(&hash).cloned();
println!("Best subexpression found: {:#?}", best_subexpr);
best_subexpr
} else {
None
}
Expand All @@ -119,30 +138,39 @@ fn find_optimal_subexpression<F: Field + Hash>(
fn count_subexpressions<F: Field + Hash>(
expr: &Expr<F, Queriable<F>, HashResult>,
count_map: &mut HashMap<u64, SubexprInfo>,
hash_to_expr: &mut HashMap<u64, Expr<F, Queriable<F>, HashResult>>,
replaced_hashes: &HashSet<u64>,
) {
let degree = expr.degree();
let hash_result = expr.meta().hash;

// Only count and store if not already replaced
if !replaced_hashes.contains(&hash_result) {
// Store the expression with its hash
hash_to_expr.insert(hash_result, expr.clone());

count_map
.entry(hash_result)
.and_modify(|info| info.update(degree))
.or_insert(SubexprInfo::new(1, degree));
}

// Recurse into subexpressions
match expr {
Expr::Const(_, _) | Expr::Query(_, _) => {}
Expr::Sum(exprs, _) | Expr::Mul(exprs, _) => {
for subexpr in exprs {
count_subexpressions(subexpr, count_map);
count_subexpressions(subexpr, count_map, hash_to_expr, replaced_hashes);
}
}
Expr::Neg(subexpr, _) | Expr::MI(subexpr, _) => {
count_subexpressions(subexpr, count_map);
count_subexpressions(subexpr, count_map, hash_to_expr, replaced_hashes);
}
Expr::Pow(subexpr, _, _) => {
count_subexpressions(subexpr, count_map);
count_subexpressions(subexpr, count_map, hash_to_expr, replaced_hashes);
}
_ => {}
}

let hash_result = expr.meta().hash;
count_map
.entry(hash_result)
.and_modify(|info| info.update(degree))
.or_insert(SubexprInfo::new(1, degree));
}

// Basic signal factory.
Expand Down Expand Up @@ -175,6 +203,8 @@ impl<F> poly::SignalFactory<Queriable<F>> for SignalFactory<F> {

#[cfg(test)]
mod test {
use std::collections::HashSet;

use halo2_proofs::halo2curves::bn256::Fr;
use rand_chacha::{rand_core::SeedableRng, ChaCha20Rng};

Expand Down Expand Up @@ -225,7 +255,7 @@ mod test {
hashed_exprs.push(hashed_expr);
}

find_optimal_subexpression(&hashed_exprs);
find_optimal_subexpression(&hashed_exprs, &HashSet::new());
}

#[test]
Expand Down Expand Up @@ -266,8 +296,6 @@ mod test {
step.add_constr("expr4".into(), expr4);
step.add_constr("expr5".into(), expr5);

println!("Step before CSE: {:#?}", step);

let mut circuit: SBPIR<Fr, NullTraceGenerator> = SBPIR::default();
circuit.add_step_type_def(step);

Expand Down
62 changes: 38 additions & 24 deletions src/poly/cse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,23 @@ pub fn replace_expr<F: Field + Hash, V: Clone + Eq + Hash + Debug, SF: SignalFac
expr: &Expr<F, V, HashResult>,
common_se: &Expr<F, V, HashResult>,
signal_factory: &mut SF,
decomp: ConstrDecomp<F, V, HashResult>
) -> (Expr<F, V, HashResult>, ConstrDecomp<F, V, HashResult>) {
let mut decomp = decomp;
let new_expr = replace_subexpr(expr, common_se, signal_factory, &mut decomp);

(new_expr, ConstrDecomp::default())
}

pub fn create_common_ses_signal<F: Field, V: Clone + PartialEq + Eq + Hash + Debug, SF: SignalFactory<V>>(
common_se: &Expr<F, V, HashResult>,
signal_factory: &mut SF,
) -> (Expr<F, V, HashResult>, ConstrDecomp<F, V, HashResult>) {
let mut decomp = ConstrDecomp::default();

let new_expr = replace_subexpr(expr, common_se, signal_factory, &mut decomp);
(new_expr, decomp)
let signal = signal_factory.create("cse");
decomp.auto_eq(signal.clone(), common_se.clone());
(Expr::Query(signal, common_se.meta().clone()), decomp)
}

fn replace_subexpr<F: Field + Hash, V: Clone + Eq + Hash + Debug, SF: SignalFactory<V>>(
Expand All @@ -21,28 +33,20 @@ fn replace_subexpr<F: Field + Hash, V: Clone + Eq + Hash + Debug, SF: SignalFact
decomp: &mut ConstrDecomp<F, V, HashResult>,
) -> Expr<F, V, HashResult> {
let common_expr_hash = common_se.meta().hash;
println!("Common expr hash: {:#?}", common_expr_hash);


if expr.meta().degree < common_se.meta().degree {
// If the current expression's degree is less than the common subexpression's degree,
// it can't contain the common subexpression, so we return it as is
return expr.clone();
}

// If the expression is the same as the common subexpression, create a new signal and return it
if expr.meta().hash == common_expr_hash {
// Find the signal or create a new signal for the expression
let signal = decomp.find_auto_signal_by_hash(common_expr_hash);
println!("signal: {:#?}", signal);
println!("decomp auto signals: {:#?}", decomp.auto_signals);

if let Some((s, _)) = signal {
Expr::Query(s.clone(), common_se.meta().clone())
} else {
let signal = signal_factory.create("cse");
decomp.auto_eq(signal.clone(), common_se.clone());
Expr::Query(signal, common_se.meta().clone())
}
} else if expr.meta().degree < common_se.meta().degree {
// If the current expression's degree is less than the common subexpression's degree,
// it can't contain the common subexpression, so we return it as is
expr.clone()
return common_se.clone();
} else {
// Only recurse if we haven't found a match and the expression could potentially contain the common subexpression
// Only recurse if we haven't found a match and the expression could potentially contain the
// common subexpression
expr.apply_subexpressions(|se| replace_subexpr(se, common_se, signal_factory, decomp))
}
}
Expand All @@ -53,7 +57,10 @@ mod tests {

use halo2_proofs::halo2curves::bn256::Fr;

use crate::{poly::{cse::replace_expr, SignalFactory, ToExpr, VarAssignments}, sbpir::{query::Queriable, InternalSignal}};
use crate::{
poly::{cse::{create_common_ses_signal, replace_expr}, SignalFactory, ToExpr, VarAssignments},
sbpir::{query::Queriable, InternalSignal},
};

#[derive(Default)]
struct TestSignalFactory {
Expand All @@ -68,7 +75,6 @@ mod tests {
}
}


#[test]
fn test_replace_expr() {
let a = Queriable::Internal(InternalSignal::new("a"));
Expand All @@ -82,9 +88,17 @@ mod tests {

let mut signal_factory = TestSignalFactory::default();

let assignments: VarAssignments<Fr, Queriable<Fr>> = vars.iter().cloned().map(|q| (q, Fr::from(2))).collect();
let assignments: VarAssignments<Fr, Queriable<Fr>> =
vars.iter().cloned().map(|q| (q, Fr::from(2))).collect();

let (common_se, decomp) = create_common_ses_signal(&common_expr.hash(&assignments), &mut signal_factory);

let (new_expr, decomp) = replace_expr(&expr.hash(&assignments), &common_expr.hash(&assignments), &mut signal_factory);
let (new_expr, decomp) = replace_expr(
&expr.hash(&assignments),
&common_se,
&mut signal_factory,
decomp,
);

assert!(decomp.auto_signals.len() == 1);
assert_eq!(format!("{:#?}", new_expr), "((-0x1) + cse-1 + (-c))");
Expand Down

0 comments on commit 77db782

Please sign in to comment.