Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
fix clippy
Browse files Browse the repository at this point in the history
  • Loading branch information
rutefig committed Aug 1, 2024
1 parent 3ec6af6 commit 1c235aa
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 32 deletions.
29 changes: 16 additions & 13 deletions src/compiler/cse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ use crate::{
/// Equivalent expressions are found by hashing the expressions with random assignments to the
/// queriables. Using the Schwartz-Zippel lemma, we can determine if two expressions are equivalent
/// with high probability.
#[allow(dead_code)]
pub(super) fn cse<F: Field + Hash>(
mut circuit: SBPIR<F, NullTraceGenerator>,
) -> SBPIR<F, NullTraceGenerator> {
Expand All @@ -36,11 +37,11 @@ pub(super) fn cse<F: Field + Hash>(
let mut queriables = Vec::<Queriable<F>>::new();

circuit.forward_signals.iter().for_each(|signal| {
queriables.push(Queriable::Forward(signal.clone(), false));
queriables.push(Queriable::Forward(signal.clone(), true));
queriables.push(Queriable::Forward(*signal, false));
queriables.push(Queriable::Forward(*signal, true));
});
step_type.signals.iter().for_each(|signal| {
queriables.push(Queriable::Internal(signal.clone()));
queriables.push(Queriable::Internal(*signal));
});

// Generate random assignments for the queriables
Expand Down Expand Up @@ -75,18 +76,14 @@ pub(super) fn cse<F: Field + Hash>(
// Add the new signal to the step type and a constraint for it
decomp.auto_signals.iter().for_each(|(q, expr)| {
if let Queriable::Internal(signal) = q {
step_type_with_hash.add_internal(signal.clone());
step_type_with_hash.add_internal(*signal);
}
step_type_with_hash
.auto_signals
.insert(q.clone(), expr.clone());
step_type_with_hash.auto_signals.insert(*q, expr.clone());
step_type_with_hash.add_constr(format!("{:?}", q), expr.clone());
});

// Replace the common subexpression in all constraints
step_type_with_hash.decomp_constraints(|expr| {
replace_expr(expr, &common_se, &mut signal_factory, decomp.clone())
});
step_type_with_hash.decomp_constraints(|expr| replace_expr(expr, &common_se));
} else {
// No more common subexpressions found, exit the loop
break;
Expand Down Expand Up @@ -117,7 +114,7 @@ impl SubexprInfo {

/// Find the optimal subexpression to replace in a list of expressions.
fn find_optimal_subexpression<F: Field + Hash>(
exprs: &Vec<Expr<F, Queriable<F>, HashResult>>,
exprs: &[Expr<F, Queriable<F>, HashResult>],
replaced_hashes: &HashSet<u64>,
) -> Option<Expr<F, Queriable<F>, HashResult>> {
let mut count_map = HashMap::<u64, SubexprInfo>::new();
Expand Down Expand Up @@ -314,8 +311,14 @@ mod test {
.auto_signals
.values();

assert!(common_ses_found_and_replaced.clone().find(|expr| format!("{:?}", expr) == "(a * b)").is_some());
assert!(common_ses_found_and_replaced.clone().find(|expr| format!("{:?}", expr) == "(e * f * d)").is_some());
assert!(common_ses_found_and_replaced
.clone()
.find(|expr| format!("{:?}", expr) == "(a * b)")
.is_some());
assert!(common_ses_found_and_replaced
.clone()
.find(|expr| format!("{:?}", expr) == "(e * f * d)")
.is_some());
}

#[derive(Clone)]
Expand Down
27 changes: 8 additions & 19 deletions src/poly/cse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,11 @@ use super::{ConstrDecomp, Expr, HashResult, SignalFactory};
use std::{fmt::Debug, hash::Hash};

/// This function replaces a common subexpression in an expression with a new signal.
pub fn replace_expr<F: Field + Hash, V: Clone + Eq + Hash + Debug, SF: SignalFactory<V>>(
pub fn replace_expr<F: Field + Hash, V: Clone + Eq + Hash + Debug>(
expr: &Expr<F, V, HashResult>,
common_se: &Expr<F, V, HashResult>,
signal_factory: &mut SF,
decomp: ConstrDecomp<F, V, HashResult>,
) -> (Expr<F, V, HashResult>, ConstrDecomp<F, V, HashResult>) {
let mut decomp = decomp;
let new_expr = replace_subexpr(expr, common_se, signal_factory, &mut decomp);
let new_expr = replace_subexpr(expr, common_se);

(new_expr, ConstrDecomp::default())
}
Expand All @@ -33,26 +30,23 @@ pub fn create_common_ses_signal<
}

/// This function replaces a common subexpression in an expression with a new signal.
fn replace_subexpr<F: Field + Hash, V: Clone + Eq + Hash + Debug, SF: SignalFactory<V>>(
fn replace_subexpr<F: Field + Hash, V: Clone + Eq + Hash + Debug>(
expr: &Expr<F, V, HashResult>,
common_se: &Expr<F, V, HashResult>,
signal_factory: &mut SF,
decomp: &mut ConstrDecomp<F, V, HashResult>,
) -> Expr<F, V, HashResult> {
let common_expr_hash = common_se.meta().hash;

if expr.meta().degree < common_se.meta().degree {
// If the current expression's degree is less than the common subexpression's degree,
// it can't contain the common subexpression, so we return it as is
return expr.clone();
expr.clone()
}

// If the expression is the same as the common subexpression return the signal
if expr.meta().hash == common_expr_hash {
return common_se.clone();
else if expr.meta().hash == common_expr_hash {
common_se.clone()
} else {
// Recursively apply the function to the subexpressions
expr.apply_subexpressions(|se| replace_subexpr(se, common_se, signal_factory, decomp))
expr.apply_subexpressions(|se| replace_subexpr(se, common_se))
}
}

Expand Down Expand Up @@ -102,12 +96,7 @@ mod tests {
let (common_se, decomp) =
create_common_ses_signal(&common_expr.hash(&assignments), &mut signal_factory);

let (new_expr, _) = replace_expr(
&expr.hash(&assignments),
&common_se,
&mut signal_factory,
decomp.clone(),
);
let (new_expr, _) = replace_expr(&expr.hash(&assignments), &common_se);

assert!(decomp.auto_signals.len() == 1);
assert_eq!(format!("{:#?}", new_expr), "((-0x1) + cse-1 + (-c))");
Expand Down

0 comments on commit 1c235aa

Please sign in to comment.