Skip to content

Commit

Permalink
Optimize Coset::at
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewmilson committed Nov 17, 2024
1 parent 123f2cb commit eec00b8
Show file tree
Hide file tree
Showing 7 changed files with 564 additions and 172 deletions.
136 changes: 58 additions & 78 deletions stwo_cairo_verifier/src/circle.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@ use core::num::traits::one::One;
use core::num::traits::zero::Zero;
use core::num::traits::{WrappingAdd, WrappingSub, WrappingMul};
use stwo_cairo_verifier::channel::{Channel, ChannelImpl};
use stwo_cairo_verifier::circle_mul_table::{
M31_CIRCLE_GEN_MUL_TABLE_BITS_24_TO_29, M31_CIRCLE_GEN_MUL_TABLE_BITS_18_TO_23,
M31_CIRCLE_GEN_MUL_TABLE_BITS_12_TO_17, M31_CIRCLE_GEN_MUL_TABLE_BITS_6_TO_11,
M31_CIRCLE_GEN_MUL_TABLE_BITS_0_TO_5
};
use stwo_cairo_verifier::fields::cm31::CM31;
use stwo_cairo_verifier::fields::m31::{M31, M31Impl};
use stwo_cairo_verifier::fields::qm31::{QM31Impl, QM31One, QM31, QM31Trait};
Expand Down Expand Up @@ -72,25 +77,6 @@ pub trait CirclePointTrait<
};
res
}

fn mul(
self: @CirclePoint<F>, scalar: u128
) -> CirclePoint<
F
> {
// TODO: `mut scalar: u128` doesn't work in trait.
let mut scalar = scalar;
let mut result = Self::zero();
let mut cur = *self;
while scalar != 0 {
if scalar & 1 == 1 {
result = result + cur;
}
cur = cur + cur;
scalar /= 2;
};
result
}
}

impl CirclePointAdd<F, +Add<F>, +Sub<F>, +Mul<F>, +Drop<F>, +Copy<F>> of Add<CirclePoint<F>> {
Expand Down Expand Up @@ -242,8 +228,39 @@ pub impl CirclePointIndexImpl of CirclePointIndexTrait {
}

fn to_point(self: @CirclePointIndex) -> CirclePoint<M31> {
const NZ_2_POW_24: NonZero<u32> = 0b1000000000000000000000000;
const NZ_2_POW_18: NonZero<u32> = 0b1000000000000000000;
const NZ_2_POW_12: NonZero<u32> = 0b1000000000000;
const NZ_2_POW_6: NonZero<u32> = 0b1000000;

// No need to call `reduce()`.
M31_CIRCLE_GEN.mul((*self.index).into())
// Start with MSBs since small domains (more common) have LSBs equal 0.
let (bits_24_to_31, bits_0_to_23) = DivRem::div_rem(*self.index, NZ_2_POW_24);
let (bits_30_to_31, bits_24_to_29) = DivRem::div_rem(bits_24_to_31, NZ_2_POW_6);
let mut res = *M31_CIRCLE_GEN_MUL_TABLE_BITS_24_TO_29.span()[bits_24_to_29];
if bits_0_to_23 != 0 {
let (bits_18_to_23, bits_0_to_17) = DivRem::div_rem(bits_0_to_23, NZ_2_POW_18);
res = res + *M31_CIRCLE_GEN_MUL_TABLE_BITS_18_TO_23.span()[bits_18_to_23];
if bits_0_to_17 != 0 {
let (bits_12_to_17, bits_0_to_11) = DivRem::div_rem(bits_0_to_17, NZ_2_POW_12);
res = res + *M31_CIRCLE_GEN_MUL_TABLE_BITS_12_TO_17.span()[bits_12_to_17];
if bits_0_to_11 != 0 {
let (bits_6_to_11, bits_0_to_5) = DivRem::div_rem(bits_0_to_11, NZ_2_POW_6);
res = res + *M31_CIRCLE_GEN_MUL_TABLE_BITS_6_TO_11.span()[bits_6_to_11];
if bits_0_to_5 != 0 {
res = res + *M31_CIRCLE_GEN_MUL_TABLE_BITS_0_TO_5.span()[bits_0_to_5];
}
}
}
}

// Note this applies the appropriate transformation based on the two highest bits.
// The highest bit has no effect. The 30th bit indicates weather to take the antipode.
if bits_30_to_31 == 0b11 || bits_30_to_31 == 0b01 {
res = CirclePoint { x: -res.x, y: -res.y };
}

res
}
}

