Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add unreduced field arithmetic #116

Merged
merged 1 commit into from
Oct 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 83 additions & 1 deletion stwo_cairo_verifier/src/fields/cm31.cairo
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use core::num::traits::{One, Zero};
use super::m31::{M31, m31, M31Trait};
use core::ops::{AddAssign, MulAssign, SubAssign};
use super::m31::{M31, M31Impl, m31, M31Trait};

#[derive(Copy, Drop, Debug, PartialEq)]
pub struct CM31 {
Expand All @@ -14,34 +15,104 @@ pub impl CM31Impl of CM31Trait {
let denom_inverse: M31 = (self.a * self.a + self.b * self.b).inverse();
CM31 { a: self.a * denom_inverse, b: -self.b * denom_inverse }
}

/// Computes all `1/arr[i]` with a single call to `inverse()` using Montgomery batch inversion.
fn batch_inverse(arr: Array<CM31>) -> Array<CM31> {
// Construct array `1, z, zy, ..., zy..b`.
let mut prefix_product_rev = array![];
let mut cumulative_product: CM31 = One::one();

let mut i = arr.len();
while i != 0 {
i -= 1;
prefix_product_rev.append(cumulative_product);
cumulative_product *= *arr[i];
};

// Compute `1/zy..a`.
let mut cumulative_product_inv = cumulative_product.inverse();
// Compute all `1/a = zy..b/zy..a, 1/b = zy..c/zy..b, ...`.
let mut inverses = array![];

let mut i = prefix_product_rev.len();
for v in arr {
i -= 1;
inverses.append(cumulative_product_inv * *prefix_product_rev[i]);
cumulative_product_inv *= v;
};

inverses
}

// TODO(andrew): When associated types are supported, support `Mul<CM31, M31>`.
#[inline]
fn mul_m31(self: CM31, rhs: M31) -> CM31 {
CM31 { a: self.a * rhs, b: self.b * rhs }
}

// TODO(andrew): When associated types are supported, support `Sub<CM31, M31>`.
#[inline]
fn sub_m31(self: CM31, rhs: M31) -> CM31 {
CM31 { a: self.a - rhs, b: self.b }
}
}

pub impl CM31Add of core::traits::Add<CM31> {
#[inline]
fn add(lhs: CM31, rhs: CM31) -> CM31 {
CM31 { a: lhs.a + rhs.a, b: lhs.b + rhs.b }
}
}

pub impl CM31Sub of core::traits::Sub<CM31> {
#[inline]
fn sub(lhs: CM31, rhs: CM31) -> CM31 {
CM31 { a: lhs.a - rhs.a, b: lhs.b - rhs.b }
}
}

pub impl CM31Mul of core::traits::Mul<CM31> {
#[inline]
fn mul(lhs: CM31, rhs: CM31) -> CM31 {
CM31 { a: lhs.a * rhs.a - lhs.b * rhs.b, b: lhs.a * rhs.b + lhs.b * rhs.a }
}
}

pub impl CM31AddAssign of AddAssign<CM31, CM31> {
#[inline]
fn add_assign(ref self: CM31, rhs: CM31) {
self = self + rhs
}
}

pub impl CM31SubAssign of SubAssign<CM31, CM31> {
#[inline]
fn sub_assign(ref self: CM31, rhs: CM31) {
self = self - rhs
}
}

pub impl CM31MulAssign of MulAssign<CM31, CM31> {
#[inline]
fn mul_assign(ref self: CM31, rhs: CM31) {
self = self * rhs
}
}

pub impl CM31Zero of Zero<CM31> {
fn zero() -> CM31 {
cm31(0, 0)
}

fn is_zero(self: @CM31) -> bool {
(*self).a.is_zero() && (*self).b.is_zero()
}

fn is_non_zero(self: @CM31) -> bool {
(*self).a.is_non_zero() || (*self).b.is_non_zero()
}
}

