Skip to content

Commit

Permalink
packed types for generic opcode
Browse files Browse the repository at this point in the history
  • Loading branch information
ohad-starkware committed Nov 18, 2024
1 parent db85303 commit 32c2f42
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 5 deletions.
2 changes: 2 additions & 0 deletions stwo_cairo_prover/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions stwo_cairo_prover/crates/prover_types/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,7 @@ ruint.workspace = true
serde.workspace = true
itertools.workspace = true
bytemuck.workspace = true
num-traits.workspace = true

[dev-dependencies]
rand.workspace = true
3 changes: 1 addition & 2 deletions stwo_cairo_prover/crates/prover_types/src/cpu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -433,8 +433,7 @@ impl Felt252 {
M31::from_u32_unchecked(value)
}

pub fn from_limbs(felts: Vec<M31>) -> Self {
assert!(felts.len() <= FELT252_N_WORDS, "Invalid number of felts");
pub fn from_limbs(felts: &[M31; FELT252_N_WORDS]) -> Self {
let mut limbs = [0u64; 4];
for (index, felt) in felts.iter().enumerate() {
let shift = FELT252_BITS_PER_WORD * index;
Expand Down
122 changes: 119 additions & 3 deletions stwo_cairo_prover/crates/prover_types/src/simd.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use std::mem::transmute;
use std::ops::{Add, BitAnd, BitOr, BitXor, Rem, Shl, Shr};
use std::ops::{Add, BitAnd, BitOr, BitXor, Mul, Rem, Shl, Shr, Sub};
use std::simd::num::SimdUint;
use std::simd::Simd;

Expand All @@ -10,7 +10,7 @@ use stwo_prover::core::backend::simd::m31::PackedM31;
use stwo_prover::core::fields::m31;

use super::cpu::{UInt16, UInt32, UInt64, PRIME};
use crate::cpu::CasmState;
use crate::cpu::{CasmState, Felt252};

pub const LOG_N_LANES: u32 = 4;

Expand Down Expand Up @@ -171,6 +171,24 @@ impl PackedUInt32 {
pub fn in_m31_range(&self) -> bool {
all(self.as_array(), |v| v.value < m31::P)
}

pub fn from_m31(val: PackedM31) -> Self {
Self {
simd: val.into_simd(),
}
}

pub fn low(&self) -> PackedUInt16 {
PackedUInt16 {
value: (self.simd & Simd::splat(0xFFFF)).cast(),
}
}

pub fn high(&self) -> PackedUInt16 {
PackedUInt16 {
value: (self.simd >> 16).cast(),
}
}
}

impl Rem for PackedUInt32 {
Expand Down Expand Up @@ -337,7 +355,8 @@ impl BitXor for PackedUInt64 {

pub const N_M31_IN_FELT252: usize = 28;

// TODO(Ohad): implement ops and change to non-redundant representation.
use num_traits::identities::Zero;
// TODO(Ohad): Change to non-redundant representation.
#[derive(Copy, Clone, Debug)]
pub struct PackedFelt252 {
pub value: [PackedM31; N_M31_IN_FELT252],
Expand All @@ -346,6 +365,67 @@ impl PackedFelt252 {
pub fn get_m31(&self, index: usize) -> PackedM31 {
self.value[index]
}

pub fn from_limbs(limbs: [PackedM31; N_M31_IN_FELT252]) -> Self {
Self { value: limbs }
}

pub fn from_m31(val: PackedM31) -> Self {
Self {
value: std::array::from_fn(|i| if i == 0 { val } else { PackedM31::zero() }),
}
}

pub fn to_array(&self) -> [Felt252; N_LANES] {
let unpacked_limbs = self.value.unpack();
unpacked_limbs.map(|limbs| Felt252::from_limbs(&limbs))
}

pub fn from_array(arr: &[Felt252; N_LANES]) -> Self {
let limbs = arr.map(|felt| std::array::from_fn(|i| felt.get_m31(i)));
let limbs = <_ as Pack>::pack(limbs);
Self::from_limbs(limbs)
}
}

// TODO(Ohad): These are very slow, optimize.
impl Add for PackedFelt252 {
type Output = Self;

fn add(self, rhs: Self) -> Self::Output {
let lhs = self.to_array();
let rhs = rhs.to_array();
let result = std::array::from_fn(|i| lhs[i] + rhs[i]);
Self::from_array(&result)
}
}

impl Sub for PackedFelt252 {
type Output = Self;

fn sub(self, rhs: Self) -> Self::Output {
let lhs = self.to_array();
let rhs = rhs.to_array();
let result = std::array::from_fn(|i| lhs[i] - rhs[i]);
Self::from_array(&result)
}
}

impl Mul for PackedFelt252 {
type Output = Self;

fn mul(self, rhs: Self) -> Self::Output {
let lhs = self.to_array();
let rhs = rhs.to_array();
let result = std::array::from_fn(|i| lhs[i] * rhs[i]);
Self::from_array(&result)
}
}

impl PartialEq for PackedFelt252 {
fn eq(&self, other: &Self) -> bool {
self.to_array() == other.to_array()
}
}

impl AsRef<[PackedM31; N_M31_IN_FELT252]> for PackedFelt252 {
Expand Down Expand Up @@ -397,3 +477,39 @@ impl Unpack for PackedCasmState {
})
}
}

#[cfg(test)]
mod tests {
use rand::rngs::SmallRng;
use rand::{Rng, SeedableRng};

use super::*;

#[test]
fn test_packed_f252_ops() {
let mut rng = SmallRng::seed_from_u64(0u64);
let mut rand_f252 = || -> Felt252 {
Felt252 {
limbs: [rng.gen(), rng.gen(), rng.gen(), 0],
}
};
let mut rand_packed_f252 = || -> PackedFelt252 {
PackedFelt252::from_array(&std::array::from_fn(|_| rand_f252()))
};
let a = rand_packed_f252();
let b = rand_packed_f252();
let unpacked_a = a.to_array();
let unpacked_b = b.to_array();
let expected_add = std::array::from_fn(|i| unpacked_a[i] + unpacked_b[i]);
let expected_sub = std::array::from_fn(|i| unpacked_a[i] - unpacked_b[i]);
let expected_mul = std::array::from_fn(|i| unpacked_a[i] * unpacked_b[i]);

let add = a + b;
let sub = a - b;
let mul = a * b;

assert_eq!(add.to_array(), expected_add);
assert_eq!(sub.to_array(), expected_sub);
assert_eq!(mul.to_array(), expected_mul);
}
}

0 comments on commit 32c2f42

Please sign in to comment.