Expand Down Expand Up @@ -274,14 +291,31 @@ impl CirclePointIndexPartialEx of PartialEq<CirclePointIndex> {
#[cfg(test)]
mod tests {
use stwo_cairo_verifier::fields::m31::m31;
use stwo_cairo_verifier::fields::qm31::{QM31One, qm31};
use stwo_cairo_verifier::utils::pow;
use stwo_cairo_verifier::fields::qm31::QM31One;
use super::{
M31_CIRCLE_GEN, CirclePointQM31Impl, QM31_CIRCLE_GEN, M31_CIRCLE_ORDER, CirclePoint,
CirclePointM31Impl, CirclePointIndexImpl, Coset, CosetImpl, QM31_CIRCLE_ORDER,
M31_CIRCLE_GEN, CirclePointQM31Impl, QM31_CIRCLE_GEN, CirclePoint, CirclePointM31Impl,
CirclePointIndex, CirclePointIndexImpl, Coset, CosetImpl,
CirclePointQM31AddCirclePointM31Impl
};


#[test]
fn test_to_point() {
let index = CirclePointIndex { index: 0b01111111111111111111111111111111 };
assert_eq!(index.to_point(), -M31_CIRCLE_GEN);
let index = CirclePointIndex { index: 0b00111111111111111111111111111111 };
assert_eq!(index.to_point(), CirclePoint { x: -M31_CIRCLE_GEN.x, y: M31_CIRCLE_GEN.y });
}


#[test]
fn test_to_point_with_unreduced_index() {
// All 32 bits are `1`.
let index = CirclePointIndex { index: 0b11111111111111111111111111111111 };

assert_eq!(index.to_point(), -M31_CIRCLE_GEN);
}

#[test]
fn test_add_1() {
let g4 = CirclePoint { x: m31(0), y: m31(1) };
Expand Down Expand Up @@ -315,52 +349,6 @@ mod tests {
assert_eq!(result, point_1.clone());
}

#[test]
fn test_mul_1() {
let point_1 = CirclePoint { x: m31(750649172), y: m31(1991648574) };

let result = point_1.mul(5);

assert_eq!(result, point_1 + point_1 + point_1 + point_1 + point_1);
}

#[test]
fn test_mul_2() {
let point_1 = CirclePoint { x: m31(750649172), y: m31(1991648574) };

let result = point_1.mul(8);

assert_eq!(
result, point_1 + point_1 + point_1 + point_1 + point_1 + point_1 + point_1 + point_1
);
}

#[test]
fn test_mul_3() {
let point_1 = CirclePoint { x: m31(750649172), y: m31(1991648574) };

let result = point_1.mul(418776494);

assert_eq!(result, CirclePoint { x: m31(1987283985), y: m31(1500510905) });
}

#[test]
fn test_generator_order() {
let half_order = M31_CIRCLE_ORDER / 2;

let mut result = M31_CIRCLE_GEN.mul(half_order.into());

// Assert `M31_CIRCLE_GEN^{2^30}` equals `-1`.
assert_eq!(result, CirclePoint { x: -m31(1), y: m31(0) });
}

#[test]
fn test_generator() {
let mut result = M31_CIRCLE_GEN.mul(pow(2, 30).into());

assert_eq!(result, CirclePoint { x: -m31(1), y: m31(0) });
}

#[test]
fn test_coset_index_at() {
let coset = Coset {
Expand Down Expand Up @@ -434,14 +422,6 @@ mod tests {
assert_eq!(result, 32);
}

#[test]
fn test_qm31_circle_gen() {
assert_eq!(
QM31_CIRCLE_GEN.mul(QM31_CIRCLE_ORDER / 2),
CirclePoint { x: -qm31(1, 0, 0, 0), y: qm31(0, 0, 0, 0) }
);
}

#[test]
fn test_add_circle_point_m31() {
assert_eq!(
Expand Down
Loading

0 comments on commit eec00b8

Please sign in to comment.