Skip to content

Commit

Permalink
Refactor FRI implementation (#115)
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewmilson authored Oct 20, 2024
1 parent 8095ddb commit 3948dca
Show file tree
Hide file tree
Showing 10 changed files with 187 additions and 358 deletions.
17 changes: 7 additions & 10 deletions stwo_cairo_verifier/src/channel.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -129,13 +129,12 @@ pub impl ChannelImpl of ChannelTrait {
fn draw_random_bytes(ref self: Channel) -> Array<u8> {
let mut cur: u256 = self.draw_felt252().into();
let mut bytes = array![];
let mut i: usize = 0;
while i < 31 {
let (q, r) = DivRem::div_rem(cur, 256);
bytes.append(r.try_into().unwrap());
cur = q;
i += 1;
};
for _ in 0_usize
..31 {
let (q, r) = DivRem::div_rem(cur, 256);
bytes.append(r.try_into().unwrap());
cur = q;
};
bytes
}
}
Expand Down Expand Up @@ -224,9 +223,7 @@ mod tests {
let initial_digest = 0;
let mut channel = ChannelTrait::new(initial_digest);

let mut n: usize = 10;
while n > 0 {
n -= 1;
for _ in 0_usize..10 {
channel.draw_felt();
};

Expand Down
16 changes: 12 additions & 4 deletions stwo_cairo_verifier/src/circle.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -247,16 +247,15 @@ mod tests {

#[test]
fn test_add_1() {
let i = CirclePoint { x: m31(0), y: m31(1) };
let result = i + i;

assert_eq!(result, CirclePoint { x: -m31(1), y: m31(0) });
let g4 = CirclePoint { x: m31(0), y: m31(1) };
assert_eq!(g4 + g4, CirclePoint { x: -m31(1), y: m31(0) });
}

#[test]
fn test_add_2() {
let point_1 = CirclePoint { x: m31(750649172), y: m31(1991648574) };
let point_2 = CirclePoint { x: m31(1737427771), y: m31(309481134) };

let result = point_1 + point_2;

assert_eq!(result, CirclePoint { x: m31(1476625263), y: m31(1040927458) });
Expand All @@ -273,6 +272,7 @@ mod tests {
fn test_zero_2() {
let point_1 = CirclePoint { x: m31(750649172), y: m31(1991648574) };
let point_2 = CirclePointM31Impl::zero();

let result = point_1 + point_2;

assert_eq!(result, point_1.clone());
Expand All @@ -281,6 +281,7 @@ mod tests {
#[test]
fn test_mul_1() {
let point_1 = CirclePoint { x: m31(750649172), y: m31(1991648574) };

let result = point_1.mul(5);

assert_eq!(result, point_1 + point_1 + point_1 + point_1 + point_1);
Expand All @@ -289,6 +290,7 @@ mod tests {
#[test]
fn test_mul_2() {
let point_1 = CirclePoint { x: m31(750649172), y: m31(1991648574) };

let result = point_1.mul(8);

assert_eq!(
Expand All @@ -299,6 +301,7 @@ mod tests {
#[test]
fn test_mul_3() {
let point_1 = CirclePoint { x: m31(750649172), y: m31(1991648574) };

let result = point_1.mul(418776494);

assert_eq!(result, CirclePoint { x: m31(1987283985), y: m31(1500510905) });
Expand All @@ -307,6 +310,7 @@ mod tests {
#[test]
fn test_generator_order() {
let half_order = M31_CIRCLE_ORDER / 2;

let mut result = M31_CIRCLE_GEN.mul(half_order.into());

// Assert `M31_CIRCLE_GEN^{2^30}` equals `-1`.
Expand All @@ -327,6 +331,7 @@ mod tests {
log_size: 5,
step_size: CirclePointIndexImpl::new(67108864)
};

let result = coset.index_at(8);

assert_eq!(result, CirclePointIndexImpl::new(553648128));
Expand All @@ -353,6 +358,7 @@ mod tests {
step_size: CirclePointIndexImpl::new(67108864),
log_size: 5
};

let result = coset.double();

assert_eq!(
Expand All @@ -372,6 +378,7 @@ mod tests {
step_size: CirclePointIndexImpl::new(67108864),
log_size: 5
};

let result = coset.at(17);

assert_eq!(result, CirclePoint { x: m31(7144319), y: m31(1742797653) });
Expand All @@ -384,6 +391,7 @@ mod tests {
step_size: CirclePointIndexImpl::new(67108864),
log_size: 5
};

let result = coset.size();

assert_eq!(result, 32);
Expand Down
22 changes: 2 additions & 20 deletions stwo_cairo_verifier/src/fields/qm31.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,11 @@ pub impl QM31Impl of QM31Trait {
QM31 { a: CM31 { a: a, b: b }, b: CM31 { a: c, b: d } }
}

fn from_u32(arr: [u32; 4]) -> QM31 {
let [a, b, c, d] = arr;
let a_mod_p = M31Impl::reduce_u32(a);
let b_mod_p = M31Impl::reduce_u32(b);
let c_mod_p = M31Impl::reduce_u32(c);
let d_mod_p = M31Impl::reduce_u32(d);

QM31 { a: CM31 { a: a_mod_p, b: b_mod_p }, b: CM31 { a: c_mod_p, b: d_mod_p } }
}
#[inline]
fn to_array(self: QM31) -> [M31; 4] {
[self.a.a, self.a.b, self.b.a, self.b.b]
}

fn inverse(self: QM31) -> QM31 {
assert_ne!(self, Zero::zero());
let b2 = self.b * self.b;
Expand All @@ -43,6 +35,7 @@ pub impl QM31Impl of QM31Trait {
let denom_inverse = denom.inverse();
QM31 { a: self.a * denom_inverse, b: -self.b * denom_inverse }
}

fn mul_m31(self: QM31, multiplier: M31) -> QM31 {
QM31 {
a: CM31 { a: self.a.a * multiplier, b: self.a.b * multiplier },
Expand Down Expand Up @@ -139,15 +132,4 @@ mod tests {
assert_eq!(qm1 * m.inverse().into(), qm1 * qm.inverse());
assert_eq!(qm1.mul_m31(m), qm1 * m.into());
}

#[test]
fn test_qm31_from_u32() {
let arr = [2147483648, 2, 3, 4];
let felt = QM31Impl::from_u32(arr);
let expected_felt = QM31 {
a: CM31 { a: m31(1), b: m31(2) }, b: CM31 { a: m31(3), b: m31(4) }
};

assert_eq!(felt, expected_felt)
}
}
93 changes: 45 additions & 48 deletions stwo_cairo_verifier/src/fri.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use stwo_cairo_verifier::channel::{Channel, ChannelTrait};
use stwo_cairo_verifier::circle::CosetImpl;
use stwo_cairo_verifier::fields::m31::M31;
use stwo_cairo_verifier::fields::m31::M31Trait;
use stwo_cairo_verifier::fields::qm31::{QM31, QM31Trait};
use stwo_cairo_verifier::fields::qm31::{QM31, QM31Zero, QM31Trait};
use stwo_cairo_verifier::poly::circle::CircleDomainImpl;
use stwo_cairo_verifier::poly::circle::{
CircleEvaluation, SparseCircleEvaluation, SparseCircleEvaluationImpl
Expand All @@ -15,7 +15,7 @@ use stwo_cairo_verifier::poly::line::{LineDomain, LineDomainImpl};
use stwo_cairo_verifier::poly::line::{LinePoly, LinePolyImpl};
use stwo_cairo_verifier::queries::SparseSubCircleDomain;
use stwo_cairo_verifier::queries::{Queries, QueriesImpl};
use stwo_cairo_verifier::utils::{bit_reverse_index, pow, pow_qm31, qm31_zero_array, find};
use stwo_cairo_verifier::utils::{bit_reverse_index, ArrayImpl, pow, pow_qm31, find};
use stwo_cairo_verifier::vcs::hasher::PoseidonMerkleHasher;
use stwo_cairo_verifier::vcs::verifier::{MerkleDecommitment, MerkleVerifier, MerkleVerifierTrait};

Expand Down Expand Up @@ -48,7 +48,7 @@ impl FriLayerVerifierImpl of FriLayerVerifierTrait {
) -> Result<(Queries, Array<QM31>), FriVerificationError> {
let commitment = self.proof.commitment;

let sparse_evaluation = @self.extract_evaluation(queries, evals_at_queries)?;
let sparse_evaluation = self.extract_evaluation(queries, evals_at_queries)?;
let mut column_0: Array<M31> = array![];
let mut column_1: Array<M31> = array![];
let mut column_2: Array<M31> = array![];
Expand Down Expand Up @@ -335,7 +335,7 @@ pub impl FriVerifierImpl of FriVerifierTrait {
let circle_poly_alpha_sq = *circle_poly_alpha * *circle_poly_alpha;

let mut layer_queries = queries.fold(CIRCLE_TO_LINE_FOLD_STEP);
let mut layer_query_evals = qm31_zero_array(layer_queries.len());
let mut layer_query_evals = ArrayImpl::new_repeated(layer_queries.len(), QM31Zero::zero());

let mut inner_layers_index = 0;
let mut column_bound_index = 0;
Expand Down Expand Up @@ -398,23 +398,23 @@ pub impl FriVerifierImpl of FriVerifierTrait {
fn decommit_last_layer(
self: @FriVerifier, queries: Queries, query_evals: Array<QM31>,
) -> Result<(), FriVerificationError> {
let mut failed = false;
let FriVerifier { last_layer_domain, last_layer_poly, .. } = self;

let mut i = 0;
while i < queries.positions.len() {
let query = queries.positions[i];
let domain = self.last_layer_domain;
let x = self.last_layer_domain.at(bit_reverse_index(*query, domain.log_size()));
loop {
if i == queries.positions.len() {
break Result::Ok(());
}

if *query_evals[i] != self.last_layer_poly.eval_at_point(x.into()) {
failed = true;
break;
let query = *queries.positions[i];
let query_eval = *query_evals[i];
let x = last_layer_domain.at(bit_reverse_index(query, last_layer_domain.log_size()));

if query_eval != last_layer_poly.eval_at_point(x.into()) {
break Result::Err(FriVerificationError::LastLayerEvaluationsInvalid);
}

i += 1;
};
if failed {
return Result::Err(FriVerificationError::LastLayerEvaluationsInvalid);
} else {
Result::Ok(())
}
}

Expand Down Expand Up @@ -483,43 +483,41 @@ fn get_opening_positions(
/// element and `pi(x) = 2x^2 - 1` be the circle's x-coordinate doubling map. This function
/// returns `f' = f0 + alpha * f1` evaluated on `pi(E)` such that `2f(x) = f0(pi(x)) + x *
/// f1(pi(x))`.
pub fn fold_line(eval: @LineEvaluation, alpha: QM31) -> LineEvaluation {
pub fn fold_line(eval: LineEvaluation, alpha: QM31) -> LineEvaluation {
let domain = eval.domain;
let mut values = array![];
let mut i = 0;
while i < eval.values.len() / 2 {
let x = domain.at(bit_reverse_index(i * pow(2, FOLD_STEP), domain.log_size()));
let f_x = eval.values[2 * i];
let f_neg_x = eval.values[2 * i + 1];
let (f0, f1) = ibutterfly(*f_x, *f_neg_x, x.inverse());
values.append(f0 + alpha * f1);
i += 1;
};
for i in 0
..eval.values.len()
/ 2 {
let x = domain.at(bit_reverse_index(i * pow(2, FOLD_STEP), domain.log_size()));
let f_x = eval.values[2 * i];
let f_neg_x = eval.values[2 * i + 1];
let (f0, f1) = ibutterfly(*f_x, *f_neg_x, x.inverse());
values.append(f0 + alpha * f1);
};
LineEvaluationImpl::new(domain.double(), values)
}

/// Folds and accumulates a degree `d` circle polynomial into a degree `d/2` univariate
/// polynomial.
/// Folds and accumulates a degree `d` circle polynomial into a degree `d/2` univariate polynomial.
///
/// Let `src` be the evaluation of a circle polynomial `f` on a
/// [`CircleDomain`] `E`. This function computes evaluations of `f' = f0
/// + alpha * f1` on the x-coordinates of `E` such that `2f(p) = f0(px) + py * f1(px)`. The
/// evaluations of `f'` are accumulated into `dst` by the formula `dst = dst * alpha^2 +
/// f'`.
pub fn fold_circle_into_line(eval: @CircleEvaluation, alpha: QM31) -> LineEvaluation {
/// Let `src` be the evaluation of a circle polynomial `f` on a [`CircleDomain`] `E`. This function
/// computes evaluations of `f' = f0 + alpha * f1` on the x-coordinates of `E` such that `2f(p) =
/// f0(px) + py * f1(px)`. The evaluations of `f'` are accumulated into `dst` by the formula
/// `dst = dst * alpha^2 + f'`.
pub fn fold_circle_into_line(eval: CircleEvaluation, alpha: QM31) -> LineEvaluation {
let domain = eval.domain;
let fold_factor = pow(2, CIRCLE_TO_LINE_FOLD_STEP);
let mut values = array![];
let mut i = 0;
while i < eval.values.len() / 2 {
let p = domain
.at(bit_reverse_index(i * pow(2, CIRCLE_TO_LINE_FOLD_STEP), domain.log_size()));
let f_p = eval.values[2 * i];
let f_neg_p = eval.values[2 * i + 1];
let (f0, f1) = ibutterfly(*f_p, *f_neg_p, p.y.inverse());
values.append(f0 + alpha * f1);
i += 1;
};
LineEvaluation { values, domain: LineDomainImpl::new(*domain.half_coset) }
for i in 0
..eval.bit_reversed_values.len()
/ 2 {
let p = domain.at(bit_reverse_index(i * fold_factor, domain.log_size()));
let f_p = eval.bit_reversed_values[2 * i];
let f_neg_p = eval.bit_reversed_values[2 * i + 1];
let (f0, f1) = ibutterfly(*f_p, *f_neg_p, p.y.inverse());
values.append(f0 + alpha * f1);
};
LineEvaluation { values, domain: LineDomainImpl::new(domain.half_coset) }
}

pub fn ibutterfly(v0: QM31, v1: QM31, itwid: M31) -> (QM31, QM31) {
Expand Down Expand Up @@ -639,11 +637,10 @@ mod test {
#[test]
fn proof_with_removed_layer_fails_verification() {
let (config, proof, bounds, _queries, _decommitted_values) = proof_with_mixed_degree_1();

let mut invalid_config = config;
invalid_config.log_last_layer_degree_bound -= 1;

let mut channel = ChannelTrait::new(0x00);

let result = FriVerifierImpl::commit(ref channel, invalid_config, proof, bounds);

match result {
Expand Down
1 change: 0 additions & 1 deletion stwo_cairo_verifier/src/lib.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ mod fields;
mod fri;
mod poly;
mod queries;
mod sort;
mod utils;
mod vcs;

Expand Down
34 changes: 14 additions & 20 deletions stwo_cairo_verifier/src/poly/circle.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -55,25 +55,21 @@ pub impl CircleDomainImpl of CircleDomainTrait {
/// The values are ordered according to the [CircleDomain] ordering.
#[derive(Debug, Drop, Clone, PartialEq)]
pub struct CircleEvaluation {
pub values: Array<QM31>,
pub bit_reversed_values: Array<QM31>,
pub domain: CircleDomain,
_eval_order: BitReversedOrder
}

#[derive(Debug, Drop, Clone, PartialEq)]
struct BitReversedOrder {}

#[generate_trait]
pub impl CircleEvaluationImpl of CircleEvaluationTrait {
fn new(domain: CircleDomain, values: Array<QM31>) -> CircleEvaluation {
CircleEvaluation { values: values, domain: domain, _eval_order: BitReversedOrder {} }
fn new(domain: CircleDomain, bit_reversed_values: Array<QM31>) -> CircleEvaluation {
CircleEvaluation { bit_reversed_values, domain }
}
}

/// Holds a foldable subset of circle polynomial evaluations.
#[derive(Drop, Clone, Debug, PartialEq)]
pub struct SparseCircleEvaluation {
pub subcircle_evals: Array<CircleEvaluation>
pub subcircle_evals: Array<CircleEvaluation>,
}

#[generate_trait]
Expand All @@ -92,10 +88,10 @@ pub impl SparseCircleEvaluationImpl of SparseCircleEvaluationImplTrait {
let lhs = self.subcircle_evals[i];
let rhs = rhs.subcircle_evals[i];
let mut values = array![];
assert_eq!(lhs.values.len(), rhs.values.len());
assert_eq!(lhs.bit_reversed_values.len(), rhs.bit_reversed_values.len());
let mut j = 0;
while j < lhs.values.len() {
values.append(*lhs.values[j] * alpha + *rhs.values[j]);
while j < lhs.bit_reversed_values.len() {
values.append(*lhs.bit_reversed_values[j] * alpha + *rhs.bit_reversed_values[j]);
j += 1;
};
subcircle_evals.append(CircleEvaluationImpl::new(*lhs.domain, values));
Expand All @@ -105,15 +101,13 @@ pub impl SparseCircleEvaluationImpl of SparseCircleEvaluationImplTrait {
SparseCircleEvaluation { subcircle_evals }
}

fn fold(self: @SparseCircleEvaluation, alpha: QM31) -> Array<QM31> {
let mut i = 0;
let mut res: Array<QM31> = array![];
while i < self.subcircle_evals.len() {
let circle_evaluation = fold_circle_into_line(self.subcircle_evals[i], alpha);
res.append(*circle_evaluation.values.at(0));
i += 1;
};
return res;
fn fold(self: SparseCircleEvaluation, alpha: QM31) -> Array<QM31> {
let mut res = array![];
for eval in self
.subcircle_evals {
res.append(*fold_circle_into_line(eval, alpha).values[0])
};
res
}
}

Expand Down
Loading

0 comments on commit 3948dca

Please sign in to comment.