diff --git a/src/poly/mod.rs b/src/poly/mod.rs index 3edab591..bc1f0d5a 100644 --- a/src/poly/mod.rs +++ b/src/poly/mod.rs @@ -233,35 +233,35 @@ pub trait SignalFactory { } #[derive(Debug, Clone)] -pub struct ExprDecomp { - root_expr: Expr, - exprs: Vec>, +pub struct ConstrDecomp { + root_constr: Expr, + constrs: Vec>, auto_signals: HashMap>, } -impl From> for ExprDecomp { +impl From> for ConstrDecomp { fn from(value: Expr) -> Self { - ExprDecomp { - root_expr: value, - exprs: Default::default(), + ConstrDecomp { + root_constr: value, + constrs: Default::default(), auto_signals: Default::default(), } } } -impl ExprDecomp { - fn merge(root_expr: Expr, reductions: Vec>) -> Self { - let mut result = ExprDecomp::from(root_expr); +impl ConstrDecomp { + fn merge(root_expr: Expr, reductions: Vec>) -> Self { + let mut result = ConstrDecomp::from(root_expr); result.expand(reductions); result } - fn expand(&mut self, reductions: Vec>) { - self.exprs.extend( + fn expand(&mut self, reductions: Vec>) { + self.constrs.extend( reductions .iter() - .map(|se| se.exprs.clone()) + .map(|se| se.constrs.clone()) .collect::>() .concat(), ); @@ -277,7 +277,7 @@ impl ExprDecomp { } fn auto_eq(&mut self, signal: V, expr: Expr) { - self.exprs.push(Expr::Sum(vec![ + self.constrs.push(Expr::Sum(vec![ expr.clone(), Expr::Neg(Box::new(Expr::Query(signal.clone()))), ])); @@ -285,10 +285,14 @@ impl ExprDecomp { self.auto_signals.insert(signal, expr); } - fn inherit(root_expr: Expr, signal: V, mut from: ExprDecomp) -> ExprDecomp { - from.auto_eq(signal, from.root_expr.clone()); + fn inherit( + root_expr: Expr, + signal: V, + mut from: ConstrDecomp, + ) -> ConstrDecomp { + from.auto_eq(signal, from.root_constr.clone()); - from.root_expr = root_expr; + from.root_constr = root_expr; from } diff --git a/src/poly/reduce.rs b/src/poly/reduce.rs index 7d82455f..4b66d7f1 100644 --- a/src/poly/reduce.rs +++ b/src/poly/reduce.rs @@ -5,14 +5,14 @@ use crate::{ poly::{simplify::simplify_mul, Expr}, }; -use super::{ExprDecomp, SignalFactory}; +use super::{ConstrDecomp, SignalFactory}; pub fn reduce_degre>( - expr: Expr, + constr: Expr, max_degree: usize, signal_factory: &mut SF, -) -> ExprDecomp { - reduce_degree_recursive(expr, max_degree, max_degree, signal_factory) +) -> ConstrDecomp { + reduce_degree_recursive(constr, max_degree, max_degree, signal_factory) } fn reduce_degree_recursive< @@ -20,17 +20,17 @@ fn reduce_degree_recursive< V: Clone + Eq + PartialEq + Hash + Debug, SF: SignalFactory, >( - expr: Expr, + constr: Expr, total_max_degree: usize, partial_max_degree: usize, signal_factory: &mut SF, -) -> ExprDecomp { - if expr.degree() <= partial_max_degree { - return ExprDecomp::from(expr); +) -> ConstrDecomp { + if constr.degree() <= partial_max_degree { + return ConstrDecomp::from(constr); } - match expr { - Expr::Const(_) => ExprDecomp::from(expr), + match constr { + Expr::Const(_) => ConstrDecomp::from(constr), Expr::Sum(ses) => { let ses_reduction: Vec<_> = ses .iter() @@ -44,9 +44,14 @@ fn reduce_degree_recursive< }) .collect(); - let root_expr = Expr::Sum(ses_reduction.iter().map(|r| r.root_expr.clone()).collect()); + let root_expr = Expr::Sum( + ses_reduction + .iter() + .map(|r| r.root_constr.clone()) + .collect(), + ); - ExprDecomp::merge(root_expr, ses_reduction) + ConstrDecomp::merge(root_expr, ses_reduction) } Expr::Mul(ses) => { reduce_degree_mul(ses, total_max_degree, partial_max_degree, signal_factory) @@ -54,17 +59,18 @@ fn reduce_degree_recursive< Expr::Neg(se) => { let mut reduction = reduce_degree_recursive(*se, total_max_degree, partial_max_degree, signal_factory); - reduction.root_expr = Expr::Neg(Box::new(reduction.root_expr)); + reduction.root_constr = Expr::Neg(Box::new(reduction.root_constr)); reduction } + // TODO: decompose in Pow expressions instead of Mul Expr::Pow(se, exp) => reduce_degree_mul( std::vec::from_elem(*se, exp as usize), total_max_degree, partial_max_degree, signal_factory, ), - Expr::Query(_) => ExprDecomp::from(expr), + Expr::Query(_) => ConstrDecomp::from(constr), Expr::Halo2Expr(_) => unimplemented!(), } } @@ -74,19 +80,21 @@ fn reduce_degree_mul ExprDecomp { +) -> ConstrDecomp { + // base case, if partial_max_degree == 1, the root expresion can only be a variable if partial_max_degree == 1 { let reduction = reduce_degree_mul(ses, total_max_degree, total_max_degree, signal_factory); let signal = signal_factory.create("virtual signal"); - return ExprDecomp::inherit(Expr::Query(signal.clone()), signal, reduction); + return ConstrDecomp::inherit(Expr::Query(signal.clone()), signal, reduction); } let ses = simplify_mul(ses); + // to reduce the problem for recursion, at least one expression should have lower degree than + // total_max_degree let mut first = true; - - let ses_reduced: Vec> = ses + let ses_reduced: Vec> = ses .into_iter() .map(|se| { let partial_max_degree = if first { @@ -102,13 +110,16 @@ fn reduce_degree_mul> = Default::default(); - let mut to_simplify: Vec> = Default::default(); + // for_root will be multipliers that will be included in the root expression + let mut for_root: Vec> = Default::default(); + // to_simplify will be multipliers that will be recursively decomposed and subsituted by a + // virtual signal in the root expression + let mut to_simplify: Vec> = Default::default(); let mut current_degree = 0; for se in ses_reduced { - if se.root_expr.degree() + current_degree < partial_max_degree { - current_degree += se.root_expr.degree(); + if se.root_constr.degree() + current_degree < partial_max_degree { + current_degree += se.root_constr.degree(); for_root.push(se); } else { to_simplify.push(se); @@ -119,18 +130,24 @@ fn reduce_degree_mul = for_root.iter().map(|se| se.root_expr.clone()).collect(); + let mut root_exprs: Vec<_> = for_root.iter().map(|se| se.root_constr.clone()).collect(); root_exprs.push(Expr::Query(rest_signal.clone())); let root_expr = Expr::Mul(root_exprs); + // recursion, for the part that exceeds the degree and will be substituted by a virtual signal let simplified = reduce_degree_recursive( - Expr::Mul(to_simplify.iter().map(|se| se.root_expr.clone()).collect()), + Expr::Mul( + to_simplify + .iter() + .map(|se| se.root_constr.clone()) + .collect(), + ), total_max_degree, total_max_degree, signal_factory, ); - let mut result = ExprDecomp::merge( + let mut result = ConstrDecomp::merge( root_expr, [for_root, to_simplify.clone(), vec![simplified.clone()]] .into_iter() @@ -138,7 +155,7 @@ fn reduce_degree_mul