diff --git a/stwo_cairo_verifier/src/channel.cairo b/stwo_cairo_verifier/src/channel.cairo index 4fb25ed9..408359b5 100644 --- a/stwo_cairo_verifier/src/channel.cairo +++ b/stwo_cairo_verifier/src/channel.cairo @@ -129,13 +129,12 @@ pub impl ChannelImpl of ChannelTrait { fn draw_random_bytes(ref self: Channel) -> Array { let mut cur: u256 = self.draw_felt252().into(); let mut bytes = array![]; - let mut i: usize = 0; - while i < 31 { - let (q, r) = DivRem::div_rem(cur, 256); - bytes.append(r.try_into().unwrap()); - cur = q; - i += 1; - }; + for _ in 0_usize + ..31 { + let (q, r) = DivRem::div_rem(cur, 256); + bytes.append(r.try_into().unwrap()); + cur = q; + }; bytes } } @@ -224,9 +223,7 @@ mod tests { let initial_digest = 0; let mut channel = ChannelTrait::new(initial_digest); - let mut n: usize = 10; - while n > 0 { - n -= 1; + for _ in 0_usize..10 { channel.draw_felt(); }; diff --git a/stwo_cairo_verifier/src/circle.cairo b/stwo_cairo_verifier/src/circle.cairo index 34848c49..1027715a 100644 --- a/stwo_cairo_verifier/src/circle.cairo +++ b/stwo_cairo_verifier/src/circle.cairo @@ -247,16 +247,15 @@ mod tests { #[test] fn test_add_1() { - let i = CirclePoint { x: m31(0), y: m31(1) }; - let result = i + i; - - assert_eq!(result, CirclePoint { x: -m31(1), y: m31(0) }); + let g4 = CirclePoint { x: m31(0), y: m31(1) }; + assert_eq!(g4 + g4, CirclePoint { x: -m31(1), y: m31(0) }); } #[test] fn test_add_2() { let point_1 = CirclePoint { x: m31(750649172), y: m31(1991648574) }; let point_2 = CirclePoint { x: m31(1737427771), y: m31(309481134) }; + let result = point_1 + point_2; assert_eq!(result, CirclePoint { x: m31(1476625263), y: m31(1040927458) }); @@ -273,6 +272,7 @@ mod tests { fn test_zero_2() { let point_1 = CirclePoint { x: m31(750649172), y: m31(1991648574) }; let point_2 = CirclePointM31Impl::zero(); + let result = point_1 + point_2; assert_eq!(result, point_1.clone()); @@ -281,6 +281,7 @@ mod tests { #[test] fn test_mul_1() { let point_1 = CirclePoint { x: m31(750649172), y: m31(1991648574) }; + let result = point_1.mul(5); assert_eq!(result, point_1 + point_1 + point_1 + point_1 + point_1); @@ -289,6 +290,7 @@ mod tests { #[test] fn test_mul_2() { let point_1 = CirclePoint { x: m31(750649172), y: m31(1991648574) }; + let result = point_1.mul(8); assert_eq!( @@ -299,6 +301,7 @@ mod tests { #[test] fn test_mul_3() { let point_1 = CirclePoint { x: m31(750649172), y: m31(1991648574) }; + let result = point_1.mul(418776494); assert_eq!(result, CirclePoint { x: m31(1987283985), y: m31(1500510905) }); @@ -307,6 +310,7 @@ mod tests { #[test] fn test_generator_order() { let half_order = M31_CIRCLE_ORDER / 2; + let mut result = M31_CIRCLE_GEN.mul(half_order.into()); // Assert `M31_CIRCLE_GEN^{2^30}` equals `-1`. @@ -327,6 +331,7 @@ mod tests { log_size: 5, step_size: CirclePointIndexImpl::new(67108864) }; + let result = coset.index_at(8); assert_eq!(result, CirclePointIndexImpl::new(553648128)); @@ -353,6 +358,7 @@ mod tests { step_size: CirclePointIndexImpl::new(67108864), log_size: 5 }; + let result = coset.double(); assert_eq!( @@ -372,6 +378,7 @@ mod tests { step_size: CirclePointIndexImpl::new(67108864), log_size: 5 }; + let result = coset.at(17); assert_eq!(result, CirclePoint { x: m31(7144319), y: m31(1742797653) }); @@ -384,6 +391,7 @@ mod tests { step_size: CirclePointIndexImpl::new(67108864), log_size: 5 }; + let result = coset.size(); assert_eq!(result, 32); diff --git a/stwo_cairo_verifier/src/fields/qm31.cairo b/stwo_cairo_verifier/src/fields/qm31.cairo index a7533769..f8eed1ff 100644 --- a/stwo_cairo_verifier/src/fields/qm31.cairo +++ b/stwo_cairo_verifier/src/fields/qm31.cairo @@ -22,19 +22,11 @@ pub impl QM31Impl of QM31Trait { QM31 { a: CM31 { a: a, b: b }, b: CM31 { a: c, b: d } } } - fn from_u32(arr: [u32; 4]) -> QM31 { - let [a, b, c, d] = arr; - let a_mod_p = M31Impl::reduce_u32(a); - let b_mod_p = M31Impl::reduce_u32(b); - let c_mod_p = M31Impl::reduce_u32(c); - let d_mod_p = M31Impl::reduce_u32(d); - - QM31 { a: CM31 { a: a_mod_p, b: b_mod_p }, b: CM31 { a: c_mod_p, b: d_mod_p } } - } #[inline] fn to_array(self: QM31) -> [M31; 4] { [self.a.a, self.a.b, self.b.a, self.b.b] } + fn inverse(self: QM31) -> QM31 { assert_ne!(self, Zero::zero()); let b2 = self.b * self.b; @@ -43,6 +35,7 @@ pub impl QM31Impl of QM31Trait { let denom_inverse = denom.inverse(); QM31 { a: self.a * denom_inverse, b: -self.b * denom_inverse } } + fn mul_m31(self: QM31, multiplier: M31) -> QM31 { QM31 { a: CM31 { a: self.a.a * multiplier, b: self.a.b * multiplier }, @@ -139,15 +132,4 @@ mod tests { assert_eq!(qm1 * m.inverse().into(), qm1 * qm.inverse()); assert_eq!(qm1.mul_m31(m), qm1 * m.into()); } - - #[test] - fn test_qm31_from_u32() { - let arr = [2147483648, 2, 3, 4]; - let felt = QM31Impl::from_u32(arr); - let expected_felt = QM31 { - a: CM31 { a: m31(1), b: m31(2) }, b: CM31 { a: m31(3), b: m31(4) } - }; - - assert_eq!(felt, expected_felt) - } } diff --git a/stwo_cairo_verifier/src/fri.cairo b/stwo_cairo_verifier/src/fri.cairo index 8ae733d7..11fabf29 100644 --- a/stwo_cairo_verifier/src/fri.cairo +++ b/stwo_cairo_verifier/src/fri.cairo @@ -3,7 +3,7 @@ use stwo_cairo_verifier::channel::{Channel, ChannelTrait}; use stwo_cairo_verifier::circle::CosetImpl; use stwo_cairo_verifier::fields::m31::M31; use stwo_cairo_verifier::fields::m31::M31Trait; -use stwo_cairo_verifier::fields::qm31::{QM31, QM31Trait}; +use stwo_cairo_verifier::fields::qm31::{QM31, QM31Zero, QM31Trait}; use stwo_cairo_verifier::poly::circle::CircleDomainImpl; use stwo_cairo_verifier::poly::circle::{ CircleEvaluation, SparseCircleEvaluation, SparseCircleEvaluationImpl @@ -15,7 +15,7 @@ use stwo_cairo_verifier::poly::line::{LineDomain, LineDomainImpl}; use stwo_cairo_verifier::poly::line::{LinePoly, LinePolyImpl}; use stwo_cairo_verifier::queries::SparseSubCircleDomain; use stwo_cairo_verifier::queries::{Queries, QueriesImpl}; -use stwo_cairo_verifier::utils::{bit_reverse_index, pow, pow_qm31, qm31_zero_array, find}; +use stwo_cairo_verifier::utils::{bit_reverse_index, ArrayImpl, pow, pow_qm31, find}; use stwo_cairo_verifier::vcs::hasher::PoseidonMerkleHasher; use stwo_cairo_verifier::vcs::verifier::{MerkleDecommitment, MerkleVerifier, MerkleVerifierTrait}; @@ -48,7 +48,7 @@ impl FriLayerVerifierImpl of FriLayerVerifierTrait { ) -> Result<(Queries, Array), FriVerificationError> { let commitment = self.proof.commitment; - let sparse_evaluation = @self.extract_evaluation(queries, evals_at_queries)?; + let sparse_evaluation = self.extract_evaluation(queries, evals_at_queries)?; let mut column_0: Array = array![]; let mut column_1: Array = array![]; let mut column_2: Array = array![]; @@ -335,7 +335,7 @@ pub impl FriVerifierImpl of FriVerifierTrait { let circle_poly_alpha_sq = *circle_poly_alpha * *circle_poly_alpha; let mut layer_queries = queries.fold(CIRCLE_TO_LINE_FOLD_STEP); - let mut layer_query_evals = qm31_zero_array(layer_queries.len()); + let mut layer_query_evals = ArrayImpl::new_repeated(layer_queries.len(), QM31Zero::zero()); let mut inner_layers_index = 0; let mut column_bound_index = 0; @@ -398,23 +398,23 @@ pub impl FriVerifierImpl of FriVerifierTrait { fn decommit_last_layer( self: @FriVerifier, queries: Queries, query_evals: Array, ) -> Result<(), FriVerificationError> { - let mut failed = false; + let FriVerifier { last_layer_domain, last_layer_poly, .. } = self; + let mut i = 0; - while i < queries.positions.len() { - let query = queries.positions[i]; - let domain = self.last_layer_domain; - let x = self.last_layer_domain.at(bit_reverse_index(*query, domain.log_size())); + loop { + if i == queries.positions.len() { + break Result::Ok(()); + } - if *query_evals[i] != self.last_layer_poly.eval_at_point(x.into()) { - failed = true; - break; + let query = *queries.positions[i]; + let query_eval = *query_evals[i]; + let x = last_layer_domain.at(bit_reverse_index(query, last_layer_domain.log_size())); + + if query_eval != last_layer_poly.eval_at_point(x.into()) { + break Result::Err(FriVerificationError::LastLayerEvaluationsInvalid); } + i += 1; - }; - if failed { - return Result::Err(FriVerificationError::LastLayerEvaluationsInvalid); - } else { - Result::Ok(()) } } @@ -483,43 +483,41 @@ fn get_opening_positions( /// element and `pi(x) = 2x^2 - 1` be the circle's x-coordinate doubling map. This function /// returns `f' = f0 + alpha * f1` evaluated on `pi(E)` such that `2f(x) = f0(pi(x)) + x * /// f1(pi(x))`. -pub fn fold_line(eval: @LineEvaluation, alpha: QM31) -> LineEvaluation { +pub fn fold_line(eval: LineEvaluation, alpha: QM31) -> LineEvaluation { let domain = eval.domain; let mut values = array![]; - let mut i = 0; - while i < eval.values.len() / 2 { - let x = domain.at(bit_reverse_index(i * pow(2, FOLD_STEP), domain.log_size())); - let f_x = eval.values[2 * i]; - let f_neg_x = eval.values[2 * i + 1]; - let (f0, f1) = ibutterfly(*f_x, *f_neg_x, x.inverse()); - values.append(f0 + alpha * f1); - i += 1; - }; + for i in 0 + ..eval.values.len() + / 2 { + let x = domain.at(bit_reverse_index(i * pow(2, FOLD_STEP), domain.log_size())); + let f_x = eval.values[2 * i]; + let f_neg_x = eval.values[2 * i + 1]; + let (f0, f1) = ibutterfly(*f_x, *f_neg_x, x.inverse()); + values.append(f0 + alpha * f1); + }; LineEvaluationImpl::new(domain.double(), values) } -/// Folds and accumulates a degree `d` circle polynomial into a degree `d/2` univariate -/// polynomial. +/// Folds and accumulates a degree `d` circle polynomial into a degree `d/2` univariate polynomial. /// -/// Let `src` be the evaluation of a circle polynomial `f` on a -/// [`CircleDomain`] `E`. This function computes evaluations of `f' = f0 -/// + alpha * f1` on the x-coordinates of `E` such that `2f(p) = f0(px) + py * f1(px)`. The -/// evaluations of `f'` are accumulated into `dst` by the formula `dst = dst * alpha^2 + -/// f'`. -pub fn fold_circle_into_line(eval: @CircleEvaluation, alpha: QM31) -> LineEvaluation { +/// Let `src` be the evaluation of a circle polynomial `f` on a [`CircleDomain`] `E`. This function +/// computes evaluations of `f' = f0 + alpha * f1` on the x-coordinates of `E` such that `2f(p) = +/// f0(px) + py * f1(px)`. The evaluations of `f'` are accumulated into `dst` by the formula +/// `dst = dst * alpha^2 + f'`. +pub fn fold_circle_into_line(eval: CircleEvaluation, alpha: QM31) -> LineEvaluation { let domain = eval.domain; + let fold_factor = pow(2, CIRCLE_TO_LINE_FOLD_STEP); let mut values = array![]; - let mut i = 0; - while i < eval.values.len() / 2 { - let p = domain - .at(bit_reverse_index(i * pow(2, CIRCLE_TO_LINE_FOLD_STEP), domain.log_size())); - let f_p = eval.values[2 * i]; - let f_neg_p = eval.values[2 * i + 1]; - let (f0, f1) = ibutterfly(*f_p, *f_neg_p, p.y.inverse()); - values.append(f0 + alpha * f1); - i += 1; - }; - LineEvaluation { values, domain: LineDomainImpl::new(*domain.half_coset) } + for i in 0 + ..eval.bit_reversed_values.len() + / 2 { + let p = domain.at(bit_reverse_index(i * fold_factor, domain.log_size())); + let f_p = eval.bit_reversed_values[2 * i]; + let f_neg_p = eval.bit_reversed_values[2 * i + 1]; + let (f0, f1) = ibutterfly(*f_p, *f_neg_p, p.y.inverse()); + values.append(f0 + alpha * f1); + }; + LineEvaluation { values, domain: LineDomainImpl::new(domain.half_coset) } } pub fn ibutterfly(v0: QM31, v1: QM31, itwid: M31) -> (QM31, QM31) { @@ -639,11 +637,10 @@ mod test { #[test] fn proof_with_removed_layer_fails_verification() { let (config, proof, bounds, _queries, _decommitted_values) = proof_with_mixed_degree_1(); - let mut invalid_config = config; invalid_config.log_last_layer_degree_bound -= 1; - let mut channel = ChannelTrait::new(0x00); + let result = FriVerifierImpl::commit(ref channel, invalid_config, proof, bounds); match result { diff --git a/stwo_cairo_verifier/src/lib.cairo b/stwo_cairo_verifier/src/lib.cairo index 1f422bbc..dda065f0 100644 --- a/stwo_cairo_verifier/src/lib.cairo +++ b/stwo_cairo_verifier/src/lib.cairo @@ -4,7 +4,6 @@ mod fields; mod fri; mod poly; mod queries; -mod sort; mod utils; mod vcs; diff --git a/stwo_cairo_verifier/src/poly/circle.cairo b/stwo_cairo_verifier/src/poly/circle.cairo index de646bc6..f803c44b 100644 --- a/stwo_cairo_verifier/src/poly/circle.cairo +++ b/stwo_cairo_verifier/src/poly/circle.cairo @@ -55,25 +55,21 @@ pub impl CircleDomainImpl of CircleDomainTrait { /// The values are ordered according to the [CircleDomain] ordering. #[derive(Debug, Drop, Clone, PartialEq)] pub struct CircleEvaluation { - pub values: Array, + pub bit_reversed_values: Array, pub domain: CircleDomain, - _eval_order: BitReversedOrder } -#[derive(Debug, Drop, Clone, PartialEq)] -struct BitReversedOrder {} - #[generate_trait] pub impl CircleEvaluationImpl of CircleEvaluationTrait { - fn new(domain: CircleDomain, values: Array) -> CircleEvaluation { - CircleEvaluation { values: values, domain: domain, _eval_order: BitReversedOrder {} } + fn new(domain: CircleDomain, bit_reversed_values: Array) -> CircleEvaluation { + CircleEvaluation { bit_reversed_values, domain } } } /// Holds a foldable subset of circle polynomial evaluations. #[derive(Drop, Clone, Debug, PartialEq)] pub struct SparseCircleEvaluation { - pub subcircle_evals: Array + pub subcircle_evals: Array, } #[generate_trait] @@ -92,10 +88,10 @@ pub impl SparseCircleEvaluationImpl of SparseCircleEvaluationImplTrait { let lhs = self.subcircle_evals[i]; let rhs = rhs.subcircle_evals[i]; let mut values = array![]; - assert_eq!(lhs.values.len(), rhs.values.len()); + assert_eq!(lhs.bit_reversed_values.len(), rhs.bit_reversed_values.len()); let mut j = 0; - while j < lhs.values.len() { - values.append(*lhs.values[j] * alpha + *rhs.values[j]); + while j < lhs.bit_reversed_values.len() { + values.append(*lhs.bit_reversed_values[j] * alpha + *rhs.bit_reversed_values[j]); j += 1; }; subcircle_evals.append(CircleEvaluationImpl::new(*lhs.domain, values)); @@ -105,15 +101,13 @@ pub impl SparseCircleEvaluationImpl of SparseCircleEvaluationImplTrait { SparseCircleEvaluation { subcircle_evals } } - fn fold(self: @SparseCircleEvaluation, alpha: QM31) -> Array { - let mut i = 0; - let mut res: Array = array![]; - while i < self.subcircle_evals.len() { - let circle_evaluation = fold_circle_into_line(self.subcircle_evals[i], alpha); - res.append(*circle_evaluation.values.at(0)); - i += 1; - }; - return res; + fn fold(self: SparseCircleEvaluation, alpha: QM31) -> Array { + let mut res = array![]; + for eval in self + .subcircle_evals { + res.append(*fold_circle_into_line(eval, alpha).values[0]) + }; + res } } diff --git a/stwo_cairo_verifier/src/poly/line.cairo b/stwo_cairo_verifier/src/poly/line.cairo index 7ea626ab..1f0bef4d 100644 --- a/stwo_cairo_verifier/src/poly/line.cairo +++ b/stwo_cairo_verifier/src/poly/line.cairo @@ -117,13 +117,10 @@ pub struct SparseLineEvaluation { #[generate_trait] pub impl SparseLineEvaluationImpl of SparseLineEvaluationTrait { - fn fold(self: @SparseLineEvaluation, alpha: QM31) -> Array { - let mut i = 0; - let mut res: Array = array![]; - while i < self.subline_evals.len() { - let line_evaluation = fold_line(self.subline_evals[i], alpha); - res.append(*line_evaluation.values.at(0)); - i += 1; + fn fold(self: SparseLineEvaluation, alpha: QM31) -> Array { + let mut res = array![]; + for eval in self.subline_evals { + res.append(*fold_line(eval, alpha).values[0]); }; res } @@ -142,23 +139,18 @@ mod tests { fn bad_line_domain() { // This coset doesn't have points with unique x-coordinates. let coset = CosetImpl::odds(2); - LineDomainImpl::new(coset); } #[test] fn line_domain_of_size_two_works() { - let LOG_SIZE: u32 = 1; - let coset = CosetImpl::new(CirclePointIndexImpl::new(0), LOG_SIZE); - + let coset = CosetImpl::new(CirclePointIndexImpl::new(0), 1); LineDomainImpl::new(coset); } #[test] fn line_domain_of_size_one_works() { - let LOG_SIZE: u32 = 0; - let coset = CosetImpl::new(CirclePointIndexImpl::new(0), LOG_SIZE); - + let coset = CosetImpl::new(CirclePointIndexImpl::new(0), 0); LineDomainImpl::new(coset); } @@ -172,9 +164,10 @@ mod tests { log_size: 1 }; let x = m31(590768354); + let result = line_poly.eval_at_point(x.into()); - let expected_result = qm31(515899232, 1030391528, 1006544539, 11142505); - assert_eq!(expected_result, result); + + assert_eq!(result, qm31(515899232, 1030391528, 1006544539, 11142505)); } #[test] @@ -183,9 +176,10 @@ mod tests { coeffs: array![qm31(1, 2, 3, 4), qm31(5, 6, 7, 8)], log_size: 1 }; let x = m31(10); + let result = line_poly.eval_at_point(x.into()); - let expected_result = qm31(51, 62, 73, 84); - assert_eq!(expected_result, result); + + assert_eq!(result, qm31(51, 62, 73, 84)); } #[test] @@ -207,7 +201,6 @@ mod tests { let result = poly.eval_at_point(x); - let expected_result = qm31(1857853974, 839310133, 939318020, 651207981); - assert_eq!(expected_result, result); + assert_eq!(result, qm31(1857853974, 839310133, 939318020, 651207981)); } } diff --git a/stwo_cairo_verifier/src/queries.cairo b/stwo_cairo_verifier/src/queries.cairo index 67d64bff..56560cd2 100644 --- a/stwo_cairo_verifier/src/queries.cairo +++ b/stwo_cairo_verifier/src/queries.cairo @@ -1,8 +1,7 @@ use stwo_cairo_verifier::channel::{Channel, ChannelTrait}; use stwo_cairo_verifier::circle::CosetImpl; use stwo_cairo_verifier::poly::circle::{CircleDomain, CircleDomainImpl}; -use stwo_cairo_verifier::sort::MinimumToMaximumSortedIterator; -use super::utils::{pow, bit_reverse_index, find}; +use super::utils::{pow, bit_reverse_index, find, ArrayImpl}; /// An ordered set of query indices over a bit reversed [CircleDomain]. @@ -16,7 +15,7 @@ pub struct Queries { pub impl QueriesImpl of QueriesImplTrait { /// Randomizes a set of query indices uniformly over the range [0, 2^`log_query_size`). fn generate(ref channel: Channel, log_domain_size: u32, n_queries: usize) -> Queries { - let mut nonsorted_positions = array![]; + let mut unsorted_positions = array![]; let max_query = pow(2, log_domain_size) - 1; let mut finished = false; loop { @@ -27,9 +26,9 @@ pub impl QueriesImpl of QueriesImplTrait { let b1: u32 = (*random_bytes[i + 1]).into(); let b2: u32 = (*random_bytes[i + 2]).into(); let b3: u32 = (*random_bytes[i + 3]).into(); - nonsorted_positions - .append((b0 + 256 * b1 + 65536 * b2 + 16777216 * b3) & max_query); - if nonsorted_positions.len() == n_queries { + let position = (((b3 * 0x100 + b2) * 0x100 + b1) * 0x100 + b0) & max_query; + unsorted_positions.append(position); + if unsorted_positions.len() == n_queries { finished = true; break; } @@ -40,13 +39,7 @@ pub impl QueriesImpl of QueriesImplTrait { } }; - let mut positions = array![]; - let mut iterator = MinimumToMaximumSortedIterator::iterate(nonsorted_positions.span()); - while let Option::Some((_, x)) = iterator.next_deduplicated() { - positions.append(x); - }; - - Queries { positions, log_domain_size } + Queries { positions: unsorted_positions.sort_ascending().dedup(), log_domain_size } } fn len(self: @Queries) -> usize { @@ -98,7 +91,7 @@ pub impl QueriesImpl of QueriesImplTrait { } } -/// Represents a circle domain relative to a larger circle domain. The `initial_index` is the bit +/// Represents a circle domain relative to a larger circle domain. The `coset_index` is the bit /// reversed query index in the larger domain. #[derive(Drop, Debug, Copy)] pub struct SubCircleDomain { @@ -111,13 +104,12 @@ pub struct SubCircleDomain { pub impl SubCircleDomainImpl of SubCircleDomainTrait { /// Calculates the decommitment positions needed for each query given the fri step size. fn to_decommitment_positions(self: @SubCircleDomain) -> Array { + let sub_circle_size = pow(2, *self.log_size); + let start = *self.coset_index * sub_circle_size; + let end = start + sub_circle_size; let mut res = array![]; - let start = *self.coset_index * pow(2, *self.log_size); - let end = (*self.coset_index + 1) * pow(2, *self.log_size); - let mut i = start; - while i < end { + for i in start..end { res.append(i); - i = i + 1; }; res } @@ -142,16 +134,13 @@ pub struct SparseSubCircleDomain { pub impl SparseSubCircleDomainImpl of SparseSubCircleDomainTrait { fn flatten(self: @SparseSubCircleDomain) -> Array { let mut res = array![]; - let mut i = 0; - while i < self.domains.len() { - let positions = self.domains[i].to_decommitment_positions(); - let mut j = 0; - while j < positions.len() { - res.append(*positions[j]); - j = j + 1; + for domain in self + .domains + .span() { + for position in domain.to_decommitment_positions() { + res.append(position); + }; }; - i = i + 1; - }; res } } diff --git a/stwo_cairo_verifier/src/sort.cairo b/stwo_cairo_verifier/src/sort.cairo deleted file mode 100644 index 9f4851c8..00000000 --- a/stwo_cairo_verifier/src/sort.cairo +++ /dev/null @@ -1,196 +0,0 @@ -use core::array::ArrayTrait; -use core::array::ToSpanTrait; -use core::option::OptionTrait; - -trait Compare { - fn compare(self: @C, a: T, b: T) -> bool; -} - -#[derive(Drop, Copy)] -pub struct LowerThan {} - -impl LowerThanCompare> of Compare { - fn compare(self: @LowerThan, a: T, b: T) -> bool { - return a < b; - } -} - -#[derive(Drop, Copy)] -pub struct GreaterThan {} - -impl GreaterThanCompare, +Copy, +Drop> of Compare { - fn compare(self: @GreaterThan, a: T, b: T) -> bool { - return a > b; - } -} - -#[derive(Drop)] -pub struct SortedIterator { - comparer: C, - array: Span, - last_index: Option, -} - -trait SortedIteratorTrait< - T, C, +PartialOrd, +PartialEq, +Copy, +Drop, +Compare, +Drop, +Copy -> { - fn iterate(array_to_iterate: Span) -> SortedIterator; - - fn next_deduplicated( - ref self: SortedIterator - ) -> Option<(u32, T)> { - next_deduplicated::(ref self) - } - - fn next( - ref self: SortedIterator - ) -> Option< - (u32, T) - > { - if self.last_index.is_some() { - let last_index = self.last_index.unwrap(); - let last_value = *self.array[last_index]; - let mut is_repeated = false; - - let mut i = last_index + 1; - while i < self.array.len() { - if *self.array[i] == last_value { - is_repeated = true; - self.last_index = Option::Some(i); - break; - } - i += 1; - }; - - if is_repeated { - return Option::Some((self.last_index.unwrap(), last_value)); - } - } - next_deduplicated::(ref self) - } -} - -fn next_deduplicated< - T, C, +PartialOrd, +PartialEq, +Copy, +Drop, +Compare, +Drop, +Copy ->( - ref self: SortedIterator -) -> Option<(u32, T)> { - let mut candidate_index = Option::None; - let mut candidate_value = Option::None; - - let last_value = if let Option::Some(last_index) = self.last_index { - Option::Some(*self.array[last_index]) - } else { - Option::None - }; - - let mut i = 0; - while i < self.array.len() { - let is_better_than_last = if let Option::Some(last_value) = last_value { - self.comparer.compare(last_value, *self.array[i]) - } else { - true - }; - let is_nearer_than_candidate = if let Option::Some(candidate_value) = candidate_value { - self.comparer.compare(*self.array[i], candidate_value) - } else { - true - }; - if is_better_than_last && is_nearer_than_candidate { - candidate_index = Option::Some(i); - candidate_value = Option::Some(*self.array[i]); - } - i += 1; - }; - - if candidate_value.is_none() { - Option::None - } else { - self.last_index = candidate_index; - Option::Some((candidate_index.unwrap(), candidate_value.unwrap())) - } -} - -pub impl MaximumToMinimumSortedIterator< - T, +PartialOrd, +PartialEq, +Copy, +Drop -> of SortedIteratorTrait { - fn iterate(array_to_iterate: Span) -> SortedIterator { - SortedIterator { - comparer: GreaterThan {}, array: array_to_iterate, last_index: Option::None - } - } -} - -pub impl MinimumToMaximumSortedIterator< - T, +PartialOrd, +PartialEq, +Copy, +Drop -> of SortedIteratorTrait { - fn iterate(array_to_iterate: Span) -> SortedIterator { - SortedIterator { comparer: LowerThan {}, array: array_to_iterate, last_index: Option::None } - } -} - - -#[test] -fn test_sort_lowest_to_greatest() { - let my_array: Array = array![3, 5, 2, 4]; - let expected_array: Array = array![2, 3, 4, 5]; - - let mut sorted_array = array![]; - - let mut iterator = MinimumToMaximumSortedIterator::iterate(my_array.span()); - while let Option::Some((_index, value)) = iterator.next_deduplicated() { - sorted_array.append(value); - }; - - assert_eq!(expected_array, sorted_array); -} - -#[test] -fn test_sort_greatest_to_lowest() { - let my_array: Array = array![3, 5, 2, 4]; - let expected_array: Array = array![5, 4, 3, 2]; - - let mut sorted_array = array![]; - - let mut iterator = MaximumToMinimumSortedIterator::iterate(my_array.span()); - while let Option::Some((_index, value)) = iterator.next_deduplicated() { - sorted_array.append(value); - }; - - assert_eq!(expected_array, sorted_array); -} - -#[test] -fn test_sort_indexes_are_correct() { - let my_array: Array = array![3, 5, 2, 4]; - let expected_indexes: Array = array![2, 0, 3, 1]; - - let mut sorted_indexes = array![]; - - let mut iterator = MinimumToMaximumSortedIterator::iterate(my_array.span()); - while let Option::Some((index, _value)) = iterator.next_deduplicated() { - sorted_indexes.append(index); - }; - - assert_eq!(expected_indexes, sorted_indexes); -} - -#[test] -fn test_sort_with_duplicates() { - let my_array: Array = array![3, 5, 2, 3, 4, 3, 4]; - let expected_indexes: Array = array![2, 0, 3, 5, 4, 6, 1]; - let expected_array: Array = array![2, 3, 3, 3, 4, 4, 5]; - - let mut sorted_indexes = array![]; - let mut sorted_array = array![]; - - let mut iterator = MinimumToMaximumSortedIterator::iterate(my_array.span()); - while let Option::Some((index, value)) = iterator.next() { - sorted_array.append(value); - sorted_indexes.append(index); - }; - - assert_eq!(expected_indexes, sorted_indexes); - assert_eq!(expected_array, sorted_array); -} - diff --git a/stwo_cairo_verifier/src/utils.cairo b/stwo_cairo_verifier/src/utils.cairo index ea791d39..400910e0 100644 --- a/stwo_cairo_verifier/src/utils.cairo +++ b/stwo_cairo_verifier/src/utils.cairo @@ -3,6 +3,8 @@ use core::box::BoxTrait; use core::dict::Felt252Dict; use core::dict::Felt252DictEntryTrait; use core::dict::Felt252DictTrait; +use core::iter::IntoIterator; +use core::iter::Iterator; use core::num::traits::BitSize; use core::traits::DivRem; use core::traits::PanicDestruct; @@ -42,9 +44,68 @@ pub impl ArrayImpl, +Drop> of ArrayExTrait { }; res } + fn max<+PartialOrd>(mut self: @Array) -> Option<@T> { self.span().max() } + + /// Sorts an array in ascending order. Uses quicksort algorithm. + fn sort_ascending<+PartialOrd>(self: Array) -> Array { + if self.len() <= 1 { + return self; + } + + let mut lhs = array![]; + let mut rhs = array![]; + let mut iter = self.into_iter(); + let pivot = iter.next().unwrap(); + + for v in iter { + if v > pivot { + rhs.append(v); + } else { + lhs.append(v); + } + }; + + let mut res = lhs.sort_ascending(); + res.append(pivot); + + for v in rhs.sort_ascending() { + res.append(v); + }; + + res + } + + /// Removes consecutive repeated elements. + /// + /// If the vector is sorted, this removes all duplicates. + fn dedup<+PartialEq>(self: Array) -> Array { + if self.len() == 0 { + return array![]; + } + + let mut iter = self.into_iter(); + let mut last_value = iter.next().unwrap(); + let mut res = array![last_value]; + for value in iter { + if value != last_value { + res.append(value); + last_value = value; + } + }; + + res + } + + fn new_repeated(n: usize, v: T) -> Array { + let mut res = array![]; + for _ in 0..n { + res.append(v); + }; + res + } } #[generate_trait] @@ -142,19 +203,9 @@ pub fn pow_qm31(base: QM31, mut exponent: u32) -> QM31 { result } -pub fn qm31_zero_array(n: u32) -> Array { - let mut result = array![]; - let mut i = 0; - while i < n { - result.append(qm31(0, 0, 0, 0)); - i += 1; - }; - result -} - #[cfg(test)] mod tests { - use super::{pow, pow_qm31, qm31, bit_reverse_index}; + use super::{pow, pow_qm31, qm31, bit_reverse_index, ArrayImpl}; #[test] fn test_pow() { @@ -210,5 +261,20 @@ mod tests { let expected_result = qm31(1394542587, 260510989, 997191897, 2127074080); assert_eq!(expected_result, result) } + + #[test] + fn test_sort_ascending() { + assert_eq!(array![6_usize, 5, 1, 4, 2, 3].sort_ascending(), array![1, 2, 3, 4, 5, 6]); + } + + #[test] + fn test_dedup() { + assert_eq!(array![1_usize, 1, 1, 2, 2, 3, 4, 5, 5, 5].dedup(), array![1, 2, 3, 4, 5]); + } + + #[test] + fn test_array_new_repeated() { + assert_eq!(ArrayImpl::new_repeated(5, 3_usize), array![3, 3, 3, 3, 3]); + } }