Skip to content

Commit

Permalink
Implement GKR protocol. (#9)
Browse files Browse the repository at this point in the history
Dones:
* Impl GKR-sumcheck
* Impl GKR protocol

TODO:
- [ ]  GKR test failed.
  • Loading branch information
SuccinctPaul authored Aug 13, 2023
1 parent 8b7726a commit 8fbbc96
Show file tree
Hide file tree
Showing 17 changed files with 1,296 additions and 140 deletions.
168 changes: 84 additions & 84 deletions 2_univariate_lagrange_interpolation/src/polynomial.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,90 +116,90 @@ impl Polynomial {
}

// todo parallel
pub fn lagrange_interpolate_parallel(domains: Vec<Scalar>, evals: Vec<Scalar>) -> Self {
assert_eq!(domains.len(), evals.len());

if evals.len() == 1 {
// Constant polynomial
Self {
coeffs: vec![evals[0]],
}
} else {
let poly_size = domains.len();
let lag_basis_poly_size = poly_size - 1;

// 1. divisors = vec(x_j - x_k). prepare for L_j(X)=∏(X−x_k)/(x_j−x_k)
let mut divisors = Vec::with_capacity(poly_size);
for (j, x_j) in domains.iter().enumerate() {
// divisor_j
let mut divisor = Vec::with_capacity(lag_basis_poly_size);
// obtain domain for x_k
for x_k in domains
.iter()
.enumerate()
.filter(|&(k, _)| k != j)
.map(|(_, x)| x)
{
divisor.push(*x_j - x_k);
}
divisors.push(divisor);
}
// Inverse (x_j - x_k)^(-1) for each j != k to compute L_j(X)=∏(X−x_k)/(x_j−x_k)
divisors
.iter_mut()
.flat_map(|v| v.iter_mut())
.batch_invert();

// 2. Calculate L_j(X) : L_j(X)=∏(X−x_k) divisors_j
let mut L_j_vec: Vec<Vec<Scalar>> = Vec::with_capacity(poly_size);

for (j, divisor_j) in divisors.into_iter().enumerate() {
let mut L_j: Vec<Scalar> = Vec::with_capacity(poly_size);
L_j.push(Scalar::one());

// (X−x_k) divisors_j
let mut product = Vec::with_capacity(lag_basis_poly_size);

// obtain domain for x_k
for (x_k, divisor) in domains
.iter()
.enumerate()
.filter(|&(k, _)| k != j)
.map(|(_, x)| x)
.zip(divisor_j.into_iter())
{
product.resize(L_j.len() + 1, Scalar::zero());

// loop (poly_size + 1) round
// calculate L_j(X)=∏(X−x_k) with coefficient form.
for ((a, b), product) in L_j
.iter()
.chain(std::iter::once(&Scalar::zero()))
.zip(std::iter::once(&Scalar::zero()).chain(L_j.iter()))
.zip(product.iter_mut())
{
*product = *a * (-divisor * x_k) + *b * divisor;
}
std::mem::swap(&mut L_j, &mut product);
}

assert_eq!(L_j.len(), poly_size);
assert_eq!(product.len(), poly_size - 1);

L_j_vec.push(L_j);
}

// p(x)=∑y_j⋅L_j(X) in coefficients
let mut final_poly = vec![Scalar::zero(); poly_size];
// 3. p(x)=∑y_j⋅L_j(X)
for (L_j, y_j) in L_j_vec.iter().zip(evals) {
for (final_coeff, L_j_coeff) in final_poly.iter_mut().zip(L_j.into_iter()) {
*final_coeff += L_j_coeff * y_j;
}
}
Self { coeffs: final_poly }
}
}
// pub fn lagrange_interpolate_parallel(domains: Vec<Scalar>, evals: Vec<Scalar>) -> Self {
// assert_eq!(domains.len(), evals.len());
//
// if evals.len() == 1 {
// // Constant polynomial
// Self {
// coeffs: vec![evals[0]],
// }
// } else {
// let poly_size = domains.len();
// let lag_basis_poly_size = poly_size - 1;
//
// // 1. divisors = vec(x_j - x_k). prepare for L_j(X)=∏(X−x_k)/(x_j−x_k)
// let mut divisors = Vec::with_capacity(poly_size);
// for (j, x_j) in domains.iter().enumerate() {
// // divisor_j
// let mut divisor = Vec::with_capacity(lag_basis_poly_size);
// // obtain domain for x_k
// for x_k in domains
// .iter()
// .enumerate()
// .filter(|&(k, _)| k != j)
// .map(|(_, x)| x)
// {
// divisor.push(*x_j - x_k);
// }
// divisors.push(divisor);
// }
// // Inverse (x_j - x_k)^(-1) for each j != k to compute L_j(X)=∏(X−x_k)/(x_j−x_k)
// divisors
// .iter_mut()
// .flat_map(|v| v.iter_mut())
// .batch_invert();
//
// // 2. Calculate L_j(X) : L_j(X)=∏(X−x_k) divisors_j
// let mut L_j_vec: Vec<Vec<Scalar>> = Vec::with_capacity(poly_size);
//
// for (j, divisor_j) in divisors.into_iter().enumerate() {
// let mut L_j: Vec<Scalar> = Vec::with_capacity(poly_size);
// L_j.push(Scalar::one());
//
// // (X−x_k) divisors_j
// let mut product = Vec::with_capacity(lag_basis_poly_size);
//
// // obtain domain for x_k
// for (x_k, divisor) in domains
// .iter()
// .enumerate()
// .filter(|&(k, _)| k != j)
// .map(|(_, x)| x)
// .zip(divisor_j.into_iter())
// {
// product.resize(L_j.len() + 1, Scalar::zero());
//
// // loop (poly_size + 1) round
// // calculate L_j(X)=∏(X−x_k) with coefficient form.
// for ((a, b), product) in L_j
// .iter()
// .chain(std::iter::once(&Scalar::zero()))
// .zip(std::iter::once(&Scalar::zero()).chain(L_j.iter()))
// .zip(product.iter_mut())
// {
// *product = *a * (-divisor * x_k) + *b * divisor;
// }
// std::mem::swap(&mut L_j, &mut product);
// }
//
// assert_eq!(L_j.len(), poly_size);
// assert_eq!(product.len(), poly_size - 1);
//
// L_j_vec.push(L_j);
// }
//
// // p(x)=∑y_j⋅L_j(X) in coefficients
// let mut final_poly = vec![Scalar::zero(); poly_size];
// // 3. p(x)=∑y_j⋅L_j(X)
// for (L_j, y_j) in L_j_vec.iter().zip(evals) {
// for (final_coeff, L_j_coeff) in final_poly.iter_mut().zip(L_j.into_iter()) {
// *final_coeff += L_j_coeff * y_j;
// }
// }
// Self { coeffs: final_poly }
// }
// }

// This evaluates a polynomial (in coefficient form) at `x`.
pub fn evaluate(&self, x: Scalar) -> Scalar {
Expand Down
29 changes: 29 additions & 0 deletions 4_GKR/src/arithmetic.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// For each layer:
// Let Si denote the number of gates at layer i of the circuit C. Number the gates at layer i from 0 to Si − 1.
// Assume Si is a power of 2 and let $Si = 2^{k_i} $.
//
// Witness
// Wi : {0, 1}ki → F denote the function t􏰌hat takes as input a binary gate label,
// and outputs the corresponding gate’s value at layer i
// \widetilde{W_i}: multilinear extension(MLE) of Wi
// NOTE: Wi depend on input x to C.
//
//
// Constraints
// wiring predicate: that encodes which pairs of wires from layeri+1 are connected to a given gate at layeri in C.
// Let in_{1,i},in_{2,i}:{0,1}ki →{0,1}ki+1 denote the functions that take as input the label a of a gate at layer i of C,
// and respectively output the label of the first and second in-neighbor of gate a.
// eg: if gate a at layer i computes the sum of gates b and c at layer i + 1, then in_{1,i}(a) = b and in_{2,i}(a) = c.
//
// Define two functions, addi and multi , mapping {0, 1}^{ki +2ki+1} to {0, 1}, which together constitute the wiring predicate of layer i of C.
// These functions take as input three gate labels (a,b,c), and return 1 if and only if (b,c) = (in1,i(a),in2,i(a))
// and gate a is an addition (respectively, multiplication) gate.
//
// Let \widetilde{add_i} and \widetilde{mult_i} denote the multilinear extensions of addi and multi.
//
// NOTE: wiring predicate(addi, multi) depend only on the circuit C and not on the input x to C

pub mod layered_circuit;

// pub mod r1cs
// pub mod plonk
Loading

0 comments on commit 8fbbc96

Please sign in to comment.