From 32c2f420cdaea57ad7dda6fa371a4cc7d07a47c0 Mon Sep 17 00:00:00 2001 From: Ohad Agadi Date: Wed, 13 Nov 2024 10:01:24 +0200 Subject: [PATCH] packed types for generic opcode --- stwo_cairo_prover/Cargo.lock | 2 + .../crates/prover_types/Cargo.toml | 4 + .../crates/prover_types/src/cpu.rs | 3 +- .../crates/prover_types/src/simd.rs | 122 +++++++++++++++++- 4 files changed, 126 insertions(+), 5 deletions(-) diff --git a/stwo_cairo_prover/Cargo.lock b/stwo_cairo_prover/Cargo.lock index 4140f6e1..9c4653c9 100644 --- a/stwo_cairo_prover/Cargo.lock +++ b/stwo_cairo_prover/Cargo.lock @@ -1186,6 +1186,8 @@ version = "0.1.0" dependencies = [ "bytemuck", "itertools 0.12.1", + "num-traits", + "rand", "ruint", "serde", "starknet-ff", diff --git a/stwo_cairo_prover/crates/prover_types/Cargo.toml b/stwo_cairo_prover/crates/prover_types/Cargo.toml index 53409ad7..5caf0459 100644 --- a/stwo_cairo_prover/crates/prover_types/Cargo.toml +++ b/stwo_cairo_prover/crates/prover_types/Cargo.toml @@ -10,3 +10,7 @@ ruint.workspace = true serde.workspace = true itertools.workspace = true bytemuck.workspace = true +num-traits.workspace = true + +[dev-dependencies] +rand.workspace = true diff --git a/stwo_cairo_prover/crates/prover_types/src/cpu.rs b/stwo_cairo_prover/crates/prover_types/src/cpu.rs index 7f6b8443..3c1e0efe 100644 --- a/stwo_cairo_prover/crates/prover_types/src/cpu.rs +++ b/stwo_cairo_prover/crates/prover_types/src/cpu.rs @@ -433,8 +433,7 @@ impl Felt252 { M31::from_u32_unchecked(value) } - pub fn from_limbs(felts: Vec) -> Self { - assert!(felts.len() <= FELT252_N_WORDS, "Invalid number of felts"); + pub fn from_limbs(felts: &[M31; FELT252_N_WORDS]) -> Self { let mut limbs = [0u64; 4]; for (index, felt) in felts.iter().enumerate() { let shift = FELT252_BITS_PER_WORD * index; diff --git a/stwo_cairo_prover/crates/prover_types/src/simd.rs b/stwo_cairo_prover/crates/prover_types/src/simd.rs index 89cb1d43..4d12a104 100644 --- a/stwo_cairo_prover/crates/prover_types/src/simd.rs +++ b/stwo_cairo_prover/crates/prover_types/src/simd.rs @@ -1,5 +1,5 @@ use std::mem::transmute; -use std::ops::{Add, BitAnd, BitOr, BitXor, Rem, Shl, Shr}; +use std::ops::{Add, BitAnd, BitOr, BitXor, Mul, Rem, Shl, Shr, Sub}; use std::simd::num::SimdUint; use std::simd::Simd; @@ -10,7 +10,7 @@ use stwo_prover::core::backend::simd::m31::PackedM31; use stwo_prover::core::fields::m31; use super::cpu::{UInt16, UInt32, UInt64, PRIME}; -use crate::cpu::CasmState; +use crate::cpu::{CasmState, Felt252}; pub const LOG_N_LANES: u32 = 4; @@ -171,6 +171,24 @@ impl PackedUInt32 { pub fn in_m31_range(&self) -> bool { all(self.as_array(), |v| v.value < m31::P) } + + pub fn from_m31(val: PackedM31) -> Self { + Self { + simd: val.into_simd(), + } + } + + pub fn low(&self) -> PackedUInt16 { + PackedUInt16 { + value: (self.simd & Simd::splat(0xFFFF)).cast(), + } + } + + pub fn high(&self) -> PackedUInt16 { + PackedUInt16 { + value: (self.simd >> 16).cast(), + } + } } impl Rem for PackedUInt32 { @@ -337,7 +355,8 @@ impl BitXor for PackedUInt64 { pub const N_M31_IN_FELT252: usize = 28; -// TODO(Ohad): implement ops and change to non-redundant representation. +use num_traits::identities::Zero; +// TODO(Ohad): Change to non-redundant representation. #[derive(Copy, Clone, Debug)] pub struct PackedFelt252 { pub value: [PackedM31; N_M31_IN_FELT252], @@ -346,6 +365,67 @@ impl PackedFelt252 { pub fn get_m31(&self, index: usize) -> PackedM31 { self.value[index] } + + pub fn from_limbs(limbs: [PackedM31; N_M31_IN_FELT252]) -> Self { + Self { value: limbs } + } + + pub fn from_m31(val: PackedM31) -> Self { + Self { + value: std::array::from_fn(|i| if i == 0 { val } else { PackedM31::zero() }), + } + } + + pub fn to_array(&self) -> [Felt252; N_LANES] { + let unpacked_limbs = self.value.unpack(); + unpacked_limbs.map(|limbs| Felt252::from_limbs(&limbs)) + } + + pub fn from_array(arr: &[Felt252; N_LANES]) -> Self { + let limbs = arr.map(|felt| std::array::from_fn(|i| felt.get_m31(i))); + let limbs = <_ as Pack>::pack(limbs); + Self::from_limbs(limbs) + } +} + +// TODO(Ohad): These are very slow, optimize. +impl Add for PackedFelt252 { + type Output = Self; + + fn add(self, rhs: Self) -> Self::Output { + let lhs = self.to_array(); + let rhs = rhs.to_array(); + let result = std::array::from_fn(|i| lhs[i] + rhs[i]); + Self::from_array(&result) + } +} + +impl Sub for PackedFelt252 { + type Output = Self; + + fn sub(self, rhs: Self) -> Self::Output { + let lhs = self.to_array(); + let rhs = rhs.to_array(); + let result = std::array::from_fn(|i| lhs[i] - rhs[i]); + Self::from_array(&result) + } +} + +impl Mul for PackedFelt252 { + type Output = Self; + + fn mul(self, rhs: Self) -> Self::Output { + let lhs = self.to_array(); + let rhs = rhs.to_array(); + let result = std::array::from_fn(|i| lhs[i] * rhs[i]); + Self::from_array(&result) + } +} + +impl PartialEq for PackedFelt252 { + fn eq(&self, other: &Self) -> bool { + self.to_array() == other.to_array() + } } impl AsRef<[PackedM31; N_M31_IN_FELT252]> for PackedFelt252 { @@ -397,3 +477,39 @@ impl Unpack for PackedCasmState { }) } } + +#[cfg(test)] +mod tests { + use rand::rngs::SmallRng; + use rand::{Rng, SeedableRng}; + + use super::*; + + #[test] + fn test_packed_f252_ops() { + let mut rng = SmallRng::seed_from_u64(0u64); + let mut rand_f252 = || -> Felt252 { + Felt252 { + limbs: [rng.gen(), rng.gen(), rng.gen(), 0], + } + }; + let mut rand_packed_f252 = || -> PackedFelt252 { + PackedFelt252::from_array(&std::array::from_fn(|_| rand_f252())) + }; + let a = rand_packed_f252(); + let b = rand_packed_f252(); + let unpacked_a = a.to_array(); + let unpacked_b = b.to_array(); + let expected_add = std::array::from_fn(|i| unpacked_a[i] + unpacked_b[i]); + let expected_sub = std::array::from_fn(|i| unpacked_a[i] - unpacked_b[i]); + let expected_mul = std::array::from_fn(|i| unpacked_a[i] * unpacked_b[i]); + + let add = a + b; + let sub = a - b; + let mul = a * b; + + assert_eq!(add.to_array(), expected_add); + assert_eq!(sub.to_array(), expected_sub); + assert_eq!(mul.to_array(), expected_mul); + } +}