Skip to content

Commit

Permalink
Refactor Coset with CirclePointIndex
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewmilson committed Oct 14, 2024
1 parent 5f4751e commit b51e6fd
Show file tree
Hide file tree
Showing 5 changed files with 166 additions and 183 deletions.
239 changes: 117 additions & 122 deletions stwo_cairo_verifier/src/circle.cairo
Original file line number Diff line number Diff line change
@@ -1,14 +1,24 @@
use core::num::traits::one::One;
use core::num::traits::zero::Zero;
use core::num::traits::{WrappingAdd, WideMul};
use core::num::traits::{WrappingAdd, WrappingSub, WrappingMul};
use stwo_cairo_verifier::fields::cm31::CM31;
use stwo_cairo_verifier::fields::m31::{M31, M31Impl};
use stwo_cairo_verifier::fields::qm31::{QM31Impl, QM31, QM31Trait};
use super::utils::pow;

/// A generator for the circle group over [`M31`].
pub const M31_CIRCLE_GEN: CirclePoint<M31> =
CirclePoint { x: M31 { inner: 2 }, y: M31 { inner: 1268011823 }, };

pub const M31_CIRCLE_LOG_ORDER: u32 = 31;

/// Equals `2^31`.
pub const M31_CIRCLE_ORDER: u32 = 2147483648;

/// Equals `2^31 - 1`.
pub const M31_CIRCLE_ORDER_BIT_MASK: u32 = 0x7fffffff;

/// A generator for the circle group over [`QM31`].
pub const QM31_CIRCLE_GEN: CirclePoint<QM31> =
CirclePoint {
x: QM31 {
Expand All @@ -21,16 +31,8 @@ pub const QM31_CIRCLE_GEN: CirclePoint<QM31> =
},
};

pub const CIRCLE_LOG_ORDER: u32 = 31;

// `CIRCLE_ORDER` equals 2^31
pub const CIRCLE_ORDER: u32 = 2147483648;

// `CIRCLE_ORDER_BIT_MASK` equals 2^31 - 1
pub const CIRCLE_ORDER_BIT_MASK: u32 = 0x7fffffff;

// `U32_BIT_MASK` equals 2^32 - 1
pub const U32_BIT_MASK: u64 = 0xffffffff;
/// Order of [`QM31_CIRCLE_GEN`].
pub const QM31_CIRCLE_ORDER: u128 = 21267647892944572736998860269687930880;

/// A point on the complex circle. Treated as an additive group.
#[derive(Drop, Copy, Debug, PartialEq)]
Expand Down Expand Up @@ -71,18 +73,20 @@ pub trait CirclePointTrait<
}

