Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
Add more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
leolara committed Nov 17, 2023
1 parent 341f5b0 commit b3c0504
Showing 1 changed file with 108 additions and 7 deletions.
115 changes: 108 additions & 7 deletions src/poly/reduce.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,19 @@ use crate::{

use super::{ExprDecomp, SignalFactory};

pub fn reduce_degree<F: Field, V: Clone + Eq + PartialEq + Hash + Debug, SF: SignalFactory<V>>(
pub fn reduce_degre<F: Field, V: Clone + Eq + PartialEq + Hash + Debug, SF: SignalFactory<V>>(
expr: Expr<F, V>,
max_degree: usize,
signal_factory: &mut SF,
) -> ExprDecomp<F, V> {
reduce_degree_recursive(expr, max_degree, max_degree, signal_factory)
}

fn reduce_degree_recursive<
F: Field,
V: Clone + Eq + PartialEq + Hash + Debug,
SF: SignalFactory<V>,
>(
expr: Expr<F, V>,
total_max_degree: usize,
partial_max_degree: usize,
Expand All @@ -23,7 +35,7 @@ pub fn reduce_degree<F: Field, V: Clone + Eq + PartialEq + Hash + Debug, SF: Sig
let ses_reduction: Vec<_> = ses
.iter()
.map(|se| {
reduce_degree(
reduce_degree_recursive(
se.clone(),
total_max_degree,
partial_max_degree,
Expand All @@ -41,7 +53,7 @@ pub fn reduce_degree<F: Field, V: Clone + Eq + PartialEq + Hash + Debug, SF: Sig
}
Expr::Neg(se) => {
let mut reduction =
reduce_degree(*se, total_max_degree, partial_max_degree, signal_factory);
reduce_degree_recursive(*se, total_max_degree, partial_max_degree, signal_factory);
reduction.root_expr = Expr::Neg(Box::new(reduction.root_expr));

reduction
Expand Down Expand Up @@ -82,7 +94,8 @@ fn reduce_degree_mul<F: Field, V: Clone + Eq + PartialEq + Hash + Debug, SF: Sig
} else {
total_max_degree
};
let reduction = reduce_degree(se, total_max_degree, partial_max_degree, signal_factory);
let reduction =
reduce_degree_recursive(se, total_max_degree, partial_max_degree, signal_factory);
first = false;

reduction
Expand Down Expand Up @@ -110,7 +123,7 @@ fn reduce_degree_mul<F: Field, V: Clone + Eq + PartialEq + Hash + Debug, SF: Sig
root_exprs.push(Expr::Query(rest_signal.clone()));
let root_expr = Expr::Mul(root_exprs);

let simplified = reduce_degree(
let simplified = reduce_degree_recursive(
Expr::Mul(to_simplify.iter().map(|se| se.root_expr.clone()).collect()),
total_max_degree,
total_max_degree,
Expand All @@ -136,10 +149,10 @@ mod test {

use crate::{
ast::{query::Queriable, InternalSignal},
poly::ToExpr,
poly::{Expr::*, ToExpr},
};

use super::{reduce_degree_mul, SignalFactory};
use super::{reduce_degre, reduce_degree_mul, SignalFactory};

#[derive(Default)]
struct TestSignalFactory {
Expand Down Expand Up @@ -195,4 +208,92 @@ mod test {
.any(|(s, expr)| format!("{:#?}: {:#?}", s, expr) == "v1: ((b + c) * (a + c))"));
assert_eq!(result.auto_signals.len(), 1);
}

#[test]
fn test_reduce_degree() {
let a: Queriable<Fr> = Queriable::Internal(InternalSignal::new("a"));
let b: Queriable<Fr> = Queriable::Internal(InternalSignal::new("b"));
let c: Queriable<Fr> = Queriable::Internal(InternalSignal::new("c"));
let d: Queriable<Fr> = Queriable::Internal(InternalSignal::new("d"));
let e: Queriable<Fr> = Queriable::Internal(InternalSignal::new("e"));

let result = reduce_degre(a * b * c * d * e, 2, &mut TestSignalFactory::default());

assert_eq!(format!("{:#?}", result.root_expr), "(a * v1)");
assert_eq!(format!("{:#?}", result.exprs[0]), "((d * e) + (-v3))");
assert_eq!(format!("{:#?}", result.exprs[1]), "((c * v3) + (-v2))");
assert_eq!(format!("{:#?}", result.exprs[2]), "((b * v2) + (-v1))");
assert_eq!(result.exprs.len(), 3);
assert!(result
.auto_signals
.iter()
.any(|(s, expr)| format!("{:#?}: {:#?}", s, expr) == "v2: (c * v3)"));
assert!(result
.auto_signals
.iter()
.any(|(s, expr)| format!("{:#?}: {:#?}", s, expr) == "v1: (b * v2)"));
assert!(result
.auto_signals
.iter()
.any(|(s, expr)| format!("{:#?}: {:#?}", s, expr) == "v3: (d * e)"));
assert_eq!(result.auto_signals.len(), 3);

let result = reduce_degre(
1.expr() - (a * b * c * d * e),
2,
&mut TestSignalFactory::default(),
);

assert_eq!(format!("{:#?}", result.root_expr), "(0x1 + (-(a * v1)))");
assert_eq!(format!("{:#?}", result.exprs[0]), "((d * e) + (-v3))");
assert_eq!(format!("{:#?}", result.exprs[1]), "((c * v3) + (-v2))");
assert_eq!(format!("{:#?}", result.exprs[2]), "((b * v2) + (-v1))");
assert_eq!(result.exprs.len(), 3);
assert!(result
.auto_signals
.iter()
.any(|(s, expr)| format!("{:#?}: {:#?}", s, expr) == "v2: (c * v3)"));
assert!(result
.auto_signals
.iter()
.any(|(s, expr)| format!("{:#?}: {:#?}", s, expr) == "v1: (b * v2)"));
assert!(result
.auto_signals
.iter()
.any(|(s, expr)| format!("{:#?}: {:#?}", s, expr) == "v3: (d * e)"));
assert_eq!(result.auto_signals.len(), 3);

let result = reduce_degre(
Pow(Box::new(a.expr()), 4) - (b * c * d * e),
2,
&mut TestSignalFactory::default(),
);

assert_eq!(
format!("{:#?}", result.root_expr),
"((a * v1) + (-(b * v3)))"
);
assert_eq!(format!("{:#?}", result.exprs[0]), "((a * a) + (-v2))");
assert_eq!(format!("{:#?}", result.exprs[1]), "((a * v2) + (-v1))");
assert_eq!(format!("{:#?}", result.exprs[2]), "((d * e) + (-v4))");
assert_eq!(format!("{:#?}", result.exprs[3]), "((c * v4) + (-v3))");
assert_eq!(result.exprs.len(), 4);
assert!(result
.auto_signals
.iter()
.any(|(s, expr)| format!("{:#?}: {:#?}", s, expr) == "v2: (a * a)"));
assert!(result
.auto_signals
.iter()
.any(|(s, expr)| format!("{:#?}: {:#?}", s, expr) == "v1: (a * v2)"));
assert!(result
.auto_signals
.iter()
.any(|(s, expr)| format!("{:#?}: {:#?}", s, expr) == "v4: (d * e)"));
assert!(result
.auto_signals
.iter()
.any(|(s, expr)| format!("{:#?}: {:#?}", s, expr) == "v3: (c * v4)"));
assert_eq!(result.auto_signals.len(), 4);
}
}

0 comments on commit b3c0504

Please sign in to comment.