From 63a2833319f5fbea634c86c61891585abdea73d7 Mon Sep 17 00:00:00 2001 From: han0110 Date: Mon, 25 Sep 2023 10:38:04 +0000 Subject: [PATCH 1/3] feat: implement GKR for fractional sumchecks described in https://eprint.iacr.org/2023/1284.pdf --- plonkish_backend/src/piop.rs | 1 + plonkish_backend/src/piop/gkr.rs | 3 + .../src/piop/gkr/fractional_sum.rs | 279 ++++++++++++++++++ .../src/piop/sum_check/classic/coeff.rs | 139 ++++----- 4 files changed, 356 insertions(+), 66 deletions(-) create mode 100644 plonkish_backend/src/piop/gkr.rs create mode 100644 plonkish_backend/src/piop/gkr/fractional_sum.rs diff --git a/plonkish_backend/src/piop.rs b/plonkish_backend/src/piop.rs index 6ce0f90e..5739e324 100644 --- a/plonkish_backend/src/piop.rs +++ b/plonkish_backend/src/piop.rs @@ -1 +1,2 @@ +pub mod gkr; pub mod sum_check; diff --git a/plonkish_backend/src/piop/gkr.rs b/plonkish_backend/src/piop/gkr.rs new file mode 100644 index 00000000..3b5f011f --- /dev/null +++ b/plonkish_backend/src/piop/gkr.rs @@ -0,0 +1,3 @@ +mod fractional_sum; + +pub use fractional_sum::{prove_fractional_sum, verify_fractional_sum}; diff --git a/plonkish_backend/src/piop/gkr/fractional_sum.rs b/plonkish_backend/src/piop/gkr/fractional_sum.rs new file mode 100644 index 00000000..6f4f9f47 --- /dev/null +++ b/plonkish_backend/src/piop/gkr/fractional_sum.rs @@ -0,0 +1,279 @@ +//! Implementation of GKR for fractional sumchecks in [PH23]. +//! Notations are same as in section 3. +//! +//! [PH23]: https://eprint.iacr.org/2023/1284.pdf + +use crate::{ + piop::sum_check::{ + classic::{ClassicSumCheck, EvaluationsProver}, + eq_xy_eval, SumCheck, VirtualPolynomial, + }, + poly::{multilinear::MultilinearPolynomial, Polynomial}, + util::{ + arithmetic::{div_ceil, PrimeField}, + chain, + expression::{Expression, Query, Rotation}, + izip, + parallel::{num_threads, parallelize_iter}, + transcript::{FieldTranscriptRead, FieldTranscriptWrite}, + Itertools, + }, + Error, +}; +use std::{array, iter}; + +struct Layer { + p_l: MultilinearPolynomial, + p_r: MultilinearPolynomial, + q_l: MultilinearPolynomial, + q_r: MultilinearPolynomial, +} + +impl From<[Vec; 4]> for Layer { + fn from(values: [Vec; 4]) -> Self { + let [p_l, p_r, q_l, q_r] = values.map(MultilinearPolynomial::new); + Self { p_l, p_r, q_l, q_r } + } +} + +impl Layer { + fn initial(p: &[F], q: &[F]) -> Self { + let mid = p.len() >> 1; + [&p[..mid], &p[mid..], &q[..mid], &q[mid..]] + .map(ToOwned::to_owned) + .into() + } + + fn num_vars(&self) -> usize { + self.p_l.num_vars() + } + + fn polys(&self) -> [&MultilinearPolynomial; 4] { + [&self.p_l, &self.p_r, &self.q_l, &self.q_r] + } + + fn poly_chunks(&self, chunk_size: usize) -> impl Iterator { + let [p_l, p_r, q_l, q_r] = self.polys().map(|poly| poly.evals().chunks(chunk_size)); + izip!(p_l, p_r, q_l, q_r) + } +} + +pub fn prove_fractional_sum( + claimed_p: Option, + claimed_q: Option, + p: &[F], + q: &[F], + transcript: &mut impl FieldTranscriptWrite, +) -> Result<(F, F, F, F, Vec), Error> { + assert_eq!(p.len(), q.len()); + assert!(p.len().is_power_of_two()); + + let num_threads = num_threads(); + + let initial_layer = Layer::initial(p, q); + let layers = iter::successors(Some(initial_layer), |layer| { + let len = 1 << layer.num_vars(); + let chunk_size = div_ceil(len, num_threads).next_power_of_two(); + (len > 1).then(|| { + let mut outputs: [_; 4] = array::from_fn(|_| vec![F::ZERO; len >> 1]); + let (p, q) = outputs.split_at_mut(2); + parallelize_iter( + izip!( + chain![p].flat_map(|p| p.chunks_mut(chunk_size)), + chain![q].flat_map(|q| q.chunks_mut(chunk_size)), + layer.poly_chunks(chunk_size), + ), + |(p, q, (p_l, p_r, q_l, q_r))| { + izip!(p, q, p_l, p_r, q_l, q_r).for_each(|(p, q, p_l, p_r, q_l, q_r)| { + *p = *p_l * q_r + *p_r * q_l; + *q = *q_l * q_r; + }) + }, + ); + outputs.into() + }) + }) + .collect_vec(); + + let [claimed_p, claimed_q]: [_; 2] = { + let [p_l, p_r, q_l, q_r] = layers.last().unwrap().polys().map(|poly| poly[0]); + let (p, q) = (p_l * q_r + p_r * q_l, q_l * q_r); + + [(claimed_p, p), (claimed_q, q)] + .into_iter() + .map(|(claimed, computed)| match claimed { + Some(claimed) => { + if cfg!(feature = "sanity-check") { + assert_eq!(claimed, computed) + } + transcript.common_field_element(&computed).map(|_| claimed) + } + None => transcript.write_field_element(&computed).map(|_| computed), + }) + .try_collect::<_, Vec<_>, _>()? + .try_into() + .unwrap() + }; + + let expression = { + let [p_l, p_r, q_l, q_r] = + &array::from_fn(|idx| Expression::Polynomial(Query::new(idx, Rotation::cur()))); + let eq_xy = &Expression::eq_xy(0); + let gamma = &Expression::Challenge(0); + (p_l * q_r + p_r * q_l + gamma * q_l * q_r) * eq_xy + }; + + let (p, q, challenges) = + layers + .iter() + .rev() + .fold(Ok((claimed_p, claimed_q, Vec::new())), |result, layer| { + let (claimed_p, claimed_q, y) = result?; + let num_vars = layer.num_vars(); + + let (mut challenges, evals) = if num_vars == 0 { + (vec![], layer.polys().map(|poly| poly[0])) + } else { + let gamma = transcript.squeeze_challenge(); + + let claim = claimed_p + gamma * claimed_q; + let (challenges, evals) = ClassicSumCheck::>::prove( + &(), + num_vars, + VirtualPolynomial::new(&expression, layer.polys(), &[gamma], &[y]), + claim, + transcript, + )?; + + (challenges, evals.try_into().unwrap()) + }; + + transcript.write_field_elements(&evals)?; + + let mu = transcript.squeeze_challenge(); + + let [p_l, p_r, q_l, q_r] = evals; + let p = p_l + mu * (p_r - p_l); + let q = q_l + mu * (q_r - q_l); + challenges.push(mu); + + Ok((p, q, challenges)) + })?; + + if cfg!(feature = "sanity-check") { + let [p_l, p_r, q_l, q_r] = layers[0].polys().map(|poly| poly.evals().to_vec()); + let p_poly = MultilinearPolynomial::new([p_l, p_r].concat()); + let q_poly = MultilinearPolynomial::new([q_l, q_r].concat()); + assert_eq!(p_poly.evaluate(&challenges), p); + assert_eq!(q_poly.evaluate(&challenges), q); + } + + Ok((claimed_p, claimed_q, p, q, challenges)) +} + +pub fn verify_fractional_sum( + num_vars: usize, + claimed_p: Option, + claimed_q: Option, + transcript: &mut impl FieldTranscriptRead, +) -> Result<(F, F, F, F, Vec), Error> { + let [claimed_p, claimed_q]: [_; 2] = { + [claimed_p, claimed_q] + .into_iter() + .map(|claimed| match claimed { + Some(claimed) => transcript.common_field_element(&claimed).map(|_| claimed), + None => transcript.read_field_element(), + }) + .try_collect::<_, Vec<_>, _>()? + .try_into() + .unwrap() + }; + + let (p, q, challenges) = (0..num_vars).fold( + Ok((claimed_p, claimed_q, Vec::new())), + |result, num_vars| { + let (claimed_p, claimed_q, y) = result?; + + let (mut challenges, evals) = if num_vars == 0 { + let evals: [_; 4] = transcript.read_field_elements(4)?.try_into().unwrap(); + let [p_l, p_r, q_l, q_r] = evals; + + if claimed_p != p_l * q_r + p_r * q_l || claimed_q != q_l * q_r { + return Err(err_unmatched_sum_check_output()); + } + + (Vec::new(), evals) + } else { + let gamma = transcript.squeeze_challenge(); + + let claim = claimed_p + gamma * claimed_q; + let (eval, challenges) = ClassicSumCheck::>::verify( + &(), + num_vars, + 3, + claim, + transcript, + )?; + + let evals: [_; 4] = transcript.read_field_elements(4)?.try_into().unwrap(); + let [p_l, p_r, q_l, q_r] = evals; + + if eval != (p_l * q_r + p_r * q_l + gamma * q_l * q_r) * eq_xy_eval(&challenges, &y) + { + return Err(err_unmatched_sum_check_output()); + } + + (challenges, evals) + }; + + let mu = transcript.squeeze_challenge(); + + let [p_l, p_r, q_l, q_r] = evals; + let p = p_l + mu * (p_r - p_l); + let q = q_l + mu * (q_r - q_l); + challenges.push(mu); + + Ok((p, q, challenges)) + }, + )?; + + Ok((claimed_p, claimed_q, p, q, challenges)) +} + +fn err_unmatched_sum_check_output() -> Error { + Error::InvalidSumcheck("Unmatched between sum_check output and query evaluation".to_string()) +} + +#[cfg(test)] +mod test { + use crate::{ + piop::gkr::fractional_sum::{prove_fractional_sum, verify_fractional_sum}, + util::{ + test::{rand_vec, seeded_std_rng}, + transcript::{InMemoryTranscript, Keccak256Transcript}, + }, + }; + use halo2_curves::bn256::Fr; + + #[test] + fn fractional_sum() { + for num_vars in 1..16 { + let mut rng = seeded_std_rng(); + + let p = rand_vec(1 << num_vars, &mut rng); + let q = rand_vec(1 << num_vars, &mut rng); + + let proof = { + let mut transcript = Keccak256Transcript::new(()); + prove_fractional_sum::(None, None, &p, &q, &mut transcript).unwrap(); + transcript.into_proof() + }; + + let result = { + let mut transcript = Keccak256Transcript::from_proof((), proof.as_slice()); + verify_fractional_sum::(num_vars, None, None, &mut transcript) + }; + assert_eq!(result.map(|_| ()), Ok(())); + } + } +} diff --git a/plonkish_backend/src/piop/sum_check/classic/coeff.rs b/plonkish_backend/src/piop/sum_check/classic/coeff.rs index 0635d2ea..78c39c76 100644 --- a/plonkish_backend/src/piop/sum_check/classic/coeff.rs +++ b/plonkish_backend/src/piop/sum_check/classic/coeff.rs @@ -1,17 +1,17 @@ use crate::{ piop::sum_check::classic::{ClassicSumCheckProver, ClassicSumCheckRoundMessage, ProverState}, - poly::multilinear::zip_self, + poly::multilinear::{zip_self, MultilinearPolynomial}, util::{ arithmetic::{div_ceil, horner, PrimeField}, expression::{CommonPolynomial, Expression, Rotation}, - impl_index, + impl_index, izip_eq, parallel::{num_threads, parallelize_iter}, transcript::{FieldTranscriptRead, FieldTranscriptWrite}, Itertools, }, Error, }; -use std::{fmt::Debug, iter, ops::AddAssign}; +use std::{array, fmt::Debug, iter, ops::AddAssign}; #[derive(Debug)] pub struct Coefficients(Vec); @@ -63,7 +63,10 @@ impl<'rhs, F: PrimeField> AddAssign<(&'rhs F, &'rhs Coefficients)> for Coeffi impl_index!(Coefficients, 0); #[derive(Clone, Debug)] -pub struct CoefficientsProver(F, Vec<(F, Vec>)>); +pub struct CoefficientsProver { + constant: F, + products: Vec<(F, Vec>)>, +} impl ClassicSumCheckProver for CoefficientsProver where @@ -72,7 +75,7 @@ where type RoundMessage = Coefficients; fn new(state: &ProverState) -> Self { - let (constant, flattened) = state.expression.evaluate( + let (constant, products) = state.expression.evaluate( &|constant| (constant, vec![]), &|poly| { ( @@ -127,21 +130,21 @@ where (constant * &rhs, products) }, ); - Self(constant, flattened) + Self { constant, products } } fn prove_round(&self, state: &ProverState) -> Self::RoundMessage { let mut coeffs = Coefficients(vec![F::ZERO; state.expression.degree() + 1]); - coeffs += &(F::from(state.size() as u64) * &self.0); - if self.1.iter().all(|(_, products)| products.len() == 2) { - for (scalar, products) in self.1.iter() { - let [lhs, rhs] = [0, 1].map(|idx| &products[idx]); - coeffs += (scalar, &self.karatsuba::(state, lhs, rhs)); + coeffs += &(F::from(state.size() as u64) * &self.constant); + + for (scalar, products) in self.products.iter() { + match products.len() { + 2 => coeffs += (scalar, &self.karatsuba::(state, products)), + _ => unimplemented!(), } - coeffs[1] = state.sum - coeffs[0].double() - coeffs[2]; - } else { - unimplemented!() } + + coeffs[1] = state.sum - coeffs.sum(); coeffs } } @@ -150,60 +153,64 @@ impl CoefficientsProver { fn karatsuba( &self, state: &ProverState, - lhs: &Expression, - rhs: &Expression, + items: &[Expression], ) -> Coefficients { - let mut coeffs = [F::ZERO; 3]; - match (lhs, rhs) { - ( - Expression::CommonPolynomial(CommonPolynomial::EqXY(idx)), - Expression::Polynomial(query), + debug_assert_eq!(items.len(), 2); + + let [lhs, rhs] = array::from_fn(|idx| poly(state, &items[idx])); + let evaluate_serial = |coeffs: &mut [F; 3], start: usize, n: usize| { + izip_eq!( + zip_self!(lhs.iter(), 2, start), + zip_self!(rhs.iter(), 2, start) ) - | ( - Expression::Polynomial(query), - Expression::CommonPolynomial(CommonPolynomial::EqXY(idx)), - ) if query.rotation() == Rotation::cur() => { - let lhs = &state.eq_xys[*idx]; - let rhs = &state.polys[query.poly()][state.num_vars]; - - let evaluate_serial = |coeffs: &mut [F; 3], start: usize, n: usize| { - zip_self!(lhs.iter(), 2, start) - .zip(zip_self!(rhs.iter(), 2, start)) - .take(n) - .for_each(|((lhs_0, lhs_1), (rhs_0, rhs_1))| { - let coeff_0 = *lhs_0 * rhs_0; - let coeff_2 = (*lhs_1 - lhs_0) * &(*rhs_1 - rhs_0); - coeffs[0] += &coeff_0; - coeffs[2] += &coeff_2; - if !LAZY { - coeffs[1] += &(*lhs_1 * rhs_1 - &coeff_0 - &coeff_2); - } - }); - }; - - let num_threads = num_threads(); - if state.size() < num_threads { - evaluate_serial(&mut coeffs, 0, state.size()); - } else { - let chunk_size = div_ceil(state.size(), num_threads); - let mut partials = vec![[F::ZERO; 3]; num_threads]; - parallelize_iter( - partials.iter_mut().zip((0..).step_by(chunk_size << 1)), - |(partial, start)| { - evaluate_serial(partial, start, chunk_size); - }, - ); - partials.iter().for_each(|partial| { - coeffs[0] += partial[0]; - coeffs[2] += partial[2]; - if !LAZY { - coeffs[1] += partial[1]; - } - }) - }; - } - _ => unimplemented!(), - } + .take(n) + .for_each(|((lhs_0, lhs_1), (rhs_0, rhs_1))| { + let eval_0 = *lhs_0 * rhs_0; + let eval_2 = (*lhs_1 - lhs_0) * &(*rhs_1 - rhs_0); + coeffs[0] += &eval_0; + coeffs[2] += &eval_2; + if !LAZY { + coeffs[1] += &(*lhs_1 * rhs_1 - &eval_0 - &eval_2); + } + }); + }; + + let mut coeffs = [F::ZERO; 3]; + + let num_threads = num_threads(); + if state.size() < 16 { + evaluate_serial(&mut coeffs, 0, state.size()); + } else { + let chunk_size = div_ceil(state.size(), num_threads); + let mut partials = vec![[F::ZERO; 3]; num_threads]; + parallelize_iter( + partials.iter_mut().zip((0..).step_by(chunk_size << 1)), + |(partial, start)| { + evaluate_serial(partial, start, chunk_size); + }, + ); + partials.iter().for_each(|partial| { + coeffs[0] += partial[0]; + coeffs[2] += partial[2]; + if !LAZY { + coeffs[1] += partial[1]; + } + }) + }; + Coefficients(coeffs.to_vec()) } } + +fn poly<'a, F: PrimeField>( + state: &'a ProverState, + expr: &Expression, +) -> &'a MultilinearPolynomial { + match expr { + Expression::CommonPolynomial(CommonPolynomial::EqXY(idx)) => &state.eq_xys[*idx], + Expression::Polynomial(query) if query.rotation() == Rotation::cur() => { + &state.polys[query.poly()][state.num_vars] + } + _ => unimplemented!(), + } +} From b8416df81c16c7d0413d343e13ad32181ffef282 Mon Sep 17 00:00:00 2001 From: han0110 Date: Wed, 27 Sep 2023 09:23:07 +0000 Subject: [PATCH 2/3] feat: make `fractional_sum` support batching --- .../src/piop/gkr/fractional_sum.rs | 349 +++++++++++------- 1 file changed, 214 insertions(+), 135 deletions(-) diff --git a/plonkish_backend/src/piop/gkr/fractional_sum.rs b/plonkish_backend/src/piop/gkr/fractional_sum.rs index 6f4f9f47..47c5ce26 100644 --- a/plonkish_backend/src/piop/gkr/fractional_sum.rs +++ b/plonkish_backend/src/piop/gkr/fractional_sum.rs @@ -6,11 +6,11 @@ use crate::{ piop::sum_check::{ classic::{ClassicSumCheck, EvaluationsProver}, - eq_xy_eval, SumCheck, VirtualPolynomial, + evaluate, SumCheck as _, VirtualPolynomial, }, poly::{multilinear::MultilinearPolynomial, Polynomial}, util::{ - arithmetic::{div_ceil, PrimeField}, + arithmetic::{div_ceil, inner_product, powers, PrimeField}, chain, expression::{Expression, Query, Rotation}, izip, @@ -20,7 +20,9 @@ use crate::{ }, Error, }; -use std::{array, iter}; +use std::{array, collections::HashMap, iter}; + +type SumCheck = ClassicSumCheck>; struct Layer { p_l: MultilinearPolynomial, @@ -37,8 +39,8 @@ impl From<[Vec; 4]> for Layer { } impl Layer { - fn initial(p: &[F], q: &[F]) -> Self { - let mid = p.len() >> 1; + fn bottom((p, q): (&&MultilinearPolynomial, &&MultilinearPolynomial)) -> Self { + let mid = p.evals().len() >> 1; [&p[..mid], &p[mid..], &q[..mid], &q[mid..]] .map(ToOwned::to_owned) .into() @@ -56,188 +58,249 @@ impl Layer { let [p_l, p_r, q_l, q_r] = self.polys().map(|poly| poly.evals().chunks(chunk_size)); izip!(p_l, p_r, q_l, q_r) } + + fn up(&self) -> Self { + assert!(self.num_vars() != 0); + + let len = 1 << self.num_vars(); + let chunk_size = div_ceil(len, num_threads()).next_power_of_two(); + + let mut outputs: [_; 4] = array::from_fn(|_| vec![F::ZERO; len >> 1]); + let (p, q) = outputs.split_at_mut(2); + parallelize_iter( + izip!( + chain![p].flat_map(|p| p.chunks_mut(chunk_size)), + chain![q].flat_map(|q| q.chunks_mut(chunk_size)), + self.poly_chunks(chunk_size), + ), + |(p, q, (p_l, p_r, q_l, q_r))| { + izip!(p, q, p_l, p_r, q_l, q_r).for_each(|(p, q, p_l, p_r, q_l, q_r)| { + *p = *p_l * q_r + *p_r * q_l; + *q = *q_l * q_r; + }) + }, + ); + + outputs.into() + } } -pub fn prove_fractional_sum( - claimed_p: Option, - claimed_q: Option, - p: &[F], - q: &[F], +#[allow(clippy::type_complexity)] +pub fn prove_fractional_sum<'a, F: PrimeField>( + claimed_p_0s: impl IntoIterator>, + claimed_q_0s: impl IntoIterator>, + ps: impl IntoIterator>, + qs: impl IntoIterator>, transcript: &mut impl FieldTranscriptWrite, -) -> Result<(F, F, F, F, Vec), Error> { - assert_eq!(p.len(), q.len()); - assert!(p.len().is_power_of_two()); - - let num_threads = num_threads(); - - let initial_layer = Layer::initial(p, q); - let layers = iter::successors(Some(initial_layer), |layer| { - let len = 1 << layer.num_vars(); - let chunk_size = div_ceil(len, num_threads).next_power_of_two(); - (len > 1).then(|| { - let mut outputs: [_; 4] = array::from_fn(|_| vec![F::ZERO; len >> 1]); - let (p, q) = outputs.split_at_mut(2); - parallelize_iter( - izip!( - chain![p].flat_map(|p| p.chunks_mut(chunk_size)), - chain![q].flat_map(|q| q.chunks_mut(chunk_size)), - layer.poly_chunks(chunk_size), - ), - |(p, q, (p_l, p_r, q_l, q_r))| { - izip!(p, q, p_l, p_r, q_l, q_r).for_each(|(p, q, p_l, p_r, q_l, q_r)| { - *p = *p_l * q_r + *p_r * q_l; - *q = *q_l * q_r; - }) - }, - ); - outputs.into() - }) +) -> Result<(Vec, Vec, Vec), Error> { + let claimed_p_0s = claimed_p_0s.into_iter().collect_vec(); + let claimed_q_0s = claimed_q_0s.into_iter().collect_vec(); + let ps = ps.into_iter().collect_vec(); + let qs = qs.into_iter().collect_vec(); + let num_batching = claimed_p_0s.len(); + + assert!(num_batching != 0); + assert_eq!(num_batching, claimed_q_0s.len()); + assert_eq!(num_batching, ps.len()); + assert_eq!(num_batching, qs.len()); + for poly in chain![&ps, &qs] { + assert_eq!(poly.num_vars(), ps[0].num_vars()); + } + + let bottom_layers = izip!(&ps, &qs).map(Layer::bottom).collect_vec(); + let layers = iter::successors(bottom_layers.into(), |layers| { + (layers[0].num_vars() > 0).then(|| layers.iter().map(Layer::up).collect()) }) .collect_vec(); - let [claimed_p, claimed_q]: [_; 2] = { - let [p_l, p_r, q_l, q_r] = layers.last().unwrap().polys().map(|poly| poly[0]); - let (p, q) = (p_l * q_r + p_r * q_l, q_l * q_r); - - [(claimed_p, p), (claimed_q, q)] - .into_iter() - .map(|(claimed, computed)| match claimed { - Some(claimed) => { - if cfg!(feature = "sanity-check") { - assert_eq!(claimed, computed) - } - transcript.common_field_element(&computed).map(|_| claimed) - } - None => transcript.write_field_element(&computed).map(|_| computed), + let [claimed_p_0s, claimed_q_0s]: [_; 2] = { + let (p_0s, q_0s) = chain![layers.last().unwrap()] + .map(|layer| { + let [p_l, p_r, q_l, q_r] = layer.polys().map(|poly| poly[0]); + (p_l * q_r + p_r * q_l, q_l * q_r) }) - .try_collect::<_, Vec<_>, _>()? - .try_into() - .unwrap() + .unzip::<_, _, Vec<_>, Vec<_>>(); + + let mut hash_to_transcript = |claimed: Vec<_>, computed: Vec<_>| { + izip!(claimed, computed) + .map(|(claimed, computed)| match claimed { + Some(claimed) => { + if cfg!(feature = "sanity-check") { + assert_eq!(claimed, computed) + } + transcript.common_field_element(&computed).map(|_| computed) + } + None => transcript.write_field_element(&computed).map(|_| computed), + }) + .try_collect::<_, Vec<_>, _>() + }; + + [ + hash_to_transcript(claimed_p_0s, p_0s)?, + hash_to_transcript(claimed_q_0s, q_0s)?, + ] }; - let expression = { - let [p_l, p_r, q_l, q_r] = - &array::from_fn(|idx| Expression::Polynomial(Query::new(idx, Rotation::cur()))); - let eq_xy = &Expression::eq_xy(0); - let gamma = &Expression::Challenge(0); - (p_l * q_r + p_r * q_l + gamma * q_l * q_r) * eq_xy - }; + let expression = sum_check_expression(num_batching); + + let (p_xs, q_xs, x) = layers.iter().rev().fold( + Ok((claimed_p_0s, claimed_q_0s, Vec::new())), + |result, layers| { + let (claimed_p_ys, claimed_q_ys, y) = result?; + + let num_vars = layers[0].num_vars(); + let polys = layers.iter().flat_map(|layer| layer.polys()); - let (p, q, challenges) = - layers - .iter() - .rev() - .fold(Ok((claimed_p, claimed_q, Vec::new())), |result, layer| { - let (claimed_p, claimed_q, y) = result?; - let num_vars = layer.num_vars(); - - let (mut challenges, evals) = if num_vars == 0 { - (vec![], layer.polys().map(|poly| poly[0])) - } else { - let gamma = transcript.squeeze_challenge(); - - let claim = claimed_p + gamma * claimed_q; - let (challenges, evals) = ClassicSumCheck::>::prove( + let (mut x, evals) = if num_vars == 0 { + (vec![], polys.map(|poly| poly[0]).collect_vec()) + } else { + let gamma = transcript.squeeze_challenge(); + + let (x, evals) = { + let claim = sum_check_claim(&claimed_p_ys, &claimed_q_ys, gamma); + SumCheck::prove( &(), num_vars, - VirtualPolynomial::new(&expression, layer.polys(), &[gamma], &[y]), + VirtualPolynomial::new(&expression, polys, &[gamma], &[y]), claim, transcript, - )?; - - (challenges, evals.try_into().unwrap()) + )? }; - transcript.write_field_elements(&evals)?; + (x, evals) + }; + + transcript.write_field_elements(&evals)?; - let mu = transcript.squeeze_challenge(); + let mu = transcript.squeeze_challenge(); - let [p_l, p_r, q_l, q_r] = evals; - let p = p_l + mu * (p_r - p_l); - let q = q_l + mu * (q_r - q_l); - challenges.push(mu); + let (p_xs, q_xs) = layer_down_claim(&evals, mu); + x.push(mu); - Ok((p, q, challenges)) - })?; + Ok((p_xs, q_xs, x)) + }, + )?; if cfg!(feature = "sanity-check") { - let [p_l, p_r, q_l, q_r] = layers[0].polys().map(|poly| poly.evals().to_vec()); - let p_poly = MultilinearPolynomial::new([p_l, p_r].concat()); - let q_poly = MultilinearPolynomial::new([q_l, q_r].concat()); - assert_eq!(p_poly.evaluate(&challenges), p); - assert_eq!(q_poly.evaluate(&challenges), q); + izip!(chain![ps, qs], chain![&p_xs, &q_xs]) + .for_each(|(poly, eval)| assert_eq!(poly.evaluate(&x), *eval)); } - Ok((claimed_p, claimed_q, p, q, challenges)) + Ok((p_xs, q_xs, x)) } +#[allow(clippy::type_complexity)] pub fn verify_fractional_sum( num_vars: usize, - claimed_p: Option, - claimed_q: Option, + claimed_p_0s: impl IntoIterator>, + claimed_q_0s: impl IntoIterator>, transcript: &mut impl FieldTranscriptRead, -) -> Result<(F, F, F, F, Vec), Error> { - let [claimed_p, claimed_q]: [_; 2] = { - [claimed_p, claimed_q] +) -> Result<(Vec, Vec, Vec), Error> { + let claimed_p_0s = claimed_p_0s.into_iter().collect_vec(); + let claimed_q_0s = claimed_q_0s.into_iter().collect_vec(); + let num_batching = claimed_p_0s.len(); + + assert!(num_batching != 0); + assert_eq!(num_batching, claimed_q_0s.len()); + + let [claimed_p_0s, claimed_q_0s]: [_; 2] = { + [claimed_p_0s, claimed_q_0s] .into_iter() - .map(|claimed| match claimed { - Some(claimed) => transcript.common_field_element(&claimed).map(|_| claimed), - None => transcript.read_field_element(), + .map(|claimed| { + claimed + .into_iter() + .map(|claimed| match claimed { + Some(claimed) => transcript.common_field_element(&claimed).map(|_| claimed), + None => transcript.read_field_element(), + }) + .try_collect::<_, Vec<_>, _>() }) .try_collect::<_, Vec<_>, _>()? .try_into() .unwrap() }; - let (p, q, challenges) = (0..num_vars).fold( - Ok((claimed_p, claimed_q, Vec::new())), + let expression = sum_check_expression(num_batching); + + let (p_xs, q_xs, x) = (0..num_vars).fold( + Ok((claimed_p_0s, claimed_q_0s, Vec::new())), |result, num_vars| { - let (claimed_p, claimed_q, y) = result?; + let (claimed_p_ys, claimed_q_ys, y) = result?; - let (mut challenges, evals) = if num_vars == 0 { - let evals: [_; 4] = transcript.read_field_elements(4)?.try_into().unwrap(); - let [p_l, p_r, q_l, q_r] = evals; + let (mut x, evals) = if num_vars == 0 { + let evals = transcript.read_field_elements(4 * num_batching)?; - if claimed_p != p_l * q_r + p_r * q_l || claimed_q != q_l * q_r { - return Err(err_unmatched_sum_check_output()); + for (claimed_p, claimed_q, (&p_l, &p_r, &q_l, &q_r)) in + izip!(claimed_p_ys, claimed_q_ys, evals.iter().tuples()) + { + if claimed_p != p_l * q_r + p_r * q_l || claimed_q != q_l * q_r { + return Err(err_unmatched_sum_check_output()); + } } (Vec::new(), evals) } else { let gamma = transcript.squeeze_challenge(); - let claim = claimed_p + gamma * claimed_q; - let (eval, challenges) = ClassicSumCheck::>::verify( - &(), - num_vars, - 3, - claim, - transcript, - )?; + let (x_eval, x) = { + let claim = sum_check_claim(&claimed_p_ys, &claimed_q_ys, gamma); + SumCheck::verify(&(), num_vars, expression.degree(), claim, transcript)? + }; - let evals: [_; 4] = transcript.read_field_elements(4)?.try_into().unwrap(); - let [p_l, p_r, q_l, q_r] = evals; + let evals = transcript.read_field_elements(4 * num_batching)?; - if eval != (p_l * q_r + p_r * q_l + gamma * q_l * q_r) * eq_xy_eval(&challenges, &y) - { + let eval_by_query = eval_by_query(&evals); + if x_eval != evaluate(&expression, num_vars, &eval_by_query, &[gamma], &[&y], &x) { return Err(err_unmatched_sum_check_output()); } - (challenges, evals) + (x, evals) }; let mu = transcript.squeeze_challenge(); - let [p_l, p_r, q_l, q_r] = evals; - let p = p_l + mu * (p_r - p_l); - let q = q_l + mu * (q_r - q_l); - challenges.push(mu); + let (p_xs, q_xs) = layer_down_claim(&evals, mu); + x.push(mu); - Ok((p, q, challenges)) + Ok((p_xs, q_xs, x)) }, )?; - Ok((claimed_p, claimed_q, p, q, challenges)) + Ok((p_xs, q_xs, x)) +} + +fn sum_check_expression(num_batching: usize) -> Expression { + let exprs = &(0..4 * num_batching) + .map(|idx| Expression::::Polynomial(Query::new(idx, Rotation::cur()))) + .tuples() + .flat_map(|(ref p_l, ref p_r, ref q_l, ref q_r)| [p_l * q_r + p_r * q_l, q_l * q_r]) + .collect_vec(); + let eq_xy = &Expression::eq_xy(0); + let gamma = &Expression::Challenge(0); + Expression::distribute_powers(exprs, gamma) * eq_xy +} + +fn sum_check_claim(claimed_p_ys: &[F], claimed_q_ys: &[F], gamma: F) -> F { + inner_product( + izip!(claimed_p_ys, claimed_q_ys).flat_map(|(p, q)| [p, q]), + &powers(gamma).take(claimed_p_ys.len() * 2).collect_vec(), + ) +} + +fn layer_down_claim(evals: &[F], mu: F) -> (Vec, Vec) { + evals + .iter() + .tuples() + .map(|(&p_l, &p_r, &q_l, &q_r)| (p_l + mu * (p_r - p_l), q_l + mu * (q_r - q_l))) + .unzip() +} + +fn eval_by_query(evals: &[F]) -> HashMap { + izip!( + (0..).map(|idx| Query::new(idx, Rotation::cur())), + evals.iter().cloned() + ) + .collect() } fn err_unmatched_sum_check_output() -> Error { @@ -248,32 +311,48 @@ fn err_unmatched_sum_check_output() -> Error { mod test { use crate::{ piop::gkr::fractional_sum::{prove_fractional_sum, verify_fractional_sum}, + poly::multilinear::MultilinearPolynomial, util::{ + chain, izip_eq, test::{rand_vec, seeded_std_rng}, transcript::{InMemoryTranscript, Keccak256Transcript}, + Itertools, }, }; use halo2_curves::bn256::Fr; + use std::iter; #[test] fn fractional_sum() { + let num_batching = 3; for num_vars in 1..16 { let mut rng = seeded_std_rng(); - let p = rand_vec(1 << num_vars, &mut rng); - let q = rand_vec(1 << num_vars, &mut rng); + let polys = iter::repeat_with(|| rand_vec(1 << num_vars, &mut rng)) + .map(MultilinearPolynomial::new) + .take(2 * num_batching) + .collect_vec(); + let claims = vec![None; 2 * num_batching]; + let (ps, qs) = polys.split_at(num_batching); + let (p_0s, q_0s) = claims.split_at(num_batching); let proof = { let mut transcript = Keccak256Transcript::new(()); - prove_fractional_sum::(None, None, &p, &q, &mut transcript).unwrap(); + prove_fractional_sum::(p_0s.to_vec(), q_0s.to_vec(), ps, qs, &mut transcript) + .unwrap(); transcript.into_proof() }; let result = { let mut transcript = Keccak256Transcript::from_proof((), proof.as_slice()); - verify_fractional_sum::(num_vars, None, None, &mut transcript) + verify_fractional_sum::(num_vars, p_0s.to_vec(), q_0s.to_vec(), &mut transcript) }; - assert_eq!(result.map(|_| ()), Ok(())); + assert_eq!(result.as_ref().map(|_| ()), Ok(())); + + let (p_xs, q_xs, x) = result.unwrap(); + for (poly, eval) in izip_eq!(chain![ps, qs], chain![p_xs, q_xs]) { + assert_eq!(poly.evaluate(&x), eval); + } } } } From 745501e8eb4aaa0939ff348cc4988f375f5c56eb Mon Sep 17 00:00:00 2001 From: han0110 Date: Wed, 27 Sep 2023 11:25:36 +0000 Subject: [PATCH 3/3] chore: rename --- plonkish_backend/src/piop/gkr.rs | 4 +-- ...ctional_sum.rs => fractional_sum_check.rs} | 27 ++++++++++++++----- 2 files changed, 22 insertions(+), 9 deletions(-) rename plonkish_backend/src/piop/gkr/{fractional_sum.rs => fractional_sum_check.rs} (94%) diff --git a/plonkish_backend/src/piop/gkr.rs b/plonkish_backend/src/piop/gkr.rs index 3b5f011f..b26907c7 100644 --- a/plonkish_backend/src/piop/gkr.rs +++ b/plonkish_backend/src/piop/gkr.rs @@ -1,3 +1,3 @@ -mod fractional_sum; +mod fractional_sum_check; -pub use fractional_sum::{prove_fractional_sum, verify_fractional_sum}; +pub use fractional_sum_check::{prove_fractional_sum_check, verify_fractional_sum_check}; diff --git a/plonkish_backend/src/piop/gkr/fractional_sum.rs b/plonkish_backend/src/piop/gkr/fractional_sum_check.rs similarity index 94% rename from plonkish_backend/src/piop/gkr/fractional_sum.rs rename to plonkish_backend/src/piop/gkr/fractional_sum_check.rs index 47c5ce26..5e16213e 100644 --- a/plonkish_backend/src/piop/gkr/fractional_sum.rs +++ b/plonkish_backend/src/piop/gkr/fractional_sum_check.rs @@ -86,7 +86,7 @@ impl Layer { } #[allow(clippy::type_complexity)] -pub fn prove_fractional_sum<'a, F: PrimeField>( +pub fn prove_fractional_sum_check<'a, F: PrimeField>( claimed_p_0s: impl IntoIterator>, claimed_q_0s: impl IntoIterator>, ps: impl IntoIterator>, @@ -190,7 +190,7 @@ pub fn prove_fractional_sum<'a, F: PrimeField>( } #[allow(clippy::type_complexity)] -pub fn verify_fractional_sum( +pub fn verify_fractional_sum_check( num_vars: usize, claimed_p_0s: impl IntoIterator>, claimed_q_0s: impl IntoIterator>, @@ -310,7 +310,9 @@ fn err_unmatched_sum_check_output() -> Error { #[cfg(test)] mod test { use crate::{ - piop::gkr::fractional_sum::{prove_fractional_sum, verify_fractional_sum}, + piop::gkr::fractional_sum_check::{ + prove_fractional_sum_check, verify_fractional_sum_check, + }, poly::multilinear::MultilinearPolynomial, util::{ chain, izip_eq, @@ -323,7 +325,7 @@ mod test { use std::iter; #[test] - fn fractional_sum() { + fn fractional_sum_check() { let num_batching = 3; for num_vars in 1..16 { let mut rng = seeded_std_rng(); @@ -338,14 +340,25 @@ mod test { let proof = { let mut transcript = Keccak256Transcript::new(()); - prove_fractional_sum::(p_0s.to_vec(), q_0s.to_vec(), ps, qs, &mut transcript) - .unwrap(); + prove_fractional_sum_check::( + p_0s.to_vec(), + q_0s.to_vec(), + ps, + qs, + &mut transcript, + ) + .unwrap(); transcript.into_proof() }; let result = { let mut transcript = Keccak256Transcript::from_proof((), proof.as_slice()); - verify_fractional_sum::(num_vars, p_0s.to_vec(), q_0s.to_vec(), &mut transcript) + verify_fractional_sum_check::( + num_vars, + p_0s.to_vec(), + q_0s.to_vec(), + &mut transcript, + ) }; assert_eq!(result.as_ref().map(|_| ()), Ok(()));