Skip to content

Commit

Permalink
Add unreduced field arithmetic (#116)
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewmilson authored Oct 20, 2024
1 parent 3948dca commit 8f2bf7f
Show file tree
Hide file tree
Showing 3 changed files with 508 additions and 4 deletions.
84 changes: 83 additions & 1 deletion stwo_cairo_verifier/src/fields/cm31.cairo
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use core::num::traits::{One, Zero};
use super::m31::{M31, m31, M31Trait};
use core::ops::{AddAssign, MulAssign, SubAssign};
use super::m31::{M31, M31Impl, m31, M31Trait};

#[derive(Copy, Drop, Debug, PartialEq)]
pub struct CM31 {
Expand All @@ -14,34 +15,104 @@ pub impl CM31Impl of CM31Trait {
let denom_inverse: M31 = (self.a * self.a + self.b * self.b).inverse();
CM31 { a: self.a * denom_inverse, b: -self.b * denom_inverse }
}

/// Computes all `1/arr[i]` with a single call to `inverse()` using Montgomery batch inversion.
fn batch_inverse(arr: Array<CM31>) -> Array<CM31> {
// Construct array `1, z, zy, ..., zy..b`.
let mut prefix_product_rev = array![];
let mut cumulative_product: CM31 = One::one();

let mut i = arr.len();
while i != 0 {
i -= 1;
prefix_product_rev.append(cumulative_product);
cumulative_product *= *arr[i];
};

// Compute `1/zy..a`.
let mut cumulative_product_inv = cumulative_product.inverse();
// Compute all `1/a = zy..b/zy..a, 1/b = zy..c/zy..b, ...`.
let mut inverses = array![];

let mut i = prefix_product_rev.len();
for v in arr {
i -= 1;
inverses.append(cumulative_product_inv * *prefix_product_rev[i]);
cumulative_product_inv *= v;
};

inverses
}

// TODO(andrew): When associated types are supported, support `Mul<CM31, M31>`.
#[inline]
fn mul_m31(self: CM31, rhs: M31) -> CM31 {
CM31 { a: self.a * rhs, b: self.b * rhs }
}

// TODO(andrew): When associated types are supported, support `Sub<CM31, M31>`.
#[inline]
fn sub_m31(self: CM31, rhs: M31) -> CM31 {
CM31 { a: self.a - rhs, b: self.b }
}
}

pub impl CM31Add of core::traits::Add<CM31> {
#[inline]
fn add(lhs: CM31, rhs: CM31) -> CM31 {
CM31 { a: lhs.a + rhs.a, b: lhs.b + rhs.b }
}
}

pub impl CM31Sub of core::traits::Sub<CM31> {
#[inline]
fn sub(lhs: CM31, rhs: CM31) -> CM31 {
CM31 { a: lhs.a - rhs.a, b: lhs.b - rhs.b }
}
}

pub impl CM31Mul of core::traits::Mul<CM31> {
#[inline]
fn mul(lhs: CM31, rhs: CM31) -> CM31 {
CM31 { a: lhs.a * rhs.a - lhs.b * rhs.b, b: lhs.a * rhs.b + lhs.b * rhs.a }
}
}

pub impl CM31AddAssign of AddAssign<CM31, CM31> {
#[inline]
fn add_assign(ref self: CM31, rhs: CM31) {
self = self + rhs
}
}

pub impl CM31SubAssign of SubAssign<CM31, CM31> {
#[inline]
fn sub_assign(ref self: CM31, rhs: CM31) {
self = self - rhs
}
}

pub impl CM31MulAssign of MulAssign<CM31, CM31> {
#[inline]
fn mul_assign(ref self: CM31, rhs: CM31) {
self = self * rhs
}
}

pub impl CM31Zero of Zero<CM31> {
fn zero() -> CM31 {
cm31(0, 0)
}

fn is_zero(self: @CM31) -> bool {
(*self).a.is_zero() && (*self).b.is_zero()
}

fn is_non_zero(self: @CM31) -> bool {
(*self).a.is_non_zero() || (*self).b.is_non_zero()
}
}

pub impl CM31One of One<CM31> {
fn one() -> CM31 {
cm31(1, 0)
Expand All @@ -53,17 +124,28 @@ pub impl CM31One of One<CM31> {
(*self).a.is_non_one() || (*self).b.is_non_zero()
}
}

pub impl M31IntoCM31 of core::traits::Into<M31, CM31> {
#[inline]
fn into(self: M31) -> CM31 {
CM31 { a: self, b: m31(0) }
}
}

pub impl CM31Neg of Neg<CM31> {
#[inline]
fn neg(a: CM31) -> CM31 {
CM31 { a: -a.a, b: -a.b }
}
}

impl CM31PartialOrd of PartialOrd<CM31> {
fn lt(lhs: CM31, rhs: CM31) -> bool {
lhs.a < rhs.a || (lhs.a == rhs.a && lhs.b < rhs.b)
}
}

#[inline]
pub fn cm31(a: u32, b: u32) -> CM31 {
CM31 { a: m31(a), b: m31(b) }
}
Expand Down
92 changes: 92 additions & 0 deletions stwo_cairo_verifier/src/fields/m31.cairo
Original file line number Diff line number Diff line change
@@ -1,28 +1,45 @@
use core::num::traits::{WideMul, CheckedSub};
use core::ops::{AddAssign, MulAssign, SubAssign};
use core::option::OptionTrait;
use core::traits::TryInto;

/// Equals `2^31 - 1`.
pub const P: u32 = 0x7fffffff;

/// Equals `2^31 - 1`.
const P32NZ: NonZero<u32> = 0x7fffffff;

/// Equals `2^31 - 1`.
const P64NZ: NonZero<u64> = 0x7fffffff;

/// Equals `2^31 - 1`.
const P128NZ: NonZero<u128> = 0x7fffffff;

#[derive(Copy, Drop, Debug, PartialEq)]
pub struct M31 {
pub inner: u32
}

#[generate_trait]
pub impl M31Impl of M31Trait {
#[inline]
fn reduce_u32(val: u32) -> M31 {
let (_, res) = core::integer::u32_safe_divmod(val, P32NZ);
M31 { inner: res.try_into().unwrap() }
}

#[inline]
fn reduce_u64(val: u64) -> M31 {
let (_, res) = core::integer::u64_safe_divmod(val, P64NZ);
M31 { inner: res.try_into().unwrap() }
}

#[inline]
fn reduce_u128(val: u128) -> M31 {
let (_, res) = core::integer::u128_safe_divmod(val, P128NZ);
M31 { inner: res.try_into().unwrap() }
}

#[inline]
fn sqn(v: M31, n: usize) -> M31 {
if n == 0 {
Expand All @@ -43,45 +60,81 @@ pub impl M31Impl of M31Trait {
}
}
pub impl M31Add of core::traits::Add<M31> {
#[inline]
fn add(lhs: M31, rhs: M31) -> M31 {
let res = lhs.inner + rhs.inner;
let res = res.checked_sub(P).unwrap_or(res);
M31 { inner: res }
}
}

pub impl M31Sub of core::traits::Sub<M31> {
#[inline]
fn sub(lhs: M31, rhs: M31) -> M31 {
lhs + (-rhs)
}
}

pub impl M31Mul of core::traits::Mul<M31> {
#[inline]
fn mul(lhs: M31, rhs: M31) -> M31 {
M31Impl::reduce_u64(lhs.inner.wide_mul(rhs.inner))
}
}

pub impl M31AddAssign of AddAssign<M31, M31> {
#[inline]
fn add_assign(ref self: M31, rhs: M31) {
self = self + rhs
}
}

pub impl M31SubAssign of SubAssign<M31, M31> {
#[inline]
fn sub_assign(ref self: M31, rhs: M31) {
self = self - rhs
}
}

pub impl M31MulAssign of MulAssign<M31, M31> {
#[inline]
fn mul_assign(ref self: M31, rhs: M31) {
self = self * rhs
}
}

pub impl M31Zero of core::num::traits::Zero<M31> {
#[inline]
fn zero() -> M31 {
M31 { inner: 0 }
}

fn is_zero(self: @M31) -> bool {
*self.inner == 0
}

fn is_non_zero(self: @M31) -> bool {
*self.inner != 0
}
}

pub impl M31One of core::num::traits::One<M31> {
#[inline]
fn one() -> M31 {
M31 { inner: 1 }
}

fn is_one(self: @M31) -> bool {
*self.inner == 1
}

fn is_non_one(self: @M31) -> bool {
*self.inner != 1
}
}

pub impl M31Neg of Neg<M31> {
#[inline]
fn neg(a: M31) -> M31 {
if a.inner == 0 {
M31 { inner: 0 }
Expand All @@ -90,16 +143,55 @@ pub impl M31Neg of Neg<M31> {
}
}
}

impl M31IntoFelt252 of Into<M31, felt252> {
#[inline]
fn into(self: M31) -> felt252 {
self.inner.into()
}
}

impl M31PartialOrd of PartialOrd<M31> {
fn ge(lhs: M31, rhs: M31) -> bool {
lhs.inner >= rhs.inner
}

fn lt(lhs: M31, rhs: M31) -> bool {
lhs.inner < rhs.inner
}
}

#[inline]
pub fn m31(val: u32) -> M31 {
M31Impl::reduce_u32(val)
}

#[derive(Copy, Drop, Debug)]
pub struct UnreducedM31 {
pub inner: felt252,
}

pub impl UnreducedM31Sub of Sub<UnreducedM31> {
#[inline]
fn sub(lhs: UnreducedM31, rhs: UnreducedM31) -> UnreducedM31 {
UnreducedM31 { inner: lhs.inner - rhs.inner }
}
}

pub impl UnreducedM31Add of Add<UnreducedM31> {
#[inline]
fn add(lhs: UnreducedM31, rhs: UnreducedM31) -> UnreducedM31 {
UnreducedM31 { inner: lhs.inner + rhs.inner }
}
}

impl M31IntoUnreducedM31 of Into<M31, UnreducedM31> {
#[inline]
fn into(self: M31) -> UnreducedM31 {
UnreducedM31 { inner: self.inner.into() }
}
}

#[cfg(test)]
mod tests {
use super::{m31, P, M31Trait};
Expand Down
Loading

0 comments on commit 8f2bf7f

Please sign in to comment.