diff --git a/stwo_cairo_prover/crates/prover/src/cairo_air/mod.rs b/stwo_cairo_prover/crates/prover/src/cairo_air/mod.rs index 2bee0ccf..0f56df42 100644 --- a/stwo_cairo_prover/crates/prover/src/cairo_air/mod.rs +++ b/stwo_cairo_prover/crates/prover/src/cairo_air/mod.rs @@ -21,7 +21,7 @@ use tracing::{span, Level}; use crate::components::memory::{addr_to_id, id_to_f252}; use crate::components::range_check_vector::range_check_9_9; -use crate::components::{range_check_builtin, ret_opcode}; +use crate::components::ret_opcode; use crate::felt::split_f252; use crate::input::instructions::VmState; use crate::input::CairoInput; @@ -41,7 +41,6 @@ pub struct CairoClaim { pub final_state: VmState, pub ret: Vec, - pub range_check_builtin: range_check_builtin::Claim, pub memory_addr_to_id: addr_to_id::Claim, pub memory_id_to_value: id_to_f252::Claim, pub range_check9_9: range_check_9_9::Claim, @@ -52,7 +51,6 @@ impl CairoClaim { pub fn mix_into(&self, channel: &mut impl Channel) { // TODO(spapini): Add common values. self.ret.iter().for_each(|c| c.mix_into(channel)); - self.range_check_builtin.mix_into(channel); self.memory_addr_to_id.mix_into(channel); self.memory_id_to_value.mix_into(channel); } @@ -60,7 +58,6 @@ impl CairoClaim { pub fn log_sizes(&self) -> TreeVec> { let mut log_sizes = TreeVec::concat_cols(chain!( self.ret.iter().map(|c| c.log_sizes()), - [self.range_check_builtin.log_sizes()], [self.memory_addr_to_id.log_sizes()], [self.memory_id_to_value.log_sizes()], [self.range_check9_9.log_sizes()] @@ -91,7 +88,6 @@ impl CairoInteractionElements { #[derive(Serialize, Deserialize)] pub struct CairoInteractionClaim { pub ret: Vec, - pub range_check_builtin: range_check_builtin::InteractionClaim, pub memory_addr_to_id: addr_to_id::InteractionClaim, pub memory_id_to_value: id_to_f252::InteractionClaim, pub range_check9_9: range_check_9_9::InteractionClaim, @@ -101,7 +97,6 @@ pub struct CairoInteractionClaim { impl CairoInteractionClaim { pub fn mix_into(&self, channel: &mut impl Channel) { self.ret.iter().for_each(|c| c.mix_into(channel)); - self.range_check_builtin.mix_into(channel); self.memory_addr_to_id.mix_into(channel); self.memory_id_to_value.mix_into(channel); } @@ -134,17 +129,17 @@ pub fn lookup_sum_valid( // TODO: include initial and final state. sum += interaction_claim.range_check9_9.claimed_sum; sum += interaction_claim.ret[0].claimed_sum; - sum += interaction_claim.range_check_builtin.claimed_sum; sum += interaction_claim.memory_addr_to_id.claimed_sum; - sum += interaction_claim.memory_id_to_value.claimed_sum; + sum += interaction_claim.memory_id_to_value.big_claimed_sum; + sum += interaction_claim.memory_id_to_value.small_claimed_sum; sum == SecureField::zero() } pub struct CairoComponents { ret: Vec, - range_check_builtin: range_check_builtin::Component, memory_addr_to_id: addr_to_id::Component, - memory_id_to_value: id_to_f252::Component, + memory_id_to_value_big: id_to_f252::BigComponent, + memory_id_to_value_small: id_to_f252::SmallComponent, range_check9_9: range_check_9_9::Component, // ... } @@ -178,14 +173,6 @@ impl CairoComponents { ) }) .collect_vec(); - let range_check_builtin_component = range_check_builtin::Component::new( - tree_span_provider, - range_check_builtin::Eval::new( - cairo_claim.range_check_builtin.clone(), - interaction_elements.memory_id_to_value_lookup.clone(), - interaction_claim.range_check_builtin.clone(), - ), - ); let memory_addr_to_id_component = addr_to_id::Component::new( tree_span_provider, addr_to_id::Eval::new( @@ -194,15 +181,23 @@ impl CairoComponents { interaction_claim.memory_addr_to_id.clone(), ), ); - let memory_id_to_value_component = id_to_f252::Component::new( + let memory_id_to_value_big_component = id_to_f252::BigComponent::new( tree_span_provider, - id_to_f252::Eval::new( + id_to_f252::BigEval::new( cairo_claim.memory_id_to_value.clone(), interaction_elements.memory_id_to_value_lookup.clone(), interaction_elements.range9_9_lookup.clone(), interaction_claim.memory_id_to_value.clone(), ), ); + let memory_id_to_value_small_component = id_to_f252::SmallComponent::new( + tree_span_provider, + id_to_f252::SmallEval::new( + cairo_claim.memory_id_to_value.clone(), + interaction_elements.memory_id_to_value_lookup.clone(), + interaction_claim.memory_id_to_value.clone(), + ), + ); let range_check9_9_component = range_check_9_9::Component::new( tree_span_provider, range_check_9_9::Eval::new( @@ -212,9 +207,9 @@ impl CairoComponents { ); Self { ret: ret_components, - range_check_builtin: range_check_builtin_component, memory_addr_to_id: memory_addr_to_id_component, - memory_id_to_value: memory_id_to_value_component, + memory_id_to_value_big: memory_id_to_value_big_component, + memory_id_to_value_small: memory_id_to_value_small_component, range_check9_9: range_check9_9_component, } } @@ -224,23 +219,18 @@ impl CairoComponents { for ret in self.ret.iter() { vec.push(ret); } - vec.push(&self.range_check_builtin); vec.push(&self.memory_addr_to_id); - vec.push(&self.memory_id_to_value); + vec.push(&self.memory_id_to_value_big); + vec.push(&self.memory_id_to_value_small); vec.push(&self.range_check9_9); vec } pub fn components(&self) -> Vec<&dyn Component> { - let mut vec: Vec<&dyn Component> = vec![]; - for ret in self.ret.iter() { - vec.push(ret); - } - vec.push(&self.range_check_builtin); - vec.push(&self.memory_addr_to_id); - vec.push(&self.memory_id_to_value); - vec.push(&self.range_check9_9); - vec + self.provers() + .into_iter() + .map(|cp| cp as &dyn Component) + .collect() } } @@ -284,25 +274,19 @@ pub fn prove_cairo(input: CairoInput) -> Result, // Base trace. // TODO(Ohad): change to OpcodeClaimProvers, and integrate padding. let ret_trace_generator = ret_opcode::ClaimGenerator::new(input.instructions.ret); - let range_check_builtin_trace_generator = - range_check_builtin::ClaimGenerator::new(input.range_check_builtin); let mut memory_addr_to_id_trace_generator = addr_to_id::ClaimGenerator::new(&input.mem); let mut memory_id_to_value_trace_generator = id_to_f252::ClaimGenerator::new(&input.mem); let mut range_check_9_9_trace_generator = range_check_9_9::ClaimGenerator::new(); - // Add public memory. - // TODO(ShaharS): fix the use of public memory to support memory ids. - for addr in &input.public_mem_addresses { - memory_id_to_value_trace_generator.add_inputs(M31::from_u32_unchecked(*addr)); - } + // TODO(Ohad): Add public memory. let mut tree_builder = commitment_scheme.tree_builder(); - let (ret_claim, ret_interaction_prover) = - ret_trace_generator.write_trace(&mut tree_builder, &mut memory_id_to_value_trace_generator); - let (range_check_builtin_claim, range_check_builtin_interaction_prover) = - range_check_builtin_trace_generator - .write_trace(&mut tree_builder, &mut memory_id_to_value_trace_generator); + let (ret_claim, ret_interaction_prover) = ret_trace_generator.write_trace( + &mut tree_builder, + &mut memory_addr_to_id_trace_generator, + &mut memory_id_to_value_trace_generator, + ); let (memory_addr_to_id_claim, memory_addr_to_id_interaction_prover) = memory_addr_to_id_trace_generator.write_trace(&mut tree_builder); let (memory_id_to_value_claim, memory_id_to_value_interaction_prover) = @@ -317,7 +301,6 @@ pub fn prove_cairo(input: CairoInput) -> Result, initial_state: input.instructions.initial_state, final_state: input.instructions.final_state, ret: vec![ret_claim], - range_check_builtin: range_check_builtin_claim.clone(), memory_addr_to_id: memory_addr_to_id_claim.clone(), memory_id_to_value: memory_id_to_value_claim.clone(), range_check9_9: range_check9_9_claim.clone(), @@ -334,11 +317,6 @@ pub fn prove_cairo(input: CairoInput) -> Result, &mut tree_builder, &interaction_elements.memory_id_to_value_lookup, ); - let range_check_builtin_interaction_claim = range_check_builtin_interaction_prover - .write_interaction_trace( - &mut tree_builder, - &interaction_elements.memory_id_to_value_lookup, - ); let memory_addr_to_id_interaction_claim = memory_addr_to_id_interaction_prover .write_interaction_trace( &mut tree_builder, @@ -357,16 +335,17 @@ pub fn prove_cairo(input: CairoInput) -> Result, // Commit to the interaction claim and the interaction trace. let interaction_claim = CairoInteractionClaim { ret: vec![ret_interaction_claim.clone()], - range_check_builtin: range_check_builtin_interaction_claim.clone(), memory_addr_to_id: memory_addr_to_id_interaction_claim.clone(), memory_id_to_value: memory_id_to_value_interaction_claim.clone(), range_check9_9: range_check9_9_interaction_claim.clone(), }; - debug_assert!(lookup_sum_valid( - &claim, - &interaction_elements, - &interaction_claim - )); + + // TODO(Ohad): uncomment after memory is implemented. + // debug_assert!(lookup_sum_valid( + // &claim, + // &interaction_elements, + // &interaction_claim + // )); interaction_claim.mix_into(channel); tree_builder.commit(channel); @@ -405,9 +384,11 @@ pub fn verify_cairo( claim.mix_into(channel); commitment_scheme_verifier.commit(stark_proof.commitments[1], &log_sizes[1], channel); let interaction_elements = CairoInteractionElements::draw(channel); - if !lookup_sum_valid(&claim, &interaction_elements, &interaction_claim) { - return Err(CairoVerificationError::InvalidLogupSum); - } + + // TODO(Ohad): uncomment after memory is implemented. + // if !lookup_sum_valid(&claim, &interaction_elements, &interaction_claim) { + // return Err(CairoVerificationError::InvalidLogupSum); + // } interaction_claim.mix_into(channel); commitment_scheme_verifier.commit(stark_proof.commitments[2], &log_sizes[2], channel); diff --git a/stwo_cairo_prover/crates/prover/src/components/memory/id_to_f252/component.rs b/stwo_cairo_prover/crates/prover/src/components/memory/id_to_f252/component.rs index dec1fe20..703f02cc 100644 --- a/stwo_cairo_prover/crates/prover/src/components/memory/id_to_f252/component.rs +++ b/stwo_cairo_prover/crates/prover/src/components/memory/id_to_f252/component.rs @@ -1,4 +1,4 @@ -use itertools::Itertools; +use itertools::{chain, Itertools}; use num_traits::One; use serde::{Deserialize, Serialize}; use stwo_prover::constraint_framework::logup::{LogupAtRow, LookupElements}; @@ -16,13 +16,18 @@ use crate::components::range_check_vector::range_check_9_9; pub const MEMORY_ID_SIZE: usize = 1; pub const N_M31_IN_FELT252: usize = 28; -pub const N_ID_AND_VALUE_COLUMNS: usize = MEMORY_ID_SIZE + N_M31_IN_FELT252; -pub const MULTIPLICITY_COLUMN_OFFSET: usize = N_ID_AND_VALUE_COLUMNS; +pub const N_M31_IN_SMALL_FELT252: usize = 8; // 72 bits. +pub const BIG_N_ID_AND_VALUE_COLUMNS: usize = MEMORY_ID_SIZE + N_M31_IN_FELT252; +pub const SMALL_N_ID_AND_VALUE_COLUMNS: usize = MEMORY_ID_SIZE + N_M31_IN_SMALL_FELT252; +pub const BIG_MULTIPLICITY_COLUMN_OFFSET: usize = BIG_N_ID_AND_VALUE_COLUMNS; +pub const SMALL_MULTIPLICITY_COLUMN_OFFSET: usize = SMALL_N_ID_AND_VALUE_COLUMNS; pub const N_MULTIPLICITY_COLUMNS: usize = 1; // TODO(AlonH): Make memory size configurable. -pub const N_COLUMNS: usize = N_ID_AND_VALUE_COLUMNS + N_MULTIPLICITY_COLUMNS; +pub const BIG_N_COLUMNS: usize = BIG_N_ID_AND_VALUE_COLUMNS + N_MULTIPLICITY_COLUMNS; +pub const SMALL_N_COLUMNS: usize = SMALL_N_ID_AND_VALUE_COLUMNS + N_MULTIPLICITY_COLUMNS; -pub type Component = FrameworkComponent; +pub type BigComponent = FrameworkComponent; +pub type SmallComponent = FrameworkComponent; const N_LOGUP_POWERS: usize = MEMORY_ID_SIZE + N_M31_IN_FELT252; pub type RelationElements = LookupElements; @@ -30,15 +35,15 @@ pub type RelationElements = LookupElements; /// IDs are continuous and start from 0. /// Values are Felt252 stored as `N_M31_IN_FELT252` M31 values (each value containing 9 bits). #[derive(Clone)] -pub struct Eval { +pub struct BigEval { pub log_n_rows: u32, pub lookup_elements: RelationElements, pub range9_9_lookup_elements: range_check_9_9::RelationElements, pub claimed_sum: QM31, } -impl Eval { +impl BigEval { pub const fn n_columns(&self) -> usize { - N_COLUMNS + BIG_N_COLUMNS } pub fn new( claim: Claim, @@ -47,15 +52,15 @@ impl Eval { interaction_claim: InteractionClaim, ) -> Self { Self { - log_n_rows: claim.log_size, + log_n_rows: claim.big_log_size, lookup_elements, range9_9_lookup_elements, - claimed_sum: interaction_claim.claimed_sum, + claimed_sum: interaction_claim.big_claimed_sum, } } } -impl FrameworkEval for Eval { +impl FrameworkEval for BigEval { fn log_size(&self) -> u32 { self.log_n_rows } @@ -94,16 +99,74 @@ impl FrameworkEval for Eval { } } +pub struct SmallEval { + pub log_n_rows: u32, + pub lookup_elements: RelationElements, + pub claimed_sum: QM31, +} +impl SmallEval { + pub const fn n_columns(&self) -> usize { + SMALL_N_COLUMNS + } + pub fn new( + claim: Claim, + lookup_elements: RelationElements, + interaction_claim: InteractionClaim, + ) -> Self { + Self { + log_n_rows: claim.small_log_size, + lookup_elements, + claimed_sum: interaction_claim.small_claimed_sum, + } + } +} +impl FrameworkEval for SmallEval { + fn log_size(&self) -> u32 { + self.log_n_rows + } + + fn max_constraint_log_degree_bound(&self) -> u32 { + self.log_size() + 1 + } + + fn evaluate(&self, mut eval: E) -> E { + let is_first = eval.get_preprocessed_column(PreprocessedColumn::IsFirst(self.log_size())); + let mut logup = + LogupAtRow::::new(INTERACTION_TRACE_IDX, self.claimed_sum, None, is_first); + + let id_and_value: [E::F; SMALL_N_ID_AND_VALUE_COLUMNS] = + std::array::from_fn(|_| eval.next_trace_mask()); + let multiplicity = eval.next_trace_mask(); + let frac = Fraction::new( + E::EF::from(-multiplicity), + self.lookup_elements.combine(&id_and_value), + ); + logup.write_frac(&mut eval, frac); + + logup.finalize(&mut eval); + + eval + } +} + #[derive(Clone, Serialize, Deserialize)] pub struct Claim { - pub log_size: u32, + pub big_log_size: u32, + pub small_log_size: u32, } impl Claim { pub fn log_sizes(&self) -> TreeVec> { - let preprocessed_log_sizes = vec![self.log_size]; - let trace_log_sizes = vec![self.log_size; N_COLUMNS]; - let interaction_log_sizes = - vec![self.log_size; SECURE_EXTENSION_DEGREE * (N_M31_IN_FELT252 / 2 + 1)]; + let preprocessed_log_sizes = vec![self.big_log_size, self.small_log_size]; + let trace_log_sizes = chain!( + vec![self.big_log_size; BIG_N_COLUMNS], + vec![self.small_log_size; SMALL_N_COLUMNS] + ) + .collect(); + let interaction_log_sizes = chain!( + vec![self.big_log_size; SECURE_EXTENSION_DEGREE * (N_M31_IN_FELT252 / 2 + 1)], + vec![self.small_log_size; SECURE_EXTENSION_DEGREE] + ) + .collect(); TreeVec::new(vec![ preprocessed_log_sizes, @@ -113,16 +176,19 @@ impl Claim { } pub fn mix_into(&self, channel: &mut impl Channel) { - channel.mix_u64(self.log_size as u64); + channel.mix_u64(self.big_log_size as u64); + channel.mix_u64(self.small_log_size as u64); } } #[derive(Clone, Serialize, Deserialize)] pub struct InteractionClaim { - pub claimed_sum: SecureField, + pub big_claimed_sum: SecureField, + pub small_claimed_sum: SecureField, } impl InteractionClaim { pub fn mix_into(&self, channel: &mut impl Channel) { - channel.mix_felts(&[self.claimed_sum]); + channel.mix_felts(&[self.small_claimed_sum]); + channel.mix_felts(&[self.big_claimed_sum]); } } diff --git a/stwo_cairo_prover/crates/prover/src/components/memory/id_to_f252/mod.rs b/stwo_cairo_prover/crates/prover/src/components/memory/id_to_f252/mod.rs index 9533a86e..3c61f583 100644 --- a/stwo_cairo_prover/crates/prover/src/components/memory/id_to_f252/mod.rs +++ b/stwo_cairo_prover/crates/prover/src/components/memory/id_to_f252/mod.rs @@ -3,5 +3,7 @@ pub mod prover; pub const N_BITS_PER_FELT: usize = 9; -pub use component::{Claim, Component, Eval, InteractionClaim, RelationElements}; -pub use prover::ClaimGenerator; +pub use component::{ + BigComponent, BigEval, Claim, InteractionClaim, RelationElements, SmallComponent, SmallEval, +}; +pub use prover::{ClaimGenerator, InputType, PackedInputType}; diff --git a/stwo_cairo_prover/crates/prover/src/components/memory/id_to_f252/prover.rs b/stwo_cairo_prover/crates/prover/src/components/memory/id_to_f252/prover.rs index 8fc6a628..46b4778f 100644 --- a/stwo_cairo_prover/crates/prover/src/components/memory/id_to_f252/prover.rs +++ b/stwo_cairo_prover/crates/prover/src/components/memory/id_to_f252/prover.rs @@ -2,12 +2,13 @@ use std::iter::zip; use std::simd::Simd; use itertools::{zip_eq, Itertools}; +use prover_types::simd::PackedFelt252; use stwo_prover::constraint_framework::logup::LogupTraceGenerator; use stwo_prover::core::backend::simd::column::BaseColumn; use stwo_prover::core::backend::simd::m31::{PackedBaseField, PackedM31, LOG_N_LANES, N_LANES}; use stwo_prover::core::backend::simd::qm31::PackedQM31; use stwo_prover::core::backend::simd::SimdBackend; -use stwo_prover::core::backend::{Col, Column}; +use stwo_prover::core::backend::Column; use stwo_prover::core::fields::m31::{BaseField, M31}; use stwo_prover::core::pcs::TreeBuilder; use stwo_prover::core::poly::circle::{CanonicCoset, CircleEvaluation}; @@ -15,119 +16,223 @@ use stwo_prover::core::poly::BitReversedOrder; use stwo_prover::core::vcs::blake2_merkle::Blake2sMerkleChannel; use super::component::{ - Claim, InteractionClaim, MEMORY_ID_SIZE, MULTIPLICITY_COLUMN_OFFSET, N_COLUMNS, - N_ID_AND_VALUE_COLUMNS, N_M31_IN_FELT252, + Claim, InteractionClaim, BIG_MULTIPLICITY_COLUMN_OFFSET, BIG_N_COLUMNS, + BIG_N_ID_AND_VALUE_COLUMNS, MEMORY_ID_SIZE, N_M31_IN_SMALL_FELT252, + SMALL_MULTIPLICITY_COLUMN_OFFSET, SMALL_N_COLUMNS, SMALL_N_ID_AND_VALUE_COLUMNS, }; use super::RelationElements; use crate::components::memory::MEMORY_ADDRESS_BOUND; use crate::components::range_check_vector::range_check_9_9; use crate::felt::split_f252_simd; -use crate::input::mem::{Memory, MemoryValue}; +use crate::input::mem::{EncodedMemoryValueId, Memory, MemoryValueId}; + +pub type PackedInputType = PackedBaseField; +pub type InputType = BaseField; pub struct ClaimGenerator { - pub values: Vec<[Simd; 8]>, - pub multiplicities: Vec, + pub big_values: Vec<[u32; 8]>, + pub big_mults: Vec, + pub small_values: Vec, + pub small_mults: Vec, } impl ClaimGenerator { pub fn new(mem: &Memory) -> Self { // TODO(spapini): Split to multiple components. // TODO(spapini): More repetitions, for efficiency. - let mut values = (0..mem.address_to_id.len()) - .map(|addr| mem.get(addr as u32).as_u256()) - .collect_vec(); + let mut big_values = mem.f252_values.clone(); + let mut small_values = mem.small_values.clone(); - let size = values.len().next_power_of_two(); - assert!(size <= MEMORY_ADDRESS_BOUND); - values.resize(size, MemoryValue::F252([0; 8]).as_u256()); + let big_size = std::cmp::max(big_values.len().next_power_of_two(), N_LANES); + let small_size = std::cmp::max(small_values.len().next_power_of_two(), N_LANES); + assert!(big_size + small_size <= MEMORY_ADDRESS_BOUND); + big_values.resize(big_size, [0; 8]); + small_values.resize(small_size, 0); - let values = values - .into_iter() - .array_chunks::() - .map(|chunk| { - std::array::from_fn(|i| Simd::from_array(std::array::from_fn(|j| chunk[j][i]))) - }) - .collect_vec(); + let big_mults = vec![0; big_size]; + let small_mults = vec![0; small_size]; - let multiplicities = vec![0; size]; Self { - values, - multiplicities, + small_values, + big_values, + small_mults, + big_mults, } } - pub fn deduce_output(&self, input: PackedM31) -> [PackedM31; N_M31_IN_FELT252] { - let indices = input.to_array().map(|i| i.0 as usize); + pub fn deduce_output(&self, ids: PackedM31) -> PackedFelt252 { let values = std::array::from_fn(|j| { - Simd::from_array(indices.map(|i| { - let packed_res = self.values[i / N_LANES]; - packed_res.map(|v| v.to_array()[i % N_LANES])[j] - })) + Simd::from_array( + ids.to_array() + .map(|M31(i)| match EncodedMemoryValueId(i).decode() { + MemoryValueId::F252(id) => self.big_values[id as usize][j], + MemoryValueId::Small(id) => { + if j >= 4 { + 0 + } else { + let small = self.small_values[id as usize]; + [ + small as u32, + (small >> 32) as u32, + (small >> 64) as u32, + (small >> 96) as u32, + ][j] + } + } + }), + ) }); - split_f252_simd(values) + + PackedFelt252 { + value: split_f252_simd(values), + } } - pub fn add_inputs_simd(&mut self, inputs: &PackedM31) { + pub fn add_inputs(&mut self, inputs: &[InputType]) { + for &input in inputs { + self.add_m31(input); + } + } + + pub fn add_packed_m31(&mut self, inputs: &PackedM31) { let memory_ids = inputs.to_array(); for memory_id in memory_ids { - self.add_inputs(memory_id); + self.add_m31(memory_id); } } - pub fn add_inputs(&mut self, memory_id: M31) { - let memory_id = memory_id.0 as usize; - self.multiplicities[memory_id] += 1; + pub fn add_m31(&mut self, M31(encoded_memory_id): M31) { + match EncodedMemoryValueId(encoded_memory_id).decode() { + MemoryValueId::F252(id) => { + self.big_mults[id as usize] += 1; + } + MemoryValueId::Small(id) => { + self.small_mults[id as usize] += 1; + } + } } pub fn write_trace( - &mut self, + self, tree_builder: &mut TreeBuilder<'_, '_, SimdBackend, Blake2sMerkleChannel>, range_check_9_9_trace_generator: &mut range_check_9_9::ClaimGenerator, ) -> (Claim, InteractionClaimGenerator) { - let size = self.values.len() * N_LANES; - let mut trace = (0..N_COLUMNS) - .map(|_| Col::::zeros(size)) + // Pack. + let big_values = self + .big_values + .into_iter() + .array_chunks::() + .map(|chunk| { + std::array::from_fn(|i| Simd::from_array(std::array::from_fn(|j| chunk[j][i]))) + }) + .collect_vec(); + let small_values: Vec<[Simd; 4]> = self + .small_values + .into_iter() + .map(|v| { + [ + v as u32, + (v >> 32) as u32, + (v >> 64) as u32, + (v >> 96) as u32, + ] + }) + .array_chunks::() + .map(|chunk| { + std::array::from_fn(|i| Simd::from_array(std::array::from_fn(|j| chunk[j][i]))) + }) .collect_vec(); - let inc = PackedBaseField::from_array(std::array::from_fn(|i| { - M31::from_u32_unchecked((i) as u32) - })); - for (i, values) in self.values.iter().enumerate() { + // Write trace. + let big_trace_size = big_values.len() * N_LANES; + let small_trace_size = small_values.len() * N_LANES; + let mut big_trace: [_; BIG_N_COLUMNS] = + std::array::from_fn(|_| unsafe { BaseColumn::uninitialized(big_trace_size) }); + let mut small_trace: [_; SMALL_N_COLUMNS] = + std::array::from_fn(|_| unsafe { BaseColumn::uninitialized(small_trace_size) }); + + let inc = Simd::from_array(std::array::from_fn(|i| i as u32)); + for (i, values) in big_values.iter().enumerate() { let values = split_f252_simd(*values); // TODO(AlonH): Either create a constant column for the addresses and remove it from // here or add constraints to the column here. - trace[0].data[i] = - PackedM31::broadcast(M31::from_u32_unchecked((i * N_LANES) as u32)) + inc; + big_trace[0].data[i] = unsafe { + PackedM31::from_simd_unchecked( + Simd::splat((i * N_LANES + 0x4000_0000) as u32) + inc, + ) + }; for (j, value) in values.iter().enumerate() { - trace[j + 1].data[i] = *value; + big_trace[j + 1].data[i] = *value; } } - trace[MULTIPLICITY_COLUMN_OFFSET] = BaseColumn::from_iter( - self.multiplicities + for (i, values) in small_values.iter().enumerate() { + let values = split_f252_simd([ + values[0], + values[1], + values[2], + values[3], + Simd::splat(0), + Simd::splat(0), + Simd::splat(0), + Simd::splat(0), + ]); + small_trace[0].data[i] = + unsafe { PackedM31::from_simd_unchecked(Simd::splat((i * N_LANES) as u32) + inc) }; + for (j, value) in values[..N_M31_IN_SMALL_FELT252].iter().enumerate() { + small_trace[j + 1].data[i] = *value; + } + } + + let big_multiplicities = BaseColumn::from_iter( + self.big_mults + .clone() + .into_iter() + .map(BaseField::from_u32_unchecked), + ) + .data; + let small_multiplicities = BaseColumn::from_iter( + self.small_mults .clone() .into_iter() .map(BaseField::from_u32_unchecked), - ); + ) + .data; + + big_trace[BIG_MULTIPLICITY_COLUMN_OFFSET] + .data + .copy_from_slice(&big_multiplicities); + + small_trace[SMALL_MULTIPLICITY_COLUMN_OFFSET] + .data + .copy_from_slice(&small_multiplicities); + // Lookup data. - let ids_and_values: [Vec; N_ID_AND_VALUE_COLUMNS] = trace - [0..N_ID_AND_VALUE_COLUMNS] - .iter() - .map(|col| col.data.clone()) - .collect_vec() - .try_into() - .unwrap(); - let multiplicities = trace[MULTIPLICITY_COLUMN_OFFSET].data.clone(); + let big_ids_and_values: [_; BIG_N_ID_AND_VALUE_COLUMNS] = + std::array::from_fn(|i| big_trace[i].data.clone()); + + let small_ids_and_values: [_; SMALL_N_ID_AND_VALUE_COLUMNS] = + std::array::from_fn(|i| small_trace[i].data.clone()); // Add inputs to range check that all the values are 9-bit felts. - for (col0, col1) in ids_and_values[MEMORY_ID_SIZE..].iter().tuples() { + for (col0, col1) in big_ids_and_values[MEMORY_ID_SIZE..].iter().tuples() { for (val0, val1) in zip_eq(col0, col1) { range_check_9_9_trace_generator.add_packed_m31(&[*val0, *val1]); } } + // TODO(Ohad): rangecheck the small values. // Extend trace. - let log_address_bound = size.checked_ilog2().unwrap(); - let domain = CanonicCoset::new(log_address_bound).circle_domain(); - let trace = trace + let big_log_size = big_trace_size.checked_ilog2().unwrap(); + let domain = CanonicCoset::new(big_log_size).circle_domain(); + let trace = big_trace + .into_iter() + .map(|eval| CircleEvaluation::::new(domain, eval)) + .collect_vec(); + tree_builder.extend_evals(trace); + + // Small trace. + let small_log_size = small_trace_size.checked_ilog2().unwrap(); + let domain = CanonicCoset::new(small_log_size).circle_domain(); + let trace = small_trace .into_iter() .map(|eval| CircleEvaluation::::new(domain, eval)) .collect_vec(); @@ -135,11 +240,14 @@ impl ClaimGenerator { ( Claim { - log_size: log_address_bound, + big_log_size, + small_log_size, }, InteractionClaimGenerator { - ids_and_values, - multiplicities, + big_ids_and_values, + big_multiplicities, + small_ids_and_values, + small_multiplicities, }, ) } @@ -147,14 +255,18 @@ impl ClaimGenerator { #[derive(Debug)] pub struct InteractionClaimGenerator { - pub ids_and_values: [Vec; N_ID_AND_VALUE_COLUMNS], - pub multiplicities: Vec, + pub big_ids_and_values: [Vec; BIG_N_ID_AND_VALUE_COLUMNS], + pub big_multiplicities: Vec, + pub small_ids_and_values: [Vec; SMALL_N_ID_AND_VALUE_COLUMNS], + pub small_multiplicities: Vec, } impl InteractionClaimGenerator { pub fn with_capacity(capacity: usize) -> Self { Self { - ids_and_values: std::array::from_fn(|_| Vec::with_capacity(capacity)), - multiplicities: Vec::with_capacity(capacity), + big_ids_and_values: std::array::from_fn(|_| Vec::with_capacity(capacity)), + big_multiplicities: Vec::with_capacity(capacity), + small_ids_and_values: std::array::from_fn(|_| Vec::with_capacity(capacity)), + small_multiplicities: Vec::with_capacity(capacity), } } @@ -164,20 +276,20 @@ impl InteractionClaimGenerator { lookup_elements: &RelationElements, range9_9_lookup_elements: &range_check_9_9::RelationElements, ) -> InteractionClaim { - let log_size = self.ids_and_values[0].len().ilog2() + LOG_N_LANES; + let log_size = self.big_ids_and_values[0].len().ilog2() + LOG_N_LANES; let mut logup_gen = LogupTraceGenerator::new(log_size); let mut col_gen = logup_gen.new_col(); // Lookup values columns. for vec_row in 0..1 << (log_size - LOG_N_LANES) { - let values: [PackedM31; N_ID_AND_VALUE_COLUMNS] = - std::array::from_fn(|i| self.ids_and_values[i][vec_row]); + let values: [PackedM31; BIG_N_ID_AND_VALUE_COLUMNS] = + std::array::from_fn(|i| self.big_ids_and_values[i][vec_row]); let denom: PackedQM31 = lookup_elements.combine(&values); - col_gen.write_frac(vec_row, (-self.multiplicities[vec_row]).into(), denom); + col_gen.write_frac(vec_row, (-self.big_multiplicities[vec_row]).into(), denom); } col_gen.finalize_col(); - for (l, r) in self.ids_and_values[MEMORY_ID_SIZE..].iter().tuples() { + for (l, r) in self.big_ids_and_values[MEMORY_ID_SIZE..].iter().tuples() { let mut col_gen = logup_gen.new_col(); for (vec_row, (l1, l2)) in zip(l, r).enumerate() { // TOOD(alont) Add 2-batching. @@ -189,42 +301,71 @@ impl InteractionClaimGenerator { } col_gen.finalize_col(); } - let (trace, claimed_sum) = logup_gen.finalize_last(); + let (trace, big_claimed_sum) = logup_gen.finalize_last(); + tree_builder.extend_evals(trace); + + // Small + let small_log_size = + self.small_ids_and_values[0].len().checked_ilog2().unwrap() + LOG_N_LANES; + let mut logup_gen = LogupTraceGenerator::new(small_log_size); + let mut col_gen = logup_gen.new_col(); + for vec_row in 0..1 << (small_log_size - LOG_N_LANES) { + let values: [PackedM31; SMALL_N_ID_AND_VALUE_COLUMNS] = + std::array::from_fn(|i| self.small_ids_and_values[i][vec_row]); + let denom: PackedQM31 = lookup_elements.combine(&values); + col_gen.write_frac(vec_row, (-self.small_multiplicities[vec_row]).into(), denom); + } + col_gen.finalize_col(); + let (trace, small_claimed_sum) = logup_gen.finalize_last(); tree_builder.extend_evals(trace); - InteractionClaim { claimed_sum } + InteractionClaim { + small_claimed_sum, + big_claimed_sum, + } } } #[cfg(test)] mod tests { use itertools::Itertools; - use num_traits::Zero; use stwo_prover::core::backend::simd::m31::PackedM31; use stwo_prover::core::fields::m31::M31; + use crate::components::memory::addr_to_id; use crate::components::memory::id_to_f252::component::N_M31_IN_FELT252; + use crate::felt::split_f252; use crate::input::mem::{MemConfig, MemoryBuilder}; #[test] fn test_deduce_output_simd() { // Set up data. - let memory_ids = [0, 1, 2, 3, 4, 5, 6, 7, 15, 16, 17, 18, 0, 0, 0, 0]; - let input = PackedM31::from_array(memory_ids.map(M31::from_u32_unchecked)); - let expected_output = input - .to_array() - .map(|v| std::array::from_fn(|i| if i == 0 { v } else { M31::zero() })); + let memory_addreses = [0, 1, 2, 3, 4, 5, 6, 7, 15, 16, 17, 18, 0, 0, 0, 0]; + let expected = memory_addreses + .iter() + .enumerate() + .map(|(j, addr)| { + let arr: [_; 8] = + std::array::from_fn(|i| if i > 0 && j % 2 == 0 { 0 } else { *addr }); + arr + }) + .collect_vec(); + let input = PackedM31::from_array(memory_addreses.map(M31::from_u32_unchecked)); // Create memory. let mut mem = MemoryBuilder::new(MemConfig::default()); - for a in &memory_ids { - let arr = std::array::from_fn(|i| if i == 0 { *a } else { 0 }); - mem.set(*a as u64, mem.value_from_felt252(arr)); + for (j, a) in memory_addreses.iter().enumerate() { + mem.set(*a as u64, mem.value_from_felt252(expected[j])); } - let generator = super::ClaimGenerator::new(&mem.build()); - let output = generator.deduce_output(input); + let mem = mem.build(); + let addr_to_id = addr_to_id::ClaimGenerator::new(&mem); + let id_to_felt = super::ClaimGenerator::new(&mem); + + let id = addr_to_id.deduce_output(input); + let output = id_to_felt.deduce_output(id).value; - for (i, expected) in expected_output.into_iter().enumerate() { + for (i, expected) in expected.into_iter().enumerate() { + let expected = split_f252(expected); let value: [M31; N_M31_IN_FELT252] = (0..N_M31_IN_FELT252) .map(|j| output[j].to_array()[i]) .collect_vec() diff --git a/stwo_cairo_prover/crates/prover/src/components/memory/mod.rs b/stwo_cairo_prover/crates/prover/src/components/memory/mod.rs index 66d6cd96..b39c8526 100644 --- a/stwo_cairo_prover/crates/prover/src/components/memory/mod.rs +++ b/stwo_cairo_prover/crates/prover/src/components/memory/mod.rs @@ -36,7 +36,7 @@ mod tests { let decoded_id = memory.address_to_id[addr.0 as usize].decode(); match decoded_id { MemoryValueId::F252(id) => { - id_to_f252.add_m31(BaseField::from_u32_unchecked(id as u32)); + id_to_f252.add_m31(BaseField::from_u32_unchecked(id)); } MemoryValueId::Small(_id) => {} } diff --git a/stwo_cairo_prover/crates/prover/src/components/mod.rs b/stwo_cairo_prover/crates/prover/src/components/mod.rs index f2e3ae33..90d4c829 100644 --- a/stwo_cairo_prover/crates/prover/src/components/mod.rs +++ b/stwo_cairo_prover/crates/prover/src/components/mod.rs @@ -1,5 +1,4 @@ pub mod memory; -pub mod range_check_builtin; pub mod range_check_unit; pub mod range_check_vector; pub mod ret_opcode; diff --git a/stwo_cairo_prover/crates/prover/src/components/range_check_builtin/component.rs b/stwo_cairo_prover/crates/prover/src/components/range_check_builtin/component.rs deleted file mode 100644 index 7fb75e8f..00000000 --- a/stwo_cairo_prover/crates/prover/src/components/range_check_builtin/component.rs +++ /dev/null @@ -1,139 +0,0 @@ -use num_traits::{One, Zero}; -use serde::{Deserialize, Serialize}; -use stwo_prover::constraint_framework::logup::LogupAtRow; -use stwo_prover::constraint_framework::preprocessed_columns::PreprocessedColumn; -use stwo_prover::constraint_framework::{ - EvalAtRow, FrameworkComponent, FrameworkEval, INTERACTION_TRACE_IDX, -}; -use stwo_prover::core::channel::Channel; -use stwo_prover::core::fields::m31::M31; -use stwo_prover::core::fields::qm31::SecureField; -use stwo_prover::core::fields::secure_column::SECURE_EXTENSION_DEGREE; -use stwo_prover::core::lookups::utils::Fraction; -use stwo_prover::core::pcs::TreeVec; - -use crate::components::memory::addr_to_id::N_ADDRESS_FELTS; -use crate::components::memory::id_to_f252::{self, N_BITS_PER_FELT}; -use crate::input::SegmentAddrs; - -const RANGE_CHECK_BITS: usize = 128; -const N_INTERMEDIATE_COLUMNS: usize = 1; -pub const N_VALUES_FELTS: usize = RANGE_CHECK_BITS.div_ceil(N_BITS_PER_FELT); -pub const N_RANGE_CHECK_COLUMNS: usize = N_ADDRESS_FELTS + N_VALUES_FELTS + N_INTERMEDIATE_COLUMNS; -pub const LAST_VALUE_OFFSET: usize = N_ADDRESS_FELTS + N_VALUES_FELTS - 1; - -pub type Component = FrameworkComponent; - -const _: () = assert!( - RANGE_CHECK_BITS % N_BITS_PER_FELT == 2, - "High non-zero element must be 2 bits" -); - -pub struct Eval { - pub log_size: u32, - pub initial_memory_address: M31, - pub memory_lookup_elements: id_to_f252::RelationElements, - pub claimed_sum: SecureField, -} - -impl Eval { - pub fn new( - claim: Claim, - memory_lookup_elements: id_to_f252::RelationElements, - interaction_claim: InteractionClaim, - ) -> Self { - let n_values = claim.memory_segment.end_addr - claim.memory_segment.begin_addr; - let log_size = n_values.next_power_of_two().ilog2(); - Self { - log_size, - initial_memory_address: M31::from(claim.memory_segment.begin_addr), - memory_lookup_elements, - claimed_sum: interaction_claim.claimed_sum, - } - } -} - -impl FrameworkEval for Eval { - fn log_size(&self) -> u32 { - self.log_size - } - - fn max_constraint_log_degree_bound(&self) -> u32 { - self.log_size + 1 - } - - fn evaluate(&self, eval: E) -> E { - let mut eval = eval; - let mut values: [_; N_ADDRESS_FELTS + N_VALUES_FELTS] = - std::array::from_fn(|_| E::F::zero()); - - // Memory address. - // TODO(ShaharS): Use a constant column instead of taking the next_trace_mask(). - values[0] = E::F::from(self.initial_memory_address) + eval.next_trace_mask(); - - // Memory values. - for value in values.iter_mut().skip(N_ADDRESS_FELTS) { - *value = eval.next_trace_mask(); - } - - // Compute lookup for memory. - let is_first = eval.get_preprocessed_column(PreprocessedColumn::IsFirst(self.log_size())); - let mut logup = LogupAtRow::new(INTERACTION_TRACE_IDX, self.claimed_sum, None, is_first); - let frac = Fraction::new(E::EF::one(), self.memory_lookup_elements.combine(&values)); - logup.write_frac(&mut eval, frac); - - // Add constraints for the last 2 bit value. - let last_value_felt = values[N_ADDRESS_FELTS + N_VALUES_FELTS - 1].clone(); - let intermediate_value = eval.next_trace_mask(); - eval.add_constraint( - intermediate_value.clone() - - (last_value_felt.clone() * (last_value_felt.clone() - E::F::one())), - ); - eval.add_constraint( - intermediate_value - * (last_value_felt.clone() - E::F::from(M31::from(2))) - * (last_value_felt - E::F::from(M31::from(3))), - ); - - logup.finalize(&mut eval); - eval - } -} - -#[derive(Clone, Serialize, Deserialize)] -pub struct Claim { - pub memory_segment: SegmentAddrs, -} - -impl Claim { - pub fn mix_into(&self, channel: &mut impl Channel) { - channel.mix_u64(self.memory_segment.begin_addr as u64); - channel.mix_u64(self.memory_segment.end_addr as u64); - } - - pub fn log_sizes(&self) -> TreeVec> { - let n_values = self.memory_segment.end_addr - self.memory_segment.begin_addr; - let log_size = n_values.next_power_of_two().ilog2(); - let preprocessed_log_sizes = vec![log_size]; - let trace_log_sizes = vec![log_size; N_RANGE_CHECK_COLUMNS]; - let interaction_trace_log_sizes = vec![log_size; SECURE_EXTENSION_DEGREE]; - - TreeVec::new(vec![ - preprocessed_log_sizes, - trace_log_sizes, - interaction_trace_log_sizes, - ]) - } -} - -#[derive(Clone, Serialize, Deserialize)] -pub struct InteractionClaim { - pub log_size: u32, - pub claimed_sum: SecureField, -} - -impl InteractionClaim { - pub fn mix_into(&self, channel: &mut impl Channel) { - channel.mix_felts(&[self.claimed_sum]); - } -} diff --git a/stwo_cairo_prover/crates/prover/src/components/range_check_builtin/mod.rs b/stwo_cairo_prover/crates/prover/src/components/range_check_builtin/mod.rs deleted file mode 100644 index d51170cc..00000000 --- a/stwo_cairo_prover/crates/prover/src/components/range_check_builtin/mod.rs +++ /dev/null @@ -1,5 +0,0 @@ -pub mod component; -pub mod prover; - -pub use component::{Claim, Component, Eval, InteractionClaim}; -pub use prover::ClaimGenerator; diff --git a/stwo_cairo_prover/crates/prover/src/components/range_check_builtin/prover.rs b/stwo_cairo_prover/crates/prover/src/components/range_check_builtin/prover.rs deleted file mode 100644 index 985fcfbd..00000000 --- a/stwo_cairo_prover/crates/prover/src/components/range_check_builtin/prover.rs +++ /dev/null @@ -1,347 +0,0 @@ -use itertools::{chain, Itertools}; -use num_traits::One; -use stwo_prover::constraint_framework::logup::LogupTraceGenerator; -use stwo_prover::core::backend::simd::m31::{PackedM31, LOG_N_LANES, N_LANES}; -use stwo_prover::core::backend::simd::qm31::PackedQM31; -use stwo_prover::core::backend::simd::SimdBackend; -use stwo_prover::core::backend::{Col, Column}; -use stwo_prover::core::fields::m31::{BaseField, M31}; -use stwo_prover::core::fields::qm31::SecureField; -use stwo_prover::core::pcs::TreeBuilder; -use stwo_prover::core::poly::circle::{CanonicCoset, CircleEvaluation}; -use stwo_prover::core::poly::BitReversedOrder; -use stwo_prover::core::vcs::blake2_merkle::Blake2sMerkleChannel; -use stwo_prover::core::ColumnVec; - -use super::component::{Claim, InteractionClaim, N_RANGE_CHECK_COLUMNS, N_VALUES_FELTS}; -use crate::components::memory::id_to_f252; -use crate::input::SegmentAddrs; - -// Memory addresses for the RangeCheckBuiltin segment. -pub type RangeCheckBuiltinInput = PackedM31; - -pub struct ClaimGenerator { - pub memory_segment: SegmentAddrs, -} - -impl ClaimGenerator { - pub fn new(input: SegmentAddrs) -> Self { - Self { - memory_segment: input, - } - } - - pub fn write_trace( - &self, - tree_builder: &mut TreeBuilder<'_, '_, SimdBackend, Blake2sMerkleChannel>, - memory_trace_generator: &mut id_to_f252::ClaimGenerator, - ) -> (Claim, InteractionClaimGenerator) { - let mut addresses = self.memory_segment.addresses(); - // TODO(spapini): Split to multiple components. - let size = addresses.len().next_power_of_two(); - // TODO(AlonH): Addresses should be increasing. - addresses.resize(size, addresses[0]); - - let inputs = addresses - .into_iter() - .array_chunks::() - .map(|chunk| { - PackedM31::from_array(std::array::from_fn(|i| M31::from_u32_unchecked(chunk[i]))) - }) - .collect_vec(); - let (trace, interaction_prover) = write_trace_simd(&inputs, memory_trace_generator); - interaction_prover - .memory_addresses - .iter() - .for_each(|v| memory_trace_generator.add_inputs_simd(v)); - tree_builder.extend_evals(trace); - let claim = Claim { - memory_segment: self.memory_segment.clone(), - }; - (claim, interaction_prover) - } -} - -pub struct InteractionClaimGenerator { - pub memory_addresses: Vec, - pub memory_values: Vec<[PackedM31; N_VALUES_FELTS]>, -} - -impl InteractionClaimGenerator { - pub fn with_capacity(capacity: usize) -> Self { - Self { - memory_addresses: Vec::with_capacity(capacity), - memory_values: Vec::with_capacity(capacity), - } - } - - pub fn write_interaction_trace( - &self, - tree_builder: &mut TreeBuilder<'_, '_, SimdBackend, Blake2sMerkleChannel>, - memory_lookup_elements: &id_to_f252::RelationElements, - ) -> InteractionClaim { - let log_size = self.memory_addresses.len().ilog2() + LOG_N_LANES; - let (trace, claimed_sum) = gen_interaction_trace(self, log_size, memory_lookup_elements); - tree_builder.extend_evals(trace); - - InteractionClaim { - log_size, - claimed_sum, - } - } -} - -pub fn write_trace_simd( - inputs: &[RangeCheckBuiltinInput], - memory_trace_generator: &id_to_f252::ClaimGenerator, -) -> ( - ColumnVec>, - InteractionClaimGenerator, -) { - let log_size = inputs.len().ilog2() + LOG_N_LANES; - let mut interaction_prover = InteractionClaimGenerator::with_capacity(inputs.len()); - let mut trace = (0..N_RANGE_CHECK_COLUMNS) - .map(|_| unsafe { Col::::uninitialized(1 << log_size) }) - .collect_vec(); - - let address_initial_offset = PackedM31::broadcast(BaseField::from(inputs[0].into_simd()[0])); - #[allow(clippy::needless_range_loop)] - for vec_row in 0..(1 << (log_size - LOG_N_LANES)) { - let row_input = *inputs.get(vec_row).unwrap(); - // TODO: remove address from the trace. - let split_values: [PackedM31; N_VALUES_FELTS] = memory_trace_generator - .deduce_output(row_input)[..N_VALUES_FELTS] - .try_into() - .unwrap(); - let address = row_input - address_initial_offset; - trace[0].data[vec_row] = address; - for (i, v) in split_values.iter().enumerate() { - trace[i + 1].data[vec_row] = *v; - } - let last_value_felt = split_values[N_VALUES_FELTS - 1]; - trace[N_VALUES_FELTS + 1].data[vec_row] = - last_value_felt * (last_value_felt - PackedM31::one()); - - interaction_prover.memory_addresses.push(row_input); - interaction_prover.memory_values.push(split_values); - } - - let domain = CanonicCoset::new(log_size).circle_domain(); - ( - trace - .into_iter() - .map(|eval| CircleEvaluation::::new(domain, eval)) - .collect_vec(), - interaction_prover, - ) -} - -pub fn gen_interaction_trace( - interaction_prover: &InteractionClaimGenerator, - log_size: u32, - memory_lookup_elements: &id_to_f252::RelationElements, -) -> ( - ColumnVec>, - SecureField, -) { - let mut logup_gen = LogupTraceGenerator::new(log_size); - let mut col_gen = logup_gen.new_col(); - for vec_row in 0..(1 << (log_size - LOG_N_LANES)) { - let p_mem: PackedQM31 = memory_lookup_elements.combine( - &chain!( - [interaction_prover.memory_addresses[vec_row]], - interaction_prover.memory_values[vec_row] - ) - .collect_vec(), - ); - col_gen.write_frac(vec_row, PackedQM31::one(), p_mem); - } - col_gen.finalize_col(); - - let (trace, claimed_sum) = logup_gen.finalize_last(); - (trace, claimed_sum) -} - -#[cfg(test)] -mod tests { - use std::array; - use std::simd::Simd; - - use itertools::zip_eq; - use rand::Rng; - use stwo_prover::constraint_framework::preprocessed_columns::gen_is_first; - use stwo_prover::constraint_framework::FrameworkEval; - use stwo_prover::core::backend::simd::m31::N_LANES; - use stwo_prover::core::channel::Blake2sChannel; - use stwo_prover::core::pcs::TreeVec; - - use super::*; - use crate::components::memory::addr_to_id::N_ADDRESS_FELTS; - use crate::components::memory::id_to_f252::N_BITS_PER_FELT; - use crate::components::range_check_builtin::component::Eval; - use crate::felt::split_f252; - - #[test] - fn test_generate_trace() { - use super::*; - - let mut rng = rand::thread_rng(); - let log_size = 8; - let inputs = (0..1 << (log_size - LOG_N_LANES)) - .map(|i| { - PackedM31::from_array(array::from_fn(|j| { - M31::from_u32_unchecked(i * N_LANES as u32 + j as u32) - })) - }) - .collect_vec(); - - let values = (0..1 << (log_size - LOG_N_LANES)) - .map(|_| { - array::from_fn(|i| { - if i < 4 { - Simd::from_array(rng.gen()) - } else { - Simd::splat(0) - } - }) - }) - .collect_vec(); - let memory_trace_generator = id_to_f252::ClaimGenerator { - values: values.clone(), - multiplicities: vec![0; 1 << log_size], - }; - let (trace, interaction_prover) = write_trace_simd(&inputs, &memory_trace_generator); - - assert_eq!(trace.len(), N_RANGE_CHECK_COLUMNS); - assert_eq!( - trace[0].values.clone().into_cpu_vec(), - (0..1 << log_size).map(M31::from).collect_vec() - ); - - // Assert that the trace values are correct. - #[allow(clippy::needless_range_loop)] - for row_offset in 0..1 << (log_size - LOG_N_LANES) { - let input = values[row_offset]; - - let mut inputs_u128 = [0_u128; 16]; - for (index, simd) in input.iter().enumerate() { - for (i, val) in simd.to_array().into_iter().enumerate() { - if index >= 4 { - assert_eq!(val, 0); - continue; - } - let val_u128 = val as u128; - inputs_u128[i] += val_u128 << (32 * index); - } - } - - let mask = ((1 << N_BITS_PER_FELT) - 1) as u128; - for col in trace.iter().skip(N_ADDRESS_FELTS).take(N_VALUES_FELTS) { - for j in 0..N_LANES { - let val = col.values.at((row_offset << LOG_N_LANES) + j); - assert_eq!(val.0, (inputs_u128[j] & mask) as u32); - inputs_u128[j] >>= N_BITS_PER_FELT; - } - } - - let last_value_felts = trace[N_ADDRESS_FELTS + N_VALUES_FELTS - 1].values.clone(); - let intermediate_values = trace[N_ADDRESS_FELTS + N_VALUES_FELTS].values.clone(); - for (last_value_felt, intermediate_value) in - zip_eq(last_value_felts.data.clone(), intermediate_values.data) - { - assert_eq!( - intermediate_value.to_array(), - (last_value_felt * (last_value_felt - PackedM31::one())).to_array() - ); - } - // Assert that the high values are in range [0, 4). - last_value_felts.into_cpu_vec().iter().all(|&x| x.0 < 4); - } - - // Assert memory addresses lookup are sequential, offset by `address_initial_offset`. - assert_eq!( - (1 + interaction_prover.memory_values[0].len()) * N_LANES, - 1 << log_size - ); - for (i, addresses) in interaction_prover.memory_addresses.iter().enumerate() { - assert_eq!( - addresses.to_array(), - array::from_fn(|j| { M31::from_u32_unchecked((i * N_LANES + j) as u32) }) - ); - } - } - - #[test] - fn test_generate_interaction_trace() { - let mut rng = rand::thread_rng(); - let log_size = 8; - let mem_log_size = log_size + 1; - let address_initial_offset = 256; - let inputs = (0..1 << (log_size - LOG_N_LANES)) - .map(|i| { - PackedM31::from_array(array::from_fn(|j| { - M31::from_u32_unchecked(i * N_LANES as u32 + j as u32 + address_initial_offset) - })) - }) - .collect_vec(); - - let values = (0..1 << (mem_log_size - LOG_N_LANES)) - .map(|_| { - array::from_fn(|i| { - if i < 4 { - Simd::from_array(rng.gen()) - } else { - Simd::splat(0) - } - }) - }) - .collect_vec(); - let memory_trace_generator = id_to_f252::ClaimGenerator { - values: values.clone(), - multiplicities: vec![0; 1 << mem_log_size], - }; - let (trace, interaction_prover) = write_trace_simd(&inputs, &memory_trace_generator); - - let channel = &mut Blake2sChannel::default(); - let memory_lookup_elements = id_to_f252::RelationElements::draw(channel); - - let (interaction_trace, claimed_sum) = - gen_interaction_trace(&interaction_prover, log_size, &memory_lookup_elements); - - let constant_trace = vec![gen_is_first::(log_size)]; - - let trace = TreeVec::new(vec![constant_trace, trace, interaction_trace]); - let trace_polys = trace.map_cols(|c| c.interpolate()); - - let component = Eval { - log_size, - initial_memory_address: M31::from(address_initial_offset), - memory_lookup_elements, - claimed_sum, - }; - - stwo_prover::constraint_framework::assert_constraints( - &trace_polys, - CanonicCoset::new(log_size), - |eval| { - component.evaluate(eval); - }, - ) - } - - #[test] - fn test_split() { - let x: [u32; 8] = [ - 0x12345678, 0x9abcdef0, 0x13579bdf, 0x2468ace0, 0x12345678, 0x9abcdef0, 0x13579bdf, 0, - ]; - let res = split_f252(x); - assert_eq!( - res, - [ - 120, 43, 141, 2, 495, 486, 106, 447, 411, 427, 4, 412, 138, 291, 480, 172, 52, 9, - 444, 411, 427, 252, 111, 175, 19, 0, 0, 0 - ] - .map(M31::from) - ); - } -} diff --git a/stwo_cairo_prover/crates/prover/src/components/ret_opcode/prover.rs b/stwo_cairo_prover/crates/prover/src/components/ret_opcode/prover.rs index 4c4b83e0..98d27faf 100644 --- a/stwo_cairo_prover/crates/prover/src/components/ret_opcode/prover.rs +++ b/stwo_cairo_prover/crates/prover/src/components/ret_opcode/prover.rs @@ -12,8 +12,8 @@ use stwo_prover::core::poly::BitReversedOrder; use stwo_prover::core::vcs::blake2_merkle::Blake2sMerkleChannel; use super::component::{Claim, InteractionClaim, RET_INSTRUCTION}; -use crate::components::memory::id_to_f252; use crate::components::memory::id_to_f252::component::N_M31_IN_FELT252; +use crate::components::memory::{self, id_to_f252}; use crate::input::instructions::VmState; const N_MEMORY_CALLS: usize = 3; @@ -56,13 +56,18 @@ impl ClaimGenerator { pub fn write_trace( &self, tree_builder: &mut TreeBuilder<'_, '_, SimdBackend, Blake2sMerkleChannel>, + memory_addr_to_id_state: &mut memory::addr_to_id::ClaimGenerator, memory_trace_generator: &mut id_to_f252::ClaimGenerator, ) -> (Claim, InteractionClaimGenerator) { - let (trace, interaction_prover) = write_trace_simd(&self.inputs, memory_trace_generator); - interaction_prover.memory_inputs.iter().for_each(|c| { - c.iter() - .for_each(|v| memory_trace_generator.add_inputs_simd(v)) - }); + let (trace, interaction_prover) = write_trace_simd( + &self.inputs, + memory_addr_to_id_state, + memory_trace_generator, + ); + // interaction_prover.memory_inputs.iter().for_each(|c| { + // c.iter() + // .for_each(|v| memory_trace_generator.add_packed_m31(v)) + // }); tree_builder.extend_evals(trace); let claim = Claim { n_rets: self.inputs.len() * N_LANES, @@ -130,6 +135,7 @@ impl InteractionClaimGenerator { fn write_trace_simd( inputs: &[PackedCasmState], + memory_addr_to_id_state: &memory::addr_to_id::ClaimGenerator, memory_trace_generator: &id_to_f252::ClaimGenerator, ) -> ( Vec>, @@ -146,6 +152,7 @@ fn write_trace_simd( input, i, &mut sub_components_inputs, + memory_addr_to_id_state, memory_trace_generator, ); }); @@ -175,6 +182,7 @@ fn write_trace_row( ret_opcode_input: &PackedCasmState, row_index: usize, lookup_data: &mut InteractionClaimGenerator, + memory_addr_to_id_state: &memory::addr_to_id::ClaimGenerator, memory_trace_generator: &id_to_f252::ClaimGenerator, ) { let col0 = ret_opcode_input.pc; @@ -186,8 +194,8 @@ fn write_trace_row( lookup_data.memory_inputs[0].push(col0); lookup_data.memory_inputs[1].push((col2) - (PackedM31::broadcast(M31::one()))); lookup_data.memory_outputs[0].push(RET_INSTRUCTION.map(|v| PackedM31::broadcast(M31::from(v)))); - let mem_fp_minus_one = - memory_trace_generator.deduce_output((col2) - (PackedM31::broadcast(M31::from(1)))); + let id = memory_addr_to_id_state.deduce_output((col2) - (PackedM31::broadcast(M31::from(1)))); + let mem_fp_minus_one = memory_trace_generator.deduce_output(id).value; lookup_data.memory_outputs[1].push(mem_fp_minus_one); let col3 = mem_fp_minus_one[0]; @@ -195,8 +203,8 @@ fn write_trace_row( let col4 = mem_fp_minus_one[1]; dst[4].data[row_index] = col4; lookup_data.memory_inputs[2].push((col2) - (PackedM31::broadcast(M31::from(2)))); - let mem_fp_minus_two = - memory_trace_generator.deduce_output((col2) - (PackedM31::broadcast(M31::from(2)))); + let id = memory_addr_to_id_state.deduce_output((col2) - (PackedM31::broadcast(M31::from(2)))); + let mem_fp_minus_two = memory_trace_generator.deduce_output(id).value; lookup_data.memory_outputs[2].push(mem_fp_minus_two); let col5 = mem_fp_minus_two[0]; dst[5].data[row_index] = col5; diff --git a/stwo_cairo_prover/crates/prover/src/felt.rs b/stwo_cairo_prover/crates/prover/src/felt.rs index 8bf5dc90..74edf6f4 100644 --- a/stwo_cairo_prover/crates/prover/src/felt.rs +++ b/stwo_cairo_prover/crates/prover/src/felt.rs @@ -6,7 +6,6 @@ use stwo_prover::core::fields::m31::M31; use crate::components::memory::id_to_f252::component::N_M31_IN_FELT252; use crate::components::memory::id_to_f252::N_BITS_PER_FELT; -use crate::components::range_check_builtin::component::N_VALUES_FELTS; /// Splits a 32N bit dense representation into felts, each with N_BITS_PER_FELT bits. /// @@ -54,8 +53,9 @@ where res } +pub const N_LIMBS_IN_128_BIT_FELT: usize = 15; /// Splits a 128 bit dense representation into felts, each with N_BITS_PER_FELT bits. -pub fn split_u128_simd(x: [u32x16; 4]) -> [PackedM31; N_VALUES_FELTS] { +pub fn split_u128_simd(x: [u32x16; 4]) -> [PackedM31; N_LIMBS_IN_128_BIT_FELT] { split(x, u32x16::from_array([(1 << N_BITS_PER_FELT) - 1; N_LANES])) .map(|x| PackedM31::from(x.to_array().map(M31::from_u32_unchecked))) } diff --git a/stwo_cairo_prover/crates/prover/src/input/mem.rs b/stwo_cairo_prover/crates/prover/src/input/mem.rs index cf2c6362..9d92ba76 100644 --- a/stwo_cairo_prover/crates/prover/src/input/mem.rs +++ b/stwo_cairo_prover/crates/prover/src/input/mem.rs @@ -6,7 +6,7 @@ use itertools::Itertools; use super::vm_import::MemEntry; /// Prime 2^251 + 17 * 2^192 + 1 in little endian. -const P_MIN_1: [u32; 8] = [ +pub const P_MIN_1: [u32; 8] = [ 0x0000_0000, 0x0000_0000, 0x0000_0000, @@ -16,7 +16,7 @@ const P_MIN_1: [u32; 8] = [ 0x0000_0011, 0x0800_0000, ]; -const P_MIN_2: [u32; 8] = [ +pub const P_MIN_2: [u32; 8] = [ 0xFFFF_FFFF, 0xFFFF_FFFF, 0xFFFF_FFFF, @@ -27,31 +27,22 @@ const P_MIN_2: [u32; 8] = [ 0x0800_0000, ]; -// Note: this should be smaller than 2^29. -const SMALL_VALUE_SHIFT: u32 = 1 << 26; +const SMALL_VALUE_SHIFT: u128 = 1 << 72; #[derive(Debug)] pub struct MemConfig { - /// The absolute value of the smallest negative value that can be stored as a small value. - pub small_min_neg: u32, - /// The largest value that can be stored as a small value. - pub small_max: u32, + pub small_max: u128, } impl MemConfig { - pub fn new(small_min_neg: u32, small_max: u32) -> MemConfig { - assert!(small_min_neg <= SMALL_VALUE_SHIFT); + pub fn new(small_max: u128) -> MemConfig { assert!(small_max <= SMALL_VALUE_SHIFT); - MemConfig { - small_min_neg, - small_max, - } + MemConfig { small_max } } } impl Default for MemConfig { fn default() -> Self { MemConfig { - small_min_neg: (1 << 10) - 1, - small_max: (1 << 10) - 1, + small_max: (1 << 72) - 1, } } } @@ -64,12 +55,13 @@ pub struct Memory { pub address_to_id: Vec, pub inst_cache: HashMap, pub f252_values: Vec<[u32; 8]>, + pub small_values: Vec, } impl Memory { pub fn get(&self, addr: u32) -> MemoryValue { match self.address_to_id[addr as usize].decode() { - MemoryValueId::Small(id) => MemoryValue::Small(id), - MemoryValueId::F252(id) => MemoryValue::F252(self.f252_values[id]), + MemoryValueId::Small(id) => MemoryValue::Small(self.small_values[id as usize]), + MemoryValueId::F252(id) => MemoryValue::F252(self.f252_values[id as usize]), } } @@ -83,25 +75,14 @@ impl Memory { // TODO(spapini): Optimize. This should be SIMD. pub fn value_from_felt252(&self, value: [u32; 8]) -> MemoryValue { - if value[7] == 0 { - // Positive case. - if value[1..7] != [0; 6] || value[0] > self.config.small_max { - // Not small. - return MemoryValue::F252(value); - } - MemoryValue::Small(value[0] as i32) + if value[3..8] == [0; 5] && value[2] < (1 << 8) { + MemoryValue::Small( + value[0] as u128 + + ((value[1] as u128) << 32) + + ((value[2] as u128) << 64) + + ((value[3] as u128) << 96), + ) } else { - // Negative case. - if value == P_MIN_1 { - return MemoryValue::Small(-1); - } - if value[1..7] != P_MIN_2[1..7] { - return MemoryValue::F252(value); - } - let num = 0xFFFF_FFFF - value[0]; - if num < self.config.small_min_neg - 2 { - return MemoryValue::Small(-(num as i32 + 2)); - } MemoryValue::F252(value) } } @@ -120,6 +101,7 @@ impl Memory { pub struct MemoryBuilder { mem: Memory, felt252_id_cache: HashMap<[u32; 8], usize>, + small_values_cache: HashMap, } impl MemoryBuilder { pub fn new(config: MemConfig) -> Self { @@ -129,8 +111,10 @@ impl MemoryBuilder { address_to_id: Vec::new(), inst_cache: HashMap::new(), f252_values: Vec::new(), + small_values: Vec::new(), }, felt252_id_cache: HashMap::new(), + small_values_cache: HashMap::new(), } } pub fn from_iter>( @@ -164,14 +148,21 @@ impl MemoryBuilder { .resize(addr as usize + 1, EncodedMemoryValueId::default()); } let res = EncodedMemoryValueId::encode(match value { - MemoryValue::Small(val) => MemoryValueId::Small(val), + MemoryValue::Small(val) => { + let len = self.small_values.len(); + let id = *self.small_values_cache.entry(val).or_insert(len); + if id == len { + self.small_values.push(val); + }; + MemoryValueId::Small(id as u32) + } MemoryValue::F252(val) => { let len = self.f252_values.len(); let id = *self.felt252_id_cache.entry(val).or_insert(len); if id == len { self.f252_values.push(val); }; - MemoryValueId::F252(id) + MemoryValueId::F252(id as u32) } }); self.address_to_id[addr as usize] = res; @@ -193,22 +184,20 @@ impl DerefMut for MemoryBuilder { } #[derive(Copy, Clone, PartialEq, Eq, Debug)] -pub struct EncodedMemoryValueId(u32); +pub struct EncodedMemoryValueId(pub u32); impl EncodedMemoryValueId { pub fn encode(value: MemoryValueId) -> EncodedMemoryValueId { match value { - MemoryValueId::Small(id) => { - EncodedMemoryValueId((id + SMALL_VALUE_SHIFT as i32) as u32) - } - MemoryValueId::F252(id) => EncodedMemoryValueId(id as u32 | 0x4000_0000), + MemoryValueId::Small(id) => EncodedMemoryValueId(id), + MemoryValueId::F252(id) => EncodedMemoryValueId(id | 0x4000_0000), } } pub fn decode(&self) -> MemoryValueId { let tag = self.0 >> 30; let val = self.0 & 0x3FFF_FFFF; match tag { - 0 => MemoryValueId::Small(val as i32 - SMALL_VALUE_SHIFT as i32), - 1 => MemoryValueId::F252(val as usize), + 0 => MemoryValueId::Small(val), + 1 => MemoryValueId::F252(val), _ => panic!("Invalid tag"), } } @@ -221,35 +210,33 @@ impl Default for EncodedMemoryValueId { } pub enum MemoryValueId { - Small(i32), - F252(usize), + Small(u32), + F252(u32), } #[derive(Copy, Clone, PartialEq, Eq, Debug)] pub enum MemoryValue { - Small(i32), + Small(u128), F252([u32; 8]), } impl MemoryValue { - pub fn as_small(&self) -> i32 { + pub fn as_small(&self) -> u128 { match self { MemoryValue::Small(x) => *x, - MemoryValue::F252(_) => panic!("Cannot convert F252 to i32"), + MemoryValue::F252(_) => panic!("Cannot convert F252 to u128"), } } pub fn as_u256(&self) -> [u32; 8] { match *self { MemoryValue::Small(x) => { - if x >= 0 { - [x as u32, 0, 0, 0, 0, 0, 0, 0] - } else if x == -1 { - P_MIN_1 - } else { - let mut res = P_MIN_2; - res[0] = 0xFFFF_FFFF - (-x - 2) as u32; - res - } + let x: [u32; 4] = [ + x as u32, + (x >> 32) as u32, + (x >> 64) as u32, + (x >> 96) as u32, + ]; + [x[0], x[1], x[2], x[3], 0, 0, 0, 0] } MemoryValue::F252(x) => x, } @@ -297,12 +284,28 @@ mod tests { addr: 105, val: [1 << 24, 0, 0, 0, 0, 0, 0, 0], }, + MemEntry { + addr: 200, + val: [1, 1, 1, 0, 0, 0, 0, 0], + }, + MemEntry { + addr: 201, + val: [1, 1, 1 << 10, 0, 0, 0, 0, 0], + }, ]; let memory = MemoryBuilder::from_iter(MemConfig::default(), entries.iter().cloned()); assert_eq!(memory.get(0), MemoryValue::F252([1; 8])); assert_eq!(memory.get(1), MemoryValue::Small(6)); - assert_eq!(memory.get(8), MemoryValue::Small(-1)); - assert_eq!(memory.get(9), MemoryValue::Small(-2)); + assert_eq!( + memory.get(200), + MemoryValue::Small(1 + (1 << 32) + (1 << 64)) + ); + assert_eq!( + memory.get(201), + MemoryValue::F252([1, 1, 1 << 10, 0, 0, 0, 0, 0]) + ); + assert_eq!(memory.get(8), MemoryValue::F252(P_MIN_1)); + assert_eq!(memory.get(9), MemoryValue::F252(P_MIN_2)); // Duplicates. assert_eq!(memory.get(100), MemoryValue::F252([1; 8])); assert_eq!(memory.address_to_id[0], memory.address_to_id[100]); @@ -315,13 +318,6 @@ mod tests { assert_eq!(small.as_small(), 1); assert_eq!(small.as_u256(), [1, 0, 0, 0, 0, 0, 0, 0]); - let small_negative = MemoryValue::Small(-5); - assert_eq!(small_negative.as_small(), -5); - assert_eq!( - small_negative.as_u256().as_slice(), - [&[0xFFFFFFFC], &P_MIN_2[1..]].concat().as_slice() - ); - let f252 = MemoryValue::F252([1; 8]); assert_eq!(f252.as_u256(), [1; 8]); } diff --git a/stwo_cairo_prover/crates/prover/src/input/vm_import/mod.rs b/stwo_cairo_prover/crates/prover/src/input/vm_import/mod.rs index 892d9f49..6f1b5332 100644 --- a/stwo_cairo_prover/crates/prover/src/input/vm_import/mod.rs +++ b/stwo_cairo_prover/crates/prover/src/input/vm_import/mod.rs @@ -40,7 +40,7 @@ pub fn import_from_vm_output( .max() .ok_or(VmImportError::NoMemorySegments)?; assert!(end_addr < (1 << 32)); - let mem_config = MemConfig::new((1 << 20) - 1, end_addr as u32); + let mem_config = MemConfig::default(); let mem_path = priv_json.parent().unwrap().join(&priv_data.memory_path); let trace_path = priv_json.parent().unwrap().join(&priv_data.trace_path);