Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
finish tests implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
rutefig committed Aug 1, 2024
1 parent 02b8742 commit 3ec6af6
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 17 deletions.
41 changes: 26 additions & 15 deletions src/compiler/cse.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
// Pub cse function receives a circuit and returns a new circuit with common subexpressions replaced
// by signals. Pub cse gets the

use rand_chacha::{rand_core::SeedableRng, ChaCha20Rng};

use std::{
Expand All @@ -20,14 +17,14 @@ use crate::{
wit_gen::NullTraceGenerator,
};


/// Common Subexpression Elimination (CSE) optimization.
/// This optimization replaces common subexpressions with new internal signals for the step type.
/// This is done by each time finding the optimal subexpression to replace and creating a new signal
/// for it and replacing it in all constraints.
/// The process is repeated until no more common subexpressions are found.
/// Equivalent expressions are found by hashing the expressions with random assignments to the queriables. Using
/// the Schwartz-Zippel lemma, we can determine if two expressions are equivalent with high probability.
/// Equivalent expressions are found by hashing the expressions with random assignments to the
/// queriables. Using the Schwartz-Zippel lemma, we can determine if two expressions are equivalent
/// with high probability.
pub(super) fn cse<F: Field + Hash>(
mut circuit: SBPIR<F, NullTraceGenerator>,
) -> SBPIR<F, NullTraceGenerator> {
Expand Down Expand Up @@ -106,6 +103,7 @@ struct SubexprInfo {
degree: usize,
}

/// Information about a subexpression to help find the optimal subexpression to replace.
impl SubexprInfo {
fn new(count: usize, degree: usize) -> Self {
Self { count, degree }
Expand All @@ -117,6 +115,7 @@ impl SubexprInfo {
}
}

/// Find the optimal subexpression to replace in a list of expressions.
fn find_optimal_subexpression<F: Field + Hash>(
exprs: &Vec<Expr<F, Queriable<F>, HashResult>>,
replaced_hashes: &HashSet<u64>,
Expand Down Expand Up @@ -149,6 +148,7 @@ fn find_optimal_subexpression<F: Field + Hash>(
}
}

/// Count the subexpressions in an expression and store them in a map.
fn count_subexpressions<F: Field + Hash>(
expr: &Expr<F, Queriable<F>, HashResult>,
count_map: &mut HashMap<u64, SubexprInfo>,
Expand Down Expand Up @@ -245,7 +245,7 @@ mod test {
let expr3 = 4.expr() + a * b + c;
let expr4 = e * f * d;
let expr5 = expr1.clone() + expr4.clone();
let exprs = vec![expr1, expr2, expr3, expr4, expr5];
let exprs = vec![expr1, expr2, expr3, expr4.clone(), expr5];

let mut rng = ChaCha20Rng::seed_from_u64(0);
let mut rand_assignments = VarAssignments::new();
Expand All @@ -259,7 +259,9 @@ mod test {
hashed_exprs.push(hashed_expr);
}

find_optimal_subexpression(&hashed_exprs, &HashSet::new());
let best_expr = find_optimal_subexpression(&hashed_exprs, &HashSet::new());

assert_eq!(format!("{:?}", best_expr.unwrap()), format!("{:?}", expr4));
}

#[test]
Expand Down Expand Up @@ -297,17 +299,23 @@ mod test {
step.add_constr("expr1".into(), expr1);
step.add_constr("expr2".into(), expr2);
step.add_constr("expr3".into(), expr3);
step.add_constr("expr4".into(), expr4);
step.add_constr("expr4".into(), expr4.clone());
step.add_constr("expr5".into(), expr5);

let mut circuit: SBPIR<Fr, NullTraceGenerator> = SBPIR::default();
circuit.add_step_type_def(step);

println!("Circuit before CSE: {:#?}", circuit);
let step_uuid = circuit.add_step_type_def(step);

let circuit = cse(circuit);

println!("Circuit after CSE: {:#?}", circuit);
let common_ses_found_and_replaced = circuit
.step_types
.get(&step_uuid)
.unwrap()
.auto_signals
.values();

assert!(common_ses_found_and_replaced.clone().find(|expr| format!("{:?}", expr) == "(a * b)").is_some());
assert!(common_ses_found_and_replaced.clone().find(|expr| format!("{:?}", expr) == "(e * f * d)").is_some());
}

#[derive(Clone)]
Expand Down Expand Up @@ -354,11 +362,14 @@ mod test {
step.add_constr("expr5".into(), expr5);

let step_with_meta = step.with_meta(|expr| TestStruct {
value: format!("{:?}", expr),
value: format!("Expr: {:?}", expr),
});

for constraint in &step_with_meta.constraints {
println!("{:?}", constraint.expr.meta().value);
assert_eq!(
constraint.expr.meta().value,
format!("Expr: {:?}", constraint.expr)
);
}
}
}
4 changes: 2 additions & 2 deletions src/poly/cse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,11 +102,11 @@ mod tests {
let (common_se, decomp) =
create_common_ses_signal(&common_expr.hash(&assignments), &mut signal_factory);

let (new_expr, decomp) = replace_expr(
let (new_expr, _) = replace_expr(
&expr.hash(&assignments),
&common_se,
&mut signal_factory,
decomp,
decomp.clone(),
);

assert!(decomp.auto_signals.len() == 1);
Expand Down

0 comments on commit 3ec6af6

Please sign in to comment.