fn mul(
self: @CirclePoint<F>, ref scalar: u128
self: @CirclePoint<F>, scalar: u128
) -> CirclePoint<
F
> {
// TODO: `mut scalar: u128` doesn't work in trait.
let mut scalar = scalar;
let mut result = Self::zero();
let mut cur = *self;
while scalar > 0 {
while scalar != 0 {
if scalar & 1 == 1 {
result = result + cur;
}
cur = cur + cur;
scalar = scalar / 2;
scalar /= 2;
};
result
}
Expand All @@ -109,44 +113,34 @@ pub impl ComplexConjugateImpl of ComplexConjugateTrait {
/// Represents the coset `initial + <step>`.
#[derive(Copy, Clone, Debug, PartialEq, Drop)]
pub struct Coset {
// This is an index in the range [0, 2^31)
pub initial_index: usize,
pub step_size: usize,
pub initial_index: CirclePointIndex,
pub step_size: CirclePointIndex,
pub log_size: u32,
}

#[generate_trait]
pub impl CosetImpl of CosetTrait {
fn new(initial_index: usize, log_size: u32) -> Coset {
assert!(initial_index < CIRCLE_ORDER);
let step_size = pow(2, CIRCLE_LOG_ORDER - log_size);
fn new(initial_index: CirclePointIndex, log_size: u32) -> Coset {
let step_size = CirclePointIndexImpl::subgroup_gen(log_size);
Coset { initial_index, step_size, log_size }
}

fn index_at(self: @Coset, index: usize) -> usize {
let index_times_step = ((*self.step_size).wide_mul(index) & U32_BIT_MASK)
.try_into()
.unwrap();
let result = (*self.initial_index).wrapping_add(index_times_step);
result & CIRCLE_ORDER_BIT_MASK
fn index_at(self: @Coset, index: usize) -> CirclePointIndex {
*self.initial_index + self.step_size.mul(index)
}

fn double(self: @Coset) -> Coset {
assert!(*self.log_size > 0);
let double_initial_index = *self.initial_index * 2;
let double_step_size = *self.step_size * 2;
let new_log_size = *self.log_size - 1;

Coset {
initial_index: double_initial_index & CIRCLE_ORDER_BIT_MASK,
step_size: double_step_size & CIRCLE_ORDER_BIT_MASK,
log_size: new_log_size
initial_index: *self.initial_index + *self.initial_index,
step_size: *self.step_size + *self.step_size,
log_size: *self.log_size - 1
}
}

#[inline]
fn at(self: @Coset, index: usize) -> CirclePoint<M31> {
let mut scalar = self.index_at(index).into();
M31_CIRCLE_GEN.mul(ref scalar)
self.index_at(index).to_point()
}

/// Returns the size of the coset.
Expand All @@ -158,35 +152,95 @@ pub impl CosetImpl of CosetTrait {
///
/// For example, for `n=8`, we get the point indices `[1,3,5,7,9,11,13,15]`.
fn odds(log_size: u32) -> Coset {
let subgroup_generator_index = Self::subgroup_generator_index(log_size);
Self::new(subgroup_generator_index, log_size)
Self::new(CirclePointIndexImpl::subgroup_gen(log_size + 1), log_size)
}

/// Creates a coset of the form `G_4n + <G_n>`.
///
///
/// For example, for `n=8`, we get the point indices `[1,5,9,13,17,21,25,29]`.
/// Its conjugate will be `[3,7,11,15,19,23,27,31]`.
fn half_odds(log_size: u32) -> Coset {
Self::new(Self::subgroup_generator_index(log_size + 2), log_size)
Self::new(CirclePointIndexImpl::subgroup_gen(log_size + 2), log_size)
}
}

/// Integer `i` that represent the circle point `i * M31_CIRCLE_GEN`.
///
/// Treated as an additive ring modulo `1 << M31_CIRCLE_LOG_ORDER`.
#[derive(Copy, Debug, Drop)]
pub struct CirclePointIndex {
/// The index, stored as an unreduced `u32` for performance reasons.
index: u32,
}

#[generate_trait]
pub impl CirclePointIndexImpl of CirclePointIndexTrait {
fn new(index: u32) -> CirclePointIndex {
CirclePointIndex { index }
}

fn zero() -> CirclePointIndex {
CirclePointIndex { index: 0 }
}

fn generator() -> CirclePointIndex {
CirclePointIndex { index: 1 }
}

fn reduce(self: @CirclePointIndex) -> CirclePointIndex {
CirclePointIndex { index: *self.index & M31_CIRCLE_ORDER_BIT_MASK }
}

fn subgroup_gen(log_size: u32) -> CirclePointIndex {
assert!(log_size <= M31_CIRCLE_LOG_ORDER);
CirclePointIndex { index: pow(2, M31_CIRCLE_LOG_ORDER - log_size) }
}

// TODO(andrew): When associated types are supported, support the Mul<Self, u32>.
fn mul(self: @CirclePointIndex, scalar: u32) -> CirclePointIndex {
CirclePointIndex { index: (*self.index).wrapping_mul(scalar) }
}

fn index(self: @CirclePointIndex) -> u32 {
self.reduce().index
}

fn to_point(self: @CirclePointIndex) -> CirclePoint<M31> {
// No need to call `reduce()`.
M31_CIRCLE_GEN.mul((*self.index).into())
}
}

impl CirclePointIndexAdd of Add<CirclePointIndex> {
#[inline]
fn add(lhs: CirclePointIndex, rhs: CirclePointIndex) -> CirclePointIndex {
CirclePointIndex { index: lhs.index.wrapping_add(rhs.index) }
}
}

fn subgroup_generator_index(log_size: u32) -> u32 {
assert!(log_size <= CIRCLE_LOG_ORDER);
pow(2, CIRCLE_LOG_ORDER - log_size)
impl CirclePointIndexNeg of Neg<CirclePointIndex> {
#[inline]
fn neg(a: CirclePointIndex) -> CirclePointIndex {
CirclePointIndex { index: M31_CIRCLE_ORDER.wrapping_sub(a.index) }
}
}

impl CirclePointIndexPartialEx of PartialEq<CirclePointIndex> {
fn eq(lhs: @CirclePointIndex, rhs: @CirclePointIndex) -> bool {
lhs.index() == rhs.index()
}

fn ne(lhs: @CirclePointIndex, rhs: @CirclePointIndex) -> bool {
lhs.index() != rhs.index()
}
}

#[cfg(test)]
mod tests {
use core::array::ArrayTrait;
use core::option::OptionTrait;
use core::traits::TryInto;
use stwo_cairo_verifier::fields::m31::m31;
use stwo_cairo_verifier::fields::qm31::QM31One;
use stwo_cairo_verifier::fields::qm31::{QM31One, qm31};
use stwo_cairo_verifier::utils::pow;
use super::{CirclePointQM31Impl, QM31_CIRCLE_GEN};
use super::{M31_CIRCLE_GEN, CIRCLE_ORDER, CirclePoint, CirclePointM31Impl, Coset, CosetImpl};
use super::{M31_CIRCLE_GEN, CirclePointQM31Impl, QM31_CIRCLE_GEN, M31_CIRCLE_ORDER, CirclePoint, CirclePointM31Impl, CirclePointIndexImpl, Coset, CosetImpl, QM31_CIRCLE_ORDER};

#[test]
fn test_add_1() {
Expand Down Expand Up @@ -224,17 +278,15 @@ mod tests {
#[test]
fn test_mul_1() {
let point_1 = CirclePoint { x: m31(750649172), y: m31(1991648574) };
let mut scalar = 5;
let result = point_1.mul(ref scalar);
let result = point_1.mul(5);

assert_eq!(result, point_1 + point_1 + point_1 + point_1 + point_1);
}

#[test]
fn test_mul_2() {
let point_1 = CirclePoint { x: m31(750649172), y: m31(1991648574) };
let mut scalar = 8;
let result = point_1.mul(ref scalar);
let result = point_1.mul(8);

assert_eq!(
result, point_1 + point_1 + point_1 + point_1 + point_1 + point_1 + point_1 + point_1
Expand All @@ -244,126 +296,69 @@ mod tests {
#[test]
fn test_mul_3() {
let point_1 = CirclePoint { x: m31(750649172), y: m31(1991648574) };
let mut scalar = 418776494;
let result = point_1.mul(ref scalar);
let result = point_1.mul(418776494);

assert_eq!(result, CirclePoint { x: m31(1987283985), y: m31(1500510905) });
}

#[test]
fn test_generator_order() {
let half_order = CIRCLE_ORDER / 2;
let mut scalar = half_order.into();
let mut result = M31_CIRCLE_GEN.mul(ref scalar);
let half_order = M31_CIRCLE_ORDER / 2;
let mut result = M31_CIRCLE_GEN.mul(half_order.into());

// Assert `M31_CIRCLE_GEN^{2^30}` equals `-1`.
assert_eq!(result, CirclePoint { x: -m31(1), y: m31(0) });
}

#[test]
fn test_generator() {
let mut scalar = pow(2, 30).try_into().unwrap();
let mut result = M31_CIRCLE_GEN.mul(ref scalar);
let mut result = M31_CIRCLE_GEN.mul(pow(2, 30).into());

assert_eq!(result, CirclePoint { x: -m31(1), y: m31(0) });
}

#[test]
fn test_coset_index_at() {
let coset = Coset { initial_index: 16777216, log_size: 5, step_size: 67108864 };
let coset = Coset { initial_index: CirclePointIndexImpl::new(16777216), log_size: 5, step_size: CirclePointIndexImpl::new(67108864) };
let result = coset.index_at(8);

assert_eq!(result, 553648128);
assert_eq!(result, CirclePointIndexImpl::new(553648128));
}

#[test]
fn test_coset_constructor() {
let result = CosetImpl::new(16777216, 5);
let result = CosetImpl::new(CirclePointIndexImpl::new(16777216), 5);

assert_eq!(result, Coset { initial_index: 16777216, log_size: 5, step_size: 67108864 });
assert_eq!(result, Coset { initial_index: CirclePointIndexImpl::new(16777216), log_size: 5, step_size: CirclePointIndexImpl::new(67108864) });
}

#[test]
fn test_coset_double() {
let coset = Coset { initial_index: 16777216, step_size: 67108864, log_size: 5 };
let coset = Coset { initial_index: CirclePointIndexImpl::new(16777216), step_size: CirclePointIndexImpl::new(67108864), log_size: 5 };
let result = coset.double();

assert_eq!(result, Coset { initial_index: 33554432, step_size: 134217728, log_size: 4 });
assert_eq!(result, Coset { initial_index: CirclePointIndexImpl::new(33554432), step_size: CirclePointIndexImpl::new(134217728), log_size: 4 });
}

#[test]
fn test_coset_at() {
let coset = Coset { initial_index: 16777216, step_size: 67108864, log_size: 5 };
let coset = Coset { initial_index: CirclePointIndexImpl::new(16777216), step_size: CirclePointIndexImpl::new(67108864), log_size: 5 };
let result = coset.at(17);

assert_eq!(result, CirclePoint { x: m31(7144319), y: m31(1742797653) });
}

#[test]
fn test_coset_size() {
let coset = Coset { initial_index: 16777216, step_size: 67108864, log_size: 5 };
let coset = Coset { initial_index: CirclePointIndexImpl::new(16777216), step_size: CirclePointIndexImpl::new(67108864), log_size: 5 };
let result = coset.size();

assert_eq!(result, 32);
}

#[test]
fn test_qm31_circle_gen() {
let P4: u128 = 21267647892944572736998860269687930881;

let first_prime = 2;
let last_prime = 368140581013;
let prime_factors: Array<(u128, u32)> = array![
(first_prime, 33),
(3, 2),
(5, 1),
(7, 1),
(11, 1),
(31, 1),
(151, 1),
(331, 1),
(733, 1),
(1709, 1),
(last_prime, 1),
];

let product = iter_product(first_prime, @prime_factors, last_prime);

assert_eq!(product, P4 - 1);

assert_eq!(
QM31_CIRCLE_GEN.x * QM31_CIRCLE_GEN.x + QM31_CIRCLE_GEN.y * QM31_CIRCLE_GEN.y,
QM31One::one()
);

let mut scalar = P4 - 1;
assert_eq!(QM31_CIRCLE_GEN.mul(ref scalar), CirclePointQM31Impl::zero());

let mut i = 0;
while i < prime_factors.len() {
let (p, _) = *prime_factors.at(i);
let mut scalar = (P4 - 1) / p.into();
assert_ne!(QM31_CIRCLE_GEN.mul(ref scalar), CirclePointQM31Impl::zero());

i = i + 1;
}
}

fn iter_product(
first_prime: u128, prime_factors: @Array<(u128, u32)>, last_prime: u128
) -> u128 {
let mut accum_product: u128 = 1;
accum_product = accum_product
* pow(first_prime.try_into().unwrap(), 31).into()
* 4; // * 2^33
let mut i = 1;
while i < prime_factors.len() - 1 {
let (prime, exponent): (u128, u32) = *prime_factors.at(i);
accum_product = accum_product * pow(prime.try_into().unwrap(), exponent).into();
i = i + 1;
};
accum_product = accum_product * last_prime;
accum_product
assert_eq!(QM31_CIRCLE_GEN.mul(QM31_CIRCLE_ORDER / 2), CirclePoint { x: -qm31(1, 0, 0, 0), y: qm31(0, 0, 0, 0) });
}
}

Loading

0 comments on commit b51e6fd

Please sign in to comment.