pub impl CM31One of One<CM31> {
fn one() -> CM31 {
cm31(1, 0)
Expand All @@ -53,17 +124,28 @@ pub impl CM31One of One<CM31> {
(*self).a.is_non_one() || (*self).b.is_non_zero()
}
}

pub impl M31IntoCM31 of core::traits::Into<M31, CM31> {
#[inline]
fn into(self: M31) -> CM31 {
CM31 { a: self, b: m31(0) }
}
}

pub impl CM31Neg of Neg<CM31> {
#[inline]
fn neg(a: CM31) -> CM31 {
CM31 { a: -a.a, b: -a.b }
}
}

impl CM31PartialOrd of PartialOrd<CM31> {
fn lt(lhs: CM31, rhs: CM31) -> bool {
lhs.a < rhs.a || (lhs.a == rhs.a && lhs.b < rhs.b)
}
}

#[inline]
pub fn cm31(a: u32, b: u32) -> CM31 {
CM31 { a: m31(a), b: m31(b) }
}
Expand Down
92 changes: 92 additions & 0 deletions stwo_cairo_verifier/src/fields/m31.cairo
Original file line number Diff line number Diff line change
@@ -1,28 +1,45 @@
use core::num::traits::{WideMul, CheckedSub};
use core::ops::{AddAssign, MulAssign, SubAssign};
use core::option::OptionTrait;
use core::traits::TryInto;

/// Equals `2^31 - 1`.
pub const P: u32 = 0x7fffffff;

/// Equals `2^31 - 1`.
const P32NZ: NonZero<u32> = 0x7fffffff;

/// Equals `2^31 - 1`.
const P64NZ: NonZero<u64> = 0x7fffffff;

/// Equals `2^31 - 1`.
const P128NZ: NonZero<u128> = 0x7fffffff;

#[derive(Copy, Drop, Debug, PartialEq)]
pub struct M31 {
pub inner: u32
}

#[generate_trait]
pub impl M31Impl of M31Trait {
#[inline]
fn reduce_u32(val: u32) -> M31 {
let (_, res) = core::integer::u32_safe_divmod(val, P32NZ);
M31 { inner: res.try_into().unwrap() }
}

#[inline]
fn reduce_u64(val: u64) -> M31 {
let (_, res) = core::integer::u64_safe_divmod(val, P64NZ);
M31 { inner: res.try_into().unwrap() }
}

#[inline]
fn reduce_u128(val: u128) -> M31 {
let (_, res) = core::integer::u128_safe_divmod(val, P128NZ);
M31 { inner: res.try_into().unwrap() }
}

#[inline]
fn sqn(v: M31, n: usize) -> M31 {
if n == 0 {
Expand All @@ -43,45 +60,81 @@ pub impl M31Impl of M31Trait {
}
}
pub impl M31Add of core::traits::Add<M31> {
#[inline]
fn add(lhs: M31, rhs: M31) -> M31 {
let res = lhs.inner + rhs.inner;
let res = res.checked_sub(P).unwrap_or(res);
M31 { inner: res }
}
}

pub impl M31Sub of core::traits::Sub<M31> {
#[inline]
fn sub(lhs: M31, rhs: M31) -> M31 {
lhs + (-rhs)
}
}

pub impl M31Mul of core::traits::Mul<M31> {
#[inline]
fn mul(lhs: M31, rhs: M31) -> M31 {
M31Impl::reduce_u64(lhs.inner.wide_mul(rhs.inner))
}
}

pub impl M31AddAssign of AddAssign<M31, M31> {
#[inline]
fn add_assign(ref self: M31, rhs: M31) {
self = self + rhs
}
}

pub impl M31SubAssign of SubAssign<M31, M31> {
#[inline]
fn sub_assign(ref self: M31, rhs: M31) {
self = self - rhs
}
}

pub impl M31MulAssign of MulAssign<M31, M31> {
#[inline]
fn mul_assign(ref self: M31, rhs: M31) {
self = self * rhs
}
}

pub impl M31Zero of core::num::traits::Zero<M31> {
#[inline]
fn zero() -> M31 {
M31 { inner: 0 }
}

fn is_zero(self: @M31) -> bool {
*self.inner == 0
}

fn is_non_zero(self: @M31) -> bool {
*self.inner != 0
}
}

pub impl M31One of core::num::traits::One<M31> {
#[inline]
fn one() -> M31 {
M31 { inner: 1 }
}

fn is_one(self: @M31) -> bool {
*self.inner == 1
}

fn is_non_one(self: @M31) -> bool {
*self.inner != 1
}
}

pub impl M31Neg of Neg<M31> {
#[inline]
fn neg(a: M31) -> M31 {
if a.inner == 0 {
M31 { inner: 0 }
Expand All @@ -90,16 +143,55 @@ pub impl M31Neg of Neg<M31> {
}
}
}

impl M31IntoFelt252 of Into<M31, felt252> {
#[inline]
fn into(self: M31) -> felt252 {
self.inner.into()
}
}

impl M31PartialOrd of PartialOrd<M31> {
fn ge(lhs: M31, rhs: M31) -> bool {
lhs.inner >= rhs.inner
}

fn lt(lhs: M31, rhs: M31) -> bool {
lhs.inner < rhs.inner
}
}

#[inline]
pub fn m31(val: u32) -> M31 {
M31Impl::reduce_u32(val)
}

#[derive(Copy, Drop, Debug)]
pub struct UnreducedM31 {
pub inner: felt252,
}

pub impl UnreducedM31Sub of Sub<UnreducedM31> {
#[inline]
fn sub(lhs: UnreducedM31, rhs: UnreducedM31) -> UnreducedM31 {
UnreducedM31 { inner: lhs.inner - rhs.inner }
}
}

pub impl UnreducedM31Add of Add<UnreducedM31> {
#[inline]
fn add(lhs: UnreducedM31, rhs: UnreducedM31) -> UnreducedM31 {
UnreducedM31 { inner: lhs.inner + rhs.inner }
}
}

impl M31IntoUnreducedM31 of Into<M31, UnreducedM31> {
#[inline]
fn into(self: M31) -> UnreducedM31 {
UnreducedM31 { inner: self.inner.into() }
}
}

#[cfg(test)]
mod tests {
use super::{m31, P, M31Trait};
Expand Down
Loading
Loading