From c229d888b71b3470df306c0f9cac7649c15214f6 Mon Sep 17 00:00:00 2001 From: Andrew Milson Date: Tue, 22 Oct 2024 13:48:33 -0400 Subject: [PATCH] Optimizations --- stwo_cairo_verifier/src/circle.cairo | 1 + stwo_cairo_verifier/src/fri.cairo | 21 +- stwo_cairo_verifier/src/poly/line.cairo | 270 ++++++++++++++++++++--- stwo_cairo_verifier/src/poly/utils.cairo | 7 +- 4 files changed, 257 insertions(+), 42 deletions(-) diff --git a/stwo_cairo_verifier/src/circle.cairo b/stwo_cairo_verifier/src/circle.cairo index 977e7c9f..6c23f68e 100644 --- a/stwo_cairo_verifier/src/circle.cairo +++ b/stwo_cairo_verifier/src/circle.cairo @@ -95,6 +95,7 @@ pub trait CirclePointTrait< impl CirclePointAdd, +Sub, +Mul, +Drop, +Copy> of Add> { /// Performs the operation of the circle as a group with additive notation. + #[inline] fn add(lhs: CirclePoint, rhs: CirclePoint) -> CirclePoint { CirclePoint { x: lhs.x * rhs.x - lhs.y * rhs.y, y: lhs.x * rhs.y + lhs.y * rhs.x } } diff --git a/stwo_cairo_verifier/src/fri.cairo b/stwo_cairo_verifier/src/fri.cairo index c3d6a6eb..7042f7ac 100644 --- a/stwo_cairo_verifier/src/fri.cairo +++ b/stwo_cairo_verifier/src/fri.cairo @@ -176,7 +176,9 @@ impl FriLayerVerifierImpl of FriLayerVerifierTrait { let subline_initial_index = bit_reverse_index(subline_start, self.domain.log_size()); let subline_initial = self.domain.coset.index_at(subline_initial_index); - let subline_domain = LineDomainImpl::new(CosetImpl::new(subline_initial, FOLD_STEP)); + let subline_domain = LineDomainImpl::new_unchecked( + CosetImpl::new(subline_initial, FOLD_STEP) + ); all_subline_evals.append(LineEvaluationImpl::new(subline_domain, subline_evals)); }; @@ -238,7 +240,7 @@ pub impl FriVerifierImpl of FriVerifierTrait { let mut inner_layers = array![]; let mut layer_bound = *max_column_bound - CIRCLE_TO_LINE_FOLD_STEP; - let mut layer_domain = LineDomainImpl::new( + let mut layer_domain = LineDomainImpl::new_unchecked( CosetImpl::half_odds(layer_bound + config.log_blowup_factor) ); @@ -316,7 +318,6 @@ pub impl FriVerifierImpl of FriVerifierTrait { self: @FriVerifier, queries: @Queries, decommitted_values: Array ) -> Result<(), FriVerificationError> { assert!(queries.log_domain_size == self.expected_query_log_domain_size); - let (last_layer_queries, last_layer_query_evals) = self .decommit_inner_layers(queries, @decommitted_values)?; self.decommit_last_layer(last_layer_queries, last_layer_query_evals) @@ -397,6 +398,11 @@ pub impl FriVerifierImpl of FriVerifierTrait { ) -> Result<(), FriVerificationError> { let FriVerifier { last_layer_domain, last_layer_poly, .. } = self; + let domain_log_size = last_layer_domain.log_size(); + // TODO(andrew): Note depending on the proof parameters, doing FFT on the last layer poly vs + // pointwize evaluation is less efficient. + let last_layer_evals = last_layer_poly.evaluate(*last_layer_domain).values; + let mut i = 0; loop { if i == queries.positions.len() { @@ -404,10 +410,11 @@ pub impl FriVerifierImpl of FriVerifierTrait { } let query = *queries.positions[i]; - let query_eval = *query_evals[i]; - let x = last_layer_domain.at(bit_reverse_index(query, last_layer_domain.log_size())); + // TODO(andrew): Makes more sense for the proof to provide coeffs in natural order and + // the FFT return evals in bit-reversed order to prevent this unnessesary bit-reverse. + let last_layer_eval_i = bit_reverse_index(query, domain_log_size); - if query_eval != last_layer_poly.eval_at_point(x.into()) { + if query_evals[i] != last_layer_evals[last_layer_eval_i] { break Result::Err(FriVerificationError::LastLayerEvaluationsInvalid); } @@ -514,7 +521,7 @@ pub fn fold_circle_into_line(eval: CircleEvaluation, alpha: QM31) -> LineEvaluat let (f0, f1) = ibutterfly(*f_p, *f_neg_p, p.y.inverse()); values.append(f0 + alpha * f1); }; - LineEvaluation { values, domain: LineDomainImpl::new(domain.half_coset) } + LineEvaluation { values, domain: LineDomainImpl::new_unchecked(domain.half_coset) } } pub fn ibutterfly(v0: QM31, v1: QM31, itwid: M31) -> (QM31, QM31) { diff --git a/stwo_cairo_verifier/src/poly/line.cairo b/stwo_cairo_verifier/src/poly/line.cairo index d75a73dc..7dde69be 100644 --- a/stwo_cairo_verifier/src/poly/line.cairo +++ b/stwo_cairo_verifier/src/poly/line.cairo @@ -1,9 +1,13 @@ -use stwo_cairo_verifier::circle::{Coset, CosetImpl, CirclePointIndexImpl, CirclePointTrait}; -use stwo_cairo_verifier::fields::SecureField; +use core::iter::{Iterator, IntoIterator}; +use stwo_cairo_verifier::circle::{ + CirclePoint, Coset, CosetImpl, CirclePointIndexImpl, CirclePointTrait +}; use stwo_cairo_verifier::fields::m31::{M31, m31}; -use stwo_cairo_verifier::fields::qm31::{QM31, QM31Zero}; +use stwo_cairo_verifier::fields::qm31::{QM31, QM31Impl, QM31Zero}; +use stwo_cairo_verifier::fields::{SecureField, BaseField}; use stwo_cairo_verifier::fri::fold_line; use stwo_cairo_verifier::poly::utils::fold; +use stwo_cairo_verifier::utils::pow; /// A univariate polynomial defined on a [LineDomain]. #[derive(Debug, Drop, Clone)] @@ -30,18 +34,130 @@ pub impl LinePolyImpl of LinePolyTrait { } /// Evaluates the polynomial at a single point. - fn eval_at_point(self: @LinePoly, mut x: SecureField) -> SecureField { + // TODO(andrew): Can remove if only use `Self::evaluate()` in the verifier. + // Note there are tradeoffs depending on the blowup factor last FRI layer degree bound. + fn eval_at_point(self: @LinePoly, mut x: BaseField) -> SecureField { let mut doublings = array![]; - let mut i = 0; - while i < *self.log_size { - doublings.append(x); - let x_square = x * x; - x = x_square + x_square - m31(1).into(); - i += 1; - }; + for _ in 0 + ..*self + .log_size { + doublings.append(x); + let x_square = x * x; + x = x_square + x_square - m31(1); + }; fold(self.coeffs, @doublings, 0, 0, self.coeffs.len()) } + + fn evaluate(self: @LinePoly, domain: LineDomain) -> LineEvaluation { + assert!(domain.size() >= self.coeffs.len()); + + // The first few FFT layers may just copy coefficients so we do it directly. + // See the docs for `n_skipped_layers` in `line_fft()`. + let log_domain_size = domain.log_size(); + let log_degree_bound = *self.log_size; + let n_skipped_layers = log_domain_size - log_degree_bound; + let duplicity = pow(2, n_skipped_layers); + let coeffs = repeat_value(self.coeffs.span(), duplicity); + + LineEvaluationImpl::new(domain, line_fft(coeffs, domain, n_skipped_layers)) + } +} + +/// Repeats each value sequentially `duplicity` many times. +pub fn repeat_value(values: Span, duplicity: usize) -> Array { + let mut res = array![]; + for v in values { + for _ in 0..duplicity { + res.append(*v) + }; + }; + res +} + +/// Performs a FFT on a univariate polynomial. +/// +/// `values` is the coefficients stored in bit-reversed order. The evaluations of the polynomial +/// over `domain` is returned in natural order. +/// +/// `n_skipped_layers` specifies how many of the initial butterfly layers to skip. This is used for +/// more efficient degree aware FFTs as the butterflies in the first layers of the FFT only involve +/// copying coefficients to different locations (because one or more of the coefficients is zero). +/// This new algorithm is `O(n log d)` vs `O(n log n)` where `n` is the domain size and `d` is the +/// degree of the polynomial. +/// +/// Note the algorithm does not operate on coefficients in the standard monomial basis but rather +/// coefficients in a basis relating to the circle's x-coordinate doubling map `pi(x) = 2x^2 - 1` +/// i.e. +/// +/// ```text +/// B = { 1 } * { x } * { pi(x) } * { pi(pi(x)) } * ... +/// = { 1, x, pi(x), pi(x) * x, pi(pi(x)), pi(pi(x)) * x, pi(pi(x)) * pi(x), ... } +/// ``` +/// +/// # Panics +/// +/// Panics if the number of values doesn't match the size of the domain. +#[inline] +fn line_fft( + mut values: Array, mut domain: LineDomain, n_skipped_layers: usize +) -> Array { + let n = values.len(); + assert!(values.len() == domain.size()); + + let mut domains = array![]; + while domain.log_size() != n_skipped_layers { + domains.append(domain); + domain = domain.double(); + }; + let mut domains = domains.span(); + + while let Option::Some(domain) = domains.pop_back() { + let chunk_size = domain.size(); + let twiddles = gen_twiddles(domain).span(); + let n_chunks = n / chunk_size; + let stride = chunk_size / 2; + let values_span = values.span(); + let mut next_values = array![]; + for chunk_i in 0 + ..n_chunks { + let chunk_offset = chunk_i * chunk_size; + let mut chunk_l_vals = values_span.slice(chunk_offset, stride).into_iter(); + let mut chunk_r_vals = values_span.slice(chunk_offset + stride, stride).into_iter(); + let mut next_r_values = array![]; + for twiddle in twiddles { + let v0 = *chunk_l_vals.next().unwrap(); + let v1 = *chunk_r_vals.next().unwrap(); + let (v0, v1) = butterfly(v0, v1, *twiddle); + next_values.append(v0); + next_r_values.append(v1); + }; + next_values.append_span(next_r_values.span()); + }; + values = next_values; + }; + + values +} + +#[inline] +fn gen_twiddles(self: @LineDomain) -> Array { + let mut iter = LineDomainIterator { + cur: self.coset.initial_index.to_point(), + step: self.coset.step_size.to_point(), + remaining: self.size() / 2 + }; + let mut res = array![]; + while let Option::Some(v) = iter.next() { + res.append(v); + }; + res +} + +#[inline] +fn butterfly(v0: QM31, v1: QM31, twid: M31) -> (QM31, QM31) { + let tmp = v1.mul_m31(twid); + (v0 + tmp, v0 - tmp) } /// Domain comprising of the x-coordinates of points in a [Coset]. @@ -72,6 +188,15 @@ pub impl LineDomainImpl of LineDomainTrait { LineDomain { coset: coset } } + /// Returns a domain comprising of the x-coordinates of points in a coset. + /// + /// # Saftey + /// + /// All coset points must have unique `x` coordinates. + fn new_unchecked(coset: Coset) -> LineDomain { + LineDomain { coset: coset } + } + /// Returns the `i`th domain element. fn at(self: @LineDomain, index: usize) -> M31 { self.coset.at(index).x @@ -96,6 +221,7 @@ pub impl LineDomainImpl of LineDomainTrait { /// Evaluations of a univariate polynomial on a [LineDomain]. #[derive(Drop)] pub struct LineEvaluation { + /// Evaluations in natural order. pub values: Array, pub domain: LineDomain } @@ -126,21 +252,35 @@ pub impl SparseLineEvaluationImpl of SparseLineEvaluationTrait { } } +#[derive(Drop, Clone)] +struct LineDomainIterator { + pub cur: CirclePoint, + pub step: CirclePoint, + pub remaining: usize, +} + +impl LineDomainIteratorImpl of Iterator { + type Item = M31; + + fn next(ref self: LineDomainIterator) -> Option { + if self.remaining == 0 { + return Option::None; + } + self.remaining -= 1; + let res = self.cur.x; + self.cur = self.cur + self.step; + Option::Some(res) + } +} + #[cfg(test)] mod tests { + use core::iter::{IntoIterator, Iterator}; use stwo_cairo_verifier::circle::{CosetImpl, CirclePointIndexImpl}; use stwo_cairo_verifier::fields::m31::m31; use stwo_cairo_verifier::fields::qm31::qm31; - use super::{LinePoly, LinePolyTrait, LineDomainImpl}; - - #[test] - #[should_panic] - fn bad_line_domain() { - // This coset doesn't have points with unique x-coordinates. - let coset = CosetImpl::odds(2); - LineDomainImpl::new(coset); - } + use super::{LinePoly, LinePolyTrait, LineDomain, LineDomainImpl, LineDomainIterator}; #[test] fn line_domain_of_size_two_works() { @@ -165,7 +305,7 @@ mod tests { }; let x = m31(590768354); - let result = line_poly.eval_at_point(x.into()); + let result = line_poly.eval_at_point(x); assert_eq!(result, qm31(515899232, 1030391528, 1006544539, 11142505)); } @@ -177,7 +317,7 @@ mod tests { }; let x = m31(10); - let result = line_poly.eval_at_point(x.into()); + let result = line_poly.eval_at_point(x); assert_eq!(result, qm31(51, 62, 73, 84)); } @@ -186,21 +326,87 @@ mod tests { fn test_eval_at_point_3() { let poly = LinePoly { coeffs: array![ - qm31(1, 2, 3, 4), - qm31(5, 6, 7, 8), - qm31(9, 10, 11, 12), - qm31(13, 14, 15, 16), - qm31(17, 18, 19, 20), - qm31(21, 22, 23, 24), - qm31(25, 26, 27, 28), - qm31(29, 30, 31, 32), + qm31(1, 8, 0, 1), + qm31(2, 7, 1, 2), + qm31(3, 6, 0, 1), + qm31(4, 5, 1, 3), + qm31(5, 4, 0, 1), + qm31(6, 3, 1, 4), + qm31(7, 2, 0, 1), + qm31(8, 1, 1, 5), ], log_size: 3 }; - let x = qm31(2, 5, 7, 11); + let x = m31(10); let result = poly.eval_at_point(x); - assert_eq!(result, qm31(1857853974, 839310133, 939318020, 651207981)); + assert_eq!(result, qm31(1328848956, 239350644, 174242200, 838661589)); + } + + #[test] + fn test_evaluate() { + let log_size = 3; + let domain = LineDomainImpl::new(CosetImpl::half_odds(log_size)); + let poly = LinePoly { + coeffs: array![ + qm31(1, 8, 0, 1), + qm31(2, 7, 1, 2), + qm31(3, 6, 0, 1), + qm31(4, 5, 1, 3), + qm31(5, 4, 0, 1), + qm31(6, 3, 1, 4), + qm31(7, 2, 0, 1), + qm31(8, 1, 1, 5), + ], + log_size, + }; + + let result = poly.evaluate(domain); + let mut result_iter = result.values.into_iter(); + + for x in domain + .into_iter() { + assert_eq!(result_iter.next().unwrap(), poly.eval_at_point(x)); + } + } + + #[test] + fn test_evaluate_with_large_domain() { + let log_size = 3; + let domain = LineDomainImpl::new(CosetImpl::half_odds(log_size + 2)); + let poly = LinePoly { + coeffs: array![ + qm31(1, 8, 0, 1), + qm31(2, 7, 1, 2), + qm31(3, 6, 0, 1), + qm31(4, 5, 1, 3), + qm31(5, 4, 0, 1), + qm31(6, 3, 1, 4), + qm31(7, 2, 0, 1), + qm31(8, 1, 1, 5), + ], + log_size, + }; + + let result = poly.evaluate(domain); + let mut result_iter = result.values.into_iter(); + + for x in domain + .into_iter() { + assert_eq!(result_iter.next().unwrap(), poly.eval_at_point(x)); + } + } + + impl LineDomainIntoIterator of IntoIterator { + type IntoIter = LineDomainIterator; + + fn into_iter(self: LineDomain) -> LineDomainIterator { + LineDomainIterator { + cur: self.coset.initial_index.to_point(), + step: self.coset.step_size.to_point(), + remaining: self.size(), + } + } } } diff --git a/stwo_cairo_verifier/src/poly/utils.cairo b/stwo_cairo_verifier/src/poly/utils.cairo index 31e0ca10..48ae86e5 100644 --- a/stwo_cairo_verifier/src/poly/utils.cairo +++ b/stwo_cairo_verifier/src/poly/utils.cairo @@ -1,4 +1,5 @@ -use stwo_cairo_verifier::fields::SecureField; +use stwo_cairo_verifier::fields::qm31::QM31Impl; +use stwo_cairo_verifier::fields::{SecureField, BaseField}; /// Folds values recursively in `O(n)` by a hierarchical application of folding factors. /// @@ -20,7 +21,7 @@ use stwo_cairo_verifier::fields::SecureField; /// factors is provided. pub fn fold( values: @Array, - folding_factors: @Array, + folding_factors: @Array, index: usize, level: usize, n: usize @@ -31,5 +32,5 @@ pub fn fold( let lhs_val = fold(values, folding_factors, index, level + 1, n / 2); let rhs_val = fold(values, folding_factors, index + n / 2, level + 1, n / 2); - return lhs_val + rhs_val * *folding_factors[level]; + return lhs_val + rhs_val.mul_m31(*folding_factors[level]); }