Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
Some renaming and comments
Browse files Browse the repository at this point in the history
  • Loading branch information
leolara committed Nov 17, 2023
1 parent b3c0504 commit 803c8b5
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 65 deletions.
38 changes: 21 additions & 17 deletions src/poly/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -233,35 +233,35 @@ pub trait SignalFactory<V> {
}

#[derive(Debug, Clone)]
pub struct ExprDecomp<F, V> {
root_expr: Expr<F, V>,
exprs: Vec<Expr<F, V>>,
pub struct ConstrDecomp<F, V> {
root_constr: Expr<F, V>,
constrs: Vec<Expr<F, V>>,
auto_signals: HashMap<V, Expr<F, V>>,
}

impl<F, V> From<Expr<F, V>> for ExprDecomp<F, V> {
impl<F, V> From<Expr<F, V>> for ConstrDecomp<F, V> {
fn from(value: Expr<F, V>) -> Self {
ExprDecomp {
root_expr: value,
exprs: Default::default(),
ConstrDecomp {
root_constr: value,
constrs: Default::default(),
auto_signals: Default::default(),
}
}
}

impl<F: Clone, V: Clone + Eq + PartialEq + Hash> ExprDecomp<F, V> {
fn merge(root_expr: Expr<F, V>, reductions: Vec<ExprDecomp<F, V>>) -> Self {
let mut result = ExprDecomp::from(root_expr);
impl<F: Clone, V: Clone + Eq + PartialEq + Hash> ConstrDecomp<F, V> {
fn merge(root_expr: Expr<F, V>, reductions: Vec<ConstrDecomp<F, V>>) -> Self {
let mut result = ConstrDecomp::from(root_expr);
result.expand(reductions);

result
}

fn expand(&mut self, reductions: Vec<ExprDecomp<F, V>>) {
self.exprs.extend(
fn expand(&mut self, reductions: Vec<ConstrDecomp<F, V>>) {
self.constrs.extend(
reductions
.iter()
.map(|se| se.exprs.clone())
.map(|se| se.constrs.clone())
.collect::<Vec<_>>()
.concat(),
);
Expand All @@ -277,18 +277,22 @@ impl<F: Clone, V: Clone + Eq + PartialEq + Hash> ExprDecomp<F, V> {
}

fn auto_eq(&mut self, signal: V, expr: Expr<F, V>) {
self.exprs.push(Expr::Sum(vec![
self.constrs.push(Expr::Sum(vec![
expr.clone(),
Expr::Neg(Box::new(Expr::Query(signal.clone()))),
]));

self.auto_signals.insert(signal, expr);
}

fn inherit(root_expr: Expr<F, V>, signal: V, mut from: ExprDecomp<F, V>) -> ExprDecomp<F, V> {
from.auto_eq(signal, from.root_expr.clone());
fn inherit(
root_expr: Expr<F, V>,
signal: V,
mut from: ConstrDecomp<F, V>,
) -> ConstrDecomp<F, V> {
from.auto_eq(signal, from.root_constr.clone());

from.root_expr = root_expr;
from.root_constr = root_expr;

from
}
Expand Down
113 changes: 65 additions & 48 deletions src/poly/reduce.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,32 +5,32 @@ use crate::{
poly::{simplify::simplify_mul, Expr},
};

use super::{ExprDecomp, SignalFactory};
use super::{ConstrDecomp, SignalFactory};

pub fn reduce_degre<F: Field, V: Clone + Eq + PartialEq + Hash + Debug, SF: SignalFactory<V>>(
expr: Expr<F, V>,
constr: Expr<F, V>,
max_degree: usize,
signal_factory: &mut SF,
) -> ExprDecomp<F, V> {
reduce_degree_recursive(expr, max_degree, max_degree, signal_factory)
) -> ConstrDecomp<F, V> {
reduce_degree_recursive(constr, max_degree, max_degree, signal_factory)
}

fn reduce_degree_recursive<
F: Field,
V: Clone + Eq + PartialEq + Hash + Debug,
SF: SignalFactory<V>,
>(
expr: Expr<F, V>,
constr: Expr<F, V>,
total_max_degree: usize,
partial_max_degree: usize,
signal_factory: &mut SF,
) -> ExprDecomp<F, V> {
if expr.degree() <= partial_max_degree {
return ExprDecomp::from(expr);
) -> ConstrDecomp<F, V> {
if constr.degree() <= partial_max_degree {
return ConstrDecomp::from(constr);
}

match expr {
Expr::Const(_) => ExprDecomp::from(expr),
match constr {
Expr::Const(_) => ConstrDecomp::from(constr),
Expr::Sum(ses) => {
let ses_reduction: Vec<_> = ses
.iter()
Expand All @@ -44,27 +44,33 @@ fn reduce_degree_recursive<
})
.collect();

let root_expr = Expr::Sum(ses_reduction.iter().map(|r| r.root_expr.clone()).collect());
let root_expr = Expr::Sum(
ses_reduction
.iter()
.map(|r| r.root_constr.clone())
.collect(),
);

ExprDecomp::merge(root_expr, ses_reduction)
ConstrDecomp::merge(root_expr, ses_reduction)
}
Expr::Mul(ses) => {
reduce_degree_mul(ses, total_max_degree, partial_max_degree, signal_factory)
}
Expr::Neg(se) => {
let mut reduction =
reduce_degree_recursive(*se, total_max_degree, partial_max_degree, signal_factory);
reduction.root_expr = Expr::Neg(Box::new(reduction.root_expr));
reduction.root_constr = Expr::Neg(Box::new(reduction.root_constr));

reduction
}
// TODO: decompose in Pow expressions instead of Mul
Expr::Pow(se, exp) => reduce_degree_mul(
std::vec::from_elem(*se, exp as usize),
total_max_degree,
partial_max_degree,
signal_factory,
),
Expr::Query(_) => ExprDecomp::from(expr),
Expr::Query(_) => ConstrDecomp::from(constr),
Expr::Halo2Expr(_) => unimplemented!(),
}
}
Expand All @@ -74,19 +80,21 @@ fn reduce_degree_mul<F: Field, V: Clone + Eq + PartialEq + Hash + Debug, SF: Sig
total_max_degree: usize,
partial_max_degree: usize,
signal_factory: &mut SF,
) -> ExprDecomp<F, V> {
) -> ConstrDecomp<F, V> {
// base case, if partial_max_degree == 1, the root expresion can only be a variable
if partial_max_degree == 1 {
let reduction = reduce_degree_mul(ses, total_max_degree, total_max_degree, signal_factory);
let signal = signal_factory.create("virtual signal");

return ExprDecomp::inherit(Expr::Query(signal.clone()), signal, reduction);
return ConstrDecomp::inherit(Expr::Query(signal.clone()), signal, reduction);
}

let ses = simplify_mul(ses);

// to reduce the problem for recursion, at least one expression should have lower degree than
// total_max_degree
let mut first = true;

let ses_reduced: Vec<ExprDecomp<F, V>> = ses
let ses_reduced: Vec<ConstrDecomp<F, V>> = ses
.into_iter()
.map(|se| {
let partial_max_degree = if first {
Expand All @@ -102,13 +110,16 @@ fn reduce_degree_mul<F: Field, V: Clone + Eq + PartialEq + Hash + Debug, SF: Sig
})
.collect();

let mut for_root: Vec<ExprDecomp<F, V>> = Default::default();
let mut to_simplify: Vec<ExprDecomp<F, V>> = Default::default();
// for_root will be multipliers that will be included in the root expression
let mut for_root: Vec<ConstrDecomp<F, V>> = Default::default();
// to_simplify will be multipliers that will be recursively decomposed and subsituted by a
// virtual signal in the root expression
let mut to_simplify: Vec<ConstrDecomp<F, V>> = Default::default();

let mut current_degree = 0;
for se in ses_reduced {
if se.root_expr.degree() + current_degree < partial_max_degree {
current_degree += se.root_expr.degree();
if se.root_constr.degree() + current_degree < partial_max_degree {
current_degree += se.root_constr.degree();
for_root.push(se);
} else {
to_simplify.push(se);
Expand All @@ -119,26 +130,32 @@ fn reduce_degree_mul<F: Field, V: Clone + Eq + PartialEq + Hash + Debug, SF: Sig
assert!(!to_simplify.is_empty());

let rest_signal = signal_factory.create("rest expr");
let mut root_exprs: Vec<_> = for_root.iter().map(|se| se.root_expr.clone()).collect();
let mut root_exprs: Vec<_> = for_root.iter().map(|se| se.root_constr.clone()).collect();
root_exprs.push(Expr::Query(rest_signal.clone()));
let root_expr = Expr::Mul(root_exprs);

// recursion, for the part that exceeds the degree and will be substituted by a virtual signal
let simplified = reduce_degree_recursive(
Expr::Mul(to_simplify.iter().map(|se| se.root_expr.clone()).collect()),
Expr::Mul(
to_simplify
.iter()
.map(|se| se.root_constr.clone())
.collect(),
),
total_max_degree,
total_max_degree,
signal_factory,
);

let mut result = ExprDecomp::merge(
let mut result = ConstrDecomp::merge(
root_expr,
[for_root, to_simplify.clone(), vec![simplified.clone()]]
.into_iter()
.flatten()
.collect(),
);

result.auto_eq(rest_signal, simplified.root_expr);
result.auto_eq(rest_signal, simplified.root_constr);

result
}
Expand Down Expand Up @@ -180,9 +197,9 @@ mod test {
&mut TestSignalFactory::default(),
);

assert_eq!(format!("{:#?}", result.root_expr), "(a * v1)");
assert_eq!(format!("{:#?}", result.exprs[0]), "((b * c) + (-v1))");
assert_eq!(result.exprs.len(), 1);
assert_eq!(format!("{:#?}", result.root_constr), "(a * v1)");
assert_eq!(format!("{:#?}", result.constrs[0]), "((b * c) + (-v1))");
assert_eq!(result.constrs.len(), 1);
assert!(result
.auto_signals
.iter()
Expand All @@ -196,12 +213,12 @@ mod test {
&mut TestSignalFactory::default(),
);

assert_eq!(format!("{:#?}", result.root_expr), "((a + b) * v1)");
assert_eq!(format!("{:#?}", result.root_constr), "((a + b) * v1)");
assert_eq!(
format!("{:#?}", result.exprs[0]),
format!("{:#?}", result.constrs[0]),
"(((b + c) * (a + c)) + (-v1))"
);
assert_eq!(result.exprs.len(), 1);
assert_eq!(result.constrs.len(), 1);
assert!(result
.auto_signals
.iter()
Expand All @@ -219,11 +236,11 @@ mod test {

let result = reduce_degre(a * b * c * d * e, 2, &mut TestSignalFactory::default());

assert_eq!(format!("{:#?}", result.root_expr), "(a * v1)");
assert_eq!(format!("{:#?}", result.exprs[0]), "((d * e) + (-v3))");
assert_eq!(format!("{:#?}", result.exprs[1]), "((c * v3) + (-v2))");
assert_eq!(format!("{:#?}", result.exprs[2]), "((b * v2) + (-v1))");
assert_eq!(result.exprs.len(), 3);
assert_eq!(format!("{:#?}", result.root_constr), "(a * v1)");
assert_eq!(format!("{:#?}", result.constrs[0]), "((d * e) + (-v3))");
assert_eq!(format!("{:#?}", result.constrs[1]), "((c * v3) + (-v2))");
assert_eq!(format!("{:#?}", result.constrs[2]), "((b * v2) + (-v1))");
assert_eq!(result.constrs.len(), 3);
assert!(result
.auto_signals
.iter()
Expand All @@ -244,11 +261,11 @@ mod test {
&mut TestSignalFactory::default(),
);

assert_eq!(format!("{:#?}", result.root_expr), "(0x1 + (-(a * v1)))");
assert_eq!(format!("{:#?}", result.exprs[0]), "((d * e) + (-v3))");
assert_eq!(format!("{:#?}", result.exprs[1]), "((c * v3) + (-v2))");
assert_eq!(format!("{:#?}", result.exprs[2]), "((b * v2) + (-v1))");
assert_eq!(result.exprs.len(), 3);
assert_eq!(format!("{:#?}", result.root_constr), "(0x1 + (-(a * v1)))");
assert_eq!(format!("{:#?}", result.constrs[0]), "((d * e) + (-v3))");
assert_eq!(format!("{:#?}", result.constrs[1]), "((c * v3) + (-v2))");
assert_eq!(format!("{:#?}", result.constrs[2]), "((b * v2) + (-v1))");
assert_eq!(result.constrs.len(), 3);
assert!(result
.auto_signals
.iter()
Expand All @@ -270,14 +287,14 @@ mod test {
);

assert_eq!(
format!("{:#?}", result.root_expr),
format!("{:#?}", result.root_constr),
"((a * v1) + (-(b * v3)))"
);
assert_eq!(format!("{:#?}", result.exprs[0]), "((a * a) + (-v2))");
assert_eq!(format!("{:#?}", result.exprs[1]), "((a * v2) + (-v1))");
assert_eq!(format!("{:#?}", result.exprs[2]), "((d * e) + (-v4))");
assert_eq!(format!("{:#?}", result.exprs[3]), "((c * v4) + (-v3))");
assert_eq!(result.exprs.len(), 4);
assert_eq!(format!("{:#?}", result.constrs[0]), "((a * a) + (-v2))");
assert_eq!(format!("{:#?}", result.constrs[1]), "((a * v2) + (-v1))");
assert_eq!(format!("{:#?}", result.constrs[2]), "((d * e) + (-v4))");
assert_eq!(format!("{:#?}", result.constrs[3]), "((c * v4) + (-v3))");
assert_eq!(result.constrs.len(), 4);
assert!(result
.auto_signals
.iter()
Expand Down

0 comments on commit 803c8b5

Please sign in to comment.