diff --git a/15_kzg/Cargo.toml b/15_kzg/Cargo.toml index 0d218c3..7773788 100644 --- a/15_kzg/Cargo.toml +++ b/15_kzg/Cargo.toml @@ -13,5 +13,4 @@ bls12_381 = "0.8.0" rand = "0.8.5" rand_core = { version = "0.6.4", default-features = false, features = ["std"] } rayon = "1.7.0" -log = "0.4.19" -num-bigint = "0.4.3" +sha3 = "0.10.6" diff --git a/15_kzg/src/kzg.rs b/15_kzg/src/kzg.rs new file mode 100644 index 0000000..6504e6e --- /dev/null +++ b/15_kzg/src/kzg.rs @@ -0,0 +1,47 @@ +mod param; +mod prover; +mod verifier; + +use bls12_381::Scalar; +use ff::Field; +use pairing::Engine; + +pub struct KZGProof { + cm: E::G1, // commit of p(x) + eval: E::Fr, // eval for p(z) + pi: E::G1, // aka.π, commit of q(x), q = p(x)-p(z)/x-z +} + +impl KZGProof { + fn new(cm: E::G1, eval: E::Fr, pi: E::G1) -> Self { + Self { cm, eval, pi } + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::kzg::param::ParamKzg; + use crate::kzg::prover::Prover; + use crate::kzg::verifier::Verifier; + use crate::poly::Polynomial; + use bls12_381::Bls12; + use ff::PrimeField; + + #[test] + fn test_kzg_protocol() { + let k = 4; + let poly = Polynomial::random(3); + + // setup + let param = ParamKzg::::setup(k); + + // prove + let prover = Prover::init(param.clone()); + let proof = prover.prover(&poly); + + // verify + let verifier = Verifier::init(param); + verifier.verify(proof); + } +} diff --git a/15_kzg/src/kzg/param.rs b/15_kzg/src/kzg/param.rs new file mode 100644 index 0000000..71f98ca --- /dev/null +++ b/15_kzg/src/kzg/param.rs @@ -0,0 +1,80 @@ +use crate::msm::small_multiexp; +use crate::poly::Polynomial; +use ff::{Field, PrimeField}; +use group::prime::PrimeCurveAffine; +use pairing::Engine; +use rand_core::OsRng; +use std::fmt::Debug; + +// The SRS +#[derive(Clone)] +pub struct ParamKzg { + pub(crate) k: usize, + pub(crate) n: usize, + pub pow_tau_g1: Vec, + pub pow_tau_g2: Vec, +} + +impl ParamKzg +where + E::Fr: PrimeField, +{ + fn new(k: usize) -> Self { + Self::setup(k) + } + + // Generate the SRS + pub fn setup(k: usize) -> Self { + let n = 1 << k; + + let tau = E::Fr::random(OsRng); + + // obtain: s, ..., s^i,..., s^n + let powers_of_tau: Vec = (0..n) + .into_iter() + .scan(E::Fr::ONE, |acc, _| { + let v = *acc; + *acc *= tau; + Some(v) + }) + .collect(); + + // obtain [s]1 + let pow_tau_g1: Vec = powers_of_tau + .iter() + .map(|tau_pow| E::G1Affine::generator() * tau_pow) + .collect(); + + // obtain [s]2 + let pow_tau_g2: Vec = powers_of_tau + .iter() + .map(|tau_pow| E::G2Affine::generator() * tau_pow) + .collect(); + + Self { + k, + n, + pow_tau_g1, + pow_tau_g2, + } + } + + // unify ti with commit_lagrange + pub fn eval_at_tau_g1(&self, poly: &Polynomial) -> E::G1 { + let mut scalars = Vec::with_capacity(poly.len()); + scalars.extend(poly.coeffs().iter()); + let bases = &self.pow_tau_g1; + let size = scalars.len(); + assert!(bases.len() >= size); + small_multiexp(&scalars, &bases[0..size]) + } + + pub fn eval_at_tau_g2(&self, poly: &Polynomial) -> E::G2 { + let mut scalars = Vec::with_capacity(poly.len()); + scalars.extend(poly.coeffs().iter()); + let bases = &self.pow_tau_g2; + let size = scalars.len(); + assert!(bases.len() >= size); + small_multiexp(&scalars, &bases[0..size]) + } +} diff --git a/15_kzg/src/kzg/prover.rs b/15_kzg/src/kzg/prover.rs new file mode 100644 index 0000000..ccbb9f8 --- /dev/null +++ b/15_kzg/src/kzg/prover.rs @@ -0,0 +1,100 @@ +use crate::kzg::param::ParamKzg; +use crate::kzg::KZGProof; +use crate::poly::Polynomial; +use crate::transcript::default::Keccak256Transcript; +use crate::transcript::Transcript; +use ff::{BitViewSized, Field, PrimeField}; +use pairing::Engine; +use std::ops::{MulAssign, SubAssign}; + +pub struct Prover { + param: ParamKzg, +} + +impl Prover { + pub fn init(param: ParamKzg) -> Self { + Self { param } + } + + pub fn prover(&self, poly: &Polynomial) -> KZGProof { + // 1. commit + let cm = self.commit(poly); + + // 2. challenge z. + let mut transcript_1 = Keccak256Transcript::::default(); + let z = transcript_1.challenge(); + // 3. eval z. + let eval = poly.evaluate(z.clone()); + + // 4. open + let pi = self.open(poly, &z); + + KZGProof::new(cm, eval, pi) + } + + // return the commit of p + fn commit(&self, poly: &Polynomial) -> E::G1 { + self.param.eval_at_tau_g1(poly) + } + + // return the commit of q, aka.pi, the proof. + fn open(&self, poly: &Polynomial, z: &E::Fr) -> E::G1 { + // q = ( p(x) - p(z) ) / x-z + let q_coeff = Self::kate_division(&poly.coeffs(), z.clone()); + let q = Polynomial::from_coeffs(q_coeff); + // the proof is evaluating the Q at tau in G1 + self.commit(&q) + } + + // Divides polynomial `a` in `X` by `X - b` with no remainder. + // q(x) = f(x)-f(z)/x-z + fn kate_division(a: &Vec, z: E::Fr) -> Vec { + let b = -z; + let a = a.into_iter(); + + let mut q = vec![E::Fr::ZERO; a.len() - 1]; + + let mut tmp: E::Fr = E::Fr::ZERO; + for (q, r) in q.iter_mut().rev().zip(a.rev()) { + let mut lead_coeff = *r; + lead_coeff.sub_assign(&tmp); + *q = lead_coeff; + tmp = lead_coeff; + tmp.mul_assign(&b); + } + q + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::kzg::param::ParamKzg; + use crate::kzg::prover::Prover; + use crate::kzg::verifier::Verifier; + use crate::poly::Polynomial; + use bls12_381::{Bls12, Scalar}; + use ff::{Field, PrimeField}; + + #[test] + fn test_div() { + // division: -2+x+x^2 = (x-1)(x+2) + let division = vec![Scalar::from_u128(2).neg(), Scalar::ONE, Scalar::ONE]; + + // dividor: 2+x + // dividor: -1+x + let coeffs = vec![Scalar::ONE.neg(), Scalar::ONE]; + let dividor = Polynomial::from_coeffs(coeffs); + + // target: + // quotient poly: 2+x + // remainder poly: 0 + let target_qoutient = vec![Scalar::from_u128(2), Scalar::ONE]; + + // q(x) = f(x)-f(z)/x-z + let z = Scalar::ONE; + let actual_qoutient = Prover::::kate_division(&division, z); + + assert_eq!(actual_qoutient, target_qoutient); + } +} diff --git a/15_kzg/src/kzg/verifier.rs b/15_kzg/src/kzg/verifier.rs new file mode 100644 index 0000000..eda1ac9 --- /dev/null +++ b/15_kzg/src/kzg/verifier.rs @@ -0,0 +1,58 @@ +// Verifies that `points` exists in `proof` + +use crate::kzg::param::ParamKzg; +use crate::kzg::KZGProof; +use crate::poly::Polynomial; +use crate::transcript::default::Keccak256Transcript; +use crate::transcript::Transcript; +use bls12_381::Scalar; +use ff::{Field, PrimeField}; +use group::prime::PrimeCurveAffine; +use group::Curve; +use pairing::Engine; +use std::fmt::Debug; +use std::ops::Neg; + +pub struct Verifier { + param: ParamKzg, +} + +impl Verifier { + pub fn init(param: ParamKzg) -> Self { + Self { param } + } + + // verify proof by pairing: + // check e(π, [x−z]_2 ) = e(cm−[p(z)]_1, g2) + // => e(g1, g2)^{q(x)*(x-z)} = e(g1, g2)^{p(x)-p(z)} + // => q(x)*(x-z) = p(x)-p(z) + // -> same as Prove::open. + pub fn verify(&self, proof: KZGProof) { + let vanish_poly = |z: E::Fr| { + let coeffs = vec![z.neg(), E::Fr::ONE]; + Polynomial::from_coeffs(coeffs) + }; + + // 1. challenge z. + let mut transcript_1 = Keccak256Transcript::::default(); + let z = transcript_1.challenge(); + + // 2. prepare poly for pairing. + // compute: x-z + let vanish_poly = vanish_poly(z); + let eval_poly = Polynomial::from_coeffs(vec![proof.eval]); + + // 3.pairing + // e(pi, [x-z]2) + let e1 = E::pairing( + &proof.pi.to_affine(), + &self.param.eval_at_tau_g2(&vanish_poly).to_affine(), + ); + // e(cm-[p(z)]1, g2) + let e2 = E::pairing( + &(proof.cm - self.param.eval_at_tau_g1(&eval_poly)).to_affine(), + &E::G2Affine::generator(), + ); + assert_eq!(e1, e2, "Verify: failed for pairing."); + } +} diff --git a/15_kzg/src/lib.rs b/15_kzg/src/lib.rs index d7e300e..0b5e1d6 100644 --- a/15_kzg/src/lib.rs +++ b/15_kzg/src/lib.rs @@ -1,163 +1,4 @@ -//! this module contains an implementation of Kate-Zaverucha-Goldberg polynomial commitments - +mod kzg; +mod msm; mod poly; - -use crate::poly::*; -use bls12_381::Scalar as Fr; -use ff::Field; -use group::prime::PrimeCurveAffine; -use group::{Curve, Group}; -use pairing::Engine; -use rand::{Rng, RngCore}; -use std::fmt::Debug; -use std::iter; -use std::ops::Neg; - -/// KZG polinomial commitments on Bls12-381. This structure contains the trusted setup. -pub struct Kzg { - pub pow_tau_g1: Vec, - pub pow_tau_g2: Vec, -} - -impl Kzg { - fn eval_at_tau_g1(&self, poly: &Poly) -> E::G1 { - poly.0 - .iter() - .enumerate() - .fold(E::G1::identity(), |acc, (n, k)| { - acc + self.pow_tau_g1[n] * k - }) - } - - fn eval_at_tau_g2(&self, poly: &Poly) -> E::G2 { - poly.0 - .iter() - .enumerate() - .fold(E::G2::identity(), |acc, (n, k)| { - acc + self.pow_tau_g2[n] * k - }) - } - - fn z_poly_of(points: &[(E::Fr, E::Fr)]) -> Poly { - points.iter().fold(Poly::one(), |acc, (z, _y)| { - &acc * &Poly::new(vec![z.neg(), E::Fr::ONE]) - }) - } - - /// Generate the trusted setup. Is expected that this function is called - /// in a safe environment what will be destroyed after its execution - /// The `n` parameter is the maximum number of points that can be proved - pub fn trusted_setup(n: usize, rng: R) -> Self { - let tau = E::Fr::random(rng); - - let powers_of_tau: Vec = (0..n) - .into_iter() - .scan(E::Fr::ONE, |acc, _| { - let v = *acc; - *acc *= tau; - Some(v) - }) - .collect(); - - let pow_tau_g1: Vec = powers_of_tau - .iter() - .map(|tau_pow| E::G1Affine::generator() * tau_pow) - .collect(); - - let pow_tau_g2: Vec = powers_of_tau - .iter() - .map(|tau_pow| E::G2Affine::generator() * tau_pow) - .collect(); - - Self { - pow_tau_g1, - pow_tau_g2, - } - } - - /// Returns the maximum degree of the polinomial commitment - pub fn max_degree(&self) -> usize { - self.pow_tau_g1.len() - 1 - } - - /// Generate a polinomial and its commitment from a `set` of points - #[allow(non_snake_case)] - pub fn poly_commitment_from_set(&self, set: &[(E::Fr, E::Fr)]) -> (Poly, E::G1) { - let poly = Poly::lagrange(set); - let commitment = self.eval_at_tau_g1(&poly); - - (poly, commitment) - } - - /// Generates a proof that `points` exists in `set` - #[allow(non_snake_case)] - pub fn prove(&self, poly: &Poly, points: &[(E::Fr, E::Fr)]) -> E::G1 { - // compute a lagrange poliomial I that have all the points to proof that are in the set - // compute the polinomial Z that has roots (y=0) in all x's of I, - // so this is I=(x-x0)(x-x1)...(x-xn) - let I = Poly::lagrange(points); - let Z = Self::z_poly_of(points); - - // now compute that Q = ( P - I(x) ) / Z(x) - // also check that the division does not have remainder - let mut poly = poly.clone(); - poly -= &I; - let (Q, remainder) = poly / Z; - assert!(remainder.is_zero()); - - // the proof is evaluating the Q at tau in G1 - self.eval_at_tau_g1(&Q) - } - - /// Verifies that `points` exists in `proof` - #[allow(non_snake_case)] - pub fn verify(&self, commitment: &E::G1, points: &[(E::Fr, E::Fr)], proof: &E::G1) -> bool { - let I = Poly::lagrange(points); - let Z = Self::z_poly_of(points); - - let e1 = E::pairing(&proof.to_affine(), &self.eval_at_tau_g2(&Z).to_affine()); - - let e2 = E::pairing( - &(*commitment - self.eval_at_tau_g1(&I)).to_affine(), - &E::G2Affine::generator(), - ); - e1 == e2 - } -} - -#[cfg(test)] -mod test { - use super::*; - use bls12_381::*; - use rand::rngs::OsRng; - - #[test] - fn test_kzg() { - // Create a trustd setup that allows maximum 4 points (degree+1) - let kzg = Kzg::::trusted_setup(5, OsRng); - - // define the set of points (the "population"), and create a polinomial - // for them, as well its polinomial commitment, see the polinomial commitment - // like the "hash" of the polinomial - let set = vec![ - (Fr::from(1), Fr::from(2)), - (Fr::from(2), Fr::from(3)), - (Fr::from(3), Fr::from(4)), - (Fr::from(4), Fr::from(57)), - ]; - let (p, c) = kzg.poly_commitment_from_set(&set); - - // generate a proof that (1,2) and (2,3) are in the set - let proof01 = kzg.prove(&p, &vec![set[0].clone(), set[1].clone()]); - - // prove that (1,2) and (2,3) are in the set - assert!(kzg.verify(&c, &vec![set[0].clone(), set[1].clone()], &proof01)); - // other proofs will fail since the proof only works for exactly (1,2) AND (2,3) - assert!(!kzg.verify(&c, &vec![set[0].clone()], &proof01)); - assert!(!kzg.verify(&c, &vec![set[0].clone(), set[2].clone()], &proof01)); - - // prove and verify that the whole set exists in the whole set - let proof0123 = kzg.prove(&p, &set); - assert!(kzg.verify(&c, &set, &proof0123)); - } -} +mod transcript; diff --git a/15_kzg/src/msm.rs b/15_kzg/src/msm.rs new file mode 100644 index 0000000..5a996d9 --- /dev/null +++ b/15_kzg/src/msm.rs @@ -0,0 +1,29 @@ +// porting from halo2 + +use ff::PrimeField; +use group::prime::PrimeCurve; +use group::Group; + +/// Performs a small multi-exponentiation operation. +/// Uses the double-and-add algorithm with doublings shared across points. +pub fn small_multiexp(coeffs: &[C::Scalar], bases: &[C]) -> C { + let coeffs: Vec<_> = coeffs.iter().map(|a| a.to_repr()).collect(); + let mut acc = C::identity(); + + // for byte idx + for byte_idx in (0..32).rev() { + // for bit idx + for bit_idx in (0..8).rev() { + acc = acc.double(); + // for each coeff + for coeff_idx in 0..coeffs.len() { + let byte = coeffs[coeff_idx].as_ref()[byte_idx]; + if ((byte >> bit_idx) & 1) != 0 { + acc += bases[coeff_idx]; + } + } + } + } + + acc +} diff --git a/15_kzg/src/poly.rs b/15_kzg/src/poly.rs index ac680bf..02204d4 100644 --- a/15_kzg/src/poly.rs +++ b/15_kzg/src/poly.rs @@ -1,364 +1,338 @@ -use bls12_381::Scalar as Fr; -use ff::{Field, PrimeField}; - -/// A polynomial field scalar -#[derive(Clone, Debug, PartialEq, Eq)] -pub struct Poly(pub Vec); - -// NOTE! only used for Bn256::Fr. -impl Poly { - /// Creates a Poly from u64 coeffs - pub fn from(coeffs: &[u64]) -> Self { - Poly::new(coeffs.iter().map(|n| Fr::from(*n)).collect::>()) - } +use bls12_381::Scalar; +use ff::{BatchInvert, Field}; +use rand_core::OsRng; +use rayon::{current_num_threads, scope}; + +// p(x) = = a_0 + a_1 * X + ... + a_n * X^(n-1) +// +// coeffs: [a_0, a_1, ..., a_n] +// basis: X^[n-1] +#[derive(Debug, Clone, Eq, PartialEq)] +pub struct Polynomial { + pub(crate) coeffs: Vec, } -impl Poly { - /// Creates a Poly from its `coefficients` with order of poly `coeff[0]*x^0, ..., coeff[m-1]*x^m` - /// for safety, input value is normalized (trailing zeroes are removed) - pub fn new(coeffs: Vec) -> Self { - let mut poly = Poly(coeffs); - poly.normalize(); - poly - } - - /// Returns p(x)=0 - pub fn zero() -> Self { - Poly(vec![F::ZERO]) - } - - /// Returns p(x)=1 - pub fn one() -> Self { - Poly(vec![F::ONE]) - } - - /// Creates a poly that satisfy a set of `p` points, by using [lagrange](https://alinush.github.io/2022/07/28/lagrange-interpolation.html) - /// - /// ϕ(X)=∑yi⋅Li(X), where Li(X)=∏(X−xj)/(xi−xj) - /// - /// ## Examples - /// ```text - /// // f(x)=x is a polinomial that fits in (1,1), (2,2) points - /// assert_eq!( - /// Poly::lagrange(&vec![ - /// (Scalar::from(1), Scalar::from(1)), - /// (Scalar::from(2), Scalar::from(2)) - /// ]), - /// Poly::from(&[0, 1]) // f(x) = x - /// ); - /// ``` - pub fn lagrange(p: &[(F, F)]) -> Self { - let k = p.len(); - let mut l = Poly::zero(); - for j in 0..k { - let mut l_j = Poly::one(); - for i in 0..k { - if i != j { - let c = (p[j].0 - p[i].0).invert().unwrap(); - l_j = &l_j * &Poly::new(vec![-(c * p[i].0), c]); - } - } - l += &(&l_j * &p[j].1); - } - l +impl Polynomial { + pub fn random(k: usize) -> Self { + let n = 1 << k; + let coeffs = (0..n).map(|_| F::random(OsRng)).collect::>(); + Self::from_coeffs(coeffs) } - /// Evals the polynomial at the A point - /// # Examples - /// ```text - /// // check that (x^2+2x+1)(2) = 9 - /// let poly = Poly::from(&[1, 2, 1]) - /// assert_eq!( - /// poly.eval(&Scalar::from(2)), - /// Scalar::from(9) - /// ); - /// ``` - pub fn eval(&self, x: &F) -> F { - let mut x_pow = F::ONE; - let mut y = self.0[0]; - for (i, _) in self.0.iter().enumerate().skip(1) { - x_pow *= x; - y += &(x_pow * self.0[i]); - } - y + pub fn from_coeffs(coeffs: Vec) -> Self { + Self { coeffs } } - /// Evals the polynomial supplying the `x_pows` [x^0, x^1, x^2, ..., x^m] - pub fn eval_with_pows(&self, x_pow: &[F]) -> F { - let mut y = self.0[0]; - for (i, _) in self.0.iter().enumerate() { - y += &(x_pow[i] * self.0[i]); + // used by div. + pub fn zero() -> Self { + Self { + coeffs: vec![F::ZERO], } - y } - /// Returns the degree of the polynomial - /// - /// poly.size = poly.degree + 1 + // The degree of the polynomial pub fn degree(&self) -> usize { - self.0.len() - 1 + assert!(self.coeffs.len() > 0); + self.coeffs.len() - 1 } - - /// Returns the coeffs size of the polynomial - pub fn size(&self) -> usize { - self.0.len() + // The len of the polynomial coeffs + pub fn len(&self) -> usize { + self.coeffs.len() } - /// Normalizes the coefficients, removing ending zeroes - /// # Examples - /// ```text - /// use a0kzg::Poly; - /// let mut p1 = Poly::from(&[1, 0, 0, 0]); - /// p1.normalize(); - /// assert_eq!(p1, Poly::from(&[1])); - /// ``` - pub fn normalize(&mut self) { - if self.0.len() > 1 && self.0[self.0.len() - 1] == F::ZERO { - let zero = F::ZERO; - let first_non_zero = self.0.iter().rev().position(|p| p != &zero); - if let Some(first_non_zero) = first_non_zero { - self.0.resize(self.0.len() - first_non_zero, F::ZERO); - } else { - self.0.resize(1, F::ZERO); - } - } + pub fn coeffs(&self) -> Vec { + self.coeffs.clone() } - /// Returns if p(x)=0 - pub fn is_zero(&self) -> bool { - *self == Self::zero() - } + // p(x)=∑y_j⋅L_j(X), where + // y_j: [a_0, a_1, ..., a_n]. + // basis: L_j(X)=∏(X−x_k)/(x_j−x_k) + // + // domain: x, most case is that{0, 1, . . . , n − 1} + // evals: [a_0, a_1, ..., a_n] + // + // we can use encode points as (domain, eval) to polynomials + // the poly + pub fn lagrange_interpolate(domains: Vec, evals: Vec) -> Self { + assert_eq!(domains.len(), evals.len()); + + if evals.len() == 1 { + // Constant polynomial + Self { + coeffs: vec![evals[0]], + } + } else { + let poly_size = domains.len(); + let lag_basis_poly_size = poly_size - 1; + + // 1. divisors = vec(x_j - x_k). prepare for L_j(X)=∏(X−x_k)/(x_j−x_k) + let mut divisors = Vec::with_capacity(poly_size); + for (j, x_j) in domains.iter().enumerate() { + // divisor_j + let mut divisor = Vec::with_capacity(lag_basis_poly_size); + // obtain domain for x_k + for x_k in domains + .iter() + .enumerate() + .filter(|&(k, _)| k != j) + .map(|(_, x)| x) + { + divisor.push(*x_j - x_k); + } + divisors.push(divisor); + } + // Inverse (x_j - x_k)^(-1) for each j != k to compute L_j(X)=∏(X−x_k)/(x_j−x_k) + divisors + .iter_mut() + .flat_map(|v| v.iter_mut()) + .batch_invert(); + + // 2. Calculate L_j(X) : L_j(X)=∏(X−x_k) divisors_j + let mut L_j_vec: Vec> = Vec::with_capacity(poly_size); + + for (j, divisor_j) in divisors.into_iter().enumerate() { + let mut L_j: Vec = Vec::with_capacity(poly_size); + L_j.push(F::ONE); + + // (X−x_k) * divisors_j + let mut product = Vec::with_capacity(lag_basis_poly_size); + + // obtain domain for x_k + for (x_k, divisor) in domains + .iter() + .enumerate() + .filter(|&(k, _)| k != j) + .map(|(_, x)| x) + .zip(divisor_j.into_iter()) + { + product.resize(L_j.len() + 1, F::ZERO); + + // loop (poly_size + 1) round + // calculate L_j(X)=∏(X−x_k) divisors_j with coefficient form. + for ((a, b), product) in L_j + .iter() + .chain(std::iter::once(&F::ZERO)) + .zip(std::iter::once(&F::ZERO).chain(L_j.iter())) + .zip(product.iter_mut()) + { + *product = *a * (-divisor * x_k) + *b * divisor; + } + std::mem::swap(&mut L_j, &mut product); + } - /// Set the `i`-th coefficient with new value - /// # Examples - /// ```text - /// let mut poly = Poly::zero(); - /// poly.set(2, Scalar::from(7)); - /// assert_eq!(poly, Poly::from(&[0, 0, 7])); - /// ``` - pub fn set(&mut self, index: usize, p: F) { - let target_size = index + 1; - if self.size() < target_size { - self.0.resize(target_size, F::ZERO); - } - self.0[index] = p; - self.normalize(); - } + assert_eq!(L_j.len(), poly_size); + assert_eq!(product.len(), poly_size - 1); - /// Returns the `i`-th coefficient - /// # Examples - /// ```text - /// let mut poly = Poly::zero(); - /// poly.set(2, Scalar::from(7)); - /// assert_eq!(poly.get(2), Some(&Scalar::from(7))); - /// assert_eq!(poly.get(3), None); - /// ``` - pub fn get(&mut self, index: usize) -> Option<&F> { - self.0.get(index) - } -} + L_j_vec.push(L_j); + } -impl std::ops::AddAssign<&Poly> for Poly { - fn add_assign(&mut self, rhs: &Poly) { - for n in 0..std::cmp::max(self.0.len(), rhs.0.len()) { - if n >= self.0.len() { - self.0.push(rhs.0[n]); - } else if n < self.0.len() && n < rhs.0.len() { - self.0[n] += rhs.0[n]; + // p(x)=∑y_j⋅L_j(X) in coefficients + let mut final_poly = vec![F::ZERO; poly_size]; + // 3. p(x)=∑y_j⋅L_j(X) + for (L_j, y_j) in L_j_vec.iter().zip(evals) { + for (final_coeff, L_j_coeff) in final_poly.iter_mut().zip(L_j.into_iter()) { + *final_coeff += L_j_coeff.mul(y_j); + } } + Self { coeffs: final_poly } } - self.normalize(); } -} -impl std::ops::AddAssign<&F> for Poly { - fn add_assign(&mut self, rhs: &F) { - self.0[0] += rhs; - } -} + // This evaluates a polynomial (in coefficient form) at `x`. + pub fn evaluate(&self, x: F) -> F { + let coeffs = self.coeffs.clone(); + let poly_size = self.coeffs.len(); -impl std::ops::SubAssign<&Poly> for Poly { - fn sub_assign(&mut self, rhs: &Poly) { - for n in 0..std::cmp::max(self.0.len(), rhs.0.len()) { - if n >= self.0.len() { - self.0.push(rhs.0[n]); - } else if n < self.0.len() && n < rhs.0.len() { - self.0[n] -= rhs.0[n]; - } + // p(x) = = a_0 + a_1 * X + ... + a_n * X^(n-1), revert it and fold sum it + fn eval(poly: &[F], point: F) -> F { + poly.iter() + .rev() + .fold(F::ZERO, |acc, coeff| acc * point + coeff) + } + + let num_threads = current_num_threads(); + if poly_size * 2 < num_threads { + eval(&coeffs, x) + } else { + let chunk_size = (poly_size + num_threads - 1) / num_threads; + let mut parts = vec![F::ZERO; num_threads]; + scope(|scope| { + for (chunk_idx, (out, c)) in parts + .chunks_mut(1) + .zip(coeffs.chunks(chunk_size)) + .enumerate() + { + scope.spawn(move |_| { + let start = chunk_idx * chunk_size; + out[0] = eval(c, x) * x.pow_vartime(&[start as u64, 0, 0, 0]); + }); + } + }); + parts.iter().fold(F::ZERO, |acc, coeff| acc + coeff) } - self.normalize(); } } -impl std::ops::Mul<&Poly> for &Poly { - type Output = Poly; - fn mul(self, rhs: &Poly) -> Self::Output { - let mut mul: Vec = std::iter::repeat(F::ZERO) - .take(self.0.len() + rhs.0.len() - 1) - .collect(); - for n in 0..self.0.len() { - for m in 0..rhs.0.len() { - mul[n + m] += self.0[n] * rhs.0[m]; +impl std::ops::Mul<&Polynomial> for &Polynomial { + type Output = Polynomial; + fn mul(self, rhs: &Polynomial) -> Self::Output { + let mut coeffs: Vec = vec![F::ZERO; self.coeffs.len() + rhs.coeffs.len() - 1]; + for n in 0..self.coeffs.len() { + for m in 0..rhs.coeffs.len() { + coeffs[n + m] += self.coeffs[n] * rhs.coeffs[m]; } } - Poly(mul) + Self::Output { coeffs } } } -impl std::ops::Mul<&F> for &Poly { - type Output = Poly; +impl std::ops::Mul<&F> for &Polynomial { + type Output = Polynomial; fn mul(self, rhs: &F) -> Self::Output { - if rhs == &F::ZERO { - Poly::zero() + let coeffs = if rhs == &F::ZERO { + vec![F::ZERO] } else { - Poly(self.0.iter().map(|v| *v * *rhs).collect::>()) - } + self.coeffs.iter().map(|c| c.mul(rhs)).collect::>() + }; + Self::Output { coeffs } } } -impl std::ops::Div for Poly { - type Output = (Poly, Poly); - - fn div(self, rhs: Poly) -> Self::Output { - let (mut q, mut r) = (Poly::zero(), self); - while !r.is_zero() && r.degree() >= rhs.degree() { - let lead_r = r.0[r.0.len() - 1]; - let lead_d = rhs.0[rhs.0.len() - 1]; - let mut t = Poly::zero(); - t.set(r.0.len() - rhs.0.len(), lead_r * lead_d.invert().unwrap()); - q += &t; - r -= &(&rhs * &t); - } - (q, r) - } -} -impl std::fmt::Display for Poly { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let mut first: bool = true; - for i in (0..=self.degree()).rev() { - let bi_n = - num_bigint::BigUint::from_bytes_le(&self.0[i].to_repr().as_ref()).to_str_radix(10); - let bi_inv = num_bigint::BigUint::from_bytes_le(self.0[i].neg().to_repr().as_ref()) - .to_str_radix(10); - - if bi_n == "0" { - continue; - } - - if bi_inv.len() < 20 && bi_n.len() > 20 { - if bi_inv == "1" && i != 0 { - write!(f, "-")?; +impl std::ops::Add<&Polynomial> for &Polynomial { + type Output = Polynomial; + + fn add(self, rhs: &Polynomial) -> Self::Output { + let max_len = std::cmp::max(self.coeffs.len(), rhs.coeffs.len()); + let coeffs = (0..max_len) + .into_iter() + .map(|n| { + if n >= self.coeffs.len() { + rhs.coeffs[n] + } else if n >= rhs.coeffs.len() { + self.coeffs[n] } else { - write!(f, "-{}", bi_inv)?; - } - } else { - if !first { - write!(f, "+")?; - } - if i == 0 || bi_n != "1" { - write!(f, "{}", bi_n)?; + // n < self.0.len() && n < rhs.0.len() + self.coeffs[n] + rhs.coeffs[n] } - } - if i >= 1 { - write!(f, "x")?; - } - if i >= 2 { - write!(f, "^{}", i)?; - } - first = false; - } - Ok(()) + }) + .collect::>(); + Self::Output { coeffs } } } +// impl std::ops::Div for &Polynomial { +// type Output = (Polynomial, Polynomial); +// +// fn div(self, rhs: &Polynomial) -> Self::Output { +// // init the (quotient, remainder) +// let (mut q, mut r) = (Polynomial::zero(), self); +// +// // r is not zero poly, and division.degree > divisor.degree. +// while *r != Polynomial::zero() && r.degree() >= rhs.degree() { +// let r_coeff = r.coeffs(); +// let rhs_coeff = rhs.coeffs(); +// +// let lead_r = r_coeff[r_coeff.len() - 1]; +// let lead_d = rhs_coeff[rhs_coeff.len() - 1]; +// let mut t = Polynomial::zero(); +// t.set( +// r_coeff.len() - rhs_coeff.len(), +// lead_r * lead_d.invert().unwrap(), +// ); +// q += &t; +// r -= &(&rhs * &t); +// } +// (q, r) +// } +// } #[cfg(test)] mod test { use super::*; - use bls12_381::Scalar as Fr; - use bls12_381::*; + use ff::PrimeField; + use std::ops::{Add, Div, Mul}; #[test] - fn test_poly_add() { - let mut p246 = Poly::from(&[1, 2, 3]); - p246 += &Poly::from(&[1, 2, 3]); - assert_eq!(p246, Poly::from(&[2, 4, 6])); - - let mut p24645 = Poly::from(&[1, 2, 3]); - p24645 += &Poly::from(&[1, 2, 3, 4, 5]); - assert_eq!(p24645, Poly::from(&[2, 4, 6, 4, 5])); - - let mut p24646 = Poly::from(&[1, 2, 3, 4, 6]); - p24646 += &Poly::from(&[1, 2, 3]); - assert_eq!(p24646, Poly::from(&[2, 4, 6, 4, 6])); - } - - #[test] - fn test_poly_sub() { - let mut p0 = Poly::from(&[1, 2, 3]); - p0 -= &Poly::from(&[1, 2, 3]); - assert_eq!(p0, Poly::from(&[0])); - - let mut p003 = Poly::from(&[1, 2, 3]); - p003 -= &Poly::from(&[1, 2]); - assert_eq!(p003, Poly::from(&[0, 0, 3])); - } + fn test_mul_poly() { + // p = 1 - x + let p = Polynomial { + coeffs: vec![Scalar::one(), Scalar::one().neg()], + }; + // q = 1 + x + let q = Polynomial { + coeffs: vec![Scalar::one(), Scalar::one()], + }; - #[test] - fn test_poly_mul() { assert_eq!( - &Poly::from(&[5, 0, 10, 6]) * &Poly::from(&[1, 2, 4]), - Poly::from(&[5, 10, 30, 26, 52, 24]) + p.mul(&q).coeffs, + vec![Scalar::one(), Scalar::zero(), Scalar::one().neg()] ); - } - #[test] - fn test_div() { - fn do_test(n: Poly, d: Poly) { - let (q, r) = n.clone() / d.clone(); - let mut n2 = &q * &d; - n2 += &r; - assert_eq!(n, n2); - } + // add + assert_eq!(p.add(&q).coeffs, vec![Scalar::from_u128(2), Scalar::zero()]); - do_test(Poly::::from(&[1]), Poly::::from(&[1, 1])); - do_test(Poly::::from(&[1, 1]), Poly::::from(&[1, 1])); - do_test(Poly::::from(&[1, 2, 1]), Poly::::from(&[1, 1])); - do_test( - Poly::::from(&[1, 2, 1, 2, 5, 8, 1, 9]), - Poly::::from(&[1, 1, 5, 4]), - ); - } - - #[test] - fn test_print() { - assert_eq!("x^2+2x+1", format!("{}", Poly::from(&[1, 2, 1]))); - assert_eq!("x^2+1", format!("{}", Poly::from(&[1, 0, 1]))); - assert_eq!("x^2", format!("{}", Poly::from(&[0, 0, 1]))); - assert_eq!("2x^2", format!("{}", Poly::from(&[0, 0, 2]))); - assert_eq!("-4", format!("{}", Poly::new(vec![-Fr::from(4)]))); + // poly.mul(scalar) assert_eq!( - "-4x", - format!("{}", Poly::new(vec![Fr::zero(), -Fr::from(4)])) - ); - assert_eq!( - "-x-2", - format!("{}", Poly::new(vec![Fr::from(2).neg(), Fr::from(1).neg()])) - ); - assert_eq!( - "x-2", - format!("{}", Poly::new(vec![-Fr::from(2), Fr::from(1)])) + p.mul(&Scalar::from_u128(5)).coeffs, + vec![Scalar::from_u128(5), Scalar::from_u128(5).neg()] ); } #[test] - fn test_lagrange_multi() { - let points = vec![ - (Fr::from(12342), Fr::from(22342)), - (Fr::from(2234), Fr::from(22222)), - (Fr::from(3982394), Fr::from(111114)), - (Fr::from(483838), Fr::from(444444)), + fn lagrange_interpolate() { + // aim: p = 1 + 2x + x^2 + + let domain = vec![ + Scalar::from_u128(1), + Scalar::from_u128(2), + Scalar::from_u128(3), + Scalar::from_u128(4), + Scalar::from_u128(5), + Scalar::from_u128(6), + Scalar::from_u128(7), + Scalar::from_u128(8), + Scalar::from_u128(9), + ]; + let evals = vec![ + Scalar::from_u128(4), + Scalar::from_u128(9), + Scalar::from_u128(10), + Scalar::from_u128(19), + Scalar::from_u128(24), + Scalar::from_u128(31), + Scalar::from_u128(40), + Scalar::from_u128(51), + Scalar::from_u128(64), ]; - let l = Poly::lagrange(&points); - points.iter().for_each(|p| assert_eq!(l.eval(&p.0), p.1)); + + let poly = Polynomial::lagrange_interpolate(domain.clone(), evals.clone()); + + for (x, y) in domain.iter().zip(evals) { + assert_eq!(poly.evaluate(*x), y); + } + println!("pass"); } + + // #[test] + // fn test_div() { + // // division: 2+3x+x^2 = (x+1)(x+2) + // let coeffs = vec![Scalar::from_u128(2), Scalar::ONE, Scalar::ONE]; + // let division = Polynomial::from_coeffs(coeffs); + // + // // dividor: 2+x + // let coeffs = vec![Scalar::from_u128(2), Scalar::ONE]; + // let dividor = Polynomial::from_coeffs(coeffs); + // + // // target: + // // quotient poly: 1+x + // // remainder poly: 0 + // let coeffs = vec![Scalar::from_u128(2), Scalar::ONE]; + // let target_qoutient = Polynomial::from_coeffs(coeffs); + // let target_remainder = Polynomial::zero(); + // + // // division / dividor = quotient + remainder + // let (actual_qoutient, actual_remainder) = division.div(dividor); + // + // assert_eq!(actual_qoutient, target_qoutient); + // assert_eq!(actual_remainder, target_remainder); + // } } diff --git a/15_kzg/src/transcript.rs b/15_kzg/src/transcript.rs new file mode 100644 index 0000000..8aae362 --- /dev/null +++ b/15_kzg/src/transcript.rs @@ -0,0 +1,32 @@ +#![allow(clippy::map_flatten)] +#![allow(clippy::ptr_arg)] +use bls12_381::Scalar; +use ff::{Field, PrimeField}; + +use crate::poly::Polynomial; +pub mod default; + +pub trait Transcript { + fn append(&mut self, new_data: &[u8]); + + fn challenge(&mut self) -> F; +} + +#[cfg(test)] +mod test { + use super::*; + use crate::transcript::default::Keccak256Transcript; + use bls12_381::Scalar; + use ff::Field; + + #[test] + fn test_challenge() { + let mut transcript_1 = Keccak256Transcript::::default(); + let challenge_1 = transcript_1.challenge(); + + let mut transcript_2 = Keccak256Transcript::::default(); + let challenge_2 = transcript_2.challenge(); + + assert_eq!(challenge_2, challenge_1); + } +} diff --git a/15_kzg/src/transcript/default.rs b/15_kzg/src/transcript/default.rs new file mode 100644 index 0000000..08771f2 --- /dev/null +++ b/15_kzg/src/transcript/default.rs @@ -0,0 +1,38 @@ +use crate::transcript::Transcript; +use bls12_381::Scalar; +use ff::{Field, PrimeField}; +use sha3::{Digest, Keccak256}; +use std::marker::PhantomData; +use std::net::UdpSocket; + +pub struct Keccak256Transcript { + hasher: Keccak256, + _marker: PhantomData, +} + +impl Transcript for Keccak256Transcript { + fn append(&mut self, new_data: &[u8]) { + self.hasher.update(&mut new_data.to_owned()); + } + + // auto append and gen challenge + fn challenge(&mut self) -> F { + self.append(&[1]); + + let mut result_hash = [0_u8; 32]; + result_hash.copy_from_slice(&self.hasher.finalize_reset()); + result_hash.reverse(); + self.hasher.update(result_hash); + let sum = result_hash.to_vec().iter().map(|&b| b as u128).sum(); + F::from_u128(sum) + } +} + +impl Default for Keccak256Transcript { + fn default() -> Self { + Self { + hasher: Keccak256::new(), + _marker: Default::default(), + } + } +} diff --git a/4_sumcheck/src/poly/univar_poly.rs b/4_sumcheck/src/poly/univar_poly.rs index fae091e..5f8c446 100644 --- a/4_sumcheck/src/poly/univar_poly.rs +++ b/4_sumcheck/src/poly/univar_poly.rs @@ -12,6 +12,20 @@ pub struct Polynomial { } impl Polynomial { + pub fn from_coeffs(coeffs: Vec) -> Self { + Self { coeffs } + } + + // The degree of the polynomial + pub fn degree(&self) -> usize { + assert!(self.coeffs.len() > 0); + self.coeffs.len() - 1 + } + + pub fn coeffs(&self) -> Vec { + self.coeffs.clone() + } + // p(x)=∑y_j⋅L_j(X), where // y_j: [a_0, a_1, ..., a_n]. // basis: L_j(X)=∏(X−x_k)/(x_j−x_k) diff --git a/7_Merkle_tree_commtment/src/merkle_tree.rs b/7_Merkle_tree_commtment/src/merkle_tree.rs index b05feac..6f02fb1 100644 --- a/7_Merkle_tree_commtment/src/merkle_tree.rs +++ b/7_Merkle_tree_commtment/src/merkle_tree.rs @@ -5,6 +5,7 @@ pub mod proof; use crate::merkle_tree::hasher::{calculate_hash, calculate_parent_hash}; use crate::merkle_tree::node::TreeNode; use crate::merkle_tree::proof::Proof; +use crate::utils::convert_to_binary; use ark_std::log2; use std::cmp::Ordering; @@ -12,8 +13,8 @@ use std::cmp::Ordering; // and where every internal node holds the hash of the concatenation of the hashes of its children nodes. // Note: For convinence, we suppose Merkle tree is a ![complete binary tree](https://www.geeksforgeeks.org/types-of-binary-tree/?ref=lbp) // Degree: 2 -// Leaf nodes: if tree height is h, so the number of leaf nodes will be `2^h` -// Total nodes: A tree of height h has total nodes = 2^(h+1)–1 +// Leaf nodes: if tree height is h, so the number of leaf nodes will be `2^(h-1)` +// Total nodes: A tree of height h has total nodes = 2^h–1 // Height of tree: If tree has N nodes, the hight `h=log(N+1)–1=Θ(ln(n))`. From root to leaf: [1,h]. #[derive(Clone, Debug)] pub struct MerkleTree { @@ -22,6 +23,7 @@ pub struct MerkleTree { } impl MerkleTree { + // init and commit // Constructs a Merkle Tree from a vector of data. // Root = hash_util(left.hash + right.hash) pub fn init(values: Vec) -> Self { @@ -30,9 +32,10 @@ impl MerkleTree { "Can't initial MerkleTree from empty vector" ); let leaves_num = values.len(); - let height: usize = log2(leaves_num) as usize; - assert_eq!(1 << height, leaves_num, "It's not a perfect tree"); + let height: usize = 1 + log2(leaves_num) as usize; + assert_eq!(1 << (height - 1), leaves_num, "It's not a perfect tree"); + // lowest level let leaves_nodes = values .iter() .map(|v| TreeNode::new_leaf(*v)) @@ -40,7 +43,7 @@ impl MerkleTree { // construct tree by leaves. let mut cur = leaves_nodes; - for i in 0..height { + for i in 0..(height - 1) { let cur_len = cur.len(); let parant = (0..(cur_len / 2)) .map(|j| { @@ -59,34 +62,12 @@ impl MerkleTree { } assert_eq!(cur.len(), 1); - // while cur.len() > 1 { - // let mut parent = Vec::new(); - // while !cur.is_empty() { - // let left = cur.remove(0); - // let right = cur.remove(0); - // - // let sum = left.get_hash() + right.get_hash(); - // let parent_hash = calculate_hash(&sum); - // - // let node = TreeNode::Node { - // hash: parent_hash, - // left: Box::new(left), - // right: Box::new(right), - // }; - // - // parent.push(node); - // } - // - // height += 1; - // - // cur = parent; - // } - let root = cur.remove(0); MerkleTree { root, height } } + // commit and open. pub fn commit(&self, x: &char) -> Proof { let mut values = Vec::with_capacity(self.height - 1); let root_hash = self.root.get_hash(); @@ -136,12 +117,12 @@ impl MerkleTree { // Leaf nodes: if tree height is h, so the number of leaf nodes will be `2^h` pub fn leaves_num(&self) -> usize { - 2 ^ self.height + 2 ^ (self.height - 1) } // Total nodes: A tree of height h has total nodes = 2^(h+1)–1 pub fn nodes_num(&self) -> usize { - 2 ^ (self.height + 1) - 1 + 2 ^ self.height - 1 } } diff --git a/7_Merkle_tree_commtment/src/merkle_tree/proof.rs b/7_Merkle_tree_commtment/src/merkle_tree/proof.rs index 702080b..64fe642 100644 --- a/7_Merkle_tree_commtment/src/merkle_tree/proof.rs +++ b/7_Merkle_tree_commtment/src/merkle_tree/proof.rs @@ -5,6 +5,6 @@ // And totally needs h hash values. #[derive(Clone, Eq, PartialEq, Debug)] pub struct Proof { - pub children: Vec, // the children from left to root. - pub root: u64, // root hash + pub children: Vec, // the children from left to root. aka evals + pub root: u64, // root hash. aka cm } diff --git a/7_Merkle_tree_commtment/src/prover.rs b/7_Merkle_tree_commtment/src/prover.rs index 278cf9c..e64320e 100644 --- a/7_Merkle_tree_commtment/src/prover.rs +++ b/7_Merkle_tree_commtment/src/prover.rs @@ -12,7 +12,7 @@ impl Prover { pub fn random_values(k: usize) -> Self { let values = random_chars(k); let merkle_tree = MerkleTree::init(values.clone()); - assert_eq!(merkle_tree.height(), k, "Unexpected Merkle tree height"); + assert_eq!(merkle_tree.height() - 1, k, "Unexpected Merkle tree height"); Self { values, merkle_tree, diff --git a/7_Merkle_tree_commtment/src/utils.rs b/7_Merkle_tree_commtment/src/utils.rs index fb2ecdd..3d0c474 100644 --- a/7_Merkle_tree_commtment/src/utils.rs +++ b/7_Merkle_tree_commtment/src/utils.rs @@ -1,6 +1,17 @@ use rand::distributions::{Alphanumeric, DistString}; use rand_core::OsRng; +// convert a num into its binary form +// eg: 8 -> 1000, will output [1, 0, 0, 0] +pub fn convert_to_binary(bit_len: &usize, num: usize) -> Vec { + let a: usize = 1 << bit_len; + assert!(a >= num); + (0..*bit_len) + .map(|n| (num >> n) & 1) + .rev() + .collect::>() +} + pub fn random_chars(k: usize) -> Vec { let n = 1 << k; let random_code = Alphanumeric.sample_string(&mut OsRng, n); diff --git a/7_low_degree_test/Cargo.toml b/7_low_degree_test/Cargo.toml new file mode 100644 index 0000000..ea04995 --- /dev/null +++ b/7_low_degree_test/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "low_degree_test" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +sumcheck = {path = "../4_sumcheck"} +ff = "0.13.0" +bls12_381 = "0.8.0" +rand = "0.8.5" +rand_core = { version = "0.6.4", default-features = false, features = ["std"] } +rayon = "1.7.0" +sha3 = "0.10.6" +ark-std = "0.4.0" \ No newline at end of file diff --git a/7_low_degree_test/src/ldt.rs b/7_low_degree_test/src/ldt.rs new file mode 100644 index 0000000..a7a6425 --- /dev/null +++ b/7_low_degree_test/src/ldt.rs @@ -0,0 +1,57 @@ +pub mod prover; +pub mod verifier; + +use self::prover::Prover; +use self::verifier::Verifier; +use crate::merkle_tree::proof::MerkleProof; +use crate::poly::*; +use bls12_381::Scalar; +use ff::Field; +use rand_core::OsRng; +use std::env::consts::OS; +use std::iter::Scan; + +#[derive(Default)] +pub struct LDTProof { + pub commits: Vec, // commit of fi + pub evals: Vec<(Scalar, Scalar)>, // The open values on challenge z for fi: (f0(z), f0(−z)), f1(z^2), f1(−z^2) + pub last_const: (Scalar, Scalar), // (p_L, p_R) +} + +// Both P and V have oracle access to function f. +// V wants to test if f is polynomial with deg(f) ≤ d. +pub struct LDT { + prover: Prover, + verifier: Verifier, +} + +impl LDT { + pub fn new(degree: usize) -> Self { + let poly = random_poly(degree); + let z = Scalar::random(OsRng); + let challenge: Scalar = (*poly.coeffs().get(0).unwrap()).clone(); + let prover = Prover::init(poly, z.clone(), challenge.clone()); + + let verifier = Verifier::init(degree, z, challenge); + + Self { prover, verifier } + } + + pub fn run_protocol(&self) { + let proofs = self.prover.prove(); + + self.verifier.verify(proofs); + } +} + +#[cfg(test)] +mod test { + use crate::ldt::LDT; + + #[test] + fn test() { + // degree = 1< Self { + Self { poly, z, merkle_c } + } + + pub fn prove(&self) -> LDTProof { + let mut transcript = Keccak256Transcript::default(); + let mut proof = LDTProof::default(); + + // iter for d rounds. + let mut d = log2(self.poly.degree()); + + // P starts from f(x), and for i = 0 sets f0(x) = f(x). + let p_0 = self.poly.clone(); + let mut p_i = p_0; + + // Use the index of coeffs as the challenge, so challenge in [1,2,4,2^i,d), by the index is [0,..,2^i-1,..,d-1]. + let mut merkle_c_i = self.merkle_c; + let mut z_i = self.z; // z^1 = z^(2^0) + while p_i.degree() > 0 { + Self::split_and_fold( + &mut transcript, + &mut proof, + &mut p_i, + z_i.clone(), + merkle_c_i, + ); + // prepare for next round + z_i = z_i.mul(&z_i); // z^(2^i), Important !!! + merkle_c_i.double(); // double. + } + + proof + } + + pub fn split_and_fold( + transcript: &mut Keccak256Transcript, + proof: &mut LDTProof, + p_i: &mut Polynomial, + z_i: Scalar, + merkle_c_i: Scalar, + ) { + assert!(p_i.degree() != 0, "poly.degree=0, can't split_and_fold"); + // 1. split + let (p_L, p_R) = split_poly(&p_i); + // last iteration + if p_L.degree() == 0 && p_R.degree() == 0 { + proof.last_const = (*p_L.coeffs().get(0).unwrap(), *p_R.coeffs().get(0).unwrap()); + *p_i = p_L.clone(); + return; + } + + // 2. fold + // gen challenge: alpha + let alpha_i = transcript.challenge(); + // compute new poly fi+1, which is the random linear combination of p_L,p_R, + // f_i_1 = f_L + c*f_R + let p_i_plus_1 = p_L.add(&p_L.mul(&alpha_i)); + + // 3. commit phase + // merkle tree commit the poly fi+1 + let merkle_tree = MerkleTree::commit(p_i_plus_1.coeffs().clone()); + // 4. query phase + let cm_i = merkle_tree.open(&merkle_c_i); + + // 5. evaluate + let f_z = p_i_plus_1.evaluate(z_i.clone()); + let f_neg_z = p_i_plus_1.evaluate(z_i.neg()); + + // cache in script + proof.commits.push(cm_i); + proof.evals.push((f_z, f_neg_z)); + + *p_i = p_i_plus_1; + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::poly::random_poly; + use ff::Field; + use rand_core::OsRng; + + #[test] + fn test_prove() { + let degree = 2; + let poly = random_poly(degree); + println!("poly{:?}", poly); + println!("poly.degree: {:?} ", poly.degree()); + + let z = Scalar::random(OsRng); + let challenge: Scalar = (*poly.coeffs().get(0).unwrap()).clone(); + let prover = Prover::init(poly, z, challenge.clone()); + let proof = prover.prove(); + // proof.commit.len:1 + // proof.evals.len:1 + // log_degree:2 + println!("\n\n output-proof"); + println!("proof.commit.len:{:?}", proof.commits.len()); + println!("proof.evals.len:{:?}", proof.evals.len()); + println!("proof.last_const:{:?}", proof.last_const); + println!("log_degree:{:?}", log2(3)); + } +} diff --git a/7_low_degree_test/src/ldt/verifier.rs b/7_low_degree_test/src/ldt/verifier.rs new file mode 100644 index 0000000..23bc653 --- /dev/null +++ b/7_low_degree_test/src/ldt/verifier.rs @@ -0,0 +1,75 @@ +use crate::ldt::LDTProof; +use crate::merkle_tree::proof::MerkleProof; +use crate::merkle_tree::MerkleTree; +use crate::transcript::default::Keccak256Transcript; +use crate::transcript::Transcript; +use ark_std::log2; +use bls12_381::Scalar; +use ff::PrimeField; + +pub struct Verifier { + pub target_deg: usize, // target degree + z: Scalar, // The origin value for evaluate. + merkle_c: Scalar, // the commit challenge. +} + +impl Verifier { + pub fn init(target_deg: usize, z: Scalar, merkle_c: Scalar) -> Self { + Self { + target_deg, + z, + merkle_c, + } + } + + pub fn verify(&self, proof: LDTProof) { + let mut transcript = Keccak256Transcript::default(); + let d = log2(self.target_deg) as usize; + + assert_eq!(proof.commits.len(), d - 1); + assert_eq!(proof.evals.len(), d - 1); + let commits = &proof.commits; + let evals = &proof.evals; + let mut z_i = self.z; // z^1 = z^(2^0) + + let mut merkle_c_i = self.merkle_c; + + let two_inv = Scalar::from_u128(2).invert().unwrap(); + for i in 0..(d - 1) { + println!(""); + println!("round: i: {:?}", i); + println!("evals.len: {:?}", evals.len()); + // 1. obtain fi_L(z^2) ,fi_R (z^2) with: fi(z) = fi_L(z^2) + z fi_R (z^2) + let (f_i_z, f_i_neg_z): (Scalar, Scalar) = *evals.get(i).unwrap(); + let f_i_L = (f_i_z.add(&f_i_neg_z)).mul(&two_inv); + // calc fiR + let double_z_inv = z_i.double().invert().unwrap(); + let f_i_R = (f_i_z.sub(&f_i_neg_z)).mul(&double_z_inv); + + if d == 1 || (d - 2) == i { + // 2. last round check + assert_eq!( + proof.last_const, + (f_i_L, f_i_R), + "Verifier: Last round check failed." + ); + } else { + // 2. check fi+1(z^2) = fi_L(z^2) + αi*fi_R(z^2) + let alpha = transcript.challenge(); + let f_i_plus_1 = f_i_L + alpha * f_i_R; + let (target_f_i_plus_1, _): (Scalar, Scalar) = *evals.get(i + 1).unwrap(); + assert_eq!( + f_i_plus_1, target_f_i_plus_1, + "Verifier: round-{i} check failed." + ); + + // 3. verify the cm todo + MerkleTree::verify(&merkle_c_i, commits.get(i).unwrap()); + + // prepare for next round + z_i = z_i.mul(&z_i); // z^(2^i), Important !!! + merkle_c_i.double(); + } + } + } +} diff --git a/7_low_degree_test/src/lib.rs b/7_low_degree_test/src/lib.rs new file mode 100644 index 0000000..fd8b7d8 --- /dev/null +++ b/7_low_degree_test/src/lib.rs @@ -0,0 +1,8 @@ +//! This is the implement of the FRI-LDT. See more on [Fast reed-solomon interactive oracle proofs of proximity](https://eccc.weizmann.ac.il/report/2017/134) +//! and [A summary on the fri low degree test](https://eprint.iacr.org/2022/1216) + +pub mod ldt; +mod merkle_tree; +mod poly; +mod transcript; +mod utils; diff --git a/7_low_degree_test/src/merkle_tree.rs b/7_low_degree_test/src/merkle_tree.rs new file mode 100644 index 0000000..b3147ce --- /dev/null +++ b/7_low_degree_test/src/merkle_tree.rs @@ -0,0 +1,371 @@ +pub mod hasher; +pub mod node; +pub mod proof; + +use crate::merkle_tree::hasher::{Keccak256Hash, ScalarHash}; +use crate::merkle_tree::node::TreeNode; +use crate::merkle_tree::proof::MerkleProof; +use crate::utils::convert_to_binary; +use ark_std::log2; +use bls12_381::Scalar; +use std::cmp::Ordering; +use std::fs::metadata; + +// A Merkle tree is a binary tree, with values of type `T` at the leafs, +// and where every internal node holds the hash of the concatenation of the hashes of its children nodes. +// Note: For convinence, we suppose Merkle tree is a ![complete binary tree](https://www.geeksforgeeks.org/types-of-binary-tree/?ref=lbp) +// Degree: 2 +// Leaf nodes: if tree height is h, so the number of leaf nodes will be `2^(h-1)` +// Total nodes: A tree of height h has total nodes = 2^h–1 +// Height of tree: If tree has N nodes, the hight `h=log(N+1)–1=Θ(ln(n))`. From root to leaf: [1,h]. +#[derive(Clone, Debug)] +pub struct MerkleTree { + root: TreeNode, // The root of the inner binary tree + height: usize, // The height of the tree +} + +impl MerkleTree { + // init and commit + // Constructs a Merkle Tree from a vector of data. + // Root = hash_util(left.hash + right.hash) + pub fn commit(values: Vec) -> Self { + assert!( + !values.is_empty(), + "Can't initial MerkleTree from empty vector" + ); + let leaves_num = values.len(); + let height: usize = 1 + log2(leaves_num) as usize; + assert_eq!(1 << (height - 1), leaves_num, "It's not a perfect tree"); + + // lowest level + let leaves_nodes = values + .iter() + .map(|v| TreeNode::new_leaf(*v)) + .collect::>(); + + // construct tree by leaves. + let mut cur = leaves_nodes; + for i in 0..(height - 1) { + let cur_len = cur.len(); + let parant = (0..(cur_len / 2)) + .map(|j| { + let left = cur.get(2 * j).unwrap(); + let right = cur.get(2 * j + 1).unwrap(); + let parent_hash = Keccak256Hash::hash(&left.get_hash().add(&right.get_hash())); + + TreeNode::Node { + hash: parent_hash, + left: Box::new(left.clone()), + right: Box::new(right.clone()), + } + }) + .collect::>(); + cur = parant; + } + assert_eq!(cur.len(), 1); + + let root = cur.remove(0); + + MerkleTree { root, height } + } + + // equal the commit, by open it by index of values. + pub fn open_by_index(&self, index: usize) -> MerkleProof { + // index belong [0, leaves_num). + assert!(index >= 0 && index < self.leaves_num(), "Wrong leaf index"); + + let path_len = self.height - 1; + // get leaf-root path, + // Suppose the left child is 0, the right child is 1, so the path can be indexed as binary form with (height-1) bits. + // eg: tree height is 3, which has total 2^2 leaves, the leave can ben indexed as (00, 01, 10, 11). + // a. turn the index into binary form with (height-1) bits. + let mut path = convert_to_binary(&path_len, index); + path.reverse(); // make path from root to left. + + // b. according the path, we can found out the MerkleProof of the indexed leaf, which just need to collect the bro-node. + // We'll collect the bro-node by the path. Collect the left child is 1, the right child is 0. + + let mut values = Vec::with_capacity(self.height); + let root_hash = self.root.get_hash(); + + let mut cur_node = &self.root; + + println!("{:?}", path); + // for now the hash values are collected from root to leaf. + for p in path { + // let p = path.get(path_len - i).unwrap(); + + match cur_node { + TreeNode::Leaf { hash, value } => panic!("Never reach leaf"), + TreeNode::Node { hash, left, right } => { + // collect the right as bro-node. + if p == 0 { + values.push(right.get_hash()); + cur_node = left.as_ref(); + } else { + println!("value:{:?}", left.get_hash()); + values.push(left.get_hash()); + cur_node = right.as_ref(); + } + } + } + } + + // reverse the hash values to make sure it's from leaf to root + values.reverse(); + + MerkleProof { + root: root_hash, + children: values, + } + } + + // open. + // The challenge maybe not in values, so return empty children. + pub fn open(&self, challenge: &Scalar) -> MerkleProof { + let mut values = Vec::with_capacity(self.height - 1); + let root_hash = self.root.get_hash(); + Self::dfs(&self.root, &challenge, &mut values); + if values.is_empty() { + // log! todo + } + MerkleProof { + root: root_hash, + children: values, + } + } + + fn dfs(root: &TreeNode, target: &Scalar, res: &mut Vec) -> bool { + match root { + TreeNode::Leaf { hash, value } => { + if value == target { + true + } else { + false + } + } + TreeNode::Node { hash, left, right } => { + let l = Self::dfs(left, target, res); + // if left meet target. + if l { + res.push(right.get_hash()); + return true; + } + + // if right meet target. + let r = Self::dfs(right, target, res); + if r { + res.push(left.get_hash()); + } + r + } + } + } + + // Returns the root hash of Merkle tree + pub fn root_hash(&self) -> Scalar { + self.root.get_hash() + } + + // Returns the height of Merkle tree + pub fn height(&self) -> usize { + self.height + } + + // Leaf nodes: if tree height is h, so the number of leaf nodes will be `2^h` + pub fn leaves_num(&self) -> usize { + 1 << (self.height - 1) + } + + // Total nodes: A tree of height h has total nodes = 2^(h+1)–1 + pub fn nodes_num(&self) -> usize { + 2 ^ self.height - 1 + } + + pub fn verify(challenge: &Scalar, proof: &MerkleProof) { + let target = proof.root; + let actual = if proof.children.is_empty() { + // The challenge maybe not in values, just verify the root. + proof.root + } else { + let leaf_hash = Keccak256Hash::hash(&challenge); + proof + .children + .iter() + .fold(leaf_hash, |acc, eval| Keccak256Hash::hash(&acc.add(&eval))) + }; + assert_eq!(target, actual, "Verifier: verify failed!") + } + + // equal the commit, by open it by index of values. + pub fn verify_by_index(&self, index: usize, proof: &MerkleProof) { + // index belong [0, leaves_num). + assert!(index >= 0 && index < self.leaves_num(), "Wrong leaf index"); + + let path_len = self.height - 1; + // 1. get leaf-root path, + // Suppose the left child is 0, the right child is 1, so the path can be indexed as binary form with (height-1) bits. + // eg: tree height is 3, which has total 2^2 leaves, the leave can ben indexed as (00, 01, 10, 11). + // a. turn the index into binary form with (height-1) bits. + let mut path = convert_to_binary(&path_len, index); + // make path from root to left. + path.reverse(); + // to make sure iter can reach the leaf layer. + path.push(2); + + // b. found out the target left. + let mut cur_node = &self.root; + let mut challenge = Scalar::zero(); + for p in path { + match cur_node { + TreeNode::Leaf { hash, value } => challenge = value.clone(), + TreeNode::Node { hash, left, right } => { + assert!(p != 2); + // collect the right as bro-node. + if p == 0 { + cur_node = left.as_ref(); + } else { + cur_node = right.as_ref(); + } + } + } + } + println!("target: {:?}", challenge); + Self::verify(&challenge, proof); + } +} + +#[cfg(test)] +mod test { + use crate::merkle_tree::proof::MerkleProof; + use crate::merkle_tree::MerkleTree; + use crate::poly::random_poly; + use crate::utils::{random_chars, random_scalars}; + use bls12_381::Scalar; + use ff::PrimeField; + use rand_core::{OsRng, RngCore}; + use std::fmt::Debug; + + #[test] + fn test_init_merkle_tree() { + let poly = random_poly(3); + println!("chars:{:?}", poly); + let merkle_tree = MerkleTree::commit(poly.coeffs()); + println!("merkle tree: {:?}", merkle_tree); + } + + #[test] + fn test_commit_and_verify() { + let coeffs = vec![ + Scalar::one(), + Scalar::from_u128(12), + Scalar::one(), + Scalar::from_u128(13), + ]; + let merkle_tree = MerkleTree::commit(coeffs); + // println!("merkle tree: {:?}", merkle_tree); + + // MerkleTree { + // root: Node { + // hash: 0x4053ef94c1db0c3a6159b84891f03ee40b5aaca60091f6e438b7b653cf1b6f20, + // left: Node { + // hash: 0x5d3b8160daf88b74a74b4a5b91ce4eaea2f64628d6c8f4717330d7734eb0f2f0, + // left: Leaf { + // hash: 0x38a2f65eb883578ccc8a27acd26c6646d22fbbaa09e533726b84bd7d9ff94c87, + // value: 0x0000000000000000000000000000000000000000000000000000000000000001 + // }, + // right: Leaf { + // hash: 0x33feef36be1c5c0384ecaba81a839c2126444a9dec203df90fa6b8ec2fdeaa87, + // value: 0x000000000000000000000000000000000000000000000000000000000000000c + // } + // }, + // right: Node { + // hash: 0x56108a065ccd17f0706ef2fa4aa8b80620d7490c9cab818b25b48b39c58594fa, + // left: Leaf { + // hash: 0x38a2f65eb883578ccc8a27acd26c6646d22fbbaa09e533726b84bd7d9ff94c87, + // value: 0x0000000000000000000000000000000000000000000000000000000000000001 + // }, + // right: Leaf { + // hash: 0x0e2c9965653910c8765b9b7f6eb348643c6da2e58d76a165cd14dfe960e1d418, + // value: 0x000000000000000000000000000000000000000000000000000000000000000d + // } + // } + // }, + // height: 2 + // } + let challenge = Scalar::one(); + // MerkleProof { + // children: [ + // 0x33feef36be1c5c0384ecaba81a839c2126444a9dec203df90fa6b8ec2fdeaa87, + // 0x56108a065ccd17f0706ef2fa4aa8b80620d7490c9cab818b25b48b39c58594fa + // ], + // root: 0x4053ef94c1db0c3a6159b84891f03ee40b5aaca60091f6e438b7b653cf1b6f20 + // } + let proof = merkle_tree.open(&challenge); + println!("{:?}", proof); + // correct + + MerkleTree::verify(&challenge, &proof); + } + + #[test] + fn test_commit_and_verify_by_index() { + let coeffs = vec![ + Scalar::one(), + Scalar::from_u128(12), + Scalar::zero(), + Scalar::from_u128(13), + ]; + let merkle_tree = MerkleTree::commit(coeffs); + println!("merkle tree: {:?}", merkle_tree); + + // merkle tree: MerkleTree { + // root: Node { + // hash: 0x5e7042679f5529053b7fff8b5578fd6d332070daf6595fba220475075e3f6ca6, + // left: Node { + // hash: 0x1c40f88c081d9412450a3b26f15202a9cb65b87f28a1a7247c3825d22de9fd16, + // left: Leaf { + // hash: 0x38a2f65eb883578ccc8a27acd26c6646d22fbbaa09e533726b84bd7d9ff94c87, + // value: 0x0000000000000000000000000000000000000000000000000000000000000001 + // }, + // right: Leaf { + // hash: 0x33feef36be1c5c0384ecaba81a839c2126444a9dec203df90fa6b8ec2fdeaa87, + // value: 0x000000000000000000000000000000000000000000000000000000000000000c + // } + // }, + // right: Node { + // hash: 0x144cd89c19f04739ec8b8f39cfeb65b16c7a5fbcf1e3043baf2a8a0ef9a2de0e, + // left: Leaf { + // hash: 0x64fdc03ab367550778acd60eef4b079fde19310f5db8ddd6700b2bf9ce4a9211, + // value: 0x0000000000000000000000000000000000000000000000000000000000000000 + // }, + // right: Leaf { + // hash: 0x0e2c9965653910c8765b9b7f6eb348643c6da2e58d76a165cd14dfe960e1d418, + // value: 0x000000000000000000000000000000000000000000000000000000000000000d + // } + // } + // }, + // height: 3 + // } + // let challenge = (OsRng.next_u32() % 4) as usize; + let challenge = 3; + println!("{challenge}"); + // [1, 1] + // value: 0x1c40f88c081d9412450a3b26f15202a9cb65b87f28a1a7247c3825d22de9fd16 + // value: 0x64fdc03ab367550778acd60eef4b079fde19310f5db8ddd6700b2bf9ce4a9211 + // MerkleProof { + // children: [ + // 0x64fdc03ab367550778acd60eef4b079fde19310f5db8ddd6700b2bf9ce4a9211, + // 0x1c40f88c081d9412450a3b26f15202a9cb65b87f28a1a7247c3825d22de9fd16 + // ], + // root: 0x5e7042679f5529053b7fff8b5578fd6d332070daf6595fba220475075e3f6ca6 + // } + let proof = merkle_tree.open_by_index(challenge); + println!("{:?}", proof); + // correct + merkle_tree.verify_by_index(challenge, &proof); + + // MerkleTree::verify(&challenge, &proof); + } +} +// diff --git a/7_low_degree_test/src/merkle_tree/hasher.rs b/7_low_degree_test/src/merkle_tree/hasher.rs new file mode 100644 index 0000000..8e80a80 --- /dev/null +++ b/7_low_degree_test/src/merkle_tree/hasher.rs @@ -0,0 +1,103 @@ +use bls12_381::Scalar; +use ff::PrimeField; +use sha3::{Digest, Keccak256}; +use std::collections::hash_map::DefaultHasher; +use std::hash::{Hash, Hasher}; +use std::io::Read; +use std::marker::PhantomData; +use std::ops::Div; + +// abstraction to set the hash function used +pub trait ScalarHash: Clone { + fn hash(inputs: &F) -> F; + fn hashes(inputs: &[F]) -> F; +} + +#[derive(Clone, Copy, Debug)] +pub struct Keccak256Hash { + _marker: PhantomData, +} + +impl ScalarHash for Keccak256Hash { + // same as calculate_hash, this is for Scalar + fn hash(input: &Scalar) -> Scalar { + // hash + let mut h = Keccak256::new(); + h.update(input.to_repr().as_ref()); + + // let r = h.finalize().as_slice(); + let slice: [u8; 32] = h.finalize().as_slice().try_into().unwrap(); + // get_scalar + let bytes = [slice.clone(), slice] + .concat() + .as_slice() + .try_into() + .unwrap(); + Scalar::from_bytes_wide(&bytes) + } + + // same as calculate_parent_hash, this is for Scalar + fn hashes(inputs: &[Scalar]) -> Scalar { + // hash + let mut h = Keccak256::new(); + for x in inputs { + h.update(x.to_repr().as_ref()); + } + + // let r = h.finalize().as_slice(); + let slice: [u8; 32] = h.finalize().as_slice().try_into().unwrap(); + // get_scalar + let bytes = [slice.clone(), slice] + .concat() + .as_slice() + .try_into() + .unwrap(); + Scalar::from_bytes_wide(&bytes) + } +} + +/// calculate the hash of the data +pub fn calculate_hash(t: &T) -> u64 { + let mut s = DefaultHasher::new(); + t.hash(&mut s); + s.finish() +} +pub fn calculate_parent_hash(left: u64, right: u64) -> u64 { + let mut sum: u128 = (left.div(2) + right / 2) as u128; + let mut s = DefaultHasher::new(); + sum.hash(&mut s); + s.finish() +} + +#[cfg(test)] +mod test { + use crate::merkle_tree::hasher::{calculate_parent_hash, Keccak256Hash, ScalarHash}; + use bls12_381::Scalar; + use ff::{Field, PrimeField}; + use rand_core::{OsRng, RngCore}; + + #[test] + fn test_calculate_parent_hash() { + let rng = &mut OsRng; + let left = rng.next_u64(); + let right = rng.next_u64(); + + let parent = calculate_parent_hash(left, right); + println!("{:?}", parent); + } + + #[test] + fn test_calculate_parent_hash_with_scalar() { + let left = Scalar::random(&mut OsRng); + let right = Scalar::random(&mut OsRng); + + let parent = Keccak256Hash::hash(&left.add(&right)); + println!("{:?}", parent); + + let left = Scalar::from_u128(10); + let right = Scalar::from_u128(12); + + let parent = Keccak256Hash::hashes(&[left, right]); + println!("{:?}", parent); + } +} diff --git a/7_low_degree_test/src/merkle_tree/node.rs b/7_low_degree_test/src/merkle_tree/node.rs new file mode 100644 index 0000000..70ea48f --- /dev/null +++ b/7_low_degree_test/src/merkle_tree/node.rs @@ -0,0 +1,72 @@ +use crate::merkle_tree::hasher::{Keccak256Hash, ScalarHash}; +use crate::merkle_tree::MerkleTree; +use bls12_381::Scalar; +use std::hash::Hash; + +/// Node of a Binary Tree. +#[derive(Clone, Debug, Eq)] +pub enum TreeNode { + Leaf { + hash: Scalar, // Hash of the node + value: Scalar, // Value of the leaf node + }, + Node { + hash: Scalar, // Hash of the node + left: Box, // Left child of the node + right: Box, // Right chiild of the node + }, +} + +impl TreeNode { + /// Create a new Node + pub fn new(hash: Scalar, value: Scalar) -> Self { + Self::Leaf { hash, value } + } + + // Create a new leaf + pub fn new_leaf(value: Scalar) -> TreeNode { + let hash = Keccak256Hash::hash(&value); + Self::new(hash, value) + } + + // Returns a hash from the Node. + pub fn get_hash(&self) -> Scalar { + match self { + &Self::Leaf { hash, .. } => hash, + &Self::Node { hash, .. } => hash, + } + } +} + +impl PartialEq for TreeNode { + fn eq(&self, other: &Self) -> bool { + match self { + TreeNode::Node { hash, left, right } => { + let (hash1, left1, right1) = (hash, left, right); + match other { + TreeNode::Node { hash, left, right } => { + if hash1 != hash || left1 != left || right1 != right || right1 != right { + false + } else { + true + } + } + _ => false, + } + } + TreeNode::Leaf { hash, value } => { + let (hash1, value1) = (hash, value); + match other { + TreeNode::Leaf { hash, value } => { + if hash1 != hash || value1 != value { + false + } else { + true + } + } + _ => false, + } + } + } + } +} diff --git a/7_low_degree_test/src/merkle_tree/proof.rs b/7_low_degree_test/src/merkle_tree/proof.rs new file mode 100644 index 0000000..0dcb47d --- /dev/null +++ b/7_low_degree_test/src/merkle_tree/proof.rs @@ -0,0 +1,12 @@ +use bls12_381::Scalar; + +// Proof is a tree, only contain the hash values from target leaf to root with related brather-nodes. +// Meanwhile, half of the tree can be calculated by the known leaf value.. +// So according the Figure 7.1(from zkbook), it's quite easy to find that just need to return the hasher from +// brather-nodes(each layer has only one!), the left infos will be calculated by verifier. +// And totally needs h hash values. +#[derive(Clone, Eq, PartialEq, Debug)] +pub struct MerkleProof { + pub children: Vec, // the children from left to root. aka evals + pub root: Scalar, // root hash. aka cm +} diff --git a/7_low_degree_test/src/poly.rs b/7_low_degree_test/src/poly.rs new file mode 100644 index 0000000..dfe27e1 --- /dev/null +++ b/7_low_degree_test/src/poly.rs @@ -0,0 +1,64 @@ +use bls12_381::Scalar; +use ff::Field; +use rand_core::OsRng; +pub use sumcheck::poly::univar_poly::*; + +// fi(x) = fi^L (x2) + x fi^R (x2) +pub fn split_poly(p: &Polynomial) -> (Polynomial, Polynomial) { + assert!(p.degree() != 0, "poly.degree=0, can't split_and_fold"); + // let d = p.degree() + 1; + let coeffs = p.coeffs(); + let odd: Vec = coeffs.iter().step_by(2).cloned().collect(); + let even: Vec = coeffs.iter().skip(1).step_by(2).cloned().collect(); + // return the fi_L and fi_R + (Polynomial::from_coeffs(odd), Polynomial::from_coeffs(even)) +} + +// random a poly with a degree +pub fn random_poly(degree: usize) -> Polynomial { + assert!(degree >= 0); + let coeffs = (0..=degree) + .into_iter() + .map(|_| Scalar::random(OsRng)) + .collect::>(); + let poly = Polynomial::from_coeffs(coeffs); + assert_eq!(poly.degree(), degree); + poly +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_split() { + let deg = 5; + let poly = random_poly(deg); + + let (pL, pR) = split_poly(&poly); + + // check that f(z) == fL(x^2) + x * fR(x^2), for a rand z + let z = Scalar::random(OsRng); + assert_eq!( + poly.evaluate(z.clone()), + pL.evaluate(z.square()) + z * pR.evaluate(z.square()) + ); + } + + #[test] + fn test_split_more() { + // let deg = 5; + for deg in 1..5 { + let poly = random_poly(deg); + + let (pL, pR) = split_poly(&poly); + + // check that f(z) == fL(x^2) + x * fR(x^2), for a rand z + let z = Scalar::random(OsRng); + assert_eq!( + poly.evaluate(z.clone()), + pL.evaluate(z.square()) + z * pR.evaluate(z.square()) + ); + } + } +} diff --git a/7_low_degree_test/src/transcript.rs b/7_low_degree_test/src/transcript.rs new file mode 100644 index 0000000..6a15d1f --- /dev/null +++ b/7_low_degree_test/src/transcript.rs @@ -0,0 +1,56 @@ +#![allow(clippy::map_flatten)] +#![allow(clippy::ptr_arg)] +use bls12_381::Scalar; + +use crate::poly::Polynomial; +pub mod default; + +pub trait Transcript { + fn append(&mut self, new_data: &[u8]); + + fn challenge(&mut self) -> Scalar; +} + +pub(crate) fn poly_to_bytes(poly: &Polynomial) -> Vec { + coeffs_to_bytes(&poly.coeffs()) +} + +fn coeffs_to_bytes(coeffs: &Vec) -> Vec { + coeffs + .iter() + .map(|c| c.to_bytes()) + .flatten() + .collect::>() +} + +#[cfg(test)] +mod test { + use super::*; + use crate::transcript::coeffs_to_bytes; + use crate::transcript::default::Keccak256Transcript; + use bls12_381::Scalar; + use ff::Field; + use rand_core::OsRng; + + #[test] + fn test_coeff_to_transcript() { + let mut rng = OsRng; + + let coeffs = (0..4).map(|_| Scalar::random(rng)).collect::>(); + + // from scalar vector + let mut transcript_1 = Keccak256Transcript::default(); + for x in coeffs.clone() { + transcript_1.append(&x.to_bytes()); + } + let challenge_1 = transcript_1.challenge(); + + // from coeffs, as mock of poly + let mut transcript_2 = Keccak256Transcript::default(); + let bytes = coeffs_to_bytes(&coeffs); + transcript_2.append(&bytes); + let challenge_2 = transcript_2.challenge(); + + assert_eq!(challenge_2, challenge_1); + } +} diff --git a/7_low_degree_test/src/transcript/default.rs b/7_low_degree_test/src/transcript/default.rs new file mode 100644 index 0000000..32c5cf1 --- /dev/null +++ b/7_low_degree_test/src/transcript/default.rs @@ -0,0 +1,35 @@ +use crate::transcript::Transcript; +use bls12_381::Scalar; +use ff::PrimeField; +use sha3::{Digest, Keccak256}; +use std::net::UdpSocket; + +pub struct Keccak256Transcript { + hasher: Keccak256, +} + +impl Transcript for Keccak256Transcript { + fn append(&mut self, new_data: &[u8]) { + self.hasher.update(&mut new_data.to_owned()); + } + + // auto append and gen challenge + fn challenge(&mut self) -> Scalar { + self.append(&[1]); + + let mut result_hash = [0_u8; 32]; + result_hash.copy_from_slice(&self.hasher.finalize_reset()); + result_hash.reverse(); + self.hasher.update(result_hash); + let sum = result_hash.to_vec().iter().map(|&b| b as u128).sum(); + Scalar::from_u128(sum) + } +} + +impl Default for Keccak256Transcript { + fn default() -> Self { + Self { + hasher: Keccak256::new(), + } + } +} diff --git a/7_low_degree_test/src/utils.rs b/7_low_degree_test/src/utils.rs new file mode 100644 index 0000000..33f0708 --- /dev/null +++ b/7_low_degree_test/src/utils.rs @@ -0,0 +1,45 @@ +use bls12_381::Scalar; +use ff::Field; +use rand::distributions::{Alphanumeric, DistString}; +use rand_core::OsRng; + +// convert a num into its binary form +// eg: 8 -> 1000, will output [1, 0, 0, 0] +pub fn convert_to_binary(bit_len: &usize, num: usize) -> Vec { + let a: usize = 1 << bit_len; + assert!(a >= num); + (0..*bit_len) + .map(|n| (num >> n) & 1) + .rev() + .collect::>() +} + +pub fn random_chars(k: usize) -> Vec { + let n = 1 << k; + let random_code = Alphanumeric.sample_string(&mut OsRng, n); + random_code.chars().collect::>() +} +pub fn random_scalars(k: usize) -> Vec { + let n = 1 << k; + (0..n).map(|_| Scalar::random(OsRng)).collect::>() +} + +#[cfg(test)] +mod test { + use crate::utils::random_chars; + use rand::distributions::{Alphanumeric, DistString}; + use rand_core::OsRng; + + #[test] + fn test_random_char() { + let k = 5; + let n = 1 << k; + let random_code = Alphanumeric.sample_string(&mut OsRng, n); + println!("{:?}", random_code); + let a = random_code.chars().collect::>(); + println!("{:?}", a); + + let b = random_chars(k); + println!("{:?}", b); + } +} diff --git a/Cargo.toml b/Cargo.toml index b514159..968f6f4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,6 +9,7 @@ members = [ "4_GKR", "5_Fiat_Shamir", "5_ni_sumcheck", + "7_low_degree_test", "7_Merkle_tree_commtment", "15_kzg", ]