diff --git a/Cargo.lock b/Cargo.lock index 51f737b3ec..b75a8e5669 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4859,11 +4859,9 @@ dependencies = [ "num-bigint 0.4.5", "p3-baby-bear", "p3-field", - "p3-symmetric", "rand", "serde", "serde_json", - "sp1-core", "sp1-recursion-compiler", "tempfile", ] diff --git a/core/Cargo.toml b/core/Cargo.toml index 548fcde6cd..df02550f1a 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -58,7 +58,6 @@ web-time = "1.1.0" rayon-scan = "0.1.1" thiserror = "1.0.60" num-bigint = { version = "0.4.3", default-features = false } -rand = "0.8.5" [dev-dependencies] tiny-keccak = { version = "2.0.2", features = ["keccak"] } diff --git a/core/src/air/builder.rs b/core/src/air/builder.rs index d957f43cd9..eba567d920 100644 --- a/core/src/air/builder.rs +++ b/core/src/air/builder.rs @@ -308,7 +308,6 @@ pub trait AluAirBuilder: BaseAirBuilder { c: Word>, shard: impl Into, channel: impl Into, - nonce: impl Into, multiplicity: impl Into, ) { let values = once(opcode.into()) @@ -317,7 +316,6 @@ pub trait AluAirBuilder: BaseAirBuilder { .chain(c.0.into_iter().map(Into::into)) .chain(once(shard.into())) .chain(once(channel.into())) - .chain(once(nonce.into())) .collect(); self.send(AirInteraction::new( @@ -337,7 +335,6 @@ pub trait AluAirBuilder: BaseAirBuilder { c: Word>, shard: impl Into, channel: impl Into, - nonce: impl Into, multiplicity: impl Into, ) { let values = once(opcode.into()) @@ -346,7 +343,6 @@ pub trait AluAirBuilder: BaseAirBuilder { .chain(c.0.into_iter().map(Into::into)) .chain(once(shard.into())) .chain(once(channel.into())) - .chain(once(nonce.into())) .collect(); self.receive(AirInteraction::new( @@ -363,7 +359,6 @@ pub trait AluAirBuilder: BaseAirBuilder { shard: impl Into + Clone, channel: impl Into + Clone, clk: impl Into + Clone, - nonce: impl Into + Clone, syscall_id: impl Into + Clone, arg1: impl Into + Clone, arg2: impl Into + Clone, @@ -374,7 +369,6 @@ pub trait AluAirBuilder: BaseAirBuilder { shard.clone().into(), channel.clone().into(), clk.clone().into(), - nonce.clone().into(), syscall_id.clone().into(), arg1.clone().into(), arg2.clone().into(), @@ -391,7 +385,6 @@ pub trait AluAirBuilder: BaseAirBuilder { shard: impl Into + Clone, channel: impl Into + Clone, clk: impl Into + Clone, - nonce: impl Into + Clone, syscall_id: impl Into + Clone, arg1: impl Into + Clone, arg2: impl Into + Clone, @@ -402,7 +395,6 @@ pub trait AluAirBuilder: BaseAirBuilder { shard.clone().into(), channel.clone().into(), clk.clone().into(), - nonce.clone().into(), syscall_id.clone().into(), arg1.clone().into(), arg2.clone().into(), diff --git a/core/src/alu/add_sub/mod.rs b/core/src/alu/add_sub/mod.rs index 3179d4d775..2321427c53 100644 --- a/core/src/alu/add_sub/mod.rs +++ b/core/src/alu/add_sub/mod.rs @@ -1,8 +1,7 @@ use core::borrow::{Borrow, BorrowMut}; use core::mem::size_of; -use p3_air::{Air, AirBuilder, BaseAir}; -use p3_field::AbstractField; +use p3_air::{Air, BaseAir}; use p3_field::PrimeField; use p3_matrix::dense::RowMajorMatrix; use p3_matrix::Matrix; @@ -39,9 +38,6 @@ pub struct AddSubCols { /// The channel number, used for byte lookup table. pub channel: T, - /// The nonce of the operation. - pub nonce: T, - /// Instance of `AddOperation` to handle addition logic in `AddSubChip`'s ALU operations. /// It's result will be `a` for the add operation and `b` for the sub operation. pub add_operation: AddOperation, @@ -133,13 +129,6 @@ impl MachineAir for AddSubChip { // Pad the trace to a power of two. pad_to_power_of_two::(&mut trace.values); - // Write the nonces to the trace. - for i in 0..trace.height() { - let cols: &mut AddSubCols = - trace.values[i * NUM_ADD_SUB_COLS..(i + 1) * NUM_ADD_SUB_COLS].borrow_mut(); - cols.nonce = F::from_canonical_usize(i); - } - trace } @@ -162,14 +151,6 @@ where let main = builder.main(); let local = main.row_slice(0); let local: &AddSubCols = (*local).borrow(); - let next = main.row_slice(1); - let next: &AddSubCols = (*next).borrow(); - - // Constrain the incrementing nonce. - builder.when_first_row().assert_zero(local.nonce); - builder - .when_transition() - .assert_eq(local.nonce + AB::Expr::one(), next.nonce); // Evaluate the addition operation. AddOperation::::eval( @@ -191,7 +172,6 @@ where local.operand_2, local.shard, local.channel, - local.nonce, local.is_add, ); @@ -203,7 +183,6 @@ where local.operand_2, local.shard, local.channel, - local.nonce, local.is_sub, ); diff --git a/core/src/alu/bitwise/mod.rs b/core/src/alu/bitwise/mod.rs index 81163b11e1..3e7227b709 100644 --- a/core/src/alu/bitwise/mod.rs +++ b/core/src/alu/bitwise/mod.rs @@ -1,9 +1,7 @@ use core::borrow::{Borrow, BorrowMut}; use core::mem::size_of; -use p3_air::AirBuilder; use p3_air::{Air, BaseAir}; -use p3_field::AbstractField; use p3_field::PrimeField; use p3_matrix::dense::RowMajorMatrix; use p3_matrix::Matrix; @@ -33,9 +31,6 @@ pub struct BitwiseCols { /// The channel number, used for byte lookup table. pub channel: T, - /// The nonce of the operation. - pub nonce: T, - /// The output operand. pub a: Word, @@ -116,12 +111,6 @@ impl MachineAir for BitwiseChip { // Pad the trace to a power of two. pad_to_power_of_two::(&mut trace.values); - for i in 0..trace.height() { - let cols: &mut BitwiseCols = - trace.values[i * NUM_BITWISE_COLS..(i + 1) * NUM_BITWISE_COLS].borrow_mut(); - cols.nonce = F::from_canonical_usize(i); - } - trace } @@ -144,14 +133,6 @@ where let main = builder.main(); let local = main.row_slice(0); let local: &BitwiseCols = (*local).borrow(); - let next = main.row_slice(1); - let next: &BitwiseCols = (*next).borrow(); - - // Constrain the incrementing nonce. - builder.when_first_row().assert_zero(local.nonce); - builder - .when_transition() - .assert_eq(local.nonce + AB::Expr::one(), next.nonce); // Get the opcode for the operation. let opcode = local.is_xor * ByteOpcode::XOR.as_field::() @@ -185,7 +166,6 @@ where local.c, local.shard, local.channel, - local.nonce, local.is_xor + local.is_or + local.is_and, ); diff --git a/core/src/alu/divrem/mod.rs b/core/src/alu/divrem/mod.rs index 84198b16bb..b54206469f 100644 --- a/core/src/alu/divrem/mod.rs +++ b/core/src/alu/divrem/mod.rs @@ -64,7 +64,6 @@ mod utils; use core::borrow::{Borrow, BorrowMut}; use core::mem::size_of; -use std::collections::HashMap; use p3_air::{Air, AirBuilder, BaseAir}; use p3_field::AbstractField; @@ -73,10 +72,11 @@ use p3_matrix::dense::RowMajorMatrix; use p3_matrix::Matrix; use sp1_derive::AlignedBorrow; +use self::utils::eval_abs_value; use crate::air::MachineAir; use crate::air::{SP1AirBuilder, Word}; use crate::alu::divrem::utils::{get_msb, get_quotient_and_remainder, is_signed_operation}; -use crate::alu::{create_alu_lookups, AluEvent}; +use crate::alu::AluEvent; use crate::bytes::event::ByteRecord; use crate::bytes::{ByteLookupEvent, ByteOpcode}; use crate::disassembler::WORD_SIZE; @@ -107,9 +107,6 @@ pub struct DivRemCols { /// The channel number, used for byte lookup table. pub channel: T, - /// The nonce of the operation. - pub nonce: T, - /// The output operand. pub a: Word, @@ -187,23 +184,6 @@ pub struct DivRemCols { /// Flag to indicate whether `c` is negative. pub c_neg: T, - /// The lower nonce of the operation. - pub lower_nonce: T, - - /// The upper nonce of the operation. - pub upper_nonce: T, - - /// The absolute nonce of the operation. - pub abs_nonce: T, - - /// Selector to determine whether an ALU Event is sent for absolute value computation of `c`. - pub abs_c_alu_event: T, - pub abs_c_alu_event_nonce: T, - - /// Selector to determine whether an ALU Event is sent for absolute value computation of `rem`. - pub abs_rem_alu_event: T, - pub abs_rem_alu_event_nonce: T, - /// Selector to know whether this row is enabled. pub is_real: T, @@ -279,24 +259,6 @@ impl MachineAir for DivRemChip { cols.max_abs_c_or_1 = Word::from(u32::max(1, event.c)); } - // Set the `alu_event` flags. - cols.abs_c_alu_event = cols.c_neg * cols.is_real; - cols.abs_c_alu_event_nonce = F::from_canonical_u32( - input - .nonce_lookup - .get(&event.sub_lookups[4]) - .copied() - .unwrap_or_default(), - ); - cols.abs_rem_alu_event = cols.rem_neg * cols.is_real; - cols.abs_rem_alu_event_nonce = F::from_canonical_u32( - input - .nonce_lookup - .get(&event.sub_lookups[5]) - .copied() - .unwrap_or_default(), - ); - // Insert the MSB lookup events. { let words = [event.b, event.c, remainder]; @@ -319,7 +281,7 @@ impl MachineAir for DivRemChip { // Calculate the modified multiplicity { - cols.remainder_check_multiplicity = cols.is_real * (F::one() - cols.is_c_0.result); + cols.remainder_check_multiplicity = cols.is_real * cols.is_c_0.result; } // Calculate c * quotient + remainder. @@ -359,40 +321,6 @@ impl MachineAir for DivRemChip { // mul and LT upon which div depends. This ordering is critical as mul and LT // require all the mul and LT events be added before we can call generate_trace. { - // Insert the absolute value computation events. - { - let mut add_events: Vec = vec![]; - if cols.abs_c_alu_event == F::one() { - add_events.push(AluEvent { - lookup_id: event.sub_lookups[4], - shard: event.shard, - channel: event.channel, - clk: event.clk, - opcode: Opcode::ADD, - a: 0, - b: event.c, - c: (event.c as i32).abs() as u32, - sub_lookups: create_alu_lookups(), - }) - } - if cols.abs_rem_alu_event == F::one() { - add_events.push(AluEvent { - lookup_id: event.sub_lookups[5], - shard: event.shard, - channel: event.channel, - clk: event.clk, - opcode: Opcode::ADD, - a: 0, - b: remainder, - c: (remainder as i32).abs() as u32, - sub_lookups: create_alu_lookups(), - }) - } - let mut alu_events = HashMap::new(); - alu_events.insert(Opcode::ADD, add_events); - output.add_alu_events(alu_events); - } - let mut lower_word = 0; for i in 0..WORD_SIZE { lower_word += (c_times_quotient[i] as u32) << (i * BYTE_SIZE); @@ -404,7 +332,6 @@ impl MachineAir for DivRemChip { } let lower_multiplication = AluEvent { - lookup_id: event.sub_lookups[0], shard: event.shard, channel: event.channel, clk: event.clk, @@ -412,19 +339,10 @@ impl MachineAir for DivRemChip { a: lower_word, c: event.c, b: quotient, - sub_lookups: create_alu_lookups(), }; - cols.lower_nonce = F::from_canonical_u32( - input - .nonce_lookup - .get(&event.sub_lookups[0]) - .copied() - .unwrap_or_default(), - ); output.add_mul_event(lower_multiplication); let upper_multiplication = AluEvent { - lookup_id: event.sub_lookups[1], shard: event.shard, channel: event.channel, clk: event.clk, @@ -438,45 +356,22 @@ impl MachineAir for DivRemChip { a: upper_word, c: event.c, b: quotient, - sub_lookups: create_alu_lookups(), }; - cols.upper_nonce = F::from_canonical_u32( - input - .nonce_lookup - .get(&event.sub_lookups[1]) - .copied() - .unwrap_or_default(), - ); + output.add_mul_event(upper_multiplication); + let lt_event = if is_signed_operation(event.opcode) { - cols.abs_nonce = F::from_canonical_u32( - input - .nonce_lookup - .get(&event.sub_lookups[2]) - .copied() - .unwrap_or_default(), - ); AluEvent { - lookup_id: event.sub_lookups[2], shard: event.shard, channel: event.channel, - opcode: Opcode::SLTU, + opcode: Opcode::SLT, a: 1, b: (remainder as i32).abs() as u32, c: u32::max(1, (event.c as i32).abs() as u32), clk: event.clk, - sub_lookups: create_alu_lookups(), } } else { - cols.abs_nonce = F::from_canonical_u32( - input - .nonce_lookup - .get(&event.sub_lookups[3]) - .copied() - .unwrap_or_default(), - ); AluEvent { - lookup_id: event.sub_lookups[3], shard: event.shard, channel: event.channel, opcode: Opcode::SLTU, @@ -484,10 +379,8 @@ impl MachineAir for DivRemChip { b: remainder, c: u32::max(1, event.c), clk: event.clk, - sub_lookups: create_alu_lookups(), } }; - if cols.remainder_check_multiplicity == F::one() { output.add_lt_event(lt_event); } @@ -537,13 +430,6 @@ impl MachineAir for DivRemChip { trace.values[i] = padded_row_template[i % NUM_DIVREM_COLS]; } - // Write the nonces to the trace. - for i in 0..trace.height() { - let cols: &mut DivRemCols = - trace.values[i * NUM_DIVREM_COLS..(i + 1) * NUM_DIVREM_COLS].borrow_mut(); - cols.nonce = F::from_canonical_usize(i); - } - trace } @@ -566,18 +452,10 @@ where let main = builder.main(); let local = main.row_slice(0); let local: &DivRemCols = (*local).borrow(); - let next = main.row_slice(1); - let next: &DivRemCols = (*next).borrow(); let base = AB::F::from_canonical_u32(1 << 8); let one: AB::Expr = AB::F::one().into(); let zero: AB::Expr = AB::F::zero().into(); - // Constrain the incrementing nonce. - builder.when_first_row().assert_zero(local.nonce); - builder - .when_transition() - .assert_eq(local.nonce + AB::Expr::one(), next.nonce); - // Calculate whether b, remainder, and c are negative. { // Negative if and only if op code is signed & MSB = 1. @@ -612,7 +490,6 @@ where local.c, local.shard, local.channel, - local.lower_nonce, local.is_real, ); @@ -638,7 +515,6 @@ where local.c, local.shard, local.channel, - local.upper_nonce, local.is_real, ); } @@ -783,37 +659,18 @@ where // Range check remainder. (i.e., |remainder| < |c| when not is_c_0) { - // For each of `c` and `rem`, assert that the absolute value is equal to the original value, - // if the original value is non-negative or the minimum i32. - for i in 0..WORD_SIZE { - builder - .when_not(local.c_neg) - .assert_eq(local.c[i], local.abs_c[i]); - builder - .when_not(local.rem_neg) - .assert_eq(local.remainder[i], local.abs_remainder[i]); - } - // In the case that `c` or `rem` is negative, instead check that their sum is zero by - // sending an AddEvent. - builder.send_alu( - AB::Expr::from_canonical_u32(Opcode::ADD as u32), - Word([zero.clone(), zero.clone(), zero.clone(), zero.clone()]), - local.c, - local.abs_c, - local.shard, - local.channel, - local.abs_c_alu_event_nonce, - local.abs_c_alu_event, + eval_abs_value( + builder, + local.remainder.borrow(), + local.abs_remainder.borrow(), + local.rem_neg.borrow(), ); - builder.send_alu( - AB::Expr::from_canonical_u32(Opcode::ADD as u32), - Word([zero.clone(), zero.clone(), zero.clone(), zero.clone()]), - local.remainder, - local.abs_remainder, - local.shard, - local.channel, - local.abs_rem_alu_event_nonce, - local.abs_rem_alu_event, + + eval_abs_value( + builder, + local.c.borrow(), + local.abs_c.borrow(), + local.c_neg.borrow(), ); // max(abs(c), 1) = abs(c) * (1 - is_c_0) + 1 * is_c_0 @@ -834,31 +691,29 @@ where builder.assert_eq(local.max_abs_c_or_1[i], max_abs_c_or_1[i].clone()); } - // Handle cases: - // - If is_real == 0 then remainder_check_multiplicity == 0 is forced. - // - If is_real == 1 then is_c_0_result must be the expected one, so - // remainder_check_multiplicity = (1 - is_c_0_result) * is_real. + let opcode = { + let is_signed = local.is_div + local.is_rem; + let is_unsigned = local.is_divu + local.is_remu; + let slt = AB::Expr::from_canonical_u32(Opcode::SLT as u32); + let sltu = AB::Expr::from_canonical_u32(Opcode::SLTU as u32); + is_signed * slt + is_unsigned * sltu + }; + + // Check that the event multiplicity column is computed correctly. builder.assert_eq( - (AB::Expr::one() - local.is_c_0.result) * local.is_real, local.remainder_check_multiplicity, + local.is_c_0.result * local.is_real, ); - // the cleaner idea is simply remainder_check_multiplicity == (1 - is_c_0_result) * is_real - - // Check that the absolute value selector columns are computed correctly. - builder.assert_eq(local.abs_c_alu_event, local.c_neg * local.is_real); - builder.assert_eq(local.abs_rem_alu_event, local.rem_neg * local.is_real); - // Dispatch abs(remainder) < max(abs(c), 1), this is equivalent to abs(remainder) < // abs(c) if not division by 0. builder.send_alu( - AB::Expr::from_canonical_u32(Opcode::SLTU as u32), + opcode, Word([one.clone(), zero.clone(), zero.clone(), zero.clone()]), local.abs_remainder, local.max_abs_c_or_1, local.shard, local.channel, - local.abs_nonce, local.remainder_check_multiplicity, ); } @@ -928,8 +783,6 @@ where local.rem_neg, local.c_neg, local.is_real, - local.abs_c_alu_event, - local.abs_rem_alu_event, ]; for flag in bool_flags.iter() { @@ -964,7 +817,6 @@ where local.c, local.shard, local.channel, - local.nonce, local.is_real, ); } diff --git a/core/src/alu/divrem/utils.rs b/core/src/alu/divrem/utils.rs index d71c35aad6..f3a7b7070f 100644 --- a/core/src/alu/divrem/utils.rs +++ b/core/src/alu/divrem/utils.rs @@ -1,3 +1,7 @@ +use p3_air::AirBuilder; +use p3_field::AbstractField; + +use crate::air::{SP1AirBuilder, Word, WORD_SIZE}; use crate::runtime::Opcode; /// Returns `true` if the given `opcode` is a signed operation. @@ -28,3 +32,47 @@ pub fn get_quotient_and_remainder(b: u32, c: u32, opcode: Opcode) -> (u32, u32) pub const fn get_msb(a: u32) -> u8 { ((a >> 31) & 1) as u8 } + +/// Verifies that `abs_value = abs(value)` using `is_negative` as a flag. +/// +/// `abs(value) + value = 0` if `value` is negative. `abs(value) = value` otherwise. +/// +/// In two's complement arithmetic, the negation involves flipping its bits and adding 1. Therefore, +/// for a negative number, `abs(value) + value` equals 0. This is because `abs(value)` is the two's +/// complement (negation) of `value`. For a positive number, `abs(value)` is the same as `value`. +/// +/// The function iterates over each limb of the `value` and `abs_value`, checking the following +/// conditions: +/// +/// 1. If `value` is non-negative, it checks that each limb in `value` and `abs_value` is identical. +/// 2. If `value` is negative, it checks that the sum of each corresponding limb in `value` and +/// `abs_value` equals the expected sum for a two's complement representation. The least +/// significant limb (first limb) should add up to `0xff + 1` (to account for the +1 in two's +/// complement negation), and other limbs should add up to `0xff` (as the rest of the limbs just +/// have their bits flipped). +pub fn eval_abs_value( + builder: &mut AB, + value: &Word, + abs_value: &Word, + is_negative: &AB::Var, +) where + AB: SP1AirBuilder, +{ + for i in 0..WORD_SIZE { + let exp_sum_if_negative = AB::Expr::from_canonical_u32({ + if i == 0 { + 0xff + 1 + } else { + 0xff + } + }); + + builder + .when(*is_negative) + .assert_eq(value[i] + abs_value[i], exp_sum_if_negative.clone()); + + builder + .when_not(*is_negative) + .assert_eq(value[i], abs_value[i]); + } +} diff --git a/core/src/alu/lt/mod.rs b/core/src/alu/lt/mod.rs index 54d5768c2c..91b504181c 100644 --- a/core/src/alu/lt/mod.rs +++ b/core/src/alu/lt/mod.rs @@ -34,9 +34,6 @@ pub struct LtCols { /// The channel number, used for byte lookup table. pub channel: T, - /// The nonce of the operation. - pub nonce: T, - /// If the opcode is SLT. pub is_slt: T, @@ -223,13 +220,6 @@ impl MachineAir for LtChip { // Pad the trace to a power of two. pad_to_power_of_two::(&mut trace.values); - // Write the nonces to the trace. - for i in 0..trace.height() { - let cols: &mut LtCols = - trace.values[i * NUM_LT_COLS..(i + 1) * NUM_LT_COLS].borrow_mut(); - cols.nonce = F::from_canonical_usize(i); - } - trace } @@ -252,14 +242,6 @@ where let main = builder.main(); let local = main.row_slice(0); let local: &LtCols = (*local).borrow(); - let next = main.row_slice(1); - let next: &LtCols = (*next).borrow(); - - // Constrain the incrementing nonce. - builder.when_first_row().assert_zero(local.nonce); - builder - .when_transition() - .assert_eq(local.nonce + AB::Expr::one(), next.nonce); let is_real = local.is_slt + local.is_sltu; @@ -449,7 +431,6 @@ where local.c, local.shard, local.channel, - local.nonce, is_real, ); } diff --git a/core/src/alu/mod.rs b/core/src/alu/mod.rs index a67d1ff909..c667c612c8 100644 --- a/core/src/alu/mod.rs +++ b/core/src/alu/mod.rs @@ -11,7 +11,6 @@ pub use bitwise::*; pub use divrem::*; pub use lt::*; pub use mul::*; -use rand::Rng; pub use sll::*; pub use sr::*; @@ -22,9 +21,6 @@ use crate::runtime::Opcode; /// A standard format for describing ALU operations that need to be proven. #[derive(Debug, Clone, Copy, Serialize, Deserialize)] pub struct AluEvent { - /// The lookup id of the event. - pub lookup_id: usize, - /// The shard number, used for byte lookup table. pub shard: u32, @@ -45,15 +41,12 @@ pub struct AluEvent { // The second input operand. pub c: u32, - - pub sub_lookups: [usize; 6], } impl AluEvent { /// Creates a new `AluEvent`. pub fn new(shard: u32, channel: u32, clk: u32, opcode: Opcode, a: u32, b: u32, c: u32) -> Self { Self { - lookup_id: 0, shard, channel, clk, @@ -61,24 +54,6 @@ impl AluEvent { a, b, c, - sub_lookups: create_alu_lookups(), } } } - -pub fn create_alu_lookup_id() -> usize { - let mut rng = rand::thread_rng(); - rng.gen() -} - -pub fn create_alu_lookups() -> [usize; 6] { - let mut rng = rand::thread_rng(); - [ - rng.gen(), - rng.gen(), - rng.gen(), - rng.gen(), - rng.gen(), - rng.gen(), - ] -} diff --git a/core/src/alu/mul/mod.rs b/core/src/alu/mul/mod.rs index 1351e78c38..c30a59c4f4 100644 --- a/core/src/alu/mul/mod.rs +++ b/core/src/alu/mul/mod.rs @@ -79,9 +79,6 @@ pub struct MulCols { /// The channel number, used for byte lookup table. pub channel: T, - /// The nonce of the operation. - pub nonce: T, - /// The output operand. pub a: Word, @@ -273,13 +270,6 @@ impl MachineAir for MulChip { // Pad the trace to a power of two. pad_to_power_of_two::(&mut trace.values); - // Write the nonces to the trace. - for i in 0..trace.height() { - let cols: &mut MulCols = - trace.values[i * NUM_MUL_COLS..(i + 1) * NUM_MUL_COLS].borrow_mut(); - cols.nonce = F::from_canonical_usize(i); - } - trace } @@ -302,20 +292,12 @@ where let main = builder.main(); let local = main.row_slice(0); let local: &MulCols = (*local).borrow(); - let next = main.row_slice(1); - let next: &MulCols = (*next).borrow(); let base = AB::F::from_canonical_u32(1 << 8); let zero: AB::Expr = AB::F::zero().into(); let one: AB::Expr = AB::F::one().into(); let byte_mask = AB::F::from_canonical_u8(BYTE_MASK); - // Constrain the incrementing nonce. - builder.when_first_row().assert_zero(local.nonce); - builder - .when_transition() - .assert_eq(local.nonce + AB::Expr::one(), next.nonce); - // Calculate the MSBs. let (b_msb, c_msb) = { let msb_pairs = [ @@ -430,6 +412,14 @@ where .when(local.c_sign_extend) .assert_eq(local.c_msb, one.clone()); + // If the opcode doesn't allow sign extension for an operand, we must not extend their sign. + builder + .when(local.is_mul + local.is_mulhu) + .assert_zero(local.b_sign_extend + local.c_sign_extend); + builder + .when(local.is_mul + local.is_mulhsu + local.is_mulhsu) + .assert_zero(local.c_sign_extend); + // Calculate the opcode. let opcode = { // Exactly one of the op codes must be on. @@ -465,7 +455,6 @@ where local.c, local.shard, local.channel, - local.nonce, local.is_real, ); } diff --git a/core/src/alu/sll/mod.rs b/core/src/alu/sll/mod.rs index b5a711542b..d87ee780d6 100644 --- a/core/src/alu/sll/mod.rs +++ b/core/src/alu/sll/mod.rs @@ -67,9 +67,6 @@ pub struct ShiftLeftCols { /// The channel number, used for byte lookup table. pub channel: T, - /// The nonce of the operation. - pub nonce: T, - /// The output operand. pub a: Word, @@ -202,12 +199,6 @@ impl MachineAir for ShiftLeft { trace.values[i] = padded_row_template[i % NUM_SHIFT_LEFT_COLS]; } - for i in 0..trace.height() { - let cols: &mut ShiftLeftCols = - trace.values[i * NUM_SHIFT_LEFT_COLS..(i + 1) * NUM_SHIFT_LEFT_COLS].borrow_mut(); - cols.nonce = F::from_canonical_usize(i); - } - trace } @@ -230,19 +221,11 @@ where let main = builder.main(); let local = main.row_slice(0); let local: &ShiftLeftCols = (*local).borrow(); - let next = main.row_slice(1); - let next: &ShiftLeftCols = (*next).borrow(); let zero: AB::Expr = AB::F::zero().into(); let one: AB::Expr = AB::F::one().into(); let base: AB::Expr = AB::F::from_canonical_u32(1 << BYTE_SIZE).into(); - // Constrain the incrementing nonce. - builder.when_first_row().assert_zero(local.nonce); - builder - .when_transition() - .assert_eq(local.nonce + AB::Expr::one(), next.nonce); - // We first "bit shift" and next we "byte shift". Then we compare the results with a. // Finally, we perform some misc checks. @@ -371,7 +354,6 @@ where local.c, local.shard, local.channel, - local.nonce, local.is_real, ); } diff --git a/core/src/alu/sr/mod.rs b/core/src/alu/sr/mod.rs index bd7a91d52e..8f9ea721e5 100644 --- a/core/src/alu/sr/mod.rs +++ b/core/src/alu/sr/mod.rs @@ -85,9 +85,6 @@ pub struct ShiftRightCols { /// The channel number, used for byte lookup table. pub channel: T, - /// The nonce of the operation. - pub nonce: T, - /// The output operand. pub a: Word, @@ -286,13 +283,6 @@ impl MachineAir for ShiftRightChip { trace.values[i] = padded_row_template[i % NUM_SHIFT_RIGHT_COLS]; } - // Write the nonces to the trace. - for i in 0..trace.height() { - let cols: &mut ShiftRightCols = - trace.values[i * NUM_SHIFT_RIGHT_COLS..(i + 1) * NUM_SHIFT_RIGHT_COLS].borrow_mut(); - cols.nonce = F::from_canonical_usize(i); - } - trace } @@ -315,17 +305,9 @@ where let main = builder.main(); let local = main.row_slice(0); let local: &ShiftRightCols = (*local).borrow(); - let next = main.row_slice(1); - let next: &ShiftRightCols = (*next).borrow(); let zero: AB::Expr = AB::F::zero().into(); let one: AB::Expr = AB::F::one().into(); - // Constrain the incrementing nonce. - builder.when_first_row().assert_zero(local.nonce); - builder - .when_transition() - .assert_eq(local.nonce + AB::Expr::one(), next.nonce); - // Check that the MSB of most_significant_byte matches local.b_msb using lookup. { let byte = local.b[WORD_SIZE - 1]; @@ -482,9 +464,6 @@ where for shift_by_n_bit in local.shift_by_n_bits.iter() { builder.assert_bool(*shift_by_n_bit); } - for bit in local.c_least_sig_byte.iter() { - builder.assert_bool(*bit); - } } // Range check bytes. @@ -506,9 +485,6 @@ where builder.assert_bool(local.is_sra); builder.assert_bool(local.is_real); - // Check that is_real is the sum of the two operation flags. - builder.assert_eq(local.is_srl + local.is_sra, local.is_real); - // Receive the arguments. builder.receive_alu( local.is_srl * AB::F::from_canonical_u32(Opcode::SRL as u32) @@ -518,7 +494,6 @@ where local.c, local.shard, local.channel, - local.nonce, local.is_real, ); } diff --git a/core/src/bytes/mod.rs b/core/src/bytes/mod.rs index 2cf8c8fb1c..f6d5bc482c 100644 --- a/core/src/bytes/mod.rs +++ b/core/src/bytes/mod.rs @@ -24,7 +24,7 @@ use crate::bytes::trace::NUM_ROWS; pub const NUM_BYTE_OPS: usize = 9; /// The number of different byte lookup channels. -pub const NUM_BYTE_LOOKUP_CHANNELS: u32 = 16; +pub const NUM_BYTE_LOOKUP_CHANNELS: u32 = 4; /// A chip for computing byte operations. /// diff --git a/core/src/bytes/trace.rs b/core/src/bytes/trace.rs index d6ba3921b6..39f2b72ff5 100644 --- a/core/src/bytes/trace.rs +++ b/core/src/bytes/trace.rs @@ -28,7 +28,7 @@ impl MachineAir for ByteChip { } fn generate_preprocessed_trace(&self, _program: &Self::Program) -> Option> { - // OPT: We should be able to make this a constant. Also, trace / map should be separate. + // TODO: We should be able to make this a constant. Also, trace / map should be separate. // Since we only need the trace and not the map, we can just pass 0 as the shard. let (trace, _) = Self::trace_and_map(0); diff --git a/core/src/cpu/air/branch.rs b/core/src/cpu/air/branch.rs index 60bba4175e..fad654de35 100644 --- a/core/src/cpu/air/branch.rs +++ b/core/src/cpu/air/branch.rs @@ -3,7 +3,6 @@ use p3_field::AbstractField; use crate::air::{BaseAirBuilder, SP1AirBuilder, Word, WordAirBuilder}; use crate::cpu::columns::{CpuCols, OpcodeSelectorCols}; -use crate::operations::BabyBearWordRangeChecker; use crate::{cpu::CpuChip, runtime::Opcode}; impl CpuChip { @@ -58,20 +57,6 @@ impl CpuChip { .when(local.branching) .assert_eq(branch_cols.next_pc.reduce::(), local.next_pc); - // Range check branch_cols.pc and branch_cols.next_pc. - BabyBearWordRangeChecker::::range_check( - builder, - branch_cols.pc, - branch_cols.pc_range_checker, - is_branch_instruction.clone(), - ); - BabyBearWordRangeChecker::::range_check( - builder, - branch_cols.next_pc, - branch_cols.next_pc_range_checker, - is_branch_instruction.clone(), - ); - // When we are branching, calculate branch_cols.next_pc <==> branch_cols.pc + c. builder.send_alu( Opcode::ADD.as_field::(), @@ -80,7 +65,6 @@ impl CpuChip { local.op_c_val(), local.shard, local.channel, - branch_cols.next_pc_nonce, local.branching, ); @@ -99,21 +83,15 @@ impl CpuChip { .when(local.is_real) .when(local.not_branching) .assert_eq(local.pc + AB::Expr::from_canonical_u8(4), local.next_pc); + } - // Assert that either we are branching or not branching when the instruction is a branch. - builder - .when(is_branch_instruction.clone()) - .assert_one(local.branching + local.not_branching); + // Evaluate branching value constraints. + { + // Assert that local.is_branching is a bit. builder .when(is_branch_instruction.clone()) .assert_bool(local.branching); - builder - .when(is_branch_instruction.clone()) - .assert_bool(local.not_branching); - } - // Evaluate branching value constraints. - { // When the opcode is BEQ and we are branching, assert that a_eq_b is true. builder .when(local.selectors.is_beq * local.branching) @@ -168,11 +146,6 @@ impl CpuChip { .when(is_branch_instruction.clone() * branch_cols.a_eq_b) .assert_word_eq(local.op_a_val(), local.op_b_val()); - // To prevent this ALU send to be arbitrarily large when is_branch_instruction is false. - builder - .when_not(is_branch_instruction.clone()) - .assert_zero(local.branching); - // Calculate a_lt_b <==> a < b (using appropriate signedness). let use_signed_comparison = local.selectors.is_blt + local.selectors.is_bge; builder.send_alu( @@ -184,7 +157,6 @@ impl CpuChip { local.op_b_val(), local.shard, local.channel, - branch_cols.a_lt_b_nonce, is_branch_instruction.clone(), ); @@ -197,7 +169,6 @@ impl CpuChip { local.op_a_val(), local.shard, local.channel, - branch_cols.a_gt_b_nonce, is_branch_instruction.clone(), ); } diff --git a/core/src/cpu/air/ecall.rs b/core/src/cpu/air/ecall.rs index 870513b83e..506b2c7b75 100644 --- a/core/src/cpu/air/ecall.rs +++ b/core/src/cpu/air/ecall.rs @@ -35,19 +35,14 @@ impl CpuChip { let syscall_id = syscall_code[0]; let send_to_table = syscall_code[1]; - // Handle cases: - // - is_ecall_instruction = 1 => ecall_mul_send_to_table == send_to_table - // - is_ecall_instruction = 0 => ecall_mul_send_to_table == 0 - builder.assert_eq( - local.ecall_mul_send_to_table, - send_to_table * is_ecall_instruction.clone(), - ); - + // When is_ecall_instruction == true AND sent_to_table == true, ecall_mul_send_to_table should be true. + builder + .when(is_ecall_instruction.clone()) + .assert_eq(send_to_table, local.ecall_mul_send_to_table); builder.send_syscall( local.shard, local.channel, local.clk, - ecall_cols.syscall_nonce, syscall_id, local.op_b_val().reduce::(), local.op_c_val().reduce::(), diff --git a/core/src/cpu/air/memory.rs b/core/src/cpu/air/memory.rs index 707a50ff95..6ac1a07c11 100644 --- a/core/src/cpu/air/memory.rs +++ b/core/src/cpu/air/memory.rs @@ -5,7 +5,6 @@ use crate::air::{BaseAirBuilder, SP1AirBuilder, Word, WordAirBuilder}; use crate::cpu::columns::{CpuCols, MemoryColumns, OpcodeSelectorCols}; use crate::cpu::CpuChip; use crate::memory::MemoryCols; -use crate::operations::BabyBearWordRangeChecker; use crate::runtime::{MemoryAccessPosition, Opcode}; impl CpuChip { @@ -67,15 +66,6 @@ impl CpuChip { local.op_c_val(), local.shard, local.channel, - memory_columns.addr_word_nonce, - is_memory_instruction.clone(), - ); - - // Range check the addr_word to be a valid babybear word. - BabyBearWordRangeChecker::::range_check( - builder, - memory_columns.addr_word, - memory_columns.addr_word_range_checker, is_memory_instruction.clone(), ); @@ -98,35 +88,6 @@ impl CpuChip { memory_columns.addr_word.reduce::(), ); - // Verify that the least significant byte of addr_word - addr_offset is divisible by 4. - let offset = [ - memory_columns.offset_is_one, - memory_columns.offset_is_two, - memory_columns.offset_is_three, - ] - .iter() - .enumerate() - .fold(AB::Expr::zero(), |acc, (index, &value)| { - acc + AB::Expr::from_canonical_usize(index + 1) * value - }); - let mut recomposed_byte = AB::Expr::zero(); - memory_columns - .aa_least_sig_byte_decomp - .iter() - .enumerate() - .for_each(|(i, value)| { - builder - .when(is_memory_instruction.clone()) - .assert_bool(*value); - - recomposed_byte = - recomposed_byte.clone() + AB::Expr::from_canonical_usize(1 << (i + 2)) * *value; - }); - - builder - .when(is_memory_instruction.clone()) - .assert_eq(memory_columns.addr_word[0] - offset, recomposed_byte); - // For operations that require reading from memory (not registers), we need to read the // value into the memory columns. builder.eval_memory_access( @@ -137,14 +98,6 @@ impl CpuChip { &memory_columns.memory_access, is_memory_instruction.clone(), ); - - // On memory load instructions, make sure that the memory value is not changed. - builder - .when(self.is_load_instruction::(&local.selectors)) - .assert_word_eq( - *memory_columns.memory_access.value(), - *memory_columns.memory_access.prev_value(), - ); } /// Evaluates constraints related to loading from memory. @@ -168,11 +121,12 @@ impl CpuChip { // Assert that if `is_lb` and `is_lh` are both true, then the most significant byte // matches the value of `local.mem_value_is_neg`. - builder.assert_eq( - local.mem_value_is_neg, - (local.selectors.is_lb + local.selectors.is_lh) - * memory_columns.most_sig_byte_decomp[7], - ); + builder + .when(local.selectors.is_lb + local.selectors.is_lh) + .assert_eq( + local.mem_value_is_neg, + memory_columns.most_sig_byte_decomp[7], + ); // When the memory value is negative, use the SUB opcode to compute the signed value of // the memory value and verify that the op_a value is correct. @@ -189,7 +143,6 @@ impl CpuChip { signed_value, local.shard, local.channel, - local.unsigned_mem_val_nonce, local.mem_value_is_neg, ); @@ -242,11 +195,6 @@ impl CpuChip { .when(local.selectors.is_sh) .assert_zero(memory_columns.offset_is_one + memory_columns.offset_is_three); - // When the instruction is SW, ensure that the offset is 0. - builder - .when(local.selectors.is_sw) - .assert_one(offset_is_zero.clone()); - // Compute the expected stored value for a SH instruction. let a_is_lower_half = offset_is_zero; let a_is_upper_half = memory_columns.offset_is_two; @@ -299,12 +247,6 @@ impl CpuChip { builder .when(local.selectors.is_lh + local.selectors.is_lhu) .assert_zero(memory_columns.offset_is_one + memory_columns.offset_is_three); - - // When the instruction is LW, ensure that the offset is zero. - builder - .when(local.selectors.is_lw) - .assert_one(offset_is_zero.clone()); - let use_lower_half = offset_is_zero; let use_upper_half = memory_columns.offset_is_two; let half_value = Word([ @@ -331,12 +273,9 @@ impl CpuChip { local: &CpuCols, unsigned_mem_val: &Word, ) { - let is_mem = self.is_memory_instruction::(&local.selectors); let mut recomposed_byte = AB::Expr::zero(); for i in 0..8 { - builder - .when(is_mem.clone()) - .assert_bool(memory_columns.most_sig_byte_decomp[i]); + builder.assert_bool(memory_columns.most_sig_byte_decomp[i]); recomposed_byte += memory_columns.most_sig_byte_decomp[i] * AB::Expr::from_canonical_u8(1 << i); } diff --git a/core/src/cpu/air/mod.rs b/core/src/cpu/air/mod.rs index 4caebbbf73..11a985bb5e 100644 --- a/core/src/cpu/air/mod.rs +++ b/core/src/cpu/air/mod.rs @@ -22,16 +22,13 @@ use crate::bytes::ByteOpcode; use crate::cpu::columns::OpcodeSelectorCols; use crate::cpu::columns::{CpuCols, NUM_CPU_COLS}; use crate::cpu::CpuChip; -use crate::operations::BabyBearWordRangeChecker; use crate::runtime::Opcode; use super::columns::eval_channel_selectors; -use super::columns::OPCODE_SELECTORS_COL_MAP; impl Air for CpuChip where AB: SP1AirBuilder + AirBuilderWithPublicValues, - AB::Var: Sized, { #[inline(never)] fn eval(&self, builder: &mut AB) { @@ -87,7 +84,6 @@ where local.op_c_val(), local.shard, local.channel, - local.nonce, is_alu_instruction, ); @@ -125,27 +121,6 @@ where // Check that the is_real flag is correct. self.eval_is_real(builder, local, next); - - // Check that when `is_real=0` that all flags that send interactions are zero. - local - .selectors - .into_iter() - .enumerate() - .for_each(|(i, selector)| { - if i == OPCODE_SELECTORS_COL_MAP.imm_b { - builder - .when(AB::Expr::one() - local.is_real) - .assert_one(local.selectors.imm_b); - } else if i == OPCODE_SELECTORS_COL_MAP.imm_c { - builder - .when(AB::Expr::one() - local.is_real) - .assert_one(local.selectors.imm_c); - } else { - builder - .when(AB::Expr::one() - local.is_real) - .assert_zero(selector); - } - }); } } @@ -200,26 +175,6 @@ impl CpuChip { .when(is_jump_instruction.clone()) .assert_eq(jump_columns.next_pc.reduce::(), local.next_pc); - // Range check op_a, pc, and next_pc. - BabyBearWordRangeChecker::::range_check( - builder, - local.op_a_val(), - jump_columns.op_a_range_checker, - is_jump_instruction.clone(), - ); - BabyBearWordRangeChecker::::range_check( - builder, - jump_columns.pc, - jump_columns.pc_range_checker, - local.selectors.is_jal.into(), - ); - BabyBearWordRangeChecker::::range_check( - builder, - jump_columns.next_pc, - jump_columns.next_pc_range_checker, - is_jump_instruction.clone(), - ); - // Verify that the new pc is calculated correctly for JAL instructions. builder.send_alu( AB::Expr::from_canonical_u32(Opcode::ADD as u32), @@ -228,7 +183,6 @@ impl CpuChip { local.op_b_val(), local.shard, local.channel, - jump_columns.jal_nonce, local.selectors.is_jal, ); @@ -240,7 +194,6 @@ impl CpuChip { local.op_c_val(), local.shard, local.channel, - jump_columns.jalr_nonce, local.selectors.is_jalr, ); } @@ -255,14 +208,6 @@ impl CpuChip { .when(local.selectors.is_auipc) .assert_eq(auipc_columns.pc.reduce::(), local.pc); - // Range check the pc. - BabyBearWordRangeChecker::::range_check( - builder, - auipc_columns.pc, - auipc_columns.pc_range_checker, - local.selectors.is_auipc.into(), - ); - // Verify that op_a == pc + op_b. builder.send_alu( AB::Expr::from_canonical_u32(Opcode::ADD as u32), @@ -271,7 +216,6 @@ impl CpuChip { local.op_b_val(), local.shard, local.channel, - auipc_columns.auipc_nonce, local.selectors.is_auipc, ); } @@ -344,16 +288,17 @@ impl CpuChip { next: &CpuCols, is_branch_instruction: AB::Expr, ) { + // Verify that if is_sequential_instr is true, assert that local.is_real is true. + // This is needed for the following constraint, which is already degree 3. + builder + .when(local.is_sequential_instr) + .assert_one(local.is_real); + // When is_sequential_instr is true, assert that instruction is not branch, jump, or halt. // Note that the condition `when(local_is_real)` is implied from the previous constraint. let is_halt = self.get_is_halt_syscall::(builder, local); - builder.when(local.is_real).assert_eq( - local.is_sequential_instr, - AB::Expr::one() - - (is_branch_instruction - + local.selectors.is_jal - + local.selectors.is_jalr - + is_halt), + builder.when(local.is_sequential_instr).assert_zero( + is_branch_instruction + local.selectors.is_jal + local.selectors.is_jalr + is_halt, ); // Verify that the pc increments by 4 for all instructions except branch, jump and halt instructions. diff --git a/core/src/cpu/air/register.rs b/core/src/cpu/air/register.rs index 23b6551d16..e0b989c2bc 100644 --- a/core/src/cpu/air/register.rs +++ b/core/src/cpu/air/register.rs @@ -57,15 +57,6 @@ impl CpuChip { local.is_real, ); - // Always range check the word value in `op_a`, as JUMP instructions may witness - // an invalid word and write it to memory. - builder.slice_range_check_u8( - &local.op_a_access.access.value.0, - local.shard, - local.channel, - local.is_real, - ); - // If we are performing a branch or a store, then the value of `a` is the previous value. builder .when(is_branch_instruction.clone() + self.is_store_instruction::(&local.selectors)) diff --git a/core/src/cpu/columns/auipc.rs b/core/src/cpu/columns/auipc.rs index fa6871c211..a6eb410e7c 100644 --- a/core/src/cpu/columns/auipc.rs +++ b/core/src/cpu/columns/auipc.rs @@ -1,7 +1,7 @@ use sp1_derive::AlignedBorrow; use std::mem::size_of; -use crate::{air::Word, operations::BabyBearWordRangeChecker}; +use crate::air::Word; pub const NUM_AUIPC_COLS: usize = size_of::>(); @@ -10,6 +10,4 @@ pub const NUM_AUIPC_COLS: usize = size_of::>(); pub struct AuipcCols { /// The current program counter. pub pc: Word, - pub pc_range_checker: BabyBearWordRangeChecker, - pub auipc_nonce: T, } diff --git a/core/src/cpu/columns/branch.rs b/core/src/cpu/columns/branch.rs index c6298ef0f7..06a77ad306 100644 --- a/core/src/cpu/columns/branch.rs +++ b/core/src/cpu/columns/branch.rs @@ -1,7 +1,7 @@ use sp1_derive::AlignedBorrow; use std::mem::size_of; -use crate::{air::Word, operations::BabyBearWordRangeChecker}; +use crate::air::Word; pub const NUM_BRANCH_COLS: usize = size_of::>(); @@ -11,11 +11,9 @@ pub const NUM_BRANCH_COLS: usize = size_of::>(); pub struct BranchCols { /// The current program counter. pub pc: Word, - pub pc_range_checker: BabyBearWordRangeChecker, /// The next program counter. pub next_pc: Word, - pub next_pc_range_checker: BabyBearWordRangeChecker, /// Whether a equals b. pub a_eq_b: T, @@ -25,13 +23,4 @@ pub struct BranchCols { /// Whether a is less than b. pub a_lt_b: T, - - /// The nonce of the operation to compute `a_lt_b`. - pub a_lt_b_nonce: T, - - /// The nonce of the operation to compute `a_gt_b`. - pub a_gt_b_nonce: T, - - /// The nonce of the operation to compute `next_pc`. - pub next_pc_nonce: T, } diff --git a/core/src/cpu/columns/ecall.rs b/core/src/cpu/columns/ecall.rs index 5d91622c36..927b70614c 100644 --- a/core/src/cpu/columns/ecall.rs +++ b/core/src/cpu/columns/ecall.rs @@ -26,7 +26,4 @@ pub struct EcallCols { /// Field to store the word index passed into the COMMIT ecall. index_bitmap[word index] should /// be set to 1 and everything else set to 0. pub index_bitmap: [T; PV_DIGEST_NUM_WORDS], - - /// The nonce of the syscall operation. - pub syscall_nonce: T, } diff --git a/core/src/cpu/columns/jump.rs b/core/src/cpu/columns/jump.rs index 0e1b5701f5..ca94f3ecac 100644 --- a/core/src/cpu/columns/jump.rs +++ b/core/src/cpu/columns/jump.rs @@ -1,7 +1,7 @@ use sp1_derive::AlignedBorrow; use std::mem::size_of; -use crate::{air::Word, operations::BabyBearWordRangeChecker}; +use crate::air::Word; pub const NUM_JUMP_COLS: usize = size_of::>(); @@ -10,15 +10,7 @@ pub const NUM_JUMP_COLS: usize = size_of::>(); pub struct JumpCols { /// The current program counter. pub pc: Word, - pub pc_range_checker: BabyBearWordRangeChecker, - /// The next program counter. + /// THe next program counter. pub next_pc: Word, - pub next_pc_range_checker: BabyBearWordRangeChecker, - - // A range checker for `op_a` which may contain `pc + 4`. - pub op_a_range_checker: BabyBearWordRangeChecker, - - pub jal_nonce: T, - pub jalr_nonce: T, } diff --git a/core/src/cpu/columns/memory.rs b/core/src/cpu/columns/memory.rs index baab9e1fc0..fc54de34c4 100644 --- a/core/src/cpu/columns/memory.rs +++ b/core/src/cpu/columns/memory.rs @@ -1,7 +1,7 @@ use sp1_derive::AlignedBorrow; use std::mem::size_of; -use crate::{air::Word, memory::MemoryReadWriteCols, operations::BabyBearWordRangeChecker}; +use crate::{air::Word, memory::MemoryReadWriteCols}; pub const NUM_MEMORY_COLUMNS: usize = size_of::>(); @@ -17,11 +17,7 @@ pub struct MemoryColumns { // addr_offset = addr_word % 4 // Note that this all needs to be verified in the AIR pub addr_word: Word, - pub addr_word_range_checker: BabyBearWordRangeChecker, - pub addr_aligned: T, - /// The LE bit decomp of the least significant byte of address aligned. - pub aa_least_sig_byte_decomp: [T; 6], pub addr_offset: T, pub memory_access: MemoryReadWriteCols, @@ -32,7 +28,4 @@ pub struct MemoryColumns { // LE bit decomposition for the most significant byte of memory value. This is used to determine // the sign for that value (used for LB and LH). pub most_sig_byte_decomp: [T; 8], - - pub addr_word_nonce: T, - pub unsigned_mem_val_nonce: T, } diff --git a/core/src/cpu/columns/mod.rs b/core/src/cpu/columns/mod.rs index 968c58362f..d81bd806fc 100644 --- a/core/src/cpu/columns/mod.rs +++ b/core/src/cpu/columns/mod.rs @@ -40,8 +40,6 @@ pub struct CpuCols { /// The channel value, used for byte lookup multiplicity. pub channel: T, - pub nonce: T, - /// The clock cycle value. This should be within 24 bits. pub clk: T, /// The least significant 16 bit limb of clk. @@ -99,8 +97,6 @@ pub struct CpuCols { /// memory opcodes (i.e. LB, LH, LW, LBU, and LHU). pub unsigned_mem_val: Word, - pub unsigned_mem_val_nonce: T, - /// The result of selectors.is_ecall * the send_to_table column for the ECALL opcode. pub ecall_mul_send_to_table: T, diff --git a/core/src/cpu/columns/opcode.rs b/core/src/cpu/columns/opcode.rs index 80fd63ad3d..ac67c6934e 100644 --- a/core/src/cpu/columns/opcode.rs +++ b/core/src/cpu/columns/opcode.rs @@ -1,23 +1,11 @@ use p3_field::PrimeField; use sp1_derive::AlignedBorrow; -use std::mem::{size_of, transmute}; +use std::mem::size_of; use std::vec::IntoIter; -use crate::{ - runtime::{Instruction, Opcode}, - utils::indices_arr, -}; +use crate::runtime::{Instruction, Opcode}; pub const NUM_OPCODE_SELECTOR_COLS: usize = size_of::>(); -pub const OPCODE_SELECTORS_COL_MAP: OpcodeSelectorCols = make_selectors_col_map(); - -/// Creates the column map for the CPU. -const fn make_selectors_col_map() -> OpcodeSelectorCols { - let indices_arr = indices_arr::(); - unsafe { - transmute::<[usize; NUM_OPCODE_SELECTOR_COLS], OpcodeSelectorCols>(indices_arr) - } -} /// The column layout for opcode selectors. #[derive(AlignedBorrow, Clone, Copy, Default, Debug)] @@ -110,7 +98,7 @@ impl IntoIterator for OpcodeSelectorCols { type IntoIter = IntoIter; fn into_iter(self) -> Self::IntoIter { - let columns = vec![ + vec![ self.imm_b, self.imm_c, self.is_alu, @@ -133,8 +121,7 @@ impl IntoIterator for OpcodeSelectorCols { self.is_jal, self.is_auipc, self.is_unimpl, - ]; - assert_eq!(columns.len(), NUM_OPCODE_SELECTOR_COLS); - columns.into_iter() + ] + .into_iter() } } diff --git a/core/src/cpu/event.rs b/core/src/cpu/event.rs index cdd38f4765..2170d91d5d 100644 --- a/core/src/cpu/event.rs +++ b/core/src/cpu/event.rs @@ -51,15 +51,4 @@ pub struct CpuEvent { /// Exit code called with halt. pub exit_code: u32, - - pub alu_lookup_id: usize, - pub syscall_lookup_id: usize, - pub memory_add_lookup_id: usize, - pub memory_sub_lookup_id: usize, - pub branch_gt_lookup_id: usize, - pub branch_lt_lookup_id: usize, - pub branch_add_lookup_id: usize, - pub jump_jal_lookup_id: usize, - pub jump_jalr_lookup_id: usize, - pub auipc_lookup_id: usize, } diff --git a/core/src/cpu/trace.rs b/core/src/cpu/trace.rs index 893faa385e..b65c4e43ca 100644 --- a/core/src/cpu/trace.rs +++ b/core/src/cpu/trace.rs @@ -1,4 +1,3 @@ -use std::array; use std::borrow::BorrowMut; use std::collections::HashMap; @@ -12,8 +11,6 @@ use tracing::instrument; use super::columns::{CPU_COL_MAP, NUM_CPU_COLS}; use super::{CpuChip, CpuEvent}; use crate::air::MachineAir; -use crate::air::Word; -use crate::alu::create_alu_lookups; use crate::alu::{self, AluEvent}; use crate::bytes::event::ByteRecord; use crate::bytes::{ByteLookupEvent, ByteOpcode}; @@ -45,7 +42,7 @@ impl MachineAir for CpuChip { let mut rows_with_events = input .cpu_events .par_iter() - .map(|op: &CpuEvent| self.event_to_row::(*op, &input.nonce_lookup)) + .map(|op: &CpuEvent| self.event_to_row::(*op)) .collect::>(); // No need to sort by the shard, since the cpu events are already partitioned by that. @@ -94,7 +91,7 @@ impl MachineAir for CpuChip { let mut alu = HashMap::new(); let mut blu: Vec<_> = Vec::default(); ops.iter().for_each(|op| { - let (_, alu_events, blu_events) = self.event_to_row::(*op, &HashMap::new()); + let (_, alu_events, blu_events) = self.event_to_row::(*op); alu_events.into_iter().for_each(|(key, value)| { alu.entry(key).or_insert(Vec::default()).extend(value); }); @@ -127,7 +124,6 @@ impl CpuChip { fn event_to_row( &self, event: CpuEvent, - nonce_lookup: &HashMap, ) -> ( [F; NUM_CPU_COLS], HashMap>, @@ -142,14 +138,6 @@ impl CpuChip { // Populate shard and clk columns. self.populate_shard_clk(cols, event, &mut new_blu_events); - // Populate the nonce. - cols.nonce = F::from_canonical_u32( - nonce_lookup - .get(&event.alu_lookup_id) - .copied() - .unwrap_or_default(), - ); - // Populate basic fields. cols.pc = F::from_canonical_u32(event.pc); cols.next_pc = F::from_canonical_u32(event.next_pc); @@ -162,45 +150,17 @@ impl CpuChip { // Populate memory accesses for a, b, and c. if let Some(record) = event.a_record { cols.op_a_access - .populate(event.channel, record, &mut new_blu_events); + .populate(event.channel, record, &mut new_blu_events) } if let Some(MemoryRecordEnum::Read(record)) = event.b_record { cols.op_b_access - .populate(event.channel, record, &mut new_blu_events); + .populate(event.channel, record, &mut new_blu_events) } if let Some(MemoryRecordEnum::Read(record)) = event.c_record { cols.op_c_access - .populate(event.channel, record, &mut new_blu_events); + .populate(event.channel, record, &mut new_blu_events) } - // Populate range checks for a. - let a_bytes = cols - .op_a_access - .access - .value - .0 - .iter() - .map(|x| x.as_canonical_u32()) - .collect::>(); - new_blu_events.push(ByteLookupEvent { - shard: event.shard, - channel: event.channel, - opcode: ByteOpcode::U8Range, - a1: 0, - a2: 0, - b: a_bytes[0], - c: a_bytes[1], - }); - new_blu_events.push(ByteLookupEvent { - shard: event.shard, - channel: event.channel, - opcode: ByteOpcode::U8Range, - a1: 0, - a2: 0, - b: a_bytes[2], - c: a_bytes[3], - }); - // Populate memory accesses for reading from memory. assert_eq!(event.memory_record.is_some(), event.memory.is_some()); let memory_columns = cols.opcode_specific_columns.memory_mut(); @@ -211,23 +171,19 @@ impl CpuChip { } // Populate memory, branch, jump, and auipc specific fields. - self.populate_memory( - cols, - event, - &mut new_alu_events, - &mut new_blu_events, - nonce_lookup, - ); - self.populate_branch(cols, event, &mut new_alu_events, nonce_lookup); - self.populate_jump(cols, event, &mut new_alu_events, nonce_lookup); - self.populate_auipc(cols, event, &mut new_alu_events, nonce_lookup); - let is_halt = self.populate_ecall(cols, event, nonce_lookup); - - cols.is_sequential_instr = F::from_bool( - !event.instruction.is_branch_instruction() - && !event.instruction.is_jump_instruction() - && !is_halt, - ); + self.populate_memory(cols, event, &mut new_alu_events, &mut new_blu_events); + self.populate_branch(cols, event, &mut new_alu_events); + self.populate_jump(cols, event, &mut new_alu_events); + self.populate_auipc(cols, event, &mut new_alu_events); + let is_halt = self.populate_ecall(cols, event); + + if !event.instruction.is_branch_instruction() + && !event.instruction.is_jump_instruction() + && !event.instruction.is_ecall_instruction() + && !is_halt + { + cols.is_sequential_instr = F::one(); + } // Assert that the instruction is not a no-op. cols.is_real = F::one(); @@ -287,7 +243,6 @@ impl CpuChip { event: CpuEvent, new_alu_events: &mut HashMap>, new_blu_events: &mut Vec, - nonce_lookup: &HashMap, ) { if !matches!( event.instruction.opcode, @@ -306,20 +261,12 @@ impl CpuChip { // Populate addr_word and addr_aligned columns. let memory_columns = cols.opcode_specific_columns.memory_mut(); let memory_addr = event.b.wrapping_add(event.c); - let aligned_addr = memory_addr - memory_addr % WORD_SIZE as u32; memory_columns.addr_word = memory_addr.into(); - memory_columns.addr_word_range_checker.populate(memory_addr); - memory_columns.addr_aligned = F::from_canonical_u32(aligned_addr); - - // Populate the aa_least_sig_byte_decomp columns. - assert!(aligned_addr % 4 == 0); - let aligned_addr_ls_byte = (aligned_addr & 0x000000FF) as u8; - let bits: [bool; 8] = array::from_fn(|i| aligned_addr_ls_byte & (1 << i) != 0); - memory_columns.aa_least_sig_byte_decomp = array::from_fn(|i| F::from_bool(bits[i + 2])); + memory_columns.addr_aligned = + F::from_canonical_u32(memory_addr - memory_addr % WORD_SIZE as u32); // Add event to ALU check to check that addr == b + c let add_event = AluEvent { - lookup_id: event.memory_add_lookup_id, shard: event.shard, channel: event.channel, clk: event.clk, @@ -327,18 +274,11 @@ impl CpuChip { a: memory_addr, b: event.b, c: event.c, - sub_lookups: create_alu_lookups(), }; new_alu_events .entry(Opcode::ADD) .and_modify(|op_new_events| op_new_events.push(add_event)) .or_insert(vec![add_event]); - memory_columns.addr_word_nonce = F::from_canonical_u32( - nonce_lookup - .get(&event.memory_add_lookup_id) - .copied() - .unwrap_or_default(), - ); // Populate memory offsets. let addr_offset = (memory_addr % WORD_SIZE as u32) as u8; @@ -392,7 +332,6 @@ impl CpuChip { if memory_columns.most_sig_byte_decomp[7] == F::one() { cols.mem_value_is_neg = F::one(); let sub_event = AluEvent { - lookup_id: event.memory_sub_lookup_id, channel: event.channel, shard: event.shard, clk: event.clk, @@ -400,14 +339,7 @@ impl CpuChip { a: event.a, b: cols.unsigned_mem_val.to_u32(), c: sign_value, - sub_lookups: create_alu_lookups(), }; - cols.unsigned_mem_val_nonce = F::from_canonical_u32( - nonce_lookup - .get(&event.memory_sub_lookup_id) - .copied() - .unwrap_or_default(), - ); new_alu_events .entry(Opcode::SUB) @@ -438,7 +370,6 @@ impl CpuChip { cols: &mut CpuCols, event: CpuEvent, alu_events: &mut HashMap>, - nonce_lookup: &HashMap, ) { if event.instruction.is_branch_instruction() { let branch_columns = cols.opcode_specific_columns.branch_mut(); @@ -464,10 +395,8 @@ impl CpuChip { } else { Opcode::SLTU }; - // Add the ALU events for the comparisons let lt_comp_event = AluEvent { - lookup_id: event.branch_lt_lookup_id, shard: event.shard, channel: event.channel, clk: event.clk, @@ -475,14 +404,7 @@ impl CpuChip { a: a_lt_b as u32, b: event.a, c: event.b, - sub_lookups: create_alu_lookups(), }; - branch_columns.a_lt_b_nonce = F::from_canonical_u32( - nonce_lookup - .get(&event.branch_lt_lookup_id) - .copied() - .unwrap_or_default(), - ); alu_events .entry(alu_op_code) @@ -490,7 +412,6 @@ impl CpuChip { .or_insert(vec![lt_comp_event]); let gt_comp_event = AluEvent { - lookup_id: event.branch_gt_lookup_id, shard: event.shard, channel: event.channel, clk: event.clk, @@ -498,14 +419,7 @@ impl CpuChip { a: a_gt_b as u32, b: event.b, c: event.a, - sub_lookups: create_alu_lookups(), }; - branch_columns.a_gt_b_nonce = F::from_canonical_u32( - nonce_lookup - .get(&event.branch_gt_lookup_id) - .copied() - .unwrap_or_default(), - ); alu_events .entry(alu_op_code) @@ -524,17 +438,14 @@ impl CpuChip { _ => unreachable!(), }; - let next_pc = event.pc.wrapping_add(event.c); - branch_columns.pc = Word::from(event.pc); - branch_columns.next_pc = Word::from(next_pc); - branch_columns.pc_range_checker.populate(event.pc); - branch_columns.next_pc_range_checker.populate(next_pc); - if branching { + let next_pc = event.pc.wrapping_add(event.c); + cols.branching = F::one(); + branch_columns.pc = event.pc.into(); + branch_columns.next_pc = next_pc.into(); let add_event = AluEvent { - lookup_id: event.branch_add_lookup_id, shard: event.shard, channel: event.channel, clk: event.clk, @@ -542,14 +453,7 @@ impl CpuChip { a: next_pc, b: event.pc, c: event.c, - sub_lookups: create_alu_lookups(), }; - branch_columns.next_pc_nonce = F::from_canonical_u32( - nonce_lookup - .get(&event.branch_add_lookup_id) - .copied() - .unwrap_or_default(), - ); alu_events .entry(Opcode::ADD) @@ -567,7 +471,6 @@ impl CpuChip { cols: &mut CpuCols, event: CpuEvent, alu_events: &mut HashMap>, - nonce_lookup: &HashMap, ) { if event.instruction.is_jump_instruction() { let jump_columns = cols.opcode_specific_columns.jump_mut(); @@ -575,14 +478,10 @@ impl CpuChip { match event.instruction.opcode { Opcode::JAL => { let next_pc = event.pc.wrapping_add(event.b); - jump_columns.op_a_range_checker.populate(event.a); - jump_columns.pc = Word::from(event.pc); - jump_columns.pc_range_checker.populate(event.pc); - jump_columns.next_pc = Word::from(next_pc); - jump_columns.next_pc_range_checker.populate(next_pc); + jump_columns.pc = event.pc.into(); + jump_columns.next_pc = next_pc.into(); let add_event = AluEvent { - lookup_id: event.jump_jal_lookup_id, shard: event.shard, channel: event.channel, clk: event.clk, @@ -590,14 +489,7 @@ impl CpuChip { a: next_pc, b: event.pc, c: event.b, - sub_lookups: create_alu_lookups(), }; - jump_columns.jal_nonce = F::from_canonical_u32( - nonce_lookup - .get(&event.jump_jal_lookup_id) - .copied() - .unwrap_or_default(), - ); alu_events .entry(Opcode::ADD) @@ -606,12 +498,9 @@ impl CpuChip { } Opcode::JALR => { let next_pc = event.b.wrapping_add(event.c); - jump_columns.op_a_range_checker.populate(event.a); - jump_columns.next_pc = Word::from(next_pc); - jump_columns.next_pc_range_checker.populate(next_pc); + jump_columns.next_pc = next_pc.into(); let add_event = AluEvent { - lookup_id: event.jump_jalr_lookup_id, shard: event.shard, channel: event.channel, clk: event.clk, @@ -619,14 +508,7 @@ impl CpuChip { a: next_pc, b: event.b, c: event.c, - sub_lookups: create_alu_lookups(), }; - jump_columns.jalr_nonce = F::from_canonical_u32( - nonce_lookup - .get(&event.jump_jalr_lookup_id) - .copied() - .unwrap_or_default(), - ); alu_events .entry(Opcode::ADD) @@ -644,16 +526,13 @@ impl CpuChip { cols: &mut CpuCols, event: CpuEvent, alu_events: &mut HashMap>, - nonce_lookup: &HashMap, ) { if matches!(event.instruction.opcode, Opcode::AUIPC) { let auipc_columns = cols.opcode_specific_columns.auipc_mut(); - auipc_columns.pc = Word::from(event.pc); - auipc_columns.pc_range_checker.populate(event.pc); + auipc_columns.pc = event.pc.into(); let add_event = AluEvent { - lookup_id: event.auipc_lookup_id, shard: event.shard, channel: event.channel, clk: event.clk, @@ -661,14 +540,7 @@ impl CpuChip { a: event.a, b: event.pc, c: event.b, - sub_lookups: create_alu_lookups(), }; - auipc_columns.auipc_nonce = F::from_canonical_u32( - nonce_lookup - .get(&event.auipc_lookup_id) - .copied() - .unwrap_or_default(), - ); alu_events .entry(Opcode::ADD) @@ -678,12 +550,7 @@ impl CpuChip { } /// Populate columns related to ECALL. - fn populate_ecall( - &self, - cols: &mut CpuCols, - event: CpuEvent, - nonce_lookup: &HashMap, - ) -> bool { + fn populate_ecall(&self, cols: &mut CpuCols, _: CpuEvent) -> bool { let mut is_halt = false; if cols.selectors.is_ecall == F::one() { @@ -737,14 +604,6 @@ impl CpuChip { ecall_cols.index_bitmap[digest_idx] = F::one(); } - // Write the syscall nonce. - ecall_cols.syscall_nonce = F::from_canonical_u32( - nonce_lookup - .get(&event.syscall_lookup_id) - .copied() - .unwrap_or_default(), - ); - is_halt = syscall_id == F::from_canonical_u32(SyscallCode::HALT.syscall_id()); } @@ -781,41 +640,41 @@ mod tests { use super::*; - use crate::runtime::{tests::simple_program, Runtime}; + use crate::runtime::{tests::simple_program, Instruction, Runtime}; use crate::utils::{run_test, setup_logger, SP1CoreOpts}; - // #[test] - // fn generate_trace() { - // let mut shard = ExecutionRecord::default(); - // shard.cpu_events = vec![CpuEvent { - // shard: 1, - // channel: 0, - // clk: 6, - // pc: 1, - // next_pc: 5, - // instruction: Instruction { - // opcode: Opcode::ADD, - // op_a: 0, - // op_b: 1, - // op_c: 2, - // imm_b: false, - // imm_c: false, - // }, - // a: 1, - // a_record: None, - // b: 2, - // b_record: None, - // c: 3, - // c_record: None, - // memory: None, - // memory_record: None, - // exit_code: 0, - // }]; - // let chip = CpuChip::default(); - // let trace: RowMajorMatrix = - // chip.generate_trace(&shard, &mut ExecutionRecord::default()); - // println!("{:?}", trace.values); - // } + #[test] + fn generate_trace() { + let mut shard = ExecutionRecord::default(); + shard.cpu_events = vec![CpuEvent { + shard: 1, + channel: 0, + clk: 6, + pc: 1, + next_pc: 5, + instruction: Instruction { + opcode: Opcode::ADD, + op_a: 0, + op_b: 1, + op_c: 2, + imm_b: false, + imm_c: false, + }, + a: 1, + a_record: None, + b: 2, + b_record: None, + c: 3, + c_record: None, + memory: None, + memory_record: None, + exit_code: 0, + }]; + let chip = CpuChip::default(); + let trace: RowMajorMatrix = + chip.generate_trace(&shard, &mut ExecutionRecord::default()); + println!("{:?}", trace.values); + } #[test] fn generate_trace_simple_program() { diff --git a/core/src/lookup/interaction.rs b/core/src/lookup/interaction.rs index 1c20938cc4..74b7a9fc06 100644 --- a/core/src/lookup/interaction.rs +++ b/core/src/lookup/interaction.rs @@ -74,6 +74,7 @@ impl Interaction { } } +// TODO: add debug for VirtualPairCol so that we can derive Debug for Interaction. impl Debug for Interaction { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Interaction") diff --git a/core/src/memory/global.rs b/core/src/memory/global.rs index a6cf49d02a..3786dd4ca1 100644 --- a/core/src/memory/global.rs +++ b/core/src/memory/global.rs @@ -1,6 +1,5 @@ use core::borrow::{Borrow, BorrowMut}; use core::mem::size_of; -use std::array; use p3_air::BaseAir; use p3_air::{Air, AirBuilder}; @@ -11,9 +10,8 @@ use p3_matrix::Matrix; use sp1_derive::AlignedBorrow; use super::MemoryInitializeFinalizeEvent; -use crate::air::MachineAir; -use crate::air::{AirInteraction, BaseAirBuilder, SP1AirBuilder}; -use crate::operations::BabyBearBitDecomposition; +use crate::air::{AirInteraction, SP1AirBuilder, Word}; +use crate::air::{MachineAir, WordAirBuilder}; use crate::runtime::{ExecutionRecord, Program}; use crate::utils::pad_to_power_of_two; @@ -64,7 +62,7 @@ impl MachineAir for MemoryChip { MemoryChipType::Finalize => input.memory_finalize_events.clone(), }; memory_events.sort_by_key(|event| event.addr); - let rows: Vec<[F; NUM_MEMORY_INIT_COLS]> = (0..memory_events.len()) // OPT: change this to par_iter + let rows: Vec<[F; 8]> = (0..memory_events.len()) // TODO: change this back to par_iter .map(|i| { let MemoryInitializeFinalizeEvent { addr, @@ -73,37 +71,14 @@ impl MachineAir for MemoryChip { timestamp, used, } = memory_events[i]; - let mut row = [F::zero(); NUM_MEMORY_INIT_COLS]; let cols: &mut MemoryInitCols = row.as_mut_slice().borrow_mut(); cols.addr = F::from_canonical_u32(addr); - cols.addr_bits.populate(addr); cols.shard = F::from_canonical_u32(shard); cols.timestamp = F::from_canonical_u32(timestamp); - cols.value = array::from_fn(|i| F::from_canonical_u32((value >> i) & 1)); + cols.value = value.into(); cols.is_real = F::from_canonical_u32(used); - if i != memory_events.len() - 1 { - let next_addr = memory_events[i + 1].addr; - assert_ne!(next_addr, addr); - - cols.addr_bits.populate(addr); - - cols.seen_diff_bits[0] = F::zero(); - for j in 0..32 { - let rev_j = 32 - j - 1; - let next_bit = ((next_addr >> rev_j) & 1) == 1; - let local_bit = ((addr >> rev_j) & 1) == 1; - cols.match_bits[j] = - F::from_bool((local_bit && next_bit) || (!local_bit && !next_bit)); - cols.seen_diff_bits[j + 1] = cols.seen_diff_bits[j] - + (F::one() - cols.seen_diff_bits[j]) * (F::one() - cols.match_bits[j]); - cols.not_match_and_not_seen_diff_bits[j] = - (F::one() - cols.match_bits[j]) * (F::one() - cols.seen_diff_bits[j]); - } - assert_eq!(cols.seen_diff_bits[cols.seen_diff_bits.len() - 1], F::one()); - } - row }) .collect::>(); @@ -126,7 +101,7 @@ impl MachineAir for MemoryChip { } } -#[derive(AlignedBorrow, Debug, Clone, Copy)] +#[derive(AlignedBorrow, Default, Debug, Clone, Copy)] #[repr(C)] pub struct MemoryInitCols { /// The shard number of the memory access. @@ -138,20 +113,8 @@ pub struct MemoryInitCols { /// The address of the memory access. pub addr: T, - /// A bit decomposition of `addr`. - pub addr_bits: BabyBearBitDecomposition, - - // Whether the i'th bit matches the next addr's bit. - pub match_bits: [T; 32], - - // Whether we've seen a different bit in the comparison. - pub seen_diff_bits: [T; 33], - - // Whether the i'th bit doesn't match the next addr's bit and we haven't seen a diff bitn yet. - pub not_match_and_not_seen_diff_bits: [T; 32], - /// The value of the memory access. - pub value: [T; 32], + pub value: Word, /// Whether the memory access is a real access. pub is_real: T, @@ -167,29 +130,10 @@ where let main = builder.main(); let local = main.row_slice(0); let local: &MemoryInitCols = (*local).borrow(); - let next = main.row_slice(1); - let next: &MemoryInitCols = (*next).borrow(); - - builder.assert_bool(local.is_real); - for i in 0..32 { - builder.assert_bool(local.value[i]); - } - - let mut byte1 = AB::Expr::zero(); - let mut byte2 = AB::Expr::zero(); - let mut byte3 = AB::Expr::zero(); - let mut byte4 = AB::Expr::zero(); - for i in 0..8 { - byte1 += local.value[i].into() * AB::F::from_canonical_u8(1 << i); - byte2 += local.value[i + 8].into() * AB::F::from_canonical_u8(1 << i); - byte3 += local.value[i + 16].into() * AB::F::from_canonical_u8(1 << i); - byte4 += local.value[i + 24].into() * AB::F::from_canonical_u8(1 << i); - } - let value = [byte1, byte2, byte3, byte4]; if self.kind == MemoryChipType::Initialize { let mut values = vec![AB::Expr::zero(), AB::Expr::zero(), local.addr.into()]; - values.extend(value.map(Into::into)); + values.extend(local.value.map(Into::into)); builder.receive(AirInteraction::new( values, local.is_real.into(), @@ -201,7 +145,7 @@ where local.timestamp.into(), local.addr.into(), ]; - values.extend(value); + values.extend(local.value.map(Into::into)); builder.send(AirInteraction::new( values, local.is_real.into(), @@ -209,106 +153,16 @@ where )); } - // We want to assert addr < addr'. Assume seen_diff_0 = 0. - // - // match_i = (addr_i & addr'_i) || (!addr_i & !addr'_i) - // => - // match_i == addr_i * addr_i + (1 - addr_i) * (1 - addr'_i) - // - // when !match_i and !seen_diff_i, then enforce (addr_i == 0) and (addr'_i == 1). - // if seen_diff_i: - // seen_diff_{i+1} = 1 - // else: - // seen_diff_{i+1} = !match_i - // => - // builder.when(!match_i * !seen_diff_i).assert_zero(addr_i) - // builder.when(!match_i * !seen_diff_i).assert_one(addr'_i) - // seen_diff_bit_{i+1} == seen_diff_i + (1-seen_diff_i) * (1 - match_i) - // - // at the end of the algorithm, assert that we've seen a diff bit. - // => - // seen_diff_bit_{last} == 1 - - // Assert that we start with assuming that we haven't seen a diff bit. - builder.assert_zero(local.seen_diff_bits[0]); - - for i in 0..local.addr_bits.bits.len() { - // Compute the i'th msb bit's index. - let rev_i = local.addr_bits.bits.len() - i - 1; - - // Compute whether the i'th msb bit matches. - let match_i = local.addr_bits.bits[rev_i] * next.addr_bits.bits[rev_i] - + (AB::Expr::one() - local.addr_bits.bits[rev_i]) - * (AB::Expr::one() - next.addr_bits.bits[rev_i]); - builder - .when_transition() - .when(next.is_real) - .assert_eq(match_i.clone(), local.match_bits[i]); - - // Compute whether it's not a match and we haven't seen a diff bit. - let not_match_and_not_seen_diff_i = (AB::Expr::one() - local.match_bits[i]) - * (AB::Expr::one() - local.seen_diff_bits[i]); - builder.when_transition().when(next.is_real).assert_eq( - local.not_match_and_not_seen_diff_bits[i], - not_match_and_not_seen_diff_i, - ); - - // If the i'th msb bit doesn't match and it's the first time we've seen a diff bit, - // then enforce that the next bit is one and the current bit is zero. - builder - .when_transition() - .when(local.not_match_and_not_seen_diff_bits[i]) - .when(next.is_real) - .assert_zero(local.addr_bits.bits[rev_i]); - builder - .when_transition() - .when(local.not_match_and_not_seen_diff_bits[i]) - .when(next.is_real) - .assert_one(next.addr_bits.bits[rev_i]); - - // Update the seen diff bits. - builder.when_transition().assert_eq( - local.seen_diff_bits[i + 1], - local.seen_diff_bits[i] + local.not_match_and_not_seen_diff_bits[i], - ); - } - - // Assert that on rows where the next row is real, we've seen a diff bit. - builder - .when_transition() - .when(next.is_real) - .assert_one(local.seen_diff_bits[local.addr_bits.bits.len()]); - - // Canonically decompose the address into bits so we can do comparisons. - BabyBearBitDecomposition::::range_check( - builder, - local.addr, - local.addr_bits, - local.is_real.into(), - ); - - // Assert that the real rows are all padded to the top. - builder - .when_transition() - .when_not(local.is_real) - .assert_zero(next.is_real); - - if self.kind == MemoryChipType::Initialize { - builder - .when(local.is_real) - .assert_eq(local.timestamp, AB::F::one()); - } - // Register %x0 should always be 0. See 2.6 Load and Store Instruction on // P.18 of the RISC-V spec. To ensure that, we expect that the first row of the Initialize // and Finalize global memory chip is for register %x0 (i.e. addr = 0x0), and that those rows // have a value of 0. Additionally, in the CPU air, we ensure that whenever op_a is set to // %x0, its value is 0. - if self.kind == MemoryChipType::Initialize || self.kind == MemoryChipType::Finalize { + // + // TODO: Add a similar check for MemoryChipType::Initialize. + if self.kind == MemoryChipType::Finalize { builder.when_first_row().assert_zero(local.addr); - for i in 0..32 { - builder.when_first_row().assert_zero(local.value[i]); - } + builder.when_first_row().assert_word_zero(local.value); } } } diff --git a/core/src/memory/mod.rs b/core/src/memory/mod.rs index 4246db14c8..7acdee1fb8 100644 --- a/core/src/memory/mod.rs +++ b/core/src/memory/mod.rs @@ -27,7 +27,7 @@ impl MemoryInitializeFinalizeEvent { addr, value, shard: 0, - timestamp: 1, + timestamp: 0, used: if used { 1 } else { 0 }, } } diff --git a/core/src/memory/program.rs b/core/src/memory/program.rs index 64eeb25a23..3d922c4ae6 100644 --- a/core/src/memory/program.rs +++ b/core/src/memory/program.rs @@ -1,6 +1,6 @@ use core::borrow::{Borrow, BorrowMut}; use core::mem::size_of; -use p3_air::{Air, AirBuilderWithPublicValues, BaseAir, PairBuilder}; +use p3_air::{Air, AirBuilder, AirBuilderWithPublicValues, BaseAir, PairBuilder}; use p3_field::AbstractField; use p3_field::PrimeField; use p3_matrix::dense::RowMajorMatrix; @@ -10,6 +10,7 @@ use sp1_derive::AlignedBorrow; use crate::air::{AirInteraction, PublicValues, SP1AirBuilder}; use crate::air::{MachineAir, Word}; +use crate::operations::IsZeroOperation; use crate::runtime::{ExecutionRecord, Program}; use crate::utils::pad_to_power_of_two; @@ -30,10 +31,10 @@ pub struct MemoryProgramPreprocessedCols { #[derive(AlignedBorrow, Clone, Copy, Default)] #[repr(C)] pub struct MemoryProgramMultCols { - /// The multiplicity of the event. - /// - /// This column is technically redundant with `is_real`, but it's included for clarity. + /// The multiplicity of the event, must be 1 in the first shard and 0 otherwise. pub multiplicity: T, + /// Columns to see if current shard is 1. + pub is_first_shard: IsZeroOperation, } /// Chip that initializes memory that is provided from the program. The table is preprocessed and @@ -119,6 +120,8 @@ impl MachineAir for MemoryProgramChip { let mut row = [F::zero(); NUM_MEMORY_PROGRAM_MULT_COLS]; let cols: &mut MemoryProgramMultCols = row.as_mut_slice().borrow_mut(); cols.multiplicity = mult; + IsZeroOperation::populate(&mut cols.is_first_shard, input.index - 1); + row }) .collect::>(); @@ -135,8 +138,8 @@ impl MachineAir for MemoryProgramChip { trace } - fn included(&self, record: &Self::Record) -> bool { - record.index == 1 + fn included(&self, _: &Self::Record) -> bool { + true } } @@ -168,15 +171,24 @@ where .map(|elm| (*elm).into()) .collect::>(), ); + IsZeroOperation::::eval( + builder, + public_values.shard - AB::Expr::one(), + mult_local.is_first_shard, + prep_local.is_real.into(), + ); + let is_first_shard = mult_local.is_first_shard.result; // Multiplicity must be either 0 or 1. builder.assert_bool(mult_local.multiplicity); - // If first shard and preprocessed is real, multiplicity must be one. - builder.assert_eq(mult_local.multiplicity, prep_local.is_real.into()); - - // The shard this chip is contained in must be one. - builder.assert_one(public_values.shard); + builder + .when(is_first_shard * prep_local.is_real) + .assert_one(mult_local.multiplicity); + // If not first shard or preprocessed is not real, multiplicity must be zero. + builder + .when((AB::Expr::one() - is_first_shard) + (AB::Expr::one() - prep_local.is_real)) + .assert_zero(mult_local.multiplicity); let mut values = vec![AB::Expr::zero(), AB::Expr::zero(), prep_local.addr.into()]; values.extend(prep_local.value.map(Into::into)); diff --git a/core/src/operations/baby_bear_range.rs b/core/src/operations/baby_bear_range.rs deleted file mode 100644 index 7e1ad0ef42..0000000000 --- a/core/src/operations/baby_bear_range.rs +++ /dev/null @@ -1,88 +0,0 @@ -use std::array; - -use p3_air::AirBuilder; -use p3_field::{AbstractField, Field}; -use sp1_derive::AlignedBorrow; - -use crate::stark::SP1AirBuilder; - -#[derive(AlignedBorrow, Default, Debug, Clone, Copy)] -#[repr(C)] -pub struct BabyBearBitDecomposition { - /// The bit decoposition of the`value`. - pub bits: [T; 32], - - /// The product of the the bits 3 to 5 in `most_sig_byte_decomp`. - pub and_most_sig_byte_decomp_3_to_5: T, - - /// The product of the the bits 3 to 6 in `most_sig_byte_decomp`. - pub and_most_sig_byte_decomp_3_to_6: T, - - /// The product of the the bits 3 to 7 in `most_sig_byte_decomp`. - pub and_most_sig_byte_decomp_3_to_7: T, -} - -impl BabyBearBitDecomposition { - pub fn populate(&mut self, value: u32) { - self.bits = array::from_fn(|i| F::from_canonical_u32((value >> i) & 1)); - let most_sig_byte_decomp = &self.bits[24..32]; - self.and_most_sig_byte_decomp_3_to_5 = most_sig_byte_decomp[3] * most_sig_byte_decomp[4]; - self.and_most_sig_byte_decomp_3_to_6 = - self.and_most_sig_byte_decomp_3_to_5 * most_sig_byte_decomp[5]; - self.and_most_sig_byte_decomp_3_to_7 = - self.and_most_sig_byte_decomp_3_to_6 * most_sig_byte_decomp[6]; - } - - pub fn range_check( - builder: &mut AB, - value: AB::Var, - cols: BabyBearBitDecomposition, - is_real: AB::Expr, - ) { - let mut reconstructed_value = AB::Expr::zero(); - for (i, bit) in cols.bits.iter().enumerate() { - builder.when(is_real.clone()).assert_bool(*bit); - reconstructed_value += AB::Expr::from_wrapped_u32(1 << i) * *bit; - } - - // Assert that bits2num(bits) == value. - builder - .when(is_real.clone()) - .assert_eq(reconstructed_value, value); - - // Range check that value is less than baby bear modulus. To do this, it is sufficient - // to just do comparisons for the most significant byte. BabyBear's modulus is (in big endian binary) - // 01111000_00000000_00000000_00000001. So we need to check the following conditions: - // 1) if most_sig_byte > 01111000, then fail. - // 2) if most_sig_byte == 01111000, then value's lower sig bytes must all be 0. - // 3) if most_sig_byte < 01111000, then pass. - let most_sig_byte_decomp = &cols.bits[24..32]; - builder - .when(is_real.clone()) - .assert_zero(most_sig_byte_decomp[7]); - - // Compute the product of the "top bits". - builder.when(is_real.clone()).assert_eq( - cols.and_most_sig_byte_decomp_3_to_5, - most_sig_byte_decomp[3] * most_sig_byte_decomp[4], - ); - builder.when(is_real.clone()).assert_eq( - cols.and_most_sig_byte_decomp_3_to_6, - cols.and_most_sig_byte_decomp_3_to_5 * most_sig_byte_decomp[5], - ); - builder.when(is_real.clone()).assert_eq( - cols.and_most_sig_byte_decomp_3_to_7, - cols.and_most_sig_byte_decomp_3_to_6 * most_sig_byte_decomp[6], - ); - - // If the top bits are all 0, then the lower bits must all be 0. - let mut lower_bits_sum: AB::Expr = AB::Expr::zero(); - for bit in cols.bits[0..27].iter() { - lower_bits_sum = lower_bits_sum + *bit; - } - builder - .when(is_real) - .when(cols.and_most_sig_byte_decomp_3_to_7) - .assert_zero(lower_bits_sum); - } -} diff --git a/core/src/operations/baby_bear_word.rs b/core/src/operations/baby_bear_word.rs deleted file mode 100644 index 2e773b3e6d..0000000000 --- a/core/src/operations/baby_bear_word.rs +++ /dev/null @@ -1,94 +0,0 @@ -use std::array; - -use p3_air::AirBuilder; -use p3_field::{AbstractField, Field}; -use sp1_derive::AlignedBorrow; - -use crate::{air::Word, stark::SP1AirBuilder}; - -/// A set of columns needed to compute the add of two words. -#[derive(AlignedBorrow, Default, Debug, Clone, Copy)] -#[repr(C)] -pub struct BabyBearWordRangeChecker { - /// Most sig byte LE bit decomposition. - pub most_sig_byte_decomp: [T; 8], - - /// The product of the the bits 3 to 5 in `most_sig_byte_decomp`. - pub and_most_sig_byte_decomp_3_to_5: T, - - /// The product of the the bits 3 to 6 in `most_sig_byte_decomp`. - pub and_most_sig_byte_decomp_3_to_6: T, - - /// The product of the the bits 3 to 7 in `most_sig_byte_decomp`. - pub and_most_sig_byte_decomp_3_to_7: T, -} - -impl BabyBearWordRangeChecker { - pub fn populate(&mut self, value: u32) { - self.most_sig_byte_decomp = array::from_fn(|i| F::from_bool(value & (1 << (i + 24)) != 0)); - self.and_most_sig_byte_decomp_3_to_5 = - self.most_sig_byte_decomp[3] * self.most_sig_byte_decomp[4]; - self.and_most_sig_byte_decomp_3_to_6 = - self.and_most_sig_byte_decomp_3_to_5 * self.most_sig_byte_decomp[5]; - self.and_most_sig_byte_decomp_3_to_7 = - self.and_most_sig_byte_decomp_3_to_6 * self.most_sig_byte_decomp[6]; - } - - pub fn range_check( - builder: &mut AB, - value: Word, - cols: BabyBearWordRangeChecker, - is_real: AB::Expr, - ) { - let mut recomposed_byte = AB::Expr::zero(); - cols.most_sig_byte_decomp - .iter() - .enumerate() - .for_each(|(i, value)| { - builder.when(is_real.clone()).assert_bool(*value); - recomposed_byte = - recomposed_byte.clone() + AB::Expr::from_canonical_usize(1 << i) * *value; - }); - - builder - .when(is_real.clone()) - .assert_eq(recomposed_byte, value[3]); - - // Range check that value is less than baby bear modulus. To do this, it is sufficient - // to just do comparisons for the most significant byte. BabyBear's modulus is (in big endian binary) - // 01111000_00000000_00000000_00000001. So we need to check the following conditions: - // 1) if most_sig_byte > 01111000, then fail. - // 2) if most_sig_byte == 01111000, then value's lower sig bytes must all be 0. - // 3) if most_sig_byte < 01111000, then pass. - builder - .when(is_real.clone()) - .assert_zero(cols.most_sig_byte_decomp[7]); - - // Compute the product of the "top bits". - builder.when(is_real.clone()).assert_eq( - cols.and_most_sig_byte_decomp_3_to_5, - cols.most_sig_byte_decomp[3] * cols.most_sig_byte_decomp[4], - ); - builder.when(is_real.clone()).assert_eq( - cols.and_most_sig_byte_decomp_3_to_6, - cols.and_most_sig_byte_decomp_3_to_5 * cols.most_sig_byte_decomp[5], - ); - builder.when(is_real.clone()).assert_eq( - cols.and_most_sig_byte_decomp_3_to_7, - cols.and_most_sig_byte_decomp_3_to_6 * cols.most_sig_byte_decomp[6], - ); - - let bottom_bits: AB::Expr = cols.most_sig_byte_decomp[0..3] - .iter() - .map(|bit| (*bit).into()) - .sum(); - builder - .when(is_real.clone()) - .when(cols.and_most_sig_byte_decomp_3_to_7) - .assert_zero(bottom_bits); - builder - .when(is_real) - .when(cols.and_most_sig_byte_decomp_3_to_7) - .assert_zero(value[0] + value[1] + value[2]); - } -} diff --git a/core/src/operations/field/field_op.rs b/core/src/operations/field/field_op.rs index 995142c2f2..ae04e2b9b1 100644 --- a/core/src/operations/field/field_op.rs +++ b/core/src/operations/field/field_op.rs @@ -445,6 +445,7 @@ mod tests { let mut challenger = config.challenger(); + // TODO: test with other fields let chip: FieldOpChip = FieldOpChip::new(*op); let shard = ExecutionRecord::default(); let trace: RowMajorMatrix = diff --git a/core/src/operations/field/field_sqrt.rs b/core/src/operations/field/field_sqrt.rs index e16de147bd..c0401a1d46 100644 --- a/core/src/operations/field/field_sqrt.rs +++ b/core/src/operations/field/field_sqrt.rs @@ -83,20 +83,6 @@ impl FieldSqrtCols { }; record.add_byte_lookup_event(and_event); - // Add the byte range check for `sqrt`. - record.add_u8_range_checks( - shard, - channel, - self.multiplication - .result - .0 - .as_slice() - .iter() - .map(|x| x.as_canonical_u32() as u8) - .collect::>() - .as_slice(), - ); - sqrt } } @@ -143,14 +129,6 @@ where is_real.clone(), ); - // Range check that `sqrt` limbs are bytes. - builder.slice_range_check_u8( - sqrt.0.as_slice(), - shard.clone(), - channel.clone(), - is_real.clone(), - ); - // Assert that the square root is the positive one, i.e., with least significant bit 0. // This is done by computing LSB = least_significant_byte & 1. builder.assert_bool(self.lsb); diff --git a/core/src/operations/mod.rs b/core/src/operations/mod.rs index e3fbcc78b1..242c9100b1 100644 --- a/core/src/operations/mod.rs +++ b/core/src/operations/mod.rs @@ -8,8 +8,6 @@ mod add; mod add4; mod add5; mod and; -mod baby_bear_range; -mod baby_bear_word; pub mod field; mod fixed_rotate_right; mod fixed_shift_right; @@ -24,8 +22,6 @@ pub use add::*; pub use add4::*; pub use add5::*; pub use and::*; -pub use baby_bear_range::*; -pub use baby_bear_word::*; pub use fixed_rotate_right::*; pub use fixed_shift_right::*; pub use is_equal_word::*; diff --git a/core/src/operations/or.rs b/core/src/operations/or.rs index b30821532a..8cb3f00191 100644 --- a/core/src/operations/or.rs +++ b/core/src/operations/or.rs @@ -10,6 +10,8 @@ use crate::disassembler::WORD_SIZE; use crate::runtime::ExecutionRecord; /// A set of columns needed to compute the or of two words. +/// +/// TODO: This is currently not in use, and thus not tested thoroughly yet. #[derive(AlignedBorrow, Default, Debug, Clone, Copy)] #[repr(C)] pub struct OrOperation { diff --git a/core/src/runtime/mod.rs b/core/src/runtime/mod.rs index 3eb8c4eadc..003d35e8c5 100644 --- a/core/src/runtime/mod.rs +++ b/core/src/runtime/mod.rs @@ -30,8 +30,6 @@ use std::sync::Arc; use thiserror::Error; -use crate::alu::create_alu_lookup_id; -use crate::alu::create_alu_lookups; use crate::bytes::NUM_BYTE_LOOKUP_CHANNELS; use crate::memory::MemoryInitializeFinalizeEvent; use crate::utils::SP1CoreOpts; @@ -447,8 +445,6 @@ impl Runtime { memory_store_value: Option, record: MemoryAccessRecord, exit_code: u32, - lookup_id: usize, - syscall_lookup_id: usize, ) { let cpu_event = CpuEvent { shard, @@ -466,25 +462,14 @@ impl Runtime { memory: memory_store_value, memory_record: record.memory, exit_code, - alu_lookup_id: lookup_id, - syscall_lookup_id, - memory_add_lookup_id: create_alu_lookup_id(), - memory_sub_lookup_id: create_alu_lookup_id(), - branch_lt_lookup_id: create_alu_lookup_id(), - branch_gt_lookup_id: create_alu_lookup_id(), - branch_add_lookup_id: create_alu_lookup_id(), - jump_jal_lookup_id: create_alu_lookup_id(), - jump_jalr_lookup_id: create_alu_lookup_id(), - auipc_lookup_id: create_alu_lookup_id(), }; self.record.cpu_events.push(cpu_event); } /// Emit an ALU event. - fn emit_alu(&mut self, clk: u32, opcode: Opcode, a: u32, b: u32, c: u32, lookup_id: usize) { + fn emit_alu(&mut self, clk: u32, opcode: Opcode, a: u32, b: u32, c: u32) { let event = AluEvent { - lookup_id, shard: self.shard(), clk, channel: self.channel(), @@ -492,7 +477,6 @@ impl Runtime { a, b, c, - sub_lookups: create_alu_lookups(), }; match opcode { Opcode::ADD => { @@ -546,18 +530,10 @@ impl Runtime { } /// Set the destination register with the result and emit an ALU event. - fn alu_rw( - &mut self, - instruction: Instruction, - rd: Register, - a: u32, - b: u32, - c: u32, - lookup_id: usize, - ) { + fn alu_rw(&mut self, instruction: Instruction, rd: Register, a: u32, b: u32, c: u32) { self.rw(rd, a); if self.emit_events { - self.emit_alu(self.state.clk, instruction.opcode, a, b, c, lookup_id); + self.emit_alu(self.state.clk, instruction.opcode, a, b, c); } } @@ -610,9 +586,6 @@ impl Runtime { let mut memory_store_value: Option = None; self.memory_accesses = MemoryAccessRecord::default(); - let lookup_id = create_alu_lookup_id(); - let syscall_lookup_id = create_alu_lookup_id(); - if self.should_report && !self.unconstrained { self.report .instruction_counts @@ -626,52 +599,52 @@ impl Runtime { Opcode::ADD => { (rd, b, c) = self.alu_rr(instruction); a = b.wrapping_add(c); - self.alu_rw(instruction, rd, a, b, c, lookup_id); + self.alu_rw(instruction, rd, a, b, c); } Opcode::SUB => { (rd, b, c) = self.alu_rr(instruction); a = b.wrapping_sub(c); - self.alu_rw(instruction, rd, a, b, c, lookup_id); + self.alu_rw(instruction, rd, a, b, c); } Opcode::XOR => { (rd, b, c) = self.alu_rr(instruction); a = b ^ c; - self.alu_rw(instruction, rd, a, b, c, lookup_id); + self.alu_rw(instruction, rd, a, b, c); } Opcode::OR => { (rd, b, c) = self.alu_rr(instruction); a = b | c; - self.alu_rw(instruction, rd, a, b, c, lookup_id); + self.alu_rw(instruction, rd, a, b, c); } Opcode::AND => { (rd, b, c) = self.alu_rr(instruction); a = b & c; - self.alu_rw(instruction, rd, a, b, c, lookup_id); + self.alu_rw(instruction, rd, a, b, c); } Opcode::SLL => { (rd, b, c) = self.alu_rr(instruction); a = b.wrapping_shl(c); - self.alu_rw(instruction, rd, a, b, c, lookup_id); + self.alu_rw(instruction, rd, a, b, c); } Opcode::SRL => { (rd, b, c) = self.alu_rr(instruction); a = b.wrapping_shr(c); - self.alu_rw(instruction, rd, a, b, c, lookup_id); + self.alu_rw(instruction, rd, a, b, c); } Opcode::SRA => { (rd, b, c) = self.alu_rr(instruction); a = (b as i32).wrapping_shr(c) as u32; - self.alu_rw(instruction, rd, a, b, c, lookup_id); + self.alu_rw(instruction, rd, a, b, c); } Opcode::SLT => { (rd, b, c) = self.alu_rr(instruction); a = if (b as i32) < (c as i32) { 1 } else { 0 }; - self.alu_rw(instruction, rd, a, b, c, lookup_id); + self.alu_rw(instruction, rd, a, b, c); } Opcode::SLTU => { (rd, b, c) = self.alu_rr(instruction); a = if b < c { 1 } else { 0 }; - self.alu_rw(instruction, rd, a, b, c, lookup_id); + self.alu_rw(instruction, rd, a, b, c); } // Load instructions. @@ -845,7 +818,6 @@ impl Runtime { let syscall_impl = self.get_syscall(syscall).cloned(); let mut precompile_rt = SyscallContext::new(self); - precompile_rt.syscall_lookup_id = syscall_lookup_id; let (precompile_next_pc, precompile_cycles, returned_exit_code) = if let Some(syscall_impl) = syscall_impl { // Executing a syscall optionally returns a value to write to the t0 register. @@ -890,22 +862,22 @@ impl Runtime { Opcode::MUL => { (rd, b, c) = self.alu_rr(instruction); a = b.wrapping_mul(c); - self.alu_rw(instruction, rd, a, b, c, lookup_id); + self.alu_rw(instruction, rd, a, b, c); } Opcode::MULH => { (rd, b, c) = self.alu_rr(instruction); a = (((b as i32) as i64).wrapping_mul((c as i32) as i64) >> 32) as u32; - self.alu_rw(instruction, rd, a, b, c, lookup_id); + self.alu_rw(instruction, rd, a, b, c); } Opcode::MULHU => { (rd, b, c) = self.alu_rr(instruction); a = ((b as u64).wrapping_mul(c as u64) >> 32) as u32; - self.alu_rw(instruction, rd, a, b, c, lookup_id); + self.alu_rw(instruction, rd, a, b, c); } Opcode::MULHSU => { (rd, b, c) = self.alu_rr(instruction); a = (((b as i32) as i64).wrapping_mul(c as i64) >> 32) as u32; - self.alu_rw(instruction, rd, a, b, c, lookup_id); + self.alu_rw(instruction, rd, a, b, c); } Opcode::DIV => { (rd, b, c) = self.alu_rr(instruction); @@ -914,7 +886,7 @@ impl Runtime { } else { a = (b as i32).wrapping_div(c as i32) as u32; } - self.alu_rw(instruction, rd, a, b, c, lookup_id); + self.alu_rw(instruction, rd, a, b, c); } Opcode::DIVU => { (rd, b, c) = self.alu_rr(instruction); @@ -923,7 +895,7 @@ impl Runtime { } else { a = b.wrapping_div(c); } - self.alu_rw(instruction, rd, a, b, c, lookup_id); + self.alu_rw(instruction, rd, a, b, c); } Opcode::REM => { (rd, b, c) = self.alu_rr(instruction); @@ -932,7 +904,7 @@ impl Runtime { } else { a = (b as i32).wrapping_rem(c as i32) as u32; } - self.alu_rw(instruction, rd, a, b, c, lookup_id); + self.alu_rw(instruction, rd, a, b, c); } Opcode::REMU => { (rd, b, c) = self.alu_rr(instruction); @@ -941,7 +913,7 @@ impl Runtime { } else { a = b.wrapping_rem(c); } - self.alu_rw(instruction, rd, a, b, c, lookup_id); + self.alu_rw(instruction, rd, a, b, c); } // See https://github.com/riscv-non-isa/riscv-asm-manual/blob/master/riscv-asm.md#instruction-aliases @@ -978,8 +950,6 @@ impl Runtime { memory_store_value, self.memory_accesses, exit_code, - lookup_id, - syscall_lookup_id, ); }; Ok(()) @@ -1141,7 +1111,7 @@ impl Runtime { None => &MemoryRecord { value: 0, shard: 0, - timestamp: 1, + timestamp: 0, }, }; memory_finalize_events.push(MemoryInitializeFinalizeEvent::finalize_from_record( diff --git a/core/src/runtime/record.rs b/core/src/runtime/record.rs index 006c235a1a..67a2464f99 100644 --- a/core/src/runtime/record.rs +++ b/core/src/runtime/record.rs @@ -17,6 +17,7 @@ use crate::cpu::CpuEvent; use crate::runtime::MemoryInitializeFinalizeEvent; use crate::runtime::MemoryRecordEnum; use crate::stark::MachineRecord; +use crate::syscall::precompiles::blake3::Blake3CompressInnerEvent; use crate::syscall::precompiles::edwards::EdDecompressEvent; use crate::syscall::precompiles::keccak256::KeccakPermuteEvent; use crate::syscall::precompiles::sha256::{ShaCompressEvent, ShaExtendEvent}; @@ -86,6 +87,8 @@ pub struct ExecutionRecord { pub k256_decompress_events: Vec, + pub blake3_compress_inner_events: Vec, + pub bls12381_add_events: Vec, pub bls12381_double_events: Vec, @@ -100,8 +103,6 @@ pub struct ExecutionRecord { /// The public values. pub public_values: PublicValues, - - pub nonce_lookup: HashMap, } pub struct ShardingConfig { @@ -219,6 +220,10 @@ impl MachineRecord for ExecutionRecord { "k256_decompress_events".to_string(), self.k256_decompress_events.len(), ); + stats.insert( + "blake3_compress_inner_events".to_string(), + self.blake3_compress_inner_events.len(), + ); stats.insert( "bls12381_add_events".to_string(), self.bls12381_add_events.len(), @@ -267,6 +272,8 @@ impl MachineRecord for ExecutionRecord { .append(&mut other.bn254_double_events); self.k256_decompress_events .append(&mut other.k256_decompress_events); + self.blake3_compress_inner_events + .append(&mut other.blake3_compress_inner_events); self.bls12381_add_events .append(&mut other.bls12381_add_events); self.bls12381_double_events @@ -349,27 +356,14 @@ impl MachineRecord for ExecutionRecord { } } + // Shard all the other events according to the configuration. + // Shard the ADD events. for (add_chunk, shard) in take(&mut self.add_events) .chunks_mut(config.add_len) .zip(shards.iter_mut()) { shard.add_events.extend_from_slice(add_chunk); - for (i, event) in add_chunk.iter().enumerate() { - self.nonce_lookup.insert(event.lookup_id, i as u32); - } - } - - // Shard the SUB events. - for (sub_chunk, shard) in take(&mut self.sub_events) - .chunks_mut(config.sub_len) - .zip(shards.iter_mut()) - { - shard.sub_events.extend_from_slice(sub_chunk); - for (i, event) in sub_chunk.iter().enumerate() { - self.nonce_lookup - .insert(event.lookup_id, shard.add_events.len() as u32 + i as u32); - } } // Shard the MUL events. @@ -378,9 +372,14 @@ impl MachineRecord for ExecutionRecord { .zip(shards.iter_mut()) { shard.mul_events.extend_from_slice(mul_chunk); - for (i, event) in mul_chunk.iter().enumerate() { - self.nonce_lookup.insert(event.lookup_id, i as u32); - } + } + + // Shard the SUB events. + for (sub_chunk, shard) in take(&mut self.sub_events) + .chunks_mut(config.sub_len) + .zip(shards.iter_mut()) + { + shard.sub_events.extend_from_slice(sub_chunk); } // Shard the bitwise events. @@ -389,9 +388,6 @@ impl MachineRecord for ExecutionRecord { .zip(shards.iter_mut()) { shard.bitwise_events.extend_from_slice(bitwise_chunk); - for (i, event) in bitwise_chunk.iter().enumerate() { - self.nonce_lookup.insert(event.lookup_id, i as u32); - } } // Shard the shift left events. @@ -400,9 +396,6 @@ impl MachineRecord for ExecutionRecord { .zip(shards.iter_mut()) { shard.shift_left_events.extend_from_slice(shift_left_chunk); - for (i, event) in shift_left_chunk.iter().enumerate() { - self.nonce_lookup.insert(event.lookup_id, i as u32); - } } // Shard the shift right events. @@ -413,9 +406,6 @@ impl MachineRecord for ExecutionRecord { shard .shift_right_events .extend_from_slice(shift_right_chunk); - for (i, event) in shift_right_chunk.iter().enumerate() { - self.nonce_lookup.insert(event.lookup_id, i as u32); - } } // Shard the divrem events. @@ -424,9 +414,6 @@ impl MachineRecord for ExecutionRecord { .zip(shards.iter_mut()) { shard.divrem_events.extend_from_slice(divrem_chunk); - for (i, event) in divrem_chunk.iter().enumerate() { - self.nonce_lookup.insert(event.lookup_id, i as u32); - } } // Shard the LT events. @@ -435,9 +422,6 @@ impl MachineRecord for ExecutionRecord { .zip(shards.iter_mut()) { shard.lt_events.extend_from_slice(lt_chunk); - for (i, event) in lt_chunk.iter().enumerate() { - self.nonce_lookup.insert(event.lookup_id, i as u32); - } } // Keccak-256 permute events. @@ -446,9 +430,6 @@ impl MachineRecord for ExecutionRecord { .zip(shards.iter_mut()) { shard.keccak_permute_events.extend_from_slice(keccak_chunk); - for (i, event) in keccak_chunk.iter().enumerate() { - self.nonce_lookup.insert(event.lookup_id, (i * 24) as u32); - } } // secp256k1 curve add events. @@ -459,9 +440,6 @@ impl MachineRecord for ExecutionRecord { shard .secp256k1_add_events .extend_from_slice(secp256k1_add_chunk); - for (i, event) in secp256k1_add_chunk.iter().enumerate() { - self.nonce_lookup.insert(event.lookup_id, i as u32); - } } // secp256k1 curve double events. @@ -472,9 +450,6 @@ impl MachineRecord for ExecutionRecord { shard .secp256k1_double_events .extend_from_slice(secp256k1_double_chunk); - for (i, event) in secp256k1_double_chunk.iter().enumerate() { - self.nonce_lookup.insert(event.lookup_id, i as u32); - } } // bn254 curve add events. @@ -483,9 +458,6 @@ impl MachineRecord for ExecutionRecord { .zip(shards.iter_mut()) { shard.bn254_add_events.extend_from_slice(bn254_add_chunk); - for (i, event) in bn254_add_chunk.iter().enumerate() { - self.nonce_lookup.insert(event.lookup_id, i as u32); - } } // bn254 curve double events. @@ -496,9 +468,6 @@ impl MachineRecord for ExecutionRecord { shard .bn254_double_events .extend_from_slice(bn254_double_chunk); - for (i, event) in bn254_double_chunk.iter().enumerate() { - self.nonce_lookup.insert(event.lookup_id, i as u32); - } } // BLS12-381 curve add events. @@ -509,9 +478,6 @@ impl MachineRecord for ExecutionRecord { shard .bls12381_add_events .extend_from_slice(bls12381_add_chunk); - for (i, event) in bls12381_add_chunk.iter().enumerate() { - self.nonce_lookup.insert(event.lookup_id, i as u32); - } } // BLS12-381 curve double events. @@ -522,9 +488,6 @@ impl MachineRecord for ExecutionRecord { shard .bls12381_double_events .extend_from_slice(bls12381_double_chunk); - for (i, event) in bls12381_double_chunk.iter().enumerate() { - self.nonce_lookup.insert(event.lookup_id, i as u32); - } } // Put the precompile events in the first shard. @@ -532,45 +495,27 @@ impl MachineRecord for ExecutionRecord { // SHA-256 extend events. first.sha_extend_events = std::mem::take(&mut self.sha_extend_events); - for (i, event) in first.sha_extend_events.iter().enumerate() { - self.nonce_lookup.insert(event.lookup_id, (i * 48) as u32); - } // SHA-256 compress events. first.sha_compress_events = std::mem::take(&mut self.sha_compress_events); - for (i, event) in first.sha_compress_events.iter().enumerate() { - self.nonce_lookup.insert(event.lookup_id, (i * 80) as u32); - } // Edwards curve add events. first.ed_add_events = std::mem::take(&mut self.ed_add_events); - for (i, event) in first.ed_add_events.iter().enumerate() { - self.nonce_lookup.insert(event.lookup_id, i as u32); - } // Edwards curve decompress events. first.ed_decompress_events = std::mem::take(&mut self.ed_decompress_events); - for (i, event) in first.ed_decompress_events.iter().enumerate() { - self.nonce_lookup.insert(event.lookup_id, i as u32); - } // K256 curve decompress events. first.k256_decompress_events = std::mem::take(&mut self.k256_decompress_events); - for (i, event) in first.k256_decompress_events.iter().enumerate() { - self.nonce_lookup.insert(event.lookup_id, i as u32); - } + + // Blake3 compress events . + first.blake3_compress_inner_events = std::mem::take(&mut self.blake3_compress_inner_events); // Uint256 mul arithmetic events. first.uint256_mul_events = std::mem::take(&mut self.uint256_mul_events); - for (i, event) in first.uint256_mul_events.iter().enumerate() { - self.nonce_lookup.insert(event.lookup_id, i as u32); - } // Bls12-381 decompress events . first.bls12381_decompress_events = std::mem::take(&mut self.bls12381_decompress_events); - for (i, event) in first.bls12381_decompress_events.iter().enumerate() { - self.nonce_lookup.insert(event.lookup_id, i as u32); - } // Put the memory records in the last shard. let last_shard = shards.last_mut().unwrap(); @@ -582,11 +527,6 @@ impl MachineRecord for ExecutionRecord { .memory_finalize_events .extend_from_slice(&self.memory_finalize_events); - // Copy the nonce lookup to all shards. - for shard in shards.iter_mut() { - shard.nonce_lookup.clone_from(&self.nonce_lookup); - } - shards } diff --git a/core/src/runtime/syscall.rs b/core/src/runtime/syscall.rs index 7cb1cb3383..c320c7e28e 100644 --- a/core/src/runtime/syscall.rs +++ b/core/src/runtime/syscall.rs @@ -5,6 +5,7 @@ use std::sync::Arc; use strum_macros::EnumIter; use crate::runtime::{Register, Runtime}; +use crate::stark::Blake3CompressInnerChip; use crate::syscall::precompiles::edwards::EdAddAssignChip; use crate::syscall::precompiles::edwards::EdDecompressChip; use crate::syscall::precompiles::keccak256::KeccakPermuteChip; @@ -67,6 +68,9 @@ pub enum SyscallCode { /// Executes the `SECP256K1_DECOMPRESS` precompile. SECP256K1_DECOMPRESS = 0x00_00_01_0C, + /// Executes the `BLAKE3_COMPRESS_INNER` precompile. + BLAKE3_COMPRESS_INNER = 0x00_38_01_0D, + /// Executes the `BN254_ADD` precompile. BN254_ADD = 0x00_01_01_0E, @@ -117,6 +121,7 @@ impl SyscallCode { 0x00_01_01_0A => SyscallCode::SECP256K1_ADD, 0x00_00_01_0B => SyscallCode::SECP256K1_DOUBLE, 0x00_00_01_0C => SyscallCode::SECP256K1_DECOMPRESS, + 0x00_38_01_0D => SyscallCode::BLAKE3_COMPRESS_INNER, 0x00_01_01_0E => SyscallCode::BN254_ADD, 0x00_00_01_0F => SyscallCode::BN254_DOUBLE, 0x00_01_01_1E => SyscallCode::BLS12381_ADD, @@ -175,7 +180,6 @@ pub struct SyscallContext<'a> { /// This is the exit_code used for the HALT syscall pub(crate) exit_code: u32, pub(crate) rt: &'a mut Runtime, - pub syscall_lookup_id: usize, } impl<'a> SyscallContext<'a> { @@ -188,7 +192,6 @@ impl<'a> SyscallContext<'a> { next_pc: runtime.state.pc.wrapping_add(4), exit_code: 0, rt: runtime, - syscall_lookup_id: 0, } } @@ -301,6 +304,10 @@ pub fn default_syscall_map() -> HashMap> { SyscallCode::BN254_DOUBLE, Arc::new(WeierstrassDoubleAssignChip::::new()), ); + syscall_map.insert( + SyscallCode::BLAKE3_COMPRESS_INNER, + Arc::new(Blake3CompressInnerChip::new()), + ); syscall_map.insert( SyscallCode::BLS12381_ADD, Arc::new(WeierstrassAddAssignChip::::new()), @@ -309,6 +316,10 @@ pub fn default_syscall_map() -> HashMap> { SyscallCode::BLS12381_DOUBLE, Arc::new(WeierstrassDoubleAssignChip::::new()), ); + syscall_map.insert( + SyscallCode::BLAKE3_COMPRESS_INNER, + Arc::new(Blake3CompressInnerChip::new()), + ); syscall_map.insert(SyscallCode::UINT256_MUL, Arc::new(Uint256MulChip::new())); syscall_map.insert( SyscallCode::ENTER_UNCONSTRAINED, @@ -348,6 +359,10 @@ mod tests { fn test_syscalls_in_default_map() { let default_syscall_map = default_syscall_map(); for code in SyscallCode::iter() { + if code == SyscallCode::BLAKE3_COMPRESS_INNER { + // Blake3 is currently disabled. + continue; + } default_syscall_map.get(&code).unwrap(); } } @@ -397,6 +412,9 @@ mod tests { SyscallCode::SECP256K1_DOUBLE => { assert_eq!(code as u32, sp1_zkvm::syscalls::SECP256K1_DOUBLE) } + SyscallCode::BLAKE3_COMPRESS_INNER => { + assert_eq!(code as u32, sp1_zkvm::syscalls::BLAKE3_COMPRESS_INNER) + } SyscallCode::BLS12381_ADD => { assert_eq!(code as u32, sp1_zkvm::syscalls::BLS12381_ADD) } diff --git a/core/src/stark/air.rs b/core/src/stark/air.rs index 558ccec5ab..dc181b18d2 100644 --- a/core/src/stark/air.rs +++ b/core/src/stark/air.rs @@ -21,6 +21,7 @@ pub(crate) mod riscv_chips { pub use crate::cpu::CpuChip; pub use crate::memory::MemoryChip; pub use crate::program::ProgramChip; + pub use crate::syscall::precompiles::blake3::Blake3CompressInnerChip; pub use crate::syscall::precompiles::edwards::EdAddAssignChip; pub use crate::syscall::precompiles::edwards::EdDecompressChip; pub use crate::syscall::precompiles::keccak256::KeccakPermuteChip; @@ -87,6 +88,8 @@ pub enum RiscvAir { Secp256k1Double(WeierstrassDoubleAssignChip>), /// A precompile for the Keccak permutation. KeccakP(KeccakPermuteChip), + /// A precompile for the Blake3 compression function. (Disabled by default.) + Blake3Compress(Blake3CompressInnerChip), /// A precompile for addition on the Elliptic curve bn254. Bn254Add(WeierstrassAddAssignChip>), /// A precompile for doubling a point on the Elliptic curve bn254. @@ -149,12 +152,12 @@ impl RiscvAir { chips.push(RiscvAir::Uint256Mul(uint256_mul)); let bls12381_decompress = WeierstrassDecompressChip::>::new(); chips.push(RiscvAir::Bls12381Decompress(bls12381_decompress)); - let div_rem = DivRemChip::default(); - chips.push(RiscvAir::DivRem(div_rem)); let add = AddSubChip::default(); chips.push(RiscvAir::Add(add)); let bitwise = BitwiseChip::default(); chips.push(RiscvAir::Bitwise(bitwise)); + let div_rem = DivRemChip::default(); + chips.push(RiscvAir::DivRem(div_rem)); let mul = MulChip::default(); chips.push(RiscvAir::Mul(mul)); let shift_right = ShiftRightChip::default(); diff --git a/core/src/stark/chip.rs b/core/src/stark/chip.rs index 7248e73a46..4a31646431 100644 --- a/core/src/stark/chip.rs +++ b/core/src/stark/chip.rs @@ -61,24 +61,12 @@ where where A: MachineAir + Air> + Air>, { + // Todo: correct values let mut builder = InteractionBuilder::new(air.preprocessed_width(), air.width()); air.eval(&mut builder); let (sends, receives) = builder.interactions(); - let nb_byte_sends = sends - .iter() - .filter(|s| s.kind == InteractionKind::Byte) - .count(); - let nb_byte_receives = receives - .iter() - .filter(|r| r.kind == InteractionKind::Byte) - .count(); - tracing::debug!( - "chip {} has {} byte interactions", - air.name(), - nb_byte_sends + nb_byte_receives - ); - + // TODO: enable different numbers of public values. let mut max_constraint_degree = get_max_constraint_degree(&air, air.preprocessed_width(), PROOF_MAX_NUM_PVS); diff --git a/core/src/stark/machine.rs b/core/src/stark/machine.rs index 9f61460a09..672226030b 100644 --- a/core/src/stark/machine.rs +++ b/core/src/stark/machine.rs @@ -473,8 +473,6 @@ pub enum MachineVerificationError { DebugInteractionsFailed, EmptyProof, InvalidPublicValues(&'static str), - TooManyShards, - InvalidChipOccurence(String), } impl Debug for MachineVerificationError { @@ -501,12 +499,6 @@ impl Debug for MachineVerificationError { MachineVerificationError::InvalidPublicValues(s) => { write!(f, "Invalid public values: {}", s) } - MachineVerificationError::TooManyShards => { - write!(f, "Too many shards") - } - MachineVerificationError::InvalidChipOccurence(s) => { - write!(f, "Invalid chip occurence: {}", s) - } } } } diff --git a/core/src/syscall/precompiles/blake3/compress/air.rs b/core/src/syscall/precompiles/blake3/compress/air.rs new file mode 100644 index 0000000000..a5876866e3 --- /dev/null +++ b/core/src/syscall/precompiles/blake3/compress/air.rs @@ -0,0 +1,235 @@ +use core::borrow::Borrow; + +use p3_air::{Air, AirBuilder, BaseAir}; +use p3_field::AbstractField; +use p3_matrix::Matrix; + +use super::columns::{Blake3CompressInnerCols, NUM_BLAKE3_COMPRESS_INNER_COLS}; +use super::g::GOperation; +use super::{ + Blake3CompressInnerChip, G_INDEX, MSG_SCHEDULE, NUM_MSG_WORDS_PER_CALL, + NUM_STATE_WORDS_PER_CALL, OPERATION_COUNT, ROUND_COUNT, +}; +use crate::air::{BaseAirBuilder, SP1AirBuilder, WORD_SIZE}; +use crate::runtime::SyscallCode; + +impl BaseAir for Blake3CompressInnerChip { + fn width(&self) -> usize { + NUM_BLAKE3_COMPRESS_INNER_COLS + } +} + +impl Air for Blake3CompressInnerChip +where + AB: SP1AirBuilder, +{ + fn eval(&self, builder: &mut AB) { + let main = builder.main(); + let (local, next) = (main.row_slice(0), main.row_slice(1)); + let local: &Blake3CompressInnerCols = (*local).borrow(); + let next: &Blake3CompressInnerCols = (*next).borrow(); + + self.constrain_control_flow_flags(builder, local, next); + + self.constrain_memory(builder, local); + + self.constrain_g_operation(builder, local); + + // TODO: constraint ecall_receive column. + // TODO: constraint clk column to increment by 1 within same invocation of syscall. + builder.receive_syscall( + local.shard, + local.channel, + local.clk, + AB::F::from_canonical_u32(SyscallCode::BLAKE3_COMPRESS_INNER.syscall_id()), + local.state_ptr, + local.message_ptr, + local.ecall_receive, + ); + } +} + +impl Blake3CompressInnerChip { + /// Constrains the given index is correct for the given selector. The `selector` is an + /// `n`-dimensional boolean array whose `i`-th element is true if and only if the index is `i`. + fn constrain_index_selector( + &self, + builder: &mut AB, + selector: &[AB::Var], + index: AB::Var, + is_real: AB::Var, + ) { + let mut acc: AB::Expr = AB::F::zero().into(); + for i in 0..selector.len() { + acc += selector[i].into(); + builder.assert_bool(selector[i]) + } + builder + .when(is_real) + .assert_eq(acc, AB::F::from_canonical_usize(1)); + for i in 0..selector.len() { + builder + .when(selector[i]) + .assert_eq(index, AB::F::from_canonical_usize(i)); + } + } + + /// Constrains the control flow flags such as the operation index and the round index. + fn constrain_control_flow_flags( + &self, + builder: &mut AB, + local: &Blake3CompressInnerCols, + next: &Blake3CompressInnerCols, + ) { + // If this is the i-th operation, then the next row should be the (i+1)-th operation. + for i in 0..OPERATION_COUNT { + builder.when_transition().when(next.is_real).assert_eq( + local.is_operation_index_n[i], + next.is_operation_index_n[(i + 1) % OPERATION_COUNT], + ); + } + + // If this is the last operation, the round index should be incremented. Otherwise, the + // round index should remain the same. + for i in 0..OPERATION_COUNT { + if i + 1 < OPERATION_COUNT { + builder + .when_transition() + .when(local.is_operation_index_n[i]) + .assert_eq(local.round_index, next.round_index); + } else { + builder + .when_transition() + .when(local.is_operation_index_n[i]) + .when_not(local.is_round_index_n[ROUND_COUNT - 1]) + .assert_eq( + local.round_index + AB::F::from_canonical_u16(1), + next.round_index, + ); + + builder + .when_transition() + .when(local.is_operation_index_n[i]) + .when(local.is_round_index_n[ROUND_COUNT - 1]) + .assert_zero(next.round_index); + } + } + } + + /// Constrain the memory access for the state and the message. + fn constrain_memory( + &self, + builder: &mut AB, + local: &Blake3CompressInnerCols, + ) { + // Calculate the 4 indices to read from the state. This corresponds to a, b, c, and d. + for i in 0..NUM_STATE_WORDS_PER_CALL { + let index_to_read = { + self.constrain_index_selector( + builder, + &local.is_operation_index_n, + local.operation_index, + local.is_real, + ); + + let mut acc = AB::Expr::from_canonical_usize(0); + for operation in 0..OPERATION_COUNT { + acc += AB::Expr::from_canonical_usize(G_INDEX[operation][i]) + * local.is_operation_index_n[operation]; + } + acc + }; + builder.assert_eq(local.state_index[i], index_to_read); + } + + // Read & write the state. + for i in 0..NUM_STATE_WORDS_PER_CALL { + builder.eval_memory_access( + local.shard, + local.channel, + local.clk, + local.state_ptr + local.state_index[i] * AB::F::from_canonical_usize(WORD_SIZE), + &local.state_reads_writes[i], + local.is_real, + ); + } + + // Calculate the indices to read from the message. + for i in 0..NUM_MSG_WORDS_PER_CALL { + let index_to_read = { + self.constrain_index_selector( + builder, + &local.is_round_index_n, + local.round_index, + local.is_real, + ); + + let mut acc = AB::Expr::from_canonical_usize(0); + + for round in 0..ROUND_COUNT { + for operation in 0..OPERATION_COUNT { + acc += + AB::Expr::from_canonical_usize(MSG_SCHEDULE[round][2 * operation + i]) + * local.is_operation_index_n[operation] + * local.is_round_index_n[round]; + } + } + acc + }; + builder.assert_eq(local.msg_schedule[i], index_to_read); + } + + // Read the message. + for i in 0..NUM_MSG_WORDS_PER_CALL { + builder.eval_memory_access( + local.shard, + local.channel, + local.clk, + local.message_ptr + local.msg_schedule[i] * AB::F::from_canonical_usize(WORD_SIZE), + &local.message_reads[i], + local.is_real, + ); + } + } + + /// Constrains the input and the output of the `g` operation. + fn constrain_g_operation( + &self, + builder: &mut AB, + local: &Blake3CompressInnerCols, + ) { + builder.assert_bool(local.is_real); + + // Call g and write the result to the state. + { + let input = [ + local.state_reads_writes[0].prev_value, + local.state_reads_writes[1].prev_value, + local.state_reads_writes[2].prev_value, + local.state_reads_writes[3].prev_value, + local.message_reads[0].access.value, + local.message_reads[1].access.value, + ]; + + // Call the g function. + GOperation::::eval( + builder, + input, + local.g, + local.shard, + local.channel, + local.is_real, + ); + + // Finally, the results of the g function should be written to the memory. + for i in 0..NUM_STATE_WORDS_PER_CALL { + for j in 0..WORD_SIZE { + builder.when(local.is_real).assert_eq( + local.state_reads_writes[i].access.value[j], + local.g.result[i][j], + ); + } + } + } + } +} diff --git a/core/src/syscall/precompiles/blake3/compress/columns.rs b/core/src/syscall/precompiles/blake3/compress/columns.rs new file mode 100644 index 0000000000..bf7bbe4e1e --- /dev/null +++ b/core/src/syscall/precompiles/blake3/compress/columns.rs @@ -0,0 +1,55 @@ +use std::mem::size_of; + +use sp1_derive::AlignedBorrow; + +use crate::memory::MemoryReadCols; +use crate::memory::MemoryReadWriteCols; + +use super::g::GOperation; +use super::NUM_MSG_WORDS_PER_CALL; +use super::NUM_STATE_WORDS_PER_CALL; +use super::OPERATION_COUNT; +use super::ROUND_COUNT; + +pub const NUM_BLAKE3_COMPRESS_INNER_COLS: usize = size_of::>(); + +#[derive(AlignedBorrow, Default, Debug, Clone, Copy)] +#[repr(C)] +pub struct Blake3CompressInnerCols { + pub shard: T, + pub channel: T, + pub clk: T, + pub ecall_receive: T, + + /// The pointer to the state. + pub state_ptr: T, + + /// The pointer to the message. + pub message_ptr: T, + + /// Reads and writes a part of the state. + pub state_reads_writes: [MemoryReadWriteCols; NUM_STATE_WORDS_PER_CALL], + + /// Reads a part of the message. + pub message_reads: [MemoryReadCols; NUM_MSG_WORDS_PER_CALL], + + /// Indicates which call of `g` is being performed. + pub operation_index: T, + pub is_operation_index_n: [T; OPERATION_COUNT], + + /// Indicates which call of `round` is being performed. + pub round_index: T, + pub is_round_index_n: [T; ROUND_COUNT], + + /// The indices to pass to `g`. + pub state_index: [T; NUM_STATE_WORDS_PER_CALL], + + /// The two values from `MSG_SCHEDULE` to pass to `g`. + pub msg_schedule: [T; NUM_MSG_WORDS_PER_CALL], + + /// The `g` operation to perform. + pub g: GOperation, + + /// Indicates if the current call is real or not. + pub is_real: T, +} diff --git a/core/src/syscall/precompiles/blake3/compress/execute.rs b/core/src/syscall/precompiles/blake3/compress/execute.rs new file mode 100644 index 0000000000..35298b0415 --- /dev/null +++ b/core/src/syscall/precompiles/blake3/compress/execute.rs @@ -0,0 +1,76 @@ +use crate::runtime::Syscall; +use crate::runtime::{MemoryReadRecord, MemoryWriteRecord}; +use crate::syscall::precompiles::blake3::{ + g_func, Blake3CompressInnerChip, Blake3CompressInnerEvent, G_INDEX, MSG_SCHEDULE, + NUM_MSG_WORDS_PER_CALL, NUM_STATE_WORDS_PER_CALL, OPERATION_COUNT, ROUND_COUNT, +}; +use crate::syscall::precompiles::SyscallContext; + +impl Syscall for Blake3CompressInnerChip { + fn num_extra_cycles(&self) -> u32 { + (ROUND_COUNT * OPERATION_COUNT) as u32 + } + + fn execute(&self, rt: &mut SyscallContext, arg1: u32, arg2: u32) -> Option { + let state_ptr = arg1; + let message_ptr = arg2; + + let start_clk = rt.clk; + let mut message_reads = + [[[MemoryReadRecord::default(); NUM_MSG_WORDS_PER_CALL]; OPERATION_COUNT]; ROUND_COUNT]; + let mut state_writes = [[[MemoryWriteRecord::default(); NUM_STATE_WORDS_PER_CALL]; + OPERATION_COUNT]; ROUND_COUNT]; + + for round in 0..ROUND_COUNT { + for operation in 0..OPERATION_COUNT { + let state_index = G_INDEX[operation]; + let message_index: [usize; NUM_MSG_WORDS_PER_CALL] = [ + MSG_SCHEDULE[round][2 * operation], + MSG_SCHEDULE[round][2 * operation + 1], + ]; + + let mut input = vec![]; + // Read the input to g. + { + for index in state_index.iter() { + input.push(rt.word_unsafe(state_ptr + (*index as u32) * 4)); + } + for i in 0..NUM_MSG_WORDS_PER_CALL { + let (record, value) = rt.mr(message_ptr + (message_index[i] as u32) * 4); + message_reads[round][operation][i] = record; + input.push(value); + } + } + + // Call g. + let results = g_func(input.try_into().unwrap()); + + // Write the state. + for i in 0..NUM_STATE_WORDS_PER_CALL { + state_writes[round][operation][i] = + rt.mw(state_ptr + (state_index[i] as u32) * 4, results[i]); + } + + // Increment the clock for the next call of g. + rt.clk += 1; + } + } + + let shard = rt.current_shard(); + let channel = rt.current_channel(); + + rt.record_mut() + .blake3_compress_inner_events + .push(Blake3CompressInnerEvent { + shard, + channel, + clk: start_clk, + state_ptr, + message_reads, + state_writes, + message_ptr, + }); + + None + } +} diff --git a/core/src/syscall/precompiles/blake3/compress/g.rs b/core/src/syscall/precompiles/blake3/compress/g.rs new file mode 100644 index 0000000000..06e8c30348 --- /dev/null +++ b/core/src/syscall/precompiles/blake3/compress/g.rs @@ -0,0 +1,277 @@ +use p3_field::Field; +use sp1_derive::AlignedBorrow; + +use crate::air::SP1AirBuilder; +use crate::air::Word; +use crate::air::WORD_SIZE; +use crate::operations::AddOperation; +use crate::operations::FixedRotateRightOperation; +use crate::operations::XorOperation; +use crate::runtime::ExecutionRecord; + +use super::g_func; +/// A set of columns needed to compute the `g` of the input state. +/// ``` ignore +/// fn g(state: &mut BlockWords, a: usize, b: usize, c: usize, d: usize, x: u32, y: u32) { +/// state[a] = state[a].wrapping_add(state[b]).wrapping_add(x); +/// state[d] = (state[d] ^ state[a]).rotate_right(16); +/// state[c] = state[c].wrapping_add(state[d]); +/// state[b] = (state[b] ^ state[c]).rotate_right(12); +/// state[a] = state[a].wrapping_add(state[b]).wrapping_add(y); +/// state[d] = (state[d] ^ state[a]).rotate_right(8); +/// state[c] = state[c].wrapping_add(state[d]); +/// state[b] = (state[b] ^ state[c]).rotate_right(7); +/// } +/// ``` +#[derive(AlignedBorrow, Default, Debug, Clone, Copy)] +#[repr(C)] +pub struct GOperation { + pub a_plus_b: AddOperation, + pub a_plus_b_plus_x: AddOperation, + pub d_xor_a: XorOperation, + // Rotate right by 16 bits by just shifting bytes. + pub c_plus_d: AddOperation, + pub b_xor_c: XorOperation, + pub b_xor_c_rotate_right_12: FixedRotateRightOperation, + pub a_plus_b_2: AddOperation, + pub a_plus_b_2_add_y: AddOperation, + // Rotate right by 8 bits by just shifting bytes. + pub d_xor_a_2: XorOperation, + pub c_plus_d_2: AddOperation, + pub b_xor_c_2: XorOperation, + pub b_xor_c_2_rotate_right_7: FixedRotateRightOperation, + /// `state[a]`, `state[b]`, `state[c]`, `state[d]` after all the steps. + pub result: [Word; 4], +} + +impl GOperation { + pub fn populate( + &mut self, + record: &mut ExecutionRecord, + shard: u32, + channel: u32, + input: [u32; 6], + ) -> [u32; 4] { + let mut a = input[0]; + let mut b = input[1]; + let mut c = input[2]; + let mut d = input[3]; + let x = input[4]; + let y = input[5]; + + // First 4 steps. + { + // a = a + b + x. + a = self.a_plus_b.populate(record, shard, channel, a, b); + a = self.a_plus_b_plus_x.populate(record, shard, channel, a, x); + + // d = (d ^ a).rotate_right(16). + d = self.d_xor_a.populate(record, shard, channel, d, a); + d = d.rotate_right(16); + + // c = c + d. + c = self.c_plus_d.populate(record, shard, channel, c, d); + + // b = (b ^ c).rotate_right(12). + b = self.b_xor_c.populate(record, shard, channel, b, c); + b = self + .b_xor_c_rotate_right_12 + .populate(record, shard, channel, b, 12); + } + + // Second 4 steps. + { + // a = a + b + y. + a = self.a_plus_b_2.populate(record, shard, channel, a, b); + a = self.a_plus_b_2_add_y.populate(record, shard, channel, a, y); + + // d = (d ^ a).rotate_right(8). + d = self.d_xor_a_2.populate(record, shard, channel, d, a); + d = d.rotate_right(8); + + // c = c + d. + c = self.c_plus_d_2.populate(record, shard, channel, c, d); + + // b = (b ^ c).rotate_right(7). + b = self.b_xor_c_2.populate(record, shard, channel, b, c); + b = self + .b_xor_c_2_rotate_right_7 + .populate(record, shard, channel, b, 7); + } + + let result = [a, b, c, d]; + assert_eq!(result, g_func(input)); + self.result = result.map(Word::from); + result + } + + pub fn eval( + builder: &mut AB, + input: [Word; 6], + cols: GOperation, + shard: AB::Var, + channel: impl Into + Clone, + is_real: AB::Var, + ) { + builder.assert_bool(is_real); + let mut a = input[0]; + let mut b = input[1]; + let mut c = input[2]; + let mut d = input[3]; + let x = input[4]; + let y = input[5]; + + // First 4 steps. + { + // a = a + b + x. + AddOperation::::eval( + builder, + a, + b, + cols.a_plus_b, + shard, + channel.clone(), + is_real.into(), + ); + a = cols.a_plus_b.value; + AddOperation::::eval( + builder, + a, + x, + cols.a_plus_b_plus_x, + shard, + channel.clone(), + is_real.into(), + ); + a = cols.a_plus_b_plus_x.value; + + // d = (d ^ a).rotate_right(16). + XorOperation::::eval( + builder, + d, + a, + cols.d_xor_a, + shard, + channel.clone(), + is_real, + ); + d = cols.d_xor_a.value; + // Rotate right by 16 bits. + d = Word([d[2], d[3], d[0], d[1]]); + + // c = c + d. + AddOperation::::eval( + builder, + c, + d, + cols.c_plus_d, + shard, + channel.clone(), + is_real.into(), + ); + c = cols.c_plus_d.value; + + // b = (b ^ c).rotate_right(12). + XorOperation::::eval( + builder, + b, + c, + cols.b_xor_c, + shard, + channel.clone(), + is_real, + ); + b = cols.b_xor_c.value; + FixedRotateRightOperation::::eval( + builder, + b, + 12, + cols.b_xor_c_rotate_right_12, + shard, + channel.clone(), + is_real, + ); + b = cols.b_xor_c_rotate_right_12.value; + } + + // Second 4 steps. + { + // a = a + b + y. + AddOperation::::eval( + builder, + a, + b, + cols.a_plus_b_2, + shard, + channel.clone(), + is_real.into(), + ); + a = cols.a_plus_b_2.value; + AddOperation::::eval( + builder, + a, + y, + cols.a_plus_b_2_add_y, + shard, + channel.clone(), + is_real.into(), + ); + a = cols.a_plus_b_2_add_y.value; + + // d = (d ^ a).rotate_right(8). + XorOperation::::eval( + builder, + d, + a, + cols.d_xor_a_2, + shard, + channel.clone(), + is_real, + ); + d = cols.d_xor_a_2.value; + // Rotate right by 8 bits. + d = Word([d[1], d[2], d[3], d[0]]); + + // c = c + d. + AddOperation::::eval( + builder, + c, + d, + cols.c_plus_d_2, + shard, + channel.clone(), + is_real.into(), + ); + c = cols.c_plus_d_2.value; + + // b = (b ^ c).rotate_right(7). + XorOperation::::eval( + builder, + b, + c, + cols.b_xor_c_2, + shard, + channel.clone(), + is_real, + ); + b = cols.b_xor_c_2.value; + FixedRotateRightOperation::::eval( + builder, + b, + 7, + cols.b_xor_c_2_rotate_right_7, + shard, + channel.clone(), + is_real, + ); + b = cols.b_xor_c_2_rotate_right_7.value; + } + + let results = [a, b, c, d]; + for i in 0..4 { + for j in 0..WORD_SIZE { + builder.assert_eq(cols.result[i][j], results[i][j]); + } + } + } +} diff --git a/core/src/syscall/precompiles/blake3/compress/mod.rs b/core/src/syscall/precompiles/blake3/compress/mod.rs new file mode 100644 index 0000000000..a89b9bcc34 --- /dev/null +++ b/core/src/syscall/precompiles/blake3/compress/mod.rs @@ -0,0 +1,179 @@ +//! This module contains the implementation of the `blake3_compress_inner` precompile based on the +//! implementation of the `blake3` hash function in BLAKE3. +//! +//! Pseudo-code. +//! +//! state = [0u32; 16] +//! message = [0u32; 16] +//! +//! for round in 0..7 { +//! for operation in 0..8 { +//! // * Pick 4 indices a, b, c, d for the state, based on the operation index. +//! // * Pick 2 indices x, y for the message, based on both the round and the operation index. +//! // +//! // g takes those 6 values, and updates the 4 state values, at indices a, b, c, d. +//! // +//! // Each call of g becomes one row in the trace. +//! g(&mut state[a], &mut state[b], &mut state[c], &mut state[d], message[x], message[y]); +//! } +//! } +//! +//! Note that this precompile is only the blake3 compress inner function. The Blake3 compress +//! function has a series of 8 XOR operations after the compress inner function. +mod air; +mod columns; +mod execute; +mod g; +mod trace; +use crate::runtime::{MemoryReadRecord, MemoryWriteRecord}; + +use serde::{Deserialize, Serialize}; + +/// The number of `Word`s in the message of the compress inner operation. +pub(crate) const MSG_SIZE: usize = 16; + +/// The number of times we call `round` in the compress inner operation. +pub(crate) const ROUND_COUNT: usize = 7; + +/// The number of times we call `g` in the compress inner operation. +pub(crate) const OPERATION_COUNT: usize = 8; + +/// The number of `Word`s in the state that we pass to `g`. +pub(crate) const NUM_STATE_WORDS_PER_CALL: usize = 4; + +/// The number of `Word`s in the message that we pass to `g`. +pub(crate) const NUM_MSG_WORDS_PER_CALL: usize = 2; + +/// The number of `Word`s in the input of `g`. +pub(crate) const G_INPUT_SIZE: usize = NUM_MSG_WORDS_PER_CALL + NUM_STATE_WORDS_PER_CALL; + +/// 2-dimensional array specifying which message values `g` should access. Values at `(i, 2 * j)` +/// and `(i, 2 * j + 1)` are the indices of the message values that `g` should access in the `j`-th +/// call of the `i`-th round. +pub(crate) const MSG_SCHEDULE: [[usize; MSG_SIZE]; ROUND_COUNT] = [ + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], + [2, 6, 3, 10, 7, 0, 4, 13, 1, 11, 12, 5, 9, 14, 15, 8], + [3, 4, 10, 12, 13, 2, 7, 14, 6, 5, 9, 0, 11, 15, 8, 1], + [10, 7, 12, 9, 14, 3, 13, 15, 4, 0, 11, 2, 5, 8, 1, 6], + [12, 13, 9, 11, 15, 10, 14, 8, 7, 2, 5, 3, 0, 1, 6, 4], + [9, 14, 11, 5, 8, 12, 15, 1, 13, 3, 0, 10, 2, 6, 4, 7], + [11, 15, 5, 0, 1, 9, 8, 6, 14, 10, 2, 12, 3, 4, 7, 13], +]; + +/// The `i`-th row of `G_INDEX` is the indices used for the `i`-th call to `g`. +pub(crate) const G_INDEX: [[usize; NUM_STATE_WORDS_PER_CALL]; OPERATION_COUNT] = [ + [0, 4, 8, 12], + [1, 5, 9, 13], + [2, 6, 10, 14], + [3, 7, 11, 15], + [0, 5, 10, 15], + [1, 6, 11, 12], + [2, 7, 8, 13], + [3, 4, 9, 14], +]; + +pub(crate) const fn g_func(input: [u32; 6]) -> [u32; 4] { + let mut a = input[0]; + let mut b = input[1]; + let mut c = input[2]; + let mut d = input[3]; + let x = input[4]; + let y = input[5]; + a = a.wrapping_add(b).wrapping_add(x); + d = (d ^ a).rotate_right(16); + c = c.wrapping_add(d); + b = (b ^ c).rotate_right(12); + a = a.wrapping_add(b).wrapping_add(y); + d = (d ^ a).rotate_right(8); + c = c.wrapping_add(d); + b = (b ^ c).rotate_right(7); + [a, b, c, d] +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Blake3CompressInnerEvent { + pub clk: u32, + pub shard: u32, + pub channel: u32, + pub state_ptr: u32, + pub message_ptr: u32, + pub message_reads: [[[MemoryReadRecord; NUM_MSG_WORDS_PER_CALL]; OPERATION_COUNT]; ROUND_COUNT], + pub state_writes: + [[[MemoryWriteRecord; NUM_STATE_WORDS_PER_CALL]; OPERATION_COUNT]; ROUND_COUNT], +} + +pub struct Blake3CompressInnerChip {} + +impl Blake3CompressInnerChip { + pub const fn new() -> Self { + Self {} + } +} + +#[cfg(test)] +pub mod compress_tests { + use crate::runtime::Instruction; + use crate::runtime::Opcode; + use crate::runtime::Register; + use crate::runtime::SyscallCode; + use crate::Program; + + use super::MSG_SIZE; + + /// The number of `Word`s in the state of the compress inner operation. + const STATE_SIZE: usize = 16; + + pub fn blake3_compress_internal_program() -> Program { + let state_ptr = 100; + let msg_ptr = 500; + let mut instructions = vec![]; + + for i in 0..STATE_SIZE { + // Store 1000 + i in memory for the i-th word of the state. 1000 + i is an arbitrary + // number that is easy to spot while debugging. + instructions.extend(vec![ + Instruction::new(Opcode::ADD, 29, 0, 1000 + i as u32, false, true), + Instruction::new(Opcode::ADD, 30, 0, state_ptr + i as u32 * 4, false, true), + Instruction::new(Opcode::SW, 29, 30, 0, false, true), + ]); + } + for i in 0..MSG_SIZE { + // Store 2000 + i in memory for the i-th word of the message. 2000 + i is an arbitrary + // number that is easy to spot while debugging. + instructions.extend(vec![ + Instruction::new(Opcode::ADD, 29, 0, 2000 + i as u32, false, true), + Instruction::new(Opcode::ADD, 30, 0, msg_ptr + i as u32 * 4, false, true), + Instruction::new(Opcode::SW, 29, 30, 0, false, true), + ]); + } + instructions.extend(vec![ + Instruction::new( + Opcode::ADD, + 5, + 0, + SyscallCode::BLAKE3_COMPRESS_INNER as u32, + false, + true, + ), + Instruction::new(Opcode::ADD, Register::X10 as u32, 0, state_ptr, false, true), + Instruction::new(Opcode::ADD, Register::X11 as u32, 0, msg_ptr, false, true), + Instruction::new(Opcode::ECALL, 5, 10, 11, false, false), + ]); + Program::new(instructions, 0, 0) + } + + // Tests disabled because syscall is not enabled in default runtime/chip configs. + // #[test] + // fn prove_babybear() { + // setup_logger(); + // let program = blake3_compress_internal_program(); + // run_test(program).unwrap(); + // } + + // #[test] + // fn test_blake3_compress_inner_elf() { + // setup_logger(); + // let program = Program::from(BLAKE3_COMPRESS_ELF); + // run_test(program).unwrap(); + // } +} diff --git a/core/src/syscall/precompiles/blake3/compress/trace.rs b/core/src/syscall/precompiles/blake3/compress/trace.rs new file mode 100644 index 0000000000..14994cb031 --- /dev/null +++ b/core/src/syscall/precompiles/blake3/compress/trace.rs @@ -0,0 +1,131 @@ +use std::borrow::BorrowMut; + +use p3_field::PrimeField32; +use p3_matrix::dense::RowMajorMatrix; + +use super::columns::Blake3CompressInnerCols; +use super::{ + G_INDEX, G_INPUT_SIZE, MSG_SCHEDULE, NUM_MSG_WORDS_PER_CALL, NUM_STATE_WORDS_PER_CALL, + OPERATION_COUNT, +}; +use crate::air::MachineAir; +use crate::bytes::event::ByteRecord; +use crate::runtime::ExecutionRecord; +use crate::runtime::MemoryRecordEnum; +use crate::runtime::Program; +use crate::syscall::precompiles::blake3::compress::columns::NUM_BLAKE3_COMPRESS_INNER_COLS; +use crate::syscall::precompiles::blake3::{Blake3CompressInnerChip, ROUND_COUNT}; +use crate::utils::pad_rows; + +impl MachineAir for Blake3CompressInnerChip { + type Record = ExecutionRecord; + type Program = Program; + + fn name(&self) -> String { + "Blake3CompressInner".to_string() + } + + fn generate_trace( + &self, + input: &ExecutionRecord, + output: &mut ExecutionRecord, + ) -> RowMajorMatrix { + let mut rows = Vec::new(); + + let mut new_byte_lookup_events = Vec::new(); + + for i in 0..input.blake3_compress_inner_events.len() { + let event = input.blake3_compress_inner_events[i].clone(); + let shard = event.shard; + let channel = event.channel; + let mut clk = event.clk; + for round in 0..ROUND_COUNT { + for operation in 0..OPERATION_COUNT { + let mut row = [F::zero(); NUM_BLAKE3_COMPRESS_INNER_COLS]; + let cols: &mut Blake3CompressInnerCols = row.as_mut_slice().borrow_mut(); + + // Assign basic values to the columns. + { + cols.shard = F::from_canonical_u32(event.shard); + cols.channel = F::from_canonical_u32(event.channel); + cols.clk = F::from_canonical_u32(clk); + + cols.round_index = F::from_canonical_u32(round as u32); + cols.is_round_index_n[round] = F::one(); + + cols.operation_index = F::from_canonical_u32(operation as u32); + cols.is_operation_index_n[operation] = F::one(); + + for i in 0..NUM_STATE_WORDS_PER_CALL { + cols.state_index[i] = F::from_canonical_usize(G_INDEX[operation][i]); + } + + for i in 0..NUM_MSG_WORDS_PER_CALL { + cols.msg_schedule[i] = + F::from_canonical_usize(MSG_SCHEDULE[round][2 * operation + i]); + } + + if round == 0 && operation == 0 { + cols.ecall_receive = F::one(); + } + } + + // Memory columns. + { + cols.message_ptr = F::from_canonical_u32(event.message_ptr); + for i in 0..NUM_MSG_WORDS_PER_CALL { + cols.message_reads[i].populate( + channel, + event.message_reads[round][operation][i], + &mut new_byte_lookup_events, + ); + } + + cols.state_ptr = F::from_canonical_u32(event.state_ptr); + for i in 0..NUM_STATE_WORDS_PER_CALL { + cols.state_reads_writes[i].populate( + channel, + MemoryRecordEnum::Write(event.state_writes[round][operation][i]), + &mut new_byte_lookup_events, + ); + } + } + + // Apply the `g` operation. + { + let input: [u32; G_INPUT_SIZE] = [ + event.state_writes[round][operation][0].prev_value, + event.state_writes[round][operation][1].prev_value, + event.state_writes[round][operation][2].prev_value, + event.state_writes[round][operation][3].prev_value, + event.message_reads[round][operation][0].value, + event.message_reads[round][operation][1].value, + ]; + + cols.g.populate(output, shard, channel, input); + } + + clk += 1; + + cols.is_real = F::one(); + + rows.push(row); + } + } + } + + output.add_byte_lookup_events(new_byte_lookup_events); + + pad_rows(&mut rows, || [F::zero(); NUM_BLAKE3_COMPRESS_INNER_COLS]); + + // Convert the trace to a row major matrix. + RowMajorMatrix::new( + rows.into_iter().flatten().collect::>(), + NUM_BLAKE3_COMPRESS_INNER_COLS, + ) + } + + fn included(&self, shard: &Self::Record) -> bool { + !shard.blake3_compress_inner_events.is_empty() + } +} diff --git a/core/src/syscall/precompiles/blake3/mod.rs b/core/src/syscall/precompiles/blake3/mod.rs new file mode 100644 index 0000000000..8b286ad176 --- /dev/null +++ b/core/src/syscall/precompiles/blake3/mod.rs @@ -0,0 +1,3 @@ +mod compress; + +pub use compress::*; diff --git a/core/src/syscall/precompiles/edwards/ed_add.rs b/core/src/syscall/precompiles/edwards/ed_add.rs index 44d29edb83..f4e15423a7 100644 --- a/core/src/syscall/precompiles/edwards/ed_add.rs +++ b/core/src/syscall/precompiles/edwards/ed_add.rs @@ -6,7 +6,6 @@ use std::marker::PhantomData; use num::BigUint; use num::Zero; -use p3_air::AirBuilder; use p3_air::{Air, BaseAir}; use p3_field::AbstractField; use p3_field::PrimeField32; @@ -55,7 +54,6 @@ pub struct EdAddAssignCols { pub shard: T, pub channel: T, pub clk: T, - pub nonce: T, pub p_ptr: T, pub q_ptr: T, pub p_access: [MemoryWriteCols; WORDS_CURVE_POINT], @@ -240,19 +238,10 @@ impl MachineAir for Ed }); // Convert the trace to a row major matrix. - let mut trace = RowMajorMatrix::new( + RowMajorMatrix::new( rows.into_iter().flatten().collect::>(), NUM_ED_ADD_COLS, - ); - - // Write the nonces to the trace. - for i in 0..trace.height() { - let cols: &mut EdAddAssignCols = - trace.values[i * NUM_ED_ADD_COLS..(i + 1) * NUM_ED_ADD_COLS].borrow_mut(); - cols.nonce = F::from_canonical_usize(i); - } - - trace + ) } fn included(&self, shard: &Self::Record) -> bool { @@ -272,150 +261,141 @@ where { fn eval(&self, builder: &mut AB) { let main = builder.main(); - let local = main.row_slice(0); - let local: &EdAddAssignCols = (*local).borrow(); - let next = main.row_slice(1); - let next: &EdAddAssignCols = (*next).borrow(); - - // Constrain the incrementing nonce. - builder.when_first_row().assert_zero(local.nonce); - builder - .when_transition() - .assert_eq(local.nonce + AB::Expr::one(), next.nonce); + let row = main.row_slice(0); + let row: &EdAddAssignCols = (*row).borrow(); - let x1 = limbs_from_prev_access(&local.p_access[0..8]); - let x2 = limbs_from_prev_access(&local.q_access[0..8]); - let y1 = limbs_from_prev_access(&local.p_access[8..16]); - let y2 = limbs_from_prev_access(&local.q_access[8..16]); + let x1 = limbs_from_prev_access(&row.p_access[0..8]); + let x2 = limbs_from_prev_access(&row.q_access[0..8]); + let y1 = limbs_from_prev_access(&row.p_access[8..16]); + let y2 = limbs_from_prev_access(&row.q_access[8..16]); // x3_numerator = x1 * y2 + x2 * y1. - local.x3_numerator.eval( + row.x3_numerator.eval( builder, &[x1, x2], &[y2, y1], - local.shard, - local.channel, - local.is_real, + row.shard, + row.channel, + row.is_real, ); // y3_numerator = y1 * y2 + x1 * x2. - local.y3_numerator.eval( + row.y3_numerator.eval( builder, &[y1, x1], &[y2, x2], - local.shard, - local.channel, - local.is_real, + row.shard, + row.channel, + row.is_real, ); // f = x1 * x2 * y1 * y2. - local.x1_mul_y1.eval( + row.x1_mul_y1.eval( builder, &x1, &y1, FieldOperation::Mul, - local.shard, - local.channel, - local.is_real, + row.shard, + row.channel, + row.is_real, ); - local.x2_mul_y2.eval( + row.x2_mul_y2.eval( builder, &x2, &y2, FieldOperation::Mul, - local.shard, - local.channel, - local.is_real, + row.shard, + row.channel, + row.is_real, ); - let x1_mul_y1 = local.x1_mul_y1.result; - let x2_mul_y2 = local.x2_mul_y2.result; - local.f.eval( + let x1_mul_y1 = row.x1_mul_y1.result; + let x2_mul_y2 = row.x2_mul_y2.result; + row.f.eval( builder, &x1_mul_y1, &x2_mul_y2, FieldOperation::Mul, - local.shard, - local.channel, - local.is_real, + row.shard, + row.channel, + row.is_real, ); // d * f. - let f = local.f.result; + let f = row.f.result; let d_biguint = E::d_biguint(); let d_const = E::BaseField::to_limbs_field::(&d_biguint); - local.d_mul_f.eval( + row.d_mul_f.eval( builder, &f, &d_const, FieldOperation::Mul, - local.shard, - local.channel, - local.is_real, + row.shard, + row.channel, + row.is_real, ); - let d_mul_f = local.d_mul_f.result; + let d_mul_f = row.d_mul_f.result; // x3 = x3_numerator / (1 + d * f). - local.x3_ins.eval( + row.x3_ins.eval( builder, - &local.x3_numerator.result, + &row.x3_numerator.result, &d_mul_f, true, - local.shard, - local.channel, - local.is_real, + row.shard, + row.channel, + row.is_real, ); // y3 = y3_numerator / (1 - d * f). - local.y3_ins.eval( + row.y3_ins.eval( builder, - &local.y3_numerator.result, + &row.y3_numerator.result, &d_mul_f, false, - local.shard, - local.channel, - local.is_real, + row.shard, + row.channel, + row.is_real, ); // Constraint self.p_access.value = [self.x3_ins.result, self.y3_ins.result] // This is to ensure that p_access is updated with the new value. - let p_access_vec = value_as_limbs(&local.p_access); + let p_access_vec = value_as_limbs(&row.p_access); builder - .when(local.is_real) - .assert_all_eq(local.x3_ins.result, p_access_vec[0..NUM_LIMBS].to_vec()); - builder.when(local.is_real).assert_all_eq( - local.y3_ins.result, + .when(row.is_real) + .assert_all_eq(row.x3_ins.result, p_access_vec[0..NUM_LIMBS].to_vec()); + builder.when(row.is_real).assert_all_eq( + row.y3_ins.result, p_access_vec[NUM_LIMBS..NUM_LIMBS * 2].to_vec(), ); builder.eval_memory_access_slice( - local.shard, - local.channel, - local.clk.into(), - local.q_ptr, - &local.q_access, - local.is_real, + row.shard, + row.channel, + row.clk.into(), + row.q_ptr, + &row.q_access, + row.is_real, ); builder.eval_memory_access_slice( - local.shard, - local.channel, - local.clk + AB::F::from_canonical_u32(1), - local.p_ptr, - &local.p_access, - local.is_real, + row.shard, + row.channel, + row.clk + AB::F::from_canonical_u32(1), + row.p_ptr, + &row.p_access, + row.is_real, ); builder.receive_syscall( - local.shard, - local.channel, - local.clk, - local.nonce, + row.shard, + row.channel, + row.clk, AB::F::from_canonical_u32(SyscallCode::ED_ADD.syscall_id()), - local.p_ptr, - local.q_ptr, - local.is_real, + row.p_ptr, + row.q_ptr, + row.is_real, ); } } diff --git a/core/src/syscall/precompiles/edwards/ed_decompress.rs b/core/src/syscall/precompiles/edwards/ed_decompress.rs index a0618137ce..be62467c00 100644 --- a/core/src/syscall/precompiles/edwards/ed_decompress.rs +++ b/core/src/syscall/precompiles/edwards/ed_decompress.rs @@ -53,7 +53,6 @@ use super::{WordsFieldElement, WORDS_FIELD_ELEMENT}; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct EdDecompressEvent { - pub lookup_id: usize, pub shard: u32, pub channel: u32, pub clk: u32, @@ -79,7 +78,6 @@ pub struct EdDecompressCols { pub shard: T, pub channel: T, pub clk: T, - pub nonce: T, pub ptr: T, pub sign: T, pub x_access: GenericArray, WordsFieldElement>, @@ -106,13 +104,6 @@ impl EdDecompressCols { self.channel = F::from_canonical_u32(event.channel); self.clk = F::from_canonical_u32(event.clk); self.ptr = F::from_canonical_u32(event.ptr); - self.nonce = F::from_canonical_u32( - record - .nonce_lookup - .get(&event.lookup_id) - .copied() - .unwrap_or_default(), - ); self.sign = F::from_bool(event.sign); for i in 0..8 { self.x_access[i].populate( @@ -285,7 +276,6 @@ impl EdDecompressCols { self.shard, self.channel, self.clk, - self.nonce, AB::F::from_canonical_u32(SyscallCode::ED_DECOMPRESS.syscall_id()), self.ptr, self.sign, @@ -336,13 +326,11 @@ impl Syscall for EdDecompressChip { let x_memory_records_vec = rt.mw_slice(slice_ptr, &decompressed_x_words); let x_memory_records: [MemoryWriteRecord; 8] = x_memory_records_vec.try_into().unwrap(); - let lookup_id = rt.syscall_lookup_id; let shard = rt.current_shard(); let channel = rt.current_channel(); rt.record_mut() .ed_decompress_events .push(EdDecompressEvent { - lookup_id, shard, channel, clk: start_clk, @@ -402,20 +390,10 @@ impl MachineAir for EdDecompressChip>(), NUM_ED_DECOMPRESS_COLS, - ); - - // Write the nonces to the trace. - for i in 0..trace.height() { - let cols: &mut EdDecompressCols = trace.values - [i * NUM_ED_DECOMPRESS_COLS..(i + 1) * NUM_ED_DECOMPRESS_COLS] - .borrow_mut(); - cols.nonce = F::from_canonical_usize(i); - } - - trace + ) } fn included(&self, shard: &Self::Record) -> bool { @@ -435,18 +413,9 @@ where { fn eval(&self, builder: &mut AB) { let main = builder.main(); - let local = main.row_slice(0); - let local: &EdDecompressCols = (*local).borrow(); - let next = main.row_slice(1); - let next: &EdDecompressCols = (*next).borrow(); - - // Constrain the incrementing nonce. - builder.when_first_row().assert_zero(local.nonce); - builder - .when_transition() - .assert_eq(local.nonce + AB::Expr::one(), next.nonce); - - local.eval::(builder); + let row = main.row_slice(0); + let row: &EdDecompressCols = (*row).borrow(); + row.eval::(builder); } } diff --git a/core/src/syscall/precompiles/keccak256/air.rs b/core/src/syscall/precompiles/keccak256/air.rs index 1647616798..9e67c12490 100644 --- a/core/src/syscall/precompiles/keccak256/air.rs +++ b/core/src/syscall/precompiles/keccak256/air.rs @@ -32,12 +32,6 @@ where let local: &KeccakMemCols = (*local).borrow(); let next: &KeccakMemCols = (*next).borrow(); - // Constrain the incrementing nonce. - builder.when_first_row().assert_zero(local.nonce); - builder - .when_transition() - .assert_eq(local.nonce + AB::Expr::one(), next.nonce); - let first_step = local.keccak.step_flags[0]; let final_step = local.keccak.step_flags[NUM_ROUNDS - 1]; let not_final_step = AB::Expr::one() - final_step; @@ -74,7 +68,6 @@ where local.shard, local.channel, local.clk, - local.nonce, AB::F::from_canonical_u32(SyscallCode::KECCAK_PERMUTE.syscall_id()), local.state_addr, AB::Expr::zero(), @@ -86,7 +79,6 @@ where let mut transition_not_final_builder = transition_builder.when(not_final_step); transition_not_final_builder.assert_eq(local.shard, next.shard); transition_not_final_builder.assert_eq(local.clk, next.clk); - transition_not_final_builder.assert_eq(local.channel, next.channel); transition_not_final_builder.assert_eq(local.state_addr, next.state_addr); transition_not_final_builder.assert_eq(local.is_real, next.is_real); @@ -131,16 +123,6 @@ where } } - // Range check all the values in `state_mem` to be bytes. - for i in 0..STATE_NUM_WORDS { - builder.slice_range_check_u8( - &local.state_mem[i].value().0, - local.shard, - local.channel, - local.do_memory_check, - ); - } - let mut sub_builder = SubAirBuilder::::new(builder, 0..NUM_KECCAK_COLS); diff --git a/core/src/syscall/precompiles/keccak256/columns.rs b/core/src/syscall/precompiles/keccak256/columns.rs index ad3aa5f099..a3e2dd3044 100644 --- a/core/src/syscall/precompiles/keccak256/columns.rs +++ b/core/src/syscall/precompiles/keccak256/columns.rs @@ -20,7 +20,6 @@ pub(crate) struct KeccakMemCols { pub shard: T, pub channel: T, pub clk: T, - pub nonce: T, pub state_addr: T, /// Memory columns for the state. diff --git a/core/src/syscall/precompiles/keccak256/execute.rs b/core/src/syscall/precompiles/keccak256/execute.rs index eecc747bed..d6c306c45f 100644 --- a/core/src/syscall/precompiles/keccak256/execute.rs +++ b/core/src/syscall/precompiles/keccak256/execute.rs @@ -99,11 +99,9 @@ impl Syscall for KeccakPermuteChip { // Push the Keccak permute event. let shard = rt.current_shard(); let channel = rt.current_channel(); - let lookup_id = rt.syscall_lookup_id; rt.record_mut() .keccak_permute_events .push(KeccakPermuteEvent { - lookup_id, shard, channel, clk: start_clk, diff --git a/core/src/syscall/precompiles/keccak256/mod.rs b/core/src/syscall/precompiles/keccak256/mod.rs index 2b95b8b400..4110707a83 100644 --- a/core/src/syscall/precompiles/keccak256/mod.rs +++ b/core/src/syscall/precompiles/keccak256/mod.rs @@ -15,7 +15,6 @@ const STATE_NUM_WORDS: usize = STATE_SIZE * 2; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct KeccakPermuteEvent { - pub lookup_id: usize, pub shard: u32, pub channel: u32, pub clk: u32, diff --git a/core/src/syscall/precompiles/keccak256/trace.rs b/core/src/syscall/precompiles/keccak256/trace.rs index e4700fe97a..01b07fb743 100644 --- a/core/src/syscall/precompiles/keccak256/trace.rs +++ b/core/src/syscall/precompiles/keccak256/trace.rs @@ -83,12 +83,8 @@ impl MachineAir for KeccakPermuteChip { *read_record, &mut new_byte_lookup_events, ); - new_byte_lookup_events.add_u8_range_checks( - shard, - channel, - &read_record.value.to_le_bytes(), - ); } + cols.do_memory_check = F::one(); cols.receive_ecall = F::one(); } @@ -103,12 +99,8 @@ impl MachineAir for KeccakPermuteChip { *write_record, &mut new_byte_lookup_events, ); - new_byte_lookup_events.add_u8_range_checks( - shard, - channel, - &write_record.value.to_le_bytes(), - ); } + cols.do_memory_check = F::one(); } @@ -155,19 +147,10 @@ impl MachineAir for KeccakPermuteChip { } // Convert the trace to a row major matrix. - let mut trace = RowMajorMatrix::new( + RowMajorMatrix::new( rows.into_iter().flatten().collect::>(), NUM_KECCAK_MEM_COLS, - ); - - // Write the nonce to the trace. - for i in 0..trace.height() { - let cols: &mut KeccakMemCols = - trace.values[i * NUM_KECCAK_MEM_COLS..(i + 1) * NUM_KECCAK_MEM_COLS].borrow_mut(); - cols.nonce = F::from_canonical_usize(i); - } - - trace + ) } fn included(&self, shard: &Self::Record) -> bool { diff --git a/core/src/syscall/precompiles/mod.rs b/core/src/syscall/precompiles/mod.rs index bc08c6856a..7b107e2d5b 100644 --- a/core/src/syscall/precompiles/mod.rs +++ b/core/src/syscall/precompiles/mod.rs @@ -1,3 +1,4 @@ +pub mod blake3; pub mod edwards; pub mod keccak256; pub mod sha256; @@ -19,7 +20,6 @@ use serde::{Deserialize, Serialize}; /// Elliptic curve add event. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ECAddEvent { - pub lookup_id: usize, pub shard: u32, pub channel: u32, pub clk: u32, @@ -67,9 +67,7 @@ pub fn create_ec_add_event( let p_memory_records = rt.mw_slice(p_ptr, &result_words); - println!("ec-add lookup id {:?}", rt.syscall_lookup_id); ECAddEvent { - lookup_id: rt.syscall_lookup_id, shard: rt.current_shard(), channel: rt.current_channel(), clk: start_clk, @@ -85,7 +83,6 @@ pub fn create_ec_add_event( /// Elliptic curve double event. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ECDoubleEvent { - pub lookup_id: usize, pub shard: u32, pub channel: u32, pub clk: u32, @@ -122,7 +119,6 @@ pub fn create_ec_double_event( let p_memory_records = rt.mw_slice(p_ptr, &result_words); ECDoubleEvent { - lookup_id: rt.syscall_lookup_id, shard: rt.current_shard(), channel: rt.current_channel(), clk: start_clk, @@ -135,7 +131,6 @@ pub fn create_ec_double_event( /// Elliptic curve point decompress event. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ECDecompressEvent { - pub lookup_id: usize, pub shard: u32, pub channel: u32, pub clk: u32, @@ -181,7 +176,6 @@ pub fn create_ec_decompress_event( let y_memory_records = rt.mw_slice(slice_ptr, &y_words); ECDecompressEvent { - lookup_id: rt.syscall_lookup_id, shard: rt.current_shard(), channel: rt.current_channel(), clk: start_clk, diff --git a/core/src/syscall/precompiles/sha256/compress/air.rs b/core/src/syscall/precompiles/sha256/compress/air.rs index 7a28a456b6..2f4bd5000a 100644 --- a/core/src/syscall/precompiles/sha256/compress/air.rs +++ b/core/src/syscall/precompiles/sha256/compress/air.rs @@ -30,12 +30,6 @@ where let local: &ShaCompressCols = (*local).borrow(); let next: &ShaCompressCols = (*next).borrow(); - // Constrain the incrementing nonce. - builder.when_first_row().assert_zero(local.nonce); - builder - .when_transition() - .assert_eq(local.nonce + AB::Expr::one(), next.nonce); - self.eval_control_flow_flags(builder, local, next); self.eval_memory(builder, local); @@ -52,7 +46,6 @@ where local.shard, local.channel, local.clk, - local.nonce, AB::F::from_canonical_u32(SyscallCode::SHA_COMPRESS.syscall_id()), local.w_ptr, local.h_ptr, @@ -78,15 +71,19 @@ impl ShaCompressChip { for i in 0..8 { octet_sum += local.octet[i].into(); } - builder.assert_one(octet_sum); + builder.when(local.is_real).assert_one(octet_sum); // Verify that the first row's octet value is correct. - builder.when_first_row().assert_one(local.octet[0]); + builder + .when_first_row() + .when(local.is_real) + .assert_one(local.octet[0]); // Verify correct transition for octet column. for i in 0..8 { builder .when_transition() + .when(next.is_real) .when(local.octet[i]) .assert_one(next.octet[(i + 1) % 8]) } @@ -101,15 +98,19 @@ impl ShaCompressChip { for i in 0..10 { octet_num_sum += local.octet_num[i].into(); } - builder.assert_one(octet_num_sum); + builder.when(local.is_real).assert_one(octet_num_sum); // The first row should have octet_num[0] = 1 if it's real. - builder.when_first_row().assert_one(local.octet_num[0]); + builder + .when_first_row() + .when(local.is_real) + .assert_one(local.octet_num[0]); // If current row is not last of an octet and next row is real, octet_num should be the same. for i in 0..10 { builder .when_transition() + .when(next.is_real) .when_not(local.octet[7]) .assert_eq(local.octet_num[i], next.octet_num[i]); } @@ -118,6 +119,7 @@ impl ShaCompressChip { for i in 0..10 { builder .when_transition() + .when(next.is_real) .when(local.octet[7]) .assert_eq(local.octet_num[i], next.octet_num[(i + 1) % 10]); } @@ -144,26 +146,19 @@ impl ShaCompressChip { .assert_word_eq(*var, *local.mem.value()); } - // Assert that the is_initialize flag is correct. - builder.assert_eq(local.is_initialize, local.octet_num[0] * local.is_real); - // Assert that the is_compression flag is correct. builder.assert_eq( local.is_compression, - (local.octet_num[1] + local.octet_num[1] + local.octet_num[2] + local.octet_num[3] + local.octet_num[4] + local.octet_num[5] + local.octet_num[6] + local.octet_num[7] - + local.octet_num[8]) - * local.is_real, + + local.octet_num[8], ); - // Assert that the is_finalize flag is correct. - builder.assert_eq(local.is_finalize, local.octet_num[9] * local.is_real); - builder.assert_eq( local.is_last_row.into(), local.octet[7] * local.octet_num[9], @@ -180,10 +175,6 @@ impl ShaCompressChip { .when(local.is_real) .when_not(local.is_last_row) .assert_eq(local.clk, next.clk); - builder - .when_transition() - .when_not(local.is_last_row) - .assert_eq(local.channel, next.channel); builder .when_transition() .when(local.is_real) @@ -195,9 +186,6 @@ impl ShaCompressChip { .when_not(local.is_last_row) .assert_eq(local.h_ptr, next.h_ptr); - // Assert that is_real is a bool. - builder.assert_bool(local.is_real); - // If this row is real and not the last cycle, then next row should also be real. builder .when_transition() @@ -205,12 +193,6 @@ impl ShaCompressChip { .when_not(local.is_last_row) .assert_one(next.is_real); - // Once the is_real flag is changed to false, it should not be changed back. - builder - .when_transition() - .when_not(local.is_real) - .assert_zero(next.is_real); - // Assert that the table ends in nonreal columns. Since each compress ecall is 80 cycles and // the table is padded to a power of 2, the last row of the table should always be padding. builder.when_last_row().assert_zero(local.is_real); @@ -218,13 +200,15 @@ impl ShaCompressChip { /// Constrains that memory address is correct and that memory is correctly written/read. fn eval_memory(&self, builder: &mut AB, local: &ShaCompressCols) { + let is_initialize = local.octet_num[0]; + let is_finalize = local.octet_num[9]; builder.eval_memory_access( local.shard, local.channel, - local.clk + local.is_finalize, + local.clk + is_finalize, local.mem_addr, &local.mem, - local.is_initialize + local.is_compression + local.is_finalize, + is_initialize + local.is_compression + is_finalize, ); // Calculate the current cycle_num. @@ -240,7 +224,7 @@ impl ShaCompressChip { } // Verify correct mem address for initialize phase - builder.when(local.is_initialize).assert_eq( + builder.when(is_initialize).assert_eq( local.mem_addr, local.h_ptr + cycle_step.clone() * AB::Expr::from_canonical_u32(4), ); @@ -255,7 +239,7 @@ impl ShaCompressChip { ); // Verify correct mem address for finalize phase - builder.when(local.is_finalize).assert_eq( + builder.when(is_finalize).assert_eq( local.mem_addr, local.h_ptr + cycle_step.clone() * AB::Expr::from_canonical_u32(4), ); @@ -267,11 +251,11 @@ impl ShaCompressChip { ]; for (i, var) in vars.iter().enumerate() { builder - .when(local.is_initialize) + .when(is_initialize) .when(local.octet[i]) .assert_word_eq(*var, *local.mem.prev_value()); builder - .when(local.is_initialize) + .when(is_initialize) .when(local.octet[i]) .assert_word_eq(*var, *local.mem.value()); } @@ -283,7 +267,7 @@ impl ShaCompressChip { // In the finalize phase, verify that the correct value is written to memory. builder - .when(local.is_finalize) + .when(is_finalize) .assert_word_eq(*local.mem.value(), local.finalize_add.value); } @@ -595,6 +579,7 @@ impl ShaCompressChip { builder: &mut AB, local: &ShaCompressCols, ) { + let is_finalize = local.octet_num[9]; // In the finalize phase, need to execute h[0] + a, h[1] + b, ..., h[7] + h, for each of the // phase's 8 rows. // We can get the needed operand (a,b,c,...,h) by doing an inner product between octet and @@ -611,7 +596,7 @@ impl ShaCompressChip { } builder - .when(local.is_finalize) + .when(is_finalize) .assert_word_eq(filtered_operand, local.finalized_operand.map(|x| x.into())); // finalize_add.result = h[i] + finalized_operand @@ -622,7 +607,7 @@ impl ShaCompressChip { local.finalize_add, local.shard, local.channel, - local.is_finalize.into(), + is_finalize.into(), ); // Memory write is constrained in constrain_memory. diff --git a/core/src/syscall/precompiles/sha256/compress/columns.rs b/core/src/syscall/precompiles/sha256/compress/columns.rs index 0fd7a7fbf4..94a200aedd 100644 --- a/core/src/syscall/precompiles/sha256/compress/columns.rs +++ b/core/src/syscall/precompiles/sha256/compress/columns.rs @@ -26,7 +26,6 @@ pub struct ShaCompressCols { /// Inputs. pub shard: T, pub channel: T, - pub nonce: T, pub clk: T, pub w_ptr: T, pub h_ptr: T, @@ -103,9 +102,7 @@ pub struct ShaCompressCols { pub finalized_operand: Word, pub finalize_add: AddOperation, - pub is_initialize: T, pub is_compression: T, - pub is_finalize: T, pub is_last_row: T, pub is_real: T, diff --git a/core/src/syscall/precompiles/sha256/compress/execute.rs b/core/src/syscall/precompiles/sha256/compress/execute.rs index 5ed33dd2b7..a019abbd4c 100644 --- a/core/src/syscall/precompiles/sha256/compress/execute.rs +++ b/core/src/syscall/precompiles/sha256/compress/execute.rs @@ -76,11 +76,9 @@ impl Syscall for ShaCompressChip { } // Push the SHA extend event. - let lookup_id = rt.syscall_lookup_id; let shard = rt.current_shard(); let channel = rt.current_channel(); rt.record_mut().sha_compress_events.push(ShaCompressEvent { - lookup_id, shard, channel, clk: start_clk, diff --git a/core/src/syscall/precompiles/sha256/compress/mod.rs b/core/src/syscall/precompiles/sha256/compress/mod.rs index 47401a25bc..fd6c50f0fc 100644 --- a/core/src/syscall/precompiles/sha256/compress/mod.rs +++ b/core/src/syscall/precompiles/sha256/compress/mod.rs @@ -20,7 +20,6 @@ pub const SHA_COMPRESS_K: [u32; 64] = [ #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ShaCompressEvent { - pub lookup_id: usize, pub shard: u32, pub channel: u32, pub clk: u32, diff --git a/core/src/syscall/precompiles/sha256/compress/trace.rs b/core/src/syscall/precompiles/sha256/compress/trace.rs index 6cd524fbd6..bd0b8f8177 100644 --- a/core/src/syscall/precompiles/sha256/compress/trace.rs +++ b/core/src/syscall/precompiles/sha256/compress/trace.rs @@ -2,7 +2,6 @@ use std::borrow::BorrowMut; use p3_field::PrimeField32; use p3_matrix::dense::RowMajorMatrix; -use p3_matrix::Matrix; use super::{ columns::{ShaCompressCols, NUM_SHA_COMPRESS_COLS}, @@ -54,7 +53,6 @@ impl MachineAir for ShaCompressChip { cols.octet[j] = F::one(); cols.octet_num[octet_num_idx] = F::one(); - cols.is_initialize = F::one(); cols.mem.populate_read( channel, @@ -209,7 +207,6 @@ impl MachineAir for ShaCompressChip { cols.octet[j] = F::one(); cols.octet_num[octet_num_idx] = F::one(); - cols.is_finalize = F::one(); cols.finalize_add .populate(output, shard, channel, og_h[j], event.h[j]); @@ -252,48 +249,13 @@ impl MachineAir for ShaCompressChip { output.add_byte_lookup_events(new_byte_lookup_events); - let num_real_rows = rows.len(); - pad_rows(&mut rows, || [F::zero(); NUM_SHA_COMPRESS_COLS]); - // Set the octet_num and octect columns for the padded rows. - let mut octet_num = 0; - let mut octet = 0; - for row in rows[num_real_rows..].iter_mut() { - let cols: &mut ShaCompressCols = row.as_mut_slice().borrow_mut(); - cols.octet_num[octet_num] = F::one(); - cols.octet[octet] = F::one(); - - // If in the compression phase, set the k value. - if octet_num != 0 && octet_num != 9 { - let compression_idx = octet_num - 1; - let k_idx = compression_idx * 8 + octet; - cols.k = Word::from(SHA_COMPRESS_K[k_idx]); - } - - octet = (octet + 1) % 8; - if octet == 0 { - octet_num = (octet_num + 1) % 10; - } - - cols.is_last_row = cols.octet[7] * cols.octet_num[9]; - } - // Convert the trace to a row major matrix. - let mut trace = RowMajorMatrix::new( + RowMajorMatrix::new( rows.into_iter().flatten().collect::>(), NUM_SHA_COMPRESS_COLS, - ); - - // Write the nonces to the trace. - for i in 0..trace.height() { - let cols: &mut ShaCompressCols = trace.values - [i * NUM_SHA_COMPRESS_COLS..(i + 1) * NUM_SHA_COMPRESS_COLS] - .borrow_mut(); - cols.nonce = F::from_canonical_usize(i); - } - - trace + ) } fn included(&self, shard: &Self::Record) -> bool { diff --git a/core/src/syscall/precompiles/sha256/extend/air.rs b/core/src/syscall/precompiles/sha256/extend/air.rs index 69f38c3557..9da6048048 100644 --- a/core/src/syscall/precompiles/sha256/extend/air.rs +++ b/core/src/syscall/precompiles/sha256/extend/air.rs @@ -27,13 +27,6 @@ where let (local, next) = (main.row_slice(0), main.row_slice(1)); let local: &ShaExtendCols = (*local).borrow(); let next: &ShaExtendCols = (*next).borrow(); - - // Constrain the incrementing nonce. - builder.when_first_row().assert_zero(local.nonce); - builder - .when_transition() - .assert_eq(local.nonce + AB::Expr::one(), next.nonce); - let i_start = AB::F::from_canonical_u32(16); let nb_bytes_in_word = AB::F::from_canonical_u32(4); @@ -49,10 +42,6 @@ where .when_transition() .when_not(local.cycle_16_end.result * local.cycle_48[2]) .assert_eq(local.clk, next.clk); - builder - .when_transition() - .when_not(local.cycle_16_end.result * local.cycle_48[2]) - .assert_eq(local.channel, next.channel); builder .when_transition() .when_not(local.cycle_16_end.result * local.cycle_48[2]) @@ -225,28 +214,22 @@ where local.is_real, ); - builder.assert_word_eq(*local.w_i.value(), local.s2.value); - // Receive syscall event in first row of 48-cycle. builder.receive_syscall( local.shard, local.channel, local.clk, - local.nonce, AB::F::from_canonical_u32(SyscallCode::SHA_EXTEND.syscall_id()), local.w_ptr, AB::Expr::zero(), local.cycle_48_start, ); - // Assert that is_real is a bool. - builder.assert_bool(local.is_real); - - // Ensure that all rows in a 48 row cycle has the same `is_real` values. + // If this row is real and not the last cycle, then next row should also be real. builder .when_transition() - .when_not(local.cycle_48_end) - .assert_eq(local.is_real, next.is_real); + .when(local.is_real - local.cycle_48_end) + .assert_one(next.is_real); // Assert that the table ends in nonreal columns. Since each extend ecall is 48 cycles and // the table is padded to a power of 2, the last row of the table should always be padding. diff --git a/core/src/syscall/precompiles/sha256/extend/columns.rs b/core/src/syscall/precompiles/sha256/extend/columns.rs index 0855b44139..5eb99e1f4d 100644 --- a/core/src/syscall/precompiles/sha256/extend/columns.rs +++ b/core/src/syscall/precompiles/sha256/extend/columns.rs @@ -18,7 +18,6 @@ pub struct ShaExtendCols { /// Inputs. pub shard: T, pub channel: T, - pub nonce: T, pub clk: T, pub w_ptr: T, @@ -37,9 +36,8 @@ pub struct ShaExtendCols { /// Flags for when in the first, second, or third 16-row cycle. pub cycle_48: [T; 3], - /// Whether the current row is the first of a 48-row cycle and is real. + /// Whether the current row is the first of a 48-row cycle. pub cycle_48_start: T, - /// Whether the current row is the end of a 48-row cycle and is real. pub cycle_48_end: T, /// Inputs to `s0`. diff --git a/core/src/syscall/precompiles/sha256/extend/execute.rs b/core/src/syscall/precompiles/sha256/extend/execute.rs index d9b1a70e09..bd163c26c9 100644 --- a/core/src/syscall/precompiles/sha256/extend/execute.rs +++ b/core/src/syscall/precompiles/sha256/extend/execute.rs @@ -60,11 +60,9 @@ impl Syscall for ShaExtendChip { } // Push the SHA extend event. - let lookup_id = rt.syscall_lookup_id; let shard = rt.current_shard(); let channel = rt.current_channel(); rt.record_mut().sha_extend_events.push(ShaExtendEvent { - lookup_id, shard, channel, clk: clk_init, diff --git a/core/src/syscall/precompiles/sha256/extend/flags.rs b/core/src/syscall/precompiles/sha256/extend/flags.rs index a06f117e3d..2f97dc92fc 100644 --- a/core/src/syscall/precompiles/sha256/extend/flags.rs +++ b/core/src/syscall/precompiles/sha256/extend/flags.rs @@ -7,7 +7,6 @@ use p3_field::PrimeField32; use p3_field::TwoAdicField; use p3_matrix::Matrix; -use crate::air::BaseAirBuilder; use crate::air::SP1AirBuilder; use crate::operations::IsZeroOperation; @@ -71,7 +70,7 @@ impl ShaExtendChip { builder, local.cycle_16 - AB::Expr::from(g), local.cycle_16_start, - one.clone(), + local.is_real.into(), ); // Constrain `cycle_16_end.result` to be `cycle_16 - 1 == 0`. Intuitively g^16 is 1. @@ -79,7 +78,7 @@ impl ShaExtendChip { builder, local.cycle_16 - AB::Expr::one(), local.cycle_16_end, - one.clone(), + local.is_real.into(), ); // Constrain `cycle_48` to be [1, 0, 0] in the first row. @@ -124,10 +123,10 @@ impl ShaExtendChip { .when(local.cycle_16_end.result * local.cycle_48[2]) .assert_eq(next.i, AB::F::from_canonical_u32(16)); - // When it's not the end of a 48-cycle, the next `i` must be the current plus one. + // When it's not the end of a 16-cycle, the next `i` must be the current plus one. builder .when_transition() - .when_not(local.cycle_16_end.result * local.cycle_48[2]) + .when(one.clone() - local.cycle_16_end.result) .assert_eq(local.i + one.clone(), next.i); } } diff --git a/core/src/syscall/precompiles/sha256/extend/mod.rs b/core/src/syscall/precompiles/sha256/extend/mod.rs index 7868cabd88..4caff508b9 100644 --- a/core/src/syscall/precompiles/sha256/extend/mod.rs +++ b/core/src/syscall/precompiles/sha256/extend/mod.rs @@ -11,7 +11,6 @@ use serde::{Deserialize, Serialize}; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ShaExtendEvent { - pub lookup_id: usize, pub shard: u32, pub channel: u32, pub clk: u32, diff --git a/core/src/syscall/precompiles/sha256/extend/trace.rs b/core/src/syscall/precompiles/sha256/extend/trace.rs index 2dcf882260..2a976ef0d6 100644 --- a/core/src/syscall/precompiles/sha256/extend/trace.rs +++ b/core/src/syscall/precompiles/sha256/extend/trace.rs @@ -1,7 +1,7 @@ +use std::borrow::BorrowMut; + use p3_field::PrimeField32; use p3_matrix::dense::RowMajorMatrix; -use p3_matrix::Matrix; -use std::borrow::BorrowMut; use crate::{ air::MachineAir, @@ -156,19 +156,10 @@ impl MachineAir for ShaExtendChip { } // Convert the trace to a row major matrix. - let mut trace = RowMajorMatrix::new( + RowMajorMatrix::new( rows.into_iter().flatten().collect::>(), NUM_SHA_EXTEND_COLS, - ); - - // Write the nonces to the trace. - for i in 0..trace.height() { - let cols: &mut ShaExtendCols = - trace.values[i * NUM_SHA_EXTEND_COLS..(i + 1) * NUM_SHA_EXTEND_COLS].borrow_mut(); - cols.nonce = F::from_canonical_usize(i); - } - - trace + ) } fn included(&self, shard: &Self::Record) -> bool { diff --git a/core/src/syscall/precompiles/uint256/air.rs b/core/src/syscall/precompiles/uint256/air.rs index dd8cce29e4..498ac78c6c 100644 --- a/core/src/syscall/precompiles/uint256/air.rs +++ b/core/src/syscall/precompiles/uint256/air.rs @@ -17,7 +17,6 @@ use crate::utils::{ use generic_array::GenericArray; use num::Zero; use num::{BigUint, One}; -use p3_air::AirBuilder; use p3_air::{Air, BaseAir}; use p3_field::AbstractField; use p3_field::PrimeField32; @@ -34,7 +33,6 @@ const NUM_COLS: usize = size_of::>(); #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Uint256MulEvent { - pub lookup_id: usize, pub shard: u32, pub channel: u32, pub clk: u32, @@ -73,9 +71,6 @@ pub struct Uint256MulCols { /// The clock cycle of the syscall. pub clk: T, - /// The none of the operation. - pub nonce: T, - /// The pointer to the first input. pub x_ptr: T, @@ -206,17 +201,7 @@ impl MachineAir for Uint256MulChip { }); // Convert the trace to a row major matrix. - let mut trace = - RowMajorMatrix::new(rows.into_iter().flatten().collect::>(), NUM_COLS); - - // Write the nonces to the trace. - for i in 0..trace.height() { - let cols: &mut Uint256MulCols = - trace.values[i * NUM_COLS..(i + 1) * NUM_COLS].borrow_mut(); - cols.nonce = F::from_canonical_usize(i); - } - - trace + RowMajorMatrix::new(rows.into_iter().flatten().collect::>(), NUM_COLS) } fn included(&self, shard: &Self::Record) -> bool { @@ -272,12 +257,10 @@ impl Syscall for Uint256MulChip { // Write the result to x and keep track of the memory records. let x_memory_records = rt.mw_slice(x_ptr, &result); - let lookup_id = rt.syscall_lookup_id; let shard = rt.current_shard(); let channel = rt.current_channel(); let clk = rt.clk; rt.record_mut().uint256_mul_events.push(Uint256MulEvent { - lookup_id, shard, channel, clk, @@ -310,14 +293,6 @@ where let main = builder.main(); let local = main.row_slice(0); let local: &Uint256MulCols = (*local).borrow(); - let next = main.row_slice(1); - let next: &Uint256MulCols = (*next).borrow(); - - // Constrain the incrementing nonce. - builder.when_first_row().assert_zero(local.nonce); - builder - .when_transition() - .assert_eq(local.nonce + AB::Expr::one(), next.nonce); // We are computing (x * y) % modulus. The value of x is stored in the "prev_value" of // the x_memory, since we write to it later. @@ -393,7 +368,6 @@ where local.shard, local.channel, local.clk, - local.nonce, AB::F::from_canonical_u32(SyscallCode::UINT256_MUL.syscall_id()), local.x_ptr, local.y_ptr, diff --git a/core/src/syscall/precompiles/weierstrass/weierstrass_add.rs b/core/src/syscall/precompiles/weierstrass/weierstrass_add.rs index 2eef29c6c7..adbab629f0 100644 --- a/core/src/syscall/precompiles/weierstrass/weierstrass_add.rs +++ b/core/src/syscall/precompiles/weierstrass/weierstrass_add.rs @@ -52,7 +52,6 @@ pub struct WeierstrassAddAssignCols { pub is_real: T, pub shard: T, pub channel: T, - pub nonce: T, pub clk: T, pub p_ptr: T, pub q_ptr: T, @@ -303,21 +302,10 @@ impl MachineAir }); // Convert the trace to a row major matrix. - let mut trace = RowMajorMatrix::new( + RowMajorMatrix::new( rows.into_iter().flatten().collect::>(), num_weierstrass_add_cols::(), - ); - - // Write the nonces to the trace. - for i in 0..trace.height() { - let cols: &mut WeierstrassAddAssignCols = trace.values[i - * num_weierstrass_add_cols::() - ..(i + 1) * num_weierstrass_add_cols::()] - .borrow_mut(); - cols.nonce = F::from_canonical_usize(i); - } - - trace + ) } fn included(&self, shard: &Self::Record) -> bool { @@ -343,125 +331,117 @@ where { fn eval(&self, builder: &mut AB) { let main = builder.main(); - let local = main.row_slice(0); - let local: &WeierstrassAddAssignCols = (*local).borrow(); - let next = main.row_slice(1); - let next: &WeierstrassAddAssignCols = (*next).borrow(); - - // Constrain the incrementing nonce. - builder.when_first_row().assert_zero(local.nonce); - builder - .when_transition() - .assert_eq(local.nonce + AB::Expr::one(), next.nonce); + let row = main.row_slice(0); + let row: &WeierstrassAddAssignCols = (*row).borrow(); let num_words_field_element = ::Limbs::USIZE / 4; - let p_x = limbs_from_prev_access(&local.p_access[0..num_words_field_element]); - let p_y = limbs_from_prev_access(&local.p_access[num_words_field_element..]); + let p_x = limbs_from_prev_access(&row.p_access[0..num_words_field_element]); + let p_y = limbs_from_prev_access(&row.p_access[num_words_field_element..]); - let q_x = limbs_from_prev_access(&local.q_access[0..num_words_field_element]); - let q_y = limbs_from_prev_access(&local.q_access[num_words_field_element..]); + let q_x = limbs_from_prev_access(&row.q_access[0..num_words_field_element]); + let q_y = limbs_from_prev_access(&row.q_access[num_words_field_element..]); // slope = (q.y - p.y) / (q.x - p.x). let slope = { - local.slope_numerator.eval( + row.slope_numerator.eval( builder, &q_y, &p_y, FieldOperation::Sub, - local.shard, - local.channel, - local.is_real, + row.shard, + row.channel, + row.is_real, ); - local.slope_denominator.eval( + row.slope_denominator.eval( builder, &q_x, &p_x, FieldOperation::Sub, - local.shard, - local.channel, - local.is_real, + row.shard, + row.channel, + row.is_real, ); - local.slope.eval( + row.slope.eval( builder, - &local.slope_numerator.result, - &local.slope_denominator.result, + &row.slope_numerator.result, + &row.slope_denominator.result, FieldOperation::Div, - local.shard, - local.channel, - local.is_real, + row.shard, + row.channel, + row.is_real, ); - &local.slope.result + &row.slope.result }; // x = slope * slope - self.x - other.x. let x = { - local.slope_squared.eval( + row.slope_squared.eval( builder, slope, slope, FieldOperation::Mul, - local.shard, - local.channel, - local.is_real, + row.shard, + row.channel, + row.is_real, ); - local.p_x_plus_q_x.eval( + row.p_x_plus_q_x.eval( builder, &p_x, &q_x, FieldOperation::Add, - local.shard, - local.channel, - local.is_real, + row.shard, + row.channel, + row.is_real, ); - local.x3_ins.eval( + row.x3_ins.eval( builder, - &local.slope_squared.result, - &local.p_x_plus_q_x.result, + &row.slope_squared.result, + &row.p_x_plus_q_x.result, FieldOperation::Sub, - local.shard, - local.channel, - local.is_real, + row.shard, + row.channel, + row.is_real, ); - &local.x3_ins.result + &row.x3_ins.result }; // y = slope * (p.x - x_3n) - q.y. { - local.p_x_minus_x.eval( + row.p_x_minus_x.eval( builder, &p_x, x, FieldOperation::Sub, - local.shard, - local.channel, - local.is_real, + row.shard, + row.channel, + row.is_real, ); - local.slope_times_p_x_minus_x.eval( + row.slope_times_p_x_minus_x.eval( builder, slope, - &local.p_x_minus_x.result, + &row.p_x_minus_x.result, FieldOperation::Mul, - local.shard, - local.channel, - local.is_real, + row.shard, + row.channel, + row.is_real, ); - local.y3_ins.eval( + row.y3_ins.eval( builder, - &local.slope_times_p_x_minus_x.result, + &row.slope_times_p_x_minus_x.result, &p_y, FieldOperation::Sub, - local.shard, - local.channel, - local.is_real, + row.shard, + row.channel, + row.is_real, ); } @@ -469,29 +449,29 @@ where // ensure that p_access is updated with the new value. for i in 0..E::BaseField::NB_LIMBS { builder - .when(local.is_real) - .assert_eq(local.x3_ins.result[i], local.p_access[i / 4].value()[i % 4]); - builder.when(local.is_real).assert_eq( - local.y3_ins.result[i], - local.p_access[num_words_field_element + i / 4].value()[i % 4], + .when(row.is_real) + .assert_eq(row.x3_ins.result[i], row.p_access[i / 4].value()[i % 4]); + builder.when(row.is_real).assert_eq( + row.y3_ins.result[i], + row.p_access[num_words_field_element + i / 4].value()[i % 4], ); } builder.eval_memory_access_slice( - local.shard, - local.channel, - local.clk.into(), - local.q_ptr, - &local.q_access, - local.is_real, + row.shard, + row.channel, + row.clk.into(), + row.q_ptr, + &row.q_access, + row.is_real, ); builder.eval_memory_access_slice( - local.shard, - local.channel, - local.clk + AB::F::from_canonical_u32(1), // We read p at +1 since p, q could be the same. - local.p_ptr, - &local.p_access, - local.is_real, + row.shard, + row.channel, + row.clk + AB::F::from_canonical_u32(1), // We read p at +1 since p, q could be the same. + row.p_ptr, + &row.p_access, + row.is_real, ); // Fetch the syscall id for the curve type. @@ -507,14 +487,13 @@ where }; builder.receive_syscall( - local.shard, - local.channel, - local.clk, - local.nonce, + row.shard, + row.channel, + row.clk, syscall_id_felt, - local.p_ptr, - local.q_ptr, - local.is_real, + row.p_ptr, + row.q_ptr, + row.is_real, ); } } diff --git a/core/src/syscall/precompiles/weierstrass/weierstrass_decompress.rs b/core/src/syscall/precompiles/weierstrass/weierstrass_decompress.rs index 62958e86ca..bd38edea8e 100644 --- a/core/src/syscall/precompiles/weierstrass/weierstrass_decompress.rs +++ b/core/src/syscall/precompiles/weierstrass/weierstrass_decompress.rs @@ -54,7 +54,6 @@ pub struct WeierstrassDecompressCols { pub shard: T, pub channel: T, pub clk: T, - pub nonce: T, pub ptr: T, pub is_odd: T, pub x_access: GenericArray, P::WordsFieldElement>, @@ -223,21 +222,10 @@ impl MachineAir row }); - let mut trace = RowMajorMatrix::new( + RowMajorMatrix::new( rows.into_iter().flatten().collect::>(), num_weierstrass_decompress_cols::(), - ); - - // Write the nonces to the trace. - for i in 0..trace.height() { - let cols: &mut WeierstrassDecompressCols = trace.values[i - * num_weierstrass_decompress_cols::() - ..(i + 1) * num_weierstrass_decompress_cols::()] - .borrow_mut(); - cols.nonce = F::from_canonical_usize(i); - } - - trace + ) } fn included(&self, shard: &Self::Record) -> bool { @@ -262,108 +250,99 @@ where { fn eval(&self, builder: &mut AB) { let main = builder.main(); - let local = main.row_slice(0); - let local: &WeierstrassDecompressCols = (*local).borrow(); - let next = main.row_slice(1); - let next: &WeierstrassDecompressCols = (*next).borrow(); - - // Constrain the incrementing nonce. - builder.when_first_row().assert_zero(local.nonce); - builder - .when_transition() - .assert_eq(local.nonce + AB::Expr::one(), next.nonce); + let row = main.row_slice(0); + let row: &WeierstrassDecompressCols = (*row).borrow(); let num_limbs = ::Limbs::USIZE; let num_words_field_element = num_limbs / 4; - builder.assert_bool(local.is_odd); + builder.assert_bool(row.is_odd); let x: Limbs::Limbs> = - limbs_from_prev_access(&local.x_access); - local - .range_x - .eval(builder, &x, local.shard, local.channel, local.is_real); - local.x_2.eval( + limbs_from_prev_access(&row.x_access); + row.range_x + .eval(builder, &x, row.shard, row.channel, row.is_real); + row.x_2.eval( builder, &x, &x, FieldOperation::Mul, - local.shard, - local.channel, - local.is_real, + row.shard, + row.channel, + row.is_real, ); - local.x_3.eval( + row.x_3.eval( builder, - &local.x_2.result, + &row.x_2.result, &x, FieldOperation::Mul, - local.shard, - local.channel, - local.is_real, + row.shard, + row.channel, + row.is_real, ); let b = E::b_int(); let b_const = E::BaseField::to_limbs_field::(&b); - local.x_3_plus_b.eval( + row.x_3_plus_b.eval( builder, - &local.x_3.result, + &row.x_3.result, &b_const, FieldOperation::Add, - local.shard, - local.channel, - local.is_real, + row.shard, + row.channel, + row.is_real, ); - local.neg_y.eval( + row.neg_y.eval( builder, &[AB::Expr::zero()].iter(), - &local.y.multiplication.result, + &row.y.multiplication.result, FieldOperation::Sub, - local.shard, - local.channel, - local.is_real, + row.shard, + row.channel, + row.is_real, ); // Interpret the lowest bit of Y as whether it is odd or not. - let y_is_odd = local.y.lsb; + let y_is_odd = row.y.lsb; - local.y.eval( + row.y.eval( builder, - &local.x_3_plus_b.result, - local.y.lsb, - local.shard, - local.channel, - local.is_real, + &row.x_3_plus_b.result, + row.y.lsb, + row.shard, + row.channel, + row.is_real, ); let y_limbs: Limbs::Limbs> = - limbs_from_access(&local.y_access); + limbs_from_access(&row.y_access); builder - .when(local.is_real) - .when_ne(y_is_odd, AB::Expr::one() - local.is_odd) - .assert_all_eq(local.y.multiplication.result, y_limbs); + .when(row.is_real) + .when_ne(y_is_odd, AB::Expr::one() - row.is_odd) + .assert_all_eq(row.y.multiplication.result, y_limbs); builder - .when(local.is_real) - .when_ne(y_is_odd, local.is_odd) - .assert_all_eq(local.neg_y.result, y_limbs); + .when(row.is_real) + .when_ne(y_is_odd, row.is_odd) + .assert_all_eq(row.neg_y.result, y_limbs); for i in 0..num_words_field_element { builder.eval_memory_access( - local.shard, - local.channel, - local.clk, - local.ptr.into() + AB::F::from_canonical_u32((i as u32) * 4 + num_limbs as u32), - &local.x_access[i], - local.is_real, + row.shard, + row.channel, + row.clk, + row.ptr.into() + AB::F::from_canonical_u32((i as u32) * 4 + num_limbs as u32), + &row.x_access[i], + row.is_real, ); } for i in 0..num_words_field_element { builder.eval_memory_access( - local.shard, - local.channel, - local.clk, - local.ptr.into() + AB::F::from_canonical_u32((i as u32) * 4), - &local.y_access[i], - local.is_real, + row.shard, + row.channel, + row.clk, + row.ptr.into() + AB::F::from_canonical_u32((i as u32) * 4), + &row.y_access[i], + row.is_real, ); } @@ -378,14 +357,13 @@ where }; builder.receive_syscall( - local.shard, - local.channel, - local.clk, - local.nonce, + row.shard, + row.channel, + row.clk, syscall_id, - local.ptr, - local.is_odd, - local.is_real, + row.ptr, + row.is_odd, + row.is_real, ); } } diff --git a/core/src/syscall/precompiles/weierstrass/weierstrass_double.rs b/core/src/syscall/precompiles/weierstrass/weierstrass_double.rs index 9221d680f1..50bb0a4332 100644 --- a/core/src/syscall/precompiles/weierstrass/weierstrass_double.rs +++ b/core/src/syscall/precompiles/weierstrass/weierstrass_double.rs @@ -53,7 +53,6 @@ pub struct WeierstrassDoubleAssignCols { pub is_real: T, pub shard: T, pub channel: T, - pub nonce: T, pub clk: T, pub p_ptr: T, pub p_access: GenericArray, P::WordsCurvePoint>, @@ -318,21 +317,10 @@ impl MachineAir }); // Convert the trace to a row major matrix. - let mut trace = RowMajorMatrix::new( + RowMajorMatrix::new( rows.into_iter().flatten().collect::>(), num_weierstrass_double_cols::(), - ); - - // Write the nonces to the trace. - for i in 0..trace.height() { - let cols: &mut WeierstrassDoubleAssignCols = trace.values[i - * num_weierstrass_double_cols::() - ..(i + 1) * num_weierstrass_double_cols::()] - .borrow_mut(); - cols.nonce = F::from_canonical_usize(i); - } - - trace + ) } fn included(&self, shard: &Self::Record) -> bool { @@ -358,143 +346,136 @@ where { fn eval(&self, builder: &mut AB) { let main = builder.main(); - let local = main.row_slice(0); - let local: &WeierstrassDoubleAssignCols = (*local).borrow(); - let next = main.row_slice(1); - let next: &WeierstrassDoubleAssignCols = (*next).borrow(); - - // Constrain the incrementing nonce. - builder.when_first_row().assert_zero(local.nonce); - builder - .when_transition() - .assert_eq(local.nonce + AB::Expr::one(), next.nonce); + let row = main.row_slice(0); + let row: &WeierstrassDoubleAssignCols = (*row).borrow(); let num_words_field_element = E::BaseField::NB_LIMBS / 4; - let p_x = limbs_from_prev_access(&local.p_access[0..num_words_field_element]); - let p_y = limbs_from_prev_access(&local.p_access[num_words_field_element..]); + let p_x = limbs_from_prev_access(&row.p_access[0..num_words_field_element]); + let p_y = limbs_from_prev_access(&row.p_access[num_words_field_element..]); - // `a` in the Weierstrass form: y^2 = x^3 + a * x + b. + // a in the Weierstrass form: y^2 = x^3 + a * x + b. + // TODO: U32 can't be hardcoded here? let a = E::BaseField::to_limbs_field::(&E::a_int()); // slope = slope_numerator / slope_denominator. let slope = { // slope_numerator = a + (p.x * p.x) * 3. { - local.p_x_squared.eval( + row.p_x_squared.eval( builder, &p_x, &p_x, FieldOperation::Mul, - local.shard, - local.channel, - local.is_real, + row.shard, + row.channel, + row.is_real, ); - local.p_x_squared_times_3.eval( + row.p_x_squared_times_3.eval( builder, - &local.p_x_squared.result, + &row.p_x_squared.result, &E::BaseField::to_limbs_field::(&BigUint::from(3u32)), FieldOperation::Mul, - local.shard, - local.channel, - local.is_real, + row.shard, + row.channel, + row.is_real, ); - local.slope_numerator.eval( + row.slope_numerator.eval( builder, &a, - &local.p_x_squared_times_3.result, + &row.p_x_squared_times_3.result, FieldOperation::Add, - local.shard, - local.channel, - local.is_real, + row.shard, + row.channel, + row.is_real, ); }; // slope_denominator = 2 * y. - local.slope_denominator.eval( + row.slope_denominator.eval( builder, &E::BaseField::to_limbs_field::(&BigUint::from(2u32)), &p_y, FieldOperation::Mul, - local.shard, - local.channel, - local.is_real, + row.shard, + row.channel, + row.is_real, ); - local.slope.eval( + row.slope.eval( builder, - &local.slope_numerator.result, - &local.slope_denominator.result, + &row.slope_numerator.result, + &row.slope_denominator.result, FieldOperation::Div, - local.shard, - local.channel, - local.is_real, + row.shard, + row.channel, + row.is_real, ); - &local.slope.result + &row.slope.result }; // x = slope * slope - (p.x + p.x). let x = { - local.slope_squared.eval( + row.slope_squared.eval( builder, slope, slope, FieldOperation::Mul, - local.shard, - local.channel, - local.is_real, + row.shard, + row.channel, + row.is_real, ); - local.p_x_plus_p_x.eval( + row.p_x_plus_p_x.eval( builder, &p_x, &p_x, FieldOperation::Add, - local.shard, - local.channel, - local.is_real, + row.shard, + row.channel, + row.is_real, ); - local.x3_ins.eval( + row.x3_ins.eval( builder, - &local.slope_squared.result, - &local.p_x_plus_p_x.result, + &row.slope_squared.result, + &row.p_x_plus_p_x.result, FieldOperation::Sub, - local.shard, - local.channel, - local.is_real, + row.shard, + row.channel, + row.is_real, ); - &local.x3_ins.result + &row.x3_ins.result }; // y = slope * (p.x - x) - p.y. { - local.p_x_minus_x.eval( + row.p_x_minus_x.eval( builder, &p_x, x, FieldOperation::Sub, - local.shard, - local.channel, - local.is_real, + row.shard, + row.channel, + row.is_real, ); - local.slope_times_p_x_minus_x.eval( + row.slope_times_p_x_minus_x.eval( builder, slope, - &local.p_x_minus_x.result, + &row.p_x_minus_x.result, FieldOperation::Mul, - local.shard, - local.channel, - local.is_real, + row.shard, + row.channel, + row.is_real, ); - local.y3_ins.eval( + row.y3_ins.eval( builder, - &local.slope_times_p_x_minus_x.result, + &row.slope_times_p_x_minus_x.result, &p_y, FieldOperation::Sub, - local.shard, - local.channel, - local.is_real, + row.shard, + row.channel, + row.is_real, ); } @@ -502,21 +483,21 @@ where // ensure that p_access is updated with the new value. for i in 0..E::BaseField::NB_LIMBS { builder - .when(local.is_real) - .assert_eq(local.x3_ins.result[i], local.p_access[i / 4].value()[i % 4]); - builder.when(local.is_real).assert_eq( - local.y3_ins.result[i], - local.p_access[num_words_field_element + i / 4].value()[i % 4], + .when(row.is_real) + .assert_eq(row.x3_ins.result[i], row.p_access[i / 4].value()[i % 4]); + builder.when(row.is_real).assert_eq( + row.y3_ins.result[i], + row.p_access[num_words_field_element + i / 4].value()[i % 4], ); } builder.eval_memory_access_slice( - local.shard, - local.channel, - local.clk.into(), - local.p_ptr, - &local.p_access, - local.is_real, + row.shard, + row.channel, + row.clk.into(), + row.p_ptr, + &row.p_access, + row.is_real, ); // Fetch the syscall id for the curve type. @@ -532,14 +513,13 @@ where }; builder.receive_syscall( - local.shard, - local.channel, - local.clk, - local.nonce, + row.shard, + row.channel, + row.clk, syscall_id_felt, - local.p_ptr, + row.p_ptr, AB::Expr::zero(), - local.is_real, + row.is_real, ); } } diff --git a/core/src/syscall/verify.rs b/core/src/syscall/verify.rs index 11b0430103..e40639aeba 100644 --- a/core/src/syscall/verify.rs +++ b/core/src/syscall/verify.rs @@ -1,6 +1,6 @@ use crate::{ runtime::{Syscall, SyscallContext}, - stark::StarkGenericConfig, + stark::{RiscvAir, StarkGenericConfig}, utils::BabyBearPoseidon2Inner, }; @@ -38,6 +38,17 @@ impl Syscall for SyscallVerifySP1Proof { let config = BabyBearPoseidon2Inner::new(); let mut challenger = config.challenger(); + // TODO: need to use RecursionAir here + let machine = RiscvAir::machine(config); + + // TODO: Need to import PublicValues from recursion. + // Assert the commit in vkey from runtime inputs matches the one from syscall. + // Assert that the public values digest from runtime inputs matches the one from syscall. + + // TODO: Verify proof + // machine + // .verify(proof_vk, proof, &mut challenger) + // .expect("proof verification failed"); None } diff --git a/core/src/utils/programs.rs b/core/src/utils/programs.rs index a71a7dfef3..58af5a08c4 100644 --- a/core/src/utils/programs.rs +++ b/core/src/utils/programs.rs @@ -34,6 +34,9 @@ pub mod tests { pub const ED25519_ELF: &[u8] = include_bytes!("../../../tests/ed25519/elf/riscv32im-succinct-zkvm-elf"); + pub const BLAKE3_COMPRESS_ELF: &[u8] = + include_bytes!("../../../tests/blake3-compress/elf/riscv32im-succinct-zkvm-elf"); + pub const CYCLE_TRACKER_ELF: &[u8] = include_bytes!("../../../tests/cycle-tracker/elf/riscv32im-succinct-zkvm-elf"); diff --git a/prover/src/build.rs b/prover/src/build.rs index 24f152d49b..129ec454a8 100644 --- a/prover/src/build.rs +++ b/prover/src/build.rs @@ -37,6 +37,9 @@ pub fn try_install_plonk_bn254_artifacts() -> PathBuf { } /// Tries to build the PLONK artifacts inside the development directory. +/// +/// TODO: Maybe add some additional logic here to handle rebuilding the artifacts if they are +/// already built. pub fn try_build_plonk_bn254_artifacts_dev( template_vk: &StarkVerifyingKey, template_proof: &ShardProof, diff --git a/prover/src/install.rs b/prover/src/install.rs index a469a40b1c..873c50f685 100644 --- a/prover/src/install.rs +++ b/prover/src/install.rs @@ -10,7 +10,7 @@ use crate::utils::block_on; pub const PLONK_BN254_ARTIFACTS_URL_BASE: &str = "https://sp1-circuits.s3-us-east-2.amazonaws.com"; /// The current version of the plonk bn254 artifacts. -pub const PLONK_BN254_ARTIFACTS_COMMIT: &str = "4a525e9f"; +pub const PLONK_BN254_ARTIFACTS_COMMIT: &str = "e48c01ec"; /// Install the latest plonk bn254 artifacts. /// diff --git a/prover/src/verify.rs b/prover/src/verify.rs index f829f4d6a9..fedd467cfc 100644 --- a/prover/src/verify.rs +++ b/prover/src/verify.rs @@ -4,7 +4,6 @@ use anyhow::Result; use num_bigint::BigUint; use p3_baby_bear::BabyBear; use p3_field::{AbstractField, PrimeField}; -use sp1_core::air::MachineAir; use sp1_core::{ air::PublicValues, io::SP1PublicValues, @@ -46,7 +45,7 @@ impl SP1Prover { self.core_machine .verify(&vk.vk, &machine_proof, &mut challenger)?; - // Verify shard transitions. + // Verify shard transitions for (i, shard_proof) in proof.0.iter().enumerate() { let public_values = PublicValues::from_vec(shard_proof.public_values.clone()); // Verify shard transitions @@ -101,58 +100,6 @@ impl SP1Prover { } } - // Verify that the number of shards is not too large. - if proof.0.len() > 1 << 16 { - return Err(MachineVerificationError::TooManyShards); - } - - // Verify that the `MemoryInit` and `MemoryFinalize` chips are the last chips in the proof. - for (i, shard_proof) in proof.0.iter().enumerate() { - let chips = self - .core_machine - .shard_chips_ordered(&shard_proof.chip_ordering) - .collect::>(); - let program_memory_init_count = chips - .clone() - .into_iter() - .filter(|chip| chip.name() == "MemoryProgram") - .count(); - let memory_init_count = chips - .clone() - .into_iter() - .filter(|chip| chip.name() == "MemoryInit") - .count(); - let memory_final_count = chips - .into_iter() - .filter(|chip| chip.name() == "MemoryFinalize") - .count(); - - // Assert that the `MemoryProgram` chip only exists in the first shard. - if i == 0 && program_memory_init_count != 1 { - return Err(MachineVerificationError::InvalidChipOccurence( - "memory should exist in the first chip".to_string(), - )); - } - if i != 0 && program_memory_init_count > 0 { - return Err(MachineVerificationError::InvalidChipOccurence( - "memory program should not exist in the first chip".to_string(), - )); - } - - // Assert that the `MemoryInit` and `MemoryFinalize` chips only exist in the last shard. - if i != proof.0.len() - 1 && (memory_final_count > 0 || memory_init_count > 0) { - return Err(MachineVerificationError::InvalidChipOccurence( - "memory init and finalize should not eixst anywhere but the last chip" - .to_string(), - )); - } - if i == proof.0.len() - 1 && (memory_init_count != 1 || memory_final_count != 1) { - return Err(MachineVerificationError::InvalidChipOccurence( - "memory init and finalize should exist the last chip".to_string(), - )); - } - } - Ok(()) } diff --git a/recursion/circuit/Cargo.toml b/recursion/circuit/Cargo.toml index 1b5076d5cc..8843d11859 100644 --- a/recursion/circuit/Cargo.toml +++ b/recursion/circuit/Cargo.toml @@ -31,6 +31,3 @@ p3-poseidon2 = { workspace = true } zkhash = { git = "https://github.com/HorizenLabs/poseidon2" } rand = "0.8.5" sp1-recursion-gnark-ffi = { path = "../gnark-ffi" } - -[features] -plonk = ["sp1-recursion-gnark-ffi/plonk"] diff --git a/recursion/circuit/src/poseidon2.rs b/recursion/circuit/src/poseidon2.rs index 792754014d..a5a8cc1136 100644 --- a/recursion/circuit/src/poseidon2.rs +++ b/recursion/circuit/src/poseidon2.rs @@ -1,7 +1,5 @@ //! An implementation of Poseidon2 over BN254. -use std::array; - use itertools::Itertools; use p3_field::AbstractField; use p3_field::Field; @@ -18,8 +16,6 @@ pub trait Poseidon2CircuitBuilder { fn p2_permute_mut(&mut self, state: [Var; SPONGE_SIZE]); fn p2_hash(&mut self, input: &[Felt]) -> OuterDigestVariable; fn p2_compress(&mut self, input: [OuterDigestVariable; 2]) -> OuterDigestVariable; - fn p2_babybear_permute_mut(&mut self, state: [Felt; 16]); - fn p2_babybear_hash(&mut self, input: &[Felt]) -> [Felt; 8]; } impl Poseidon2CircuitBuilder for Builder { @@ -56,24 +52,6 @@ impl Poseidon2CircuitBuilder for Builder { self.p2_permute_mut(state); [state[0]; DIGEST_SIZE] } - - fn p2_babybear_permute_mut(&mut self, state: [Felt; 16]) { - self.push(DslIr::CircuitPoseidon2PermuteBabyBear(state)); - } - - fn p2_babybear_hash(&mut self, input: &[Felt]) -> [Felt; 8] { - let mut state: [Felt; 16] = array::from_fn(|_| self.eval(C::F::zero())); - - for block_chunk in &input.iter().chunks(8) { - state - .iter_mut() - .zip(block_chunk) - .for_each(|(s, i)| *s = self.eval(*i)); - self.p2_babybear_permute_mut(state); - } - - array::from_fn(|i| state[i]) - } } #[cfg(test)] @@ -82,9 +60,6 @@ pub mod tests { use p3_bn254_fr::Bn254Fr; use p3_field::AbstractField; use p3_symmetric::{CryptographicHasher, Permutation, PseudoCompressionFunction}; - use rand::thread_rng; - use rand::Rng; - use sp1_core::utils::{inner_perm, InnerHash}; use sp1_recursion_compiler::config::OuterConfig; use sp1_recursion_compiler::constraints::ConstraintCompiler; use sp1_recursion_compiler::ir::{Builder, Felt, Var, Witness}; @@ -120,25 +95,6 @@ pub mod tests { PlonkBn254Prover::test::(constraints.clone(), Witness::default()); } - #[test] - fn test_p2_babybear_permute_mut() { - let mut rng = thread_rng(); - let mut builder = Builder::::default(); - let input: [BabyBear; 16] = [rng.gen(); 16]; - let input_vars: [Felt<_>; 16] = input.map(|x| builder.eval(x)); - builder.p2_babybear_permute_mut(input_vars); - - let perm = inner_perm(); - let result = perm.permute(input); - for i in 0..16 { - builder.assert_felt_eq(input_vars[i], result[i]); - } - - let mut backend = ConstraintCompiler::::default(); - let constraints = backend.emit(builder.operations); - PlonkBn254Prover::test::(constraints.clone(), Witness::default()); - } - #[test] fn test_p2_hash() { let perm = outer_perm(); @@ -191,53 +147,4 @@ pub mod tests { let constraints = backend.emit(builder.operations); PlonkBn254Prover::test::(constraints.clone(), Witness::default()); } - - #[test] - fn test_p2_babybear_hash() { - let perm = inner_perm(); - let hasher = InnerHash::new(perm.clone()); - - let input: [BabyBear; 26] = [ - BabyBear::from_canonical_u32(0), - BabyBear::from_canonical_u32(1), - BabyBear::from_canonical_u32(2), - BabyBear::from_canonical_u32(2), - BabyBear::from_canonical_u32(2), - BabyBear::from_canonical_u32(2), - BabyBear::from_canonical_u32(2), - BabyBear::from_canonical_u32(2), - BabyBear::from_canonical_u32(2), - BabyBear::from_canonical_u32(2), - BabyBear::from_canonical_u32(2), - BabyBear::from_canonical_u32(2), - BabyBear::from_canonical_u32(2), - BabyBear::from_canonical_u32(2), - BabyBear::from_canonical_u32(2), - BabyBear::from_canonical_u32(3), - BabyBear::from_canonical_u32(3), - BabyBear::from_canonical_u32(3), - BabyBear::from_canonical_u32(3), - BabyBear::from_canonical_u32(3), - BabyBear::from_canonical_u32(3), - BabyBear::from_canonical_u32(3), - BabyBear::from_canonical_u32(3), - BabyBear::from_canonical_u32(3), - BabyBear::from_canonical_u32(3), - BabyBear::from_canonical_u32(3), - ]; - let output = hasher.hash_iter(input); - println!("{:?}", output); - - let mut builder = Builder::::default(); - let input_felts: [Felt<_>; 26] = input.map(|x| builder.eval(x)); - let result = builder.p2_babybear_hash(input_felts.as_slice()); - - for i in 0..8 { - builder.assert_felt_eq(result[i], output[i]); - } - - let mut backend = ConstraintCompiler::::default(); - let constraints = backend.emit(builder.operations); - PlonkBn254Prover::test::(constraints.clone(), Witness::default()); - } } diff --git a/recursion/circuit/src/stark.rs b/recursion/circuit/src/stark.rs index 574a801411..48c800aa13 100644 --- a/recursion/circuit/src/stark.rs +++ b/recursion/circuit/src/stark.rs @@ -2,7 +2,6 @@ use std::borrow::Borrow; use std::marker::PhantomData; use crate::fri::verify_two_adic_pcs; -use crate::poseidon2::Poseidon2CircuitBuilder; use crate::types::OuterDigestVariable; use crate::utils::{babybear_bytes_to_bn254, babybears_to_bn254, words_to_bytes}; use crate::witness::Witnessable; @@ -21,7 +20,7 @@ use sp1_recursion_compiler::constraints::{Constraint, ConstraintCompiler}; use sp1_recursion_compiler::ir::{Builder, Config, Ext, Felt, Var}; use sp1_recursion_compiler::ir::{Usize, Witness}; use sp1_recursion_compiler::prelude::SymbolicVar; -use sp1_recursion_core::air::{RecursionPublicValues, NUM_PV_ELMS_TO_HASH}; +use sp1_recursion_core::air::RecursionPublicValues; use sp1_recursion_core::stark::config::{outer_fri_config, BabyBearPoseidon2Outer}; use sp1_recursion_core::stark::RecursionAirSkinnyDeg9; use sp1_recursion_program::commit::PolynomialSpaceVariable; @@ -271,9 +270,7 @@ pub fn build_wrap_circuit( let element = builder.get(&proof.public_values, i); pv_elements.push(element); } - let pv: &RecursionPublicValues<_> = pv_elements.as_slice().borrow(); - let one_felt: Felt<_> = builder.constant(BabyBear::one()); // Proof must be complete. In the reduce program, this will ensure that the SP1 proof has been // fully accumulated. @@ -350,13 +347,6 @@ pub fn build_wrap_circuit( } builder.assert_ext_eq(cumulative_sum, zero_ext); - // Verify the public values digest. - let calculated_digest = builder.p2_babybear_hash(&pv_elements[0..NUM_PV_ELMS_TO_HASH]); - let expected_digest = pv.digest; - for (calculated_elm, expected_elm) in calculated_digest.iter().zip(expected_digest.iter()) { - builder.assert_felt_eq(*expected_elm, *calculated_elm); - } - let mut backend = ConstraintCompiler::::default(); backend.emit(builder.operations) } diff --git a/recursion/compiler/src/constraints/mod.rs b/recursion/compiler/src/constraints/mod.rs index eb43951358..c5c67647a9 100644 --- a/recursion/compiler/src/constraints/mod.rs +++ b/recursion/compiler/src/constraints/mod.rs @@ -249,10 +249,6 @@ impl ConstraintCompiler { opcode: ConstraintOpcode::Permute, args: state.iter().map(|x| vec![x.id()]).collect(), }), - DslIr::CircuitPoseidon2PermuteBabyBear(state) => constraints.push(Constraint { - opcode: ConstraintOpcode::PermuteBabyBear, - args: state.iter().map(|x| vec![x.id()]).collect(), - }), DslIr::CircuitSelectV(cond, a, b, out) => { constraints.push(Constraint { opcode: ConstraintOpcode::SelectV, diff --git a/recursion/compiler/src/constraints/opcodes.rs b/recursion/compiler/src/constraints/opcodes.rs index 4911e0f108..581b4558da 100644 --- a/recursion/compiler/src/constraints/opcodes.rs +++ b/recursion/compiler/src/constraints/opcodes.rs @@ -46,5 +46,4 @@ pub enum ConstraintOpcode { CommitVkeyHash, CommitCommitedValuesDigest, CircuitFelts2Ext, - PermuteBabyBear, } diff --git a/recursion/compiler/src/ir/bits.rs b/recursion/compiler/src/ir/bits.rs index 396fb92be4..f69c8cee1d 100644 --- a/recursion/compiler/src/ir/bits.rs +++ b/recursion/compiler/src/ir/bits.rs @@ -26,15 +26,6 @@ impl Builder { output } - /// Range checks a variable to a certain number of bits. - pub fn range_check_v(&mut self, num: Var, num_bits: usize) { - let bits = self.num2bits_v(num); - self.range(num_bits, bits.len()).for_each(|i, builder| { - let bit = builder.get(&bits, i); - builder.assert_var_eq(bit, C::N::zero()); - }); - } - /// Converts a variable to bits inside a circuit. pub fn num2bits_v_circuit(&mut self, num: Var, bits: usize) -> Vec> { let mut output = Vec::new(); diff --git a/recursion/compiler/src/ir/instructions.rs b/recursion/compiler/src/ir/instructions.rs index 3826fc3270..f7a5cee3e0 100644 --- a/recursion/compiler/src/ir/instructions.rs +++ b/recursion/compiler/src/ir/instructions.rs @@ -201,8 +201,6 @@ pub enum DslIr { /// Permutes an array of Bn254 elements using Poseidon2 (output = p2_permute(array)). Should only /// be used when target is a gnark circuit. CircuitPoseidon2Permute([Var; 3]), - /// Permutates an array of BabyBear elements in the circuit. - CircuitPoseidon2PermuteBabyBear([Felt; 16]), // Miscellaneous instructions. /// Decompose hint operation of a usize into an array. (output = num2bits(usize)). diff --git a/recursion/core/src/air/builder.rs b/recursion/core/src/air/builder.rs index ab6e6c1017..6bcf20d408 100644 --- a/recursion/core/src/air/builder.rs +++ b/recursion/core/src/air/builder.rs @@ -30,8 +30,6 @@ pub trait RecursionMemoryAirBuilder: RecursionInteractionAirBuilder { is_real: impl Into, ) { let is_real: Self::Expr = is_real.into(); - self.assert_bool(is_real.clone()); - let timestamp: Self::Expr = timestamp.into(); let mem_access = memory_access.access(); @@ -68,8 +66,6 @@ pub trait RecursionMemoryAirBuilder: RecursionInteractionAirBuilder { is_real: impl Into, ) { let is_real: Self::Expr = is_real.into(); - self.assert_bool(is_real.clone()); - let timestamp: Self::Expr = timestamp.into(); let mem_access = memory_access.access(); diff --git a/recursion/core/src/cpu/air/branch.rs b/recursion/core/src/cpu/air/branch.rs index 105dd771da..91bebfa65e 100644 --- a/recursion/core/src/cpu/air/branch.rs +++ b/recursion/core/src/cpu/air/branch.rs @@ -3,9 +3,7 @@ use p3_field::{AbstractField, Field}; use sp1_core::air::{BinomialExtension, ExtensionAirBuilder}; use crate::{ - air::{ - BinomialExtensionUtils, Block, BlockBuilder, IsExtZeroOperation, SP1RecursionAirBuilder, - }, + air::{BinomialExtensionUtils, IsExtZeroOperation, SP1RecursionAirBuilder}, cpu::{CpuChip, CpuCols}, memory::MemoryCols, }; @@ -24,24 +22,18 @@ impl CpuChip { let is_branch_instruction = self.is_branch_instruction::(local); let one = AB::Expr::one(); + // If the instruction is a BNEINC, verify that the a value is incremented by one. + builder + .when(local.is_real) + .when(local.selectors.is_bneinc) + .assert_eq(local.a.value()[0], local.a.prev_value()[0] + one.clone()); + // Convert operand values from Block to BinomialExtension. Note that it gets the // previous value of the `a` and `b` operands, since BNENIC will modify `a`. - let a_prev_ext: BinomialExtension = - BinomialExtensionUtils::from_block(local.a.prev_value().map(|x| x.into())); let a_ext: BinomialExtension = BinomialExtensionUtils::from_block(local.a.value().map(|x| x.into())); let b_ext: BinomialExtension = BinomialExtensionUtils::from_block(local.b.value().map(|x| x.into())); - let one_ext: BinomialExtension = - BinomialExtensionUtils::from_block(Block::from(one.clone())); - - let expected_a_ext = a_prev_ext + one_ext; - - // If the instruction is a BNEINC, verify that the a value is incremented by one. - builder - .when(local.is_real) - .when(local.selectors.is_bneinc) - .assert_block_eq(a_ext.as_block(), expected_a_ext.as_block()); let comparison_diff = a_ext - b_ext; diff --git a/recursion/core/src/cpu/air/jump.rs b/recursion/core/src/cpu/air/jump.rs index dd5e9b8bba..bf86a70cce 100644 --- a/recursion/core/src/cpu/air/jump.rs +++ b/recursion/core/src/cpu/air/jump.rs @@ -2,7 +2,7 @@ use p3_air::AirBuilder; use p3_field::{AbstractField, Field}; use crate::{ - air::{Block, BlockBuilder, SP1RecursionAirBuilder}, + air::SP1RecursionAirBuilder, cpu::{CpuChip, CpuCols}, memory::MemoryCols, runtime::STACK_SIZE, @@ -21,29 +21,19 @@ impl CpuChip { ) where AB: SP1RecursionAirBuilder, { - let is_jump_instr = self.is_jump_instruction::(local); - // Verify the next row's fp. builder .when_first_row() .assert_eq(local.fp, F::from_canonical_usize(STACK_SIZE)); - let not_jump_instruction = AB::Expr::one() - is_jump_instr.clone(); + let not_jump_instruction = AB::Expr::one() - self.is_jump_instruction::(local); let expected_next_fp = local.selectors.is_jal * (local.fp + local.c.value()[0]) - + local.selectors.is_jalr * local.c.value()[0] + + local.selectors.is_jalr * local.a.value()[0] + not_jump_instruction * local.fp; builder .when_transition() .when(next.is_real) .assert_eq(next.fp, expected_next_fp); - // Verify the a operand values. - let expected_a_val = local.selectors.is_jal * local.pc - + local.selectors.is_jalr * (local.pc + AB::Expr::one()); - let expected_a_val_block = Block::from(expected_a_val); - builder - .when(is_jump_instr) - .assert_block_eq(*local.a.value(), expected_a_val_block); - // Add to the `next_pc` expression. *next_pc += local.selectors.is_jal * (local.pc + local.b.value()[0]); *next_pc += local.selectors.is_jalr * local.b.value()[0]; diff --git a/recursion/core/src/cpu/air/memory.rs b/recursion/core/src/cpu/air/memory.rs index d1b024130f..c0a3a2b639 100644 --- a/recursion/core/src/cpu/air/memory.rs +++ b/recursion/core/src/cpu/air/memory.rs @@ -30,7 +30,7 @@ impl CpuChip { local.clk + AB::F::from_canonical_u32(MemoryAccessPosition::Memory as u32), memory_cols.memory_addr, &memory_cols.memory, - is_memory_instr.clone(), + is_memory_instr, ); // Constraints on the memory column depending on load or store. @@ -41,7 +41,7 @@ impl CpuChip { ); // When there is a store, we ensure that we are writing the value of the a operand to the memory. builder - .when(is_memory_instr) + .when(local.selectors.is_store) .assert_block_eq(*local.a.value(), *memory_cols.memory.value()); } } diff --git a/recursion/core/src/multi/mod.rs b/recursion/core/src/multi/mod.rs index 10a9c3db0d..23173f7de9 100644 --- a/recursion/core/src/multi/mod.rs +++ b/recursion/core/src/multi/mod.rs @@ -190,7 +190,8 @@ where local.poseidon2_receive_table, ); sub_builder.assert_eq( - local.is_poseidon2 * Poseidon2Chip::do_memory_access::(poseidon2_columns), + local.is_poseidon2 + * Poseidon2Chip::do_memory_access::(poseidon2_columns), local.poseidon2_memory_access, ); @@ -200,7 +201,7 @@ where local.poseidon2(), next.poseidon2(), local.poseidon2_receive_table, - local.poseidon2_memory_access, + local.poseidon2_memory_access.into(), ); } } diff --git a/recursion/core/src/poseidon2/columns.rs b/recursion/core/src/poseidon2/columns.rs index fa12a655f2..12fa730477 100644 --- a/recursion/core/src/poseidon2/columns.rs +++ b/recursion/core/src/poseidon2/columns.rs @@ -11,10 +11,7 @@ pub struct Poseidon2Cols { pub left_input: T, pub right_input: T, pub rounds: [T; 24], // 1 round for memory input; 1 round for initialize; 8 rounds for external; 13 rounds for internal; 1 round for memory output - pub do_receive: T, - pub do_memory: T, pub round_specific_cols: RoundSpecificCols, - pub is_real: T, } #[derive(AlignedBorrow, Clone, Copy)] @@ -48,7 +45,6 @@ impl RoundSpecificCols { pub struct ComputationCols { pub input: [T; WIDTH], pub add_rc: [T; WIDTH], - pub sbox_deg_3: [T; WIDTH], pub sbox_deg_7: [T; WIDTH], pub output: [T; WIDTH], } diff --git a/recursion/core/src/poseidon2/external.rs b/recursion/core/src/poseidon2/external.rs index d340ba2b41..c871bd873d 100644 --- a/recursion/core/src/poseidon2/external.rs +++ b/recursion/core/src/poseidon2/external.rs @@ -6,6 +6,7 @@ use p3_field::AbstractField; use p3_matrix::Matrix; use sp1_core::air::{BaseAirBuilder, ExtensionAirBuilder, SP1AirBuilder}; use sp1_primitives::RC_16_30_U32; +use std::ops::Add; use crate::air::{RecursionInteractionAirBuilder, RecursionMemoryAirBuilder}; use crate::memory::MemoryCols; @@ -39,7 +40,7 @@ impl Poseidon2Chip { local: &Poseidon2Cols, next: &Poseidon2Cols, receive_table: AB::Var, - memory_access: AB::Var, + memory_access: AB::Expr, ) { const NUM_ROUNDS_F: usize = 8; const NUM_ROUNDS_P: usize = 13; @@ -65,10 +66,6 @@ impl Poseidon2Chip { .sum::(); let is_memory_write = local.rounds[local.rounds.len() - 1]; - self.eval_control_flow_and_inputs(builder, local, next); - - self.eval_syscall(builder, local, receive_table); - self.eval_mem( builder, local, @@ -87,71 +84,16 @@ impl Poseidon2Chip { is_internal_layer.clone(), NUM_ROUNDS_F + NUM_ROUNDS_P + 1, ); - } - - fn eval_control_flow_and_inputs( - &self, - builder: &mut AB, - local: &Poseidon2Cols, - next: &Poseidon2Cols, - ) { - let num_total_rounds = local.rounds.len(); - for i in 0..num_total_rounds { - // Verify that the round flags are correct. - builder.assert_bool(local.rounds[i]); - // Assert that the next round is correct. - builder - .when_transition() - .assert_eq(local.rounds[i], next.rounds[(i + 1) % num_total_rounds]); + self.eval_syscall(builder, local, receive_table); - if i != num_total_rounds - 1 { - builder - .when_transition() - .when(local.rounds[i]) - .assert_eq(local.clk, next.clk); - builder - .when_transition() - .when(local.rounds[i]) - .assert_eq(local.dst_input, next.dst_input); - builder - .when_transition() - .when(local.rounds[i]) - .assert_eq(local.left_input, next.left_input); - builder - .when_transition() - .when(local.rounds[i]) - .assert_eq(local.right_input, next.right_input); - } + // Range check all flags. + for i in 0..local.rounds.len() { + builder.assert_bool(local.rounds[i]); } - - // Ensure that at most one of the round flags is set. - let round_acc = local - .rounds - .iter() - .fold(AB::Expr::zero(), |acc, round_flag| acc + *round_flag); - builder.assert_bool(round_acc); - - // Verify the do_memory flag. - builder.assert_eq( - local.do_memory, - local.is_real * (local.rounds[0] + local.rounds[23]), + builder.assert_bool( + is_memory_read + is_initial + is_external_layer + is_internal_layer + is_memory_write, ); - - // Verify the do_receive flag. - builder.assert_eq(local.do_receive, local.is_real * local.rounds[0]); - - // Verify the first row starts at round 0. - builder.when_first_row().assert_one(local.rounds[0]); - // The round count is not a power of 2, so the last row should not be real. - builder.when_last_row().assert_zero(local.is_real); - - // Verify that all is_real flags within a round are equal. - let is_last_round = local.rounds[23]; - builder - .when_transition() - .when_not(is_last_round) - .assert_eq(local.is_real, next.is_real); } fn eval_mem( @@ -161,23 +103,20 @@ impl Poseidon2Chip { next: &Poseidon2Cols, is_memory_read: AB::Var, is_memory_write: AB::Var, - memory_access: AB::Var, + memory_access: AB::Expr, ) { let memory_access_cols = local.round_specific_cols.memory_access(); builder - .when(local.is_real) .when(is_memory_read) .assert_eq(local.left_input, memory_access_cols.addr_first_half); builder - .when(local.is_real) .when(is_memory_read) .assert_eq(local.right_input, memory_access_cols.addr_second_half); builder - .when(local.is_real) .when(is_memory_write) .assert_eq(local.dst_input, memory_access_cols.addr_first_half); - builder.when(local.is_real).when(is_memory_write).assert_eq( + builder.when(is_memory_write).assert_eq( local.dst_input + AB::F::from_canonical_usize(WIDTH / 2), memory_access_cols.addr_second_half, ); @@ -192,11 +131,7 @@ impl Poseidon2Chip { local.clk + AB::Expr::one() * is_memory_write, addr, &memory_access_cols.mem_access[i], - memory_access, - ); - builder.when(local.is_real).when(is_memory_read).assert_eq( - *memory_access_cols.mem_access[i].value(), - *memory_access_cols.mem_access[i].prev_value(), + memory_access.clone(), ); } @@ -204,14 +139,10 @@ impl Poseidon2Chip { // computation round. let next_computation_col = next.round_specific_cols.computation(); for i in 0..WIDTH { - builder - .when_transition() - .when(local.is_real) - .when(is_memory_read) - .assert_eq( - *memory_access_cols.mem_access[i].value(), - next_computation_col.input[i], - ); + builder.when_transition().when(is_memory_read).assert_eq( + *memory_access_cols.mem_access[i].value(), + next_computation_col.input[i], + ); } } @@ -253,7 +184,6 @@ impl Poseidon2Chip { } } builder - .when(local.is_real) .when(is_initial.clone() + is_external_layer.clone() + is_internal_layer.clone()) .assert_eq(result, computation_cols.add_rc[i]); } @@ -266,15 +196,8 @@ impl Poseidon2Chip { let sbox_deg_3 = computation_cols.add_rc[i] * computation_cols.add_rc[i] * computation_cols.add_rc[i]; + let sbox_deg_7 = sbox_deg_3.clone() * sbox_deg_3.clone() * computation_cols.add_rc[i]; builder - .when(local.is_real) - .when(is_initial.clone() + is_external_layer.clone() + is_internal_layer.clone()) - .assert_eq(computation_cols.sbox_deg_3[i], sbox_deg_3); - let sbox_deg_7 = computation_cols.sbox_deg_3[i] - * computation_cols.sbox_deg_3[i] - * computation_cols.add_rc[i]; - builder - .when(local.is_real) .when(is_initial.clone() + is_external_layer.clone() + is_internal_layer.clone()) .assert_eq(sbox_deg_7, computation_cols.sbox_deg_7[i]); } @@ -330,7 +253,6 @@ impl Poseidon2Chip { for i in 0..WIDTH { state[i] += sums[i % 4].clone(); builder - .when(local.is_real) .when(is_external_layer.clone() + is_initial.clone()) .assert_eq(state[i].clone(), computation_cols.output[i]); } @@ -342,7 +264,6 @@ impl Poseidon2Chip { let mut state: [AB::Expr; WIDTH] = sbox_result.clone(); internal_linear_layer(&mut state); builder - .when(local.is_real) .when(is_internal_layer.clone()) .assert_all_eq(state.clone(), computation_cols.output); } @@ -360,7 +281,6 @@ impl Poseidon2Chip { builder .when_transition() - .when(local.is_real) .when(is_initial.clone() + is_external_layer.clone() + is_internal_layer.clone()) .assert_eq(computation_cols.output[i], next_round_value); } @@ -387,11 +307,13 @@ impl Poseidon2Chip { } pub const fn do_receive_table(local: &Poseidon2Cols) -> T { - local.do_receive + local.rounds[0] } - pub fn do_memory_access(local: &Poseidon2Cols) -> T { - local.do_memory + pub fn do_memory_access, Output>( + local: &Poseidon2Cols, + ) -> Output { + local.rounds[0] + local.rounds[23] } } @@ -411,7 +333,7 @@ where local, next, Self::do_receive_table::(local), - Self::do_memory_access::(local), + Self::do_memory_access::(local), ); } } diff --git a/recursion/core/src/poseidon2/trace.rs b/recursion/core/src/poseidon2/trace.rs index 2d5639edd2..cc6a41d94f 100644 --- a/recursion/core/src/poseidon2/trace.rs +++ b/recursion/core/src/poseidon2/trace.rs @@ -49,9 +49,7 @@ impl MachineAir for Poseidon2Chip { for r in 0..rounds { let mut row = [F::zero(); NUM_POSEIDON2_COLS]; let cols: &mut Poseidon2Cols = row.as_mut_slice().borrow_mut(); - cols.is_real = F::one(); - let is_receive = r == 0; let is_memory_read = r == 0; let is_initial_layer = r == 1; let is_external_layer = @@ -80,10 +78,6 @@ impl MachineAir for Poseidon2Chip { cols.right_input = poseidon2_event.right; cols.rounds[r] = F::one(); - if is_receive { - cols.do_receive = F::one(); - } - if is_memory_read || is_memory_write { let memory_access_cols = cols.round_specific_cols.memory_access_mut(); @@ -103,7 +97,6 @@ impl MachineAir for Poseidon2Chip { .populate(&poseidon2_event.result_records[i]); } } - cols.do_memory = F::one(); } else { let computation_cols = cols.round_specific_cols.computation_mut(); @@ -138,7 +131,6 @@ impl MachineAir for Poseidon2Chip { let sbox_deg_3 = computation_cols.add_rc[j] * computation_cols.add_rc[j] * computation_cols.add_rc[j]; - computation_cols.sbox_deg_3[j] = sbox_deg_3; computation_cols.sbox_deg_7[j] = sbox_deg_3 * sbox_deg_3 * computation_cols.add_rc[j]; } @@ -171,8 +163,6 @@ impl MachineAir for Poseidon2Chip { } } - let num_real_rows = rows.len(); - // Pad the trace to a power of two. pad_rows_fixed( &mut rows, @@ -180,14 +170,6 @@ impl MachineAir for Poseidon2Chip { self.fixed_log2_rows, ); - let mut round_num = 0; - for row in rows[num_real_rows..].iter_mut() { - let cols: &mut Poseidon2Cols = row.as_mut_slice().borrow_mut(); - cols.rounds[round_num] = F::one(); - - round_num = (round_num + 1) % rounds; - } - // Convert the trace to a row major matrix. RowMajorMatrix::new( rows.into_iter().flatten().collect::>(), diff --git a/recursion/gnark-ffi/Cargo.toml b/recursion/gnark-ffi/Cargo.toml index 467ba5bc08..d10ed2fd31 100644 --- a/recursion/gnark-ffi/Cargo.toml +++ b/recursion/gnark-ffi/Cargo.toml @@ -5,10 +5,8 @@ edition = "2021" [dependencies] p3-field = { workspace = true } -p3-symmetric = { workspace = true } p3-baby-bear = { workspace = true } sp1-recursion-compiler = { path = "../compiler" } -sp1-core = { path = "../../core" } serde = "1.0.201" serde_json = "1.0.117" tempfile = "3.10.1" diff --git a/recursion/gnark-ffi/go/main.go b/recursion/gnark-ffi/go/main.go index ed782400f2..89bba4a7e8 100644 --- a/recursion/gnark-ffi/go/main.go +++ b/recursion/gnark-ffi/go/main.go @@ -17,15 +17,11 @@ import ( "sync" "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark/backend/groth16" "github.com/consensys/gnark/backend/plonk" "github.com/consensys/gnark/frontend" - "github.com/consensys/gnark/frontend/cs/r1cs" "github.com/consensys/gnark/frontend/cs/scs" "github.com/consensys/gnark/test/unsafekzg" "github.com/succinctlabs/sp1-recursion-gnark/sp1" - "github.com/succinctlabs/sp1-recursion-gnark/sp1/babybear" - "github.com/succinctlabs/sp1-recursion-gnark/sp1/poseidon2" ) func main() {} @@ -145,73 +141,3 @@ func TestMain() error { return nil } - -//export TestPoseidonBabyBear2 -func TestPoseidonBabyBear2() *C.char { - input := [poseidon2.BABYBEAR_WIDTH]babybear.Variable{ - babybear.NewF("0"), - babybear.NewF("0"), - babybear.NewF("0"), - babybear.NewF("0"), - babybear.NewF("0"), - babybear.NewF("0"), - babybear.NewF("0"), - babybear.NewF("0"), - babybear.NewF("0"), - babybear.NewF("0"), - babybear.NewF("0"), - babybear.NewF("0"), - babybear.NewF("0"), - babybear.NewF("0"), - babybear.NewF("0"), - babybear.NewF("0"), - } - - expectedOutput := [poseidon2.BABYBEAR_WIDTH]babybear.Variable{ - babybear.NewF("348670919"), - babybear.NewF("1568590631"), - babybear.NewF("1535107508"), - babybear.NewF("186917780"), - babybear.NewF("587749971"), - babybear.NewF("1827585060"), - babybear.NewF("1218809104"), - babybear.NewF("691692291"), - babybear.NewF("1480664293"), - babybear.NewF("1491566329"), - babybear.NewF("366224457"), - babybear.NewF("490018300"), - babybear.NewF("732772134"), - babybear.NewF("560796067"), - babybear.NewF("484676252"), - babybear.NewF("405025962"), - } - - circuit := sp1.TestPoseidon2BabyBearCircuit{Input: input, ExpectedOutput: expectedOutput} - assignment := sp1.TestPoseidon2BabyBearCircuit{Input: input, ExpectedOutput: expectedOutput} - - builder := r1cs.NewBuilder - r1cs, err := frontend.Compile(ecc.BN254.ScalarField(), builder, &circuit) - if err != nil { - return C.CString(err.Error()) - } - - var pk groth16.ProvingKey - pk, err = groth16.DummySetup(r1cs) - if err != nil { - return C.CString(err.Error()) - } - - // Generate witness. - witness, err := frontend.NewWitness(&assignment, ecc.BN254.ScalarField()) - if err != nil { - return C.CString(err.Error()) - } - - // Generate the proof. - _, err = groth16.Prove(r1cs, pk, witness) - if err != nil { - return C.CString(err.Error()) - } - - return nil -} diff --git a/recursion/gnark-ffi/go/sp1/poseidon2/constants.go b/recursion/gnark-ffi/go/sp1/poseidon2/constants.go index edb5a5e4a7..fb350f180e 100644 --- a/recursion/gnark-ffi/go/sp1/poseidon2/constants.go +++ b/recursion/gnark-ffi/go/sp1/poseidon2/constants.go @@ -3,22 +3,11 @@ package poseidon2 import ( "github.com/consensys/gnark/frontend" - "github.com/succinctlabs/sp1-recursion-gnark/sp1/babybear" ) -// Poseidon2 round constants for a state consisting of three BN254 field elements. var RC3 [NUM_EXTERNAL_ROUNDS + NUM_INTERNAL_ROUNDS][WIDTH]frontend.Variable -// Poseidon2 round constaints for a state consisting of 16 BabyBear field elements. - -var RC16 [30][BABYBEAR_WIDTH]babybear.Variable - func init() { - init_rc3() - init_rc16() -} - -func init_rc3() { round := 0 RC3[round] = [WIDTH]frontend.Variable{ @@ -468,580 +457,3 @@ func init_rc3() { frontend.Variable("0x0fc1bbceba0590f5abbdffa6d3b35e3297c021a3a409926d0e2d54dc1c84fda6"), } } - -func init_rc16() { - round := 0 - - RC16[round] = [BABYBEAR_WIDTH]babybear.Variable{ - babybear.NewF("2110014213"), - babybear.NewF("3964964605"), - babybear.NewF("2190662774"), - babybear.NewF("2732996483"), - babybear.NewF("640767983"), - babybear.NewF("3403899136"), - babybear.NewF("1716033721"), - babybear.NewF("1606702601"), - babybear.NewF("3759873288"), - babybear.NewF("1466015491"), - babybear.NewF("1498308946"), - babybear.NewF("2844375094"), - babybear.NewF("3042463841"), - babybear.NewF("1969905919"), - babybear.NewF("4109944726"), - babybear.NewF("3925048366"), - } - - round += 1 - RC16[round] = [BABYBEAR_WIDTH]babybear.Variable{ - babybear.NewF("3706859504"), - babybear.NewF("759122502"), - babybear.NewF("3167665446"), - babybear.NewF("1131812921"), - babybear.NewF("1080754908"), - babybear.NewF("4080114493"), - babybear.NewF("893583089"), - babybear.NewF("2019677373"), - babybear.NewF("3128604556"), - babybear.NewF("580640471"), - babybear.NewF("3277620260"), - babybear.NewF("842931656"), - babybear.NewF("548879852"), - babybear.NewF("3608554714"), - babybear.NewF("3575647916"), - babybear.NewF("81826002"), - } - - round += 1 - RC16[round] = [BABYBEAR_WIDTH]babybear.Variable{ - babybear.NewF("4289086263"), - babybear.NewF("1563933798"), - babybear.NewF("1440025885"), - babybear.NewF("184445025"), - babybear.NewF("2598651360"), - babybear.NewF("1396647410"), - babybear.NewF("1575877922"), - babybear.NewF("3303853401"), - babybear.NewF("137125468"), - babybear.NewF("765010148"), - babybear.NewF("633675867"), - babybear.NewF("2037803363"), - babybear.NewF("2573389828"), - babybear.NewF("1895729703"), - babybear.NewF("541515871"), - babybear.NewF("1783382863"), - } - - round += 1 - RC16[round] = [BABYBEAR_WIDTH]babybear.Variable{ - babybear.NewF("2641856484"), - babybear.NewF("3035743342"), - babybear.NewF("3672796326"), - babybear.NewF("245668751"), - babybear.NewF("2025460432"), - babybear.NewF("201609705"), - babybear.NewF("286217151"), - babybear.NewF("4093475563"), - babybear.NewF("2519572182"), - babybear.NewF("3080699870"), - babybear.NewF("2762001832"), - babybear.NewF("1244250808"), - babybear.NewF("606038199"), - babybear.NewF("3182740831"), - babybear.NewF("73007766"), - babybear.NewF("2572204153"), - } - round += 1 - RC16[round] = [BABYBEAR_WIDTH]babybear.Variable{ - babybear.NewF("1196780786"), - babybear.NewF("3447394443"), - babybear.NewF("747167305"), - babybear.NewF("2968073607"), - babybear.NewF("1053214930"), - babybear.NewF("1074411832"), - babybear.NewF("4016794508"), - babybear.NewF("1570312929"), - babybear.NewF("113576933"), - babybear.NewF("4042581186"), - babybear.NewF("3634515733"), - babybear.NewF("1032701597"), - babybear.NewF("2364839308"), - babybear.NewF("3840286918"), - babybear.NewF("888378655"), - babybear.NewF("2520191583"), - } - round += 1 - RC16[round] = [BABYBEAR_WIDTH]babybear.Variable{ - babybear.NewF("36046858"), - babybear.NewF("2927525953"), - babybear.NewF("3912129105"), - babybear.NewF("4004832531"), - babybear.NewF("193772436"), - babybear.NewF("1590247392"), - babybear.NewF("4125818172"), - babybear.NewF("2516251696"), - babybear.NewF("4050945750"), - babybear.NewF("269498914"), - babybear.NewF("1973292656"), - babybear.NewF("891403491"), - babybear.NewF("1845429189"), - babybear.NewF("2611996363"), - babybear.NewF("2310542653"), - babybear.NewF("4071195740"), - } - round += 1 - RC16[round] = [BABYBEAR_WIDTH]babybear.Variable{ - babybear.NewF("3505307391"), - babybear.NewF("786445290"), - babybear.NewF("3815313971"), - babybear.NewF("1111591756"), - babybear.NewF("4233279834"), - babybear.NewF("2775453034"), - babybear.NewF("1991257625"), - babybear.NewF("2940505809"), - babybear.NewF("2751316206"), - babybear.NewF("1028870679"), - babybear.NewF("1282466273"), - babybear.NewF("1059053371"), - babybear.NewF("834521354"), - babybear.NewF("138721483"), - babybear.NewF("3100410803"), - babybear.NewF("3843128331"), - } - round += 1 - RC16[round] = [BABYBEAR_WIDTH]babybear.Variable{ - babybear.NewF("3878220780"), - babybear.NewF("4058162439"), - babybear.NewF("1478942487"), - babybear.NewF("799012923"), - babybear.NewF("496734827"), - babybear.NewF("3521261236"), - babybear.NewF("755421082"), - babybear.NewF("1361409515"), - babybear.NewF("392099473"), - babybear.NewF("3178453393"), - babybear.NewF("4068463721"), - babybear.NewF("7935614"), - babybear.NewF("4140885645"), - babybear.NewF("2150748066"), - babybear.NewF("1685210312"), - babybear.NewF("3852983224"), - } - round += 1 - RC16[round] = [BABYBEAR_WIDTH]babybear.Variable{ - babybear.NewF("2896943075"), - babybear.NewF("3087590927"), - babybear.NewF("992175959"), - babybear.NewF("970216228"), - babybear.NewF("3473630090"), - babybear.NewF("3899670400"), - babybear.NewF("3603388822"), - babybear.NewF("2633488197"), - babybear.NewF("2479406964"), - babybear.NewF("2420952999"), - babybear.NewF("1852516800"), - babybear.NewF("4253075697"), - babybear.NewF("979699862"), - babybear.NewF("1163403191"), - babybear.NewF("1608599874"), - babybear.NewF("3056104448"), - } - round += 1 - RC16[round] = [BABYBEAR_WIDTH]babybear.Variable{ - babybear.NewF("3779109343"), - babybear.NewF("536205958"), - babybear.NewF("4183458361"), - babybear.NewF("1649720295"), - babybear.NewF("1444912244"), - babybear.NewF("3122230878"), - babybear.NewF("384301396"), - babybear.NewF("4228198516"), - babybear.NewF("1662916865"), - babybear.NewF("4082161114"), - babybear.NewF("2121897314"), - babybear.NewF("1706239958"), - babybear.NewF("4166959388"), - babybear.NewF("1626054781"), - babybear.NewF("3005858978"), - babybear.NewF("1431907253"), - } - round += 1 - RC16[round] = [BABYBEAR_WIDTH]babybear.Variable{ - babybear.NewF("1418914503"), - babybear.NewF("1365856753"), - babybear.NewF("3942715745"), - babybear.NewF("1429155552"), - babybear.NewF("3545642795"), - babybear.NewF("3772474257"), - babybear.NewF("1621094396"), - babybear.NewF("2154399145"), - babybear.NewF("826697382"), - babybear.NewF("1700781391"), - babybear.NewF("3539164324"), - babybear.NewF("652815039"), - babybear.NewF("442484755"), - babybear.NewF("2055299391"), - babybear.NewF("1064289978"), - babybear.NewF("1152335780"), - } - round += 1 - RC16[round] = [BABYBEAR_WIDTH]babybear.Variable{ - babybear.NewF("3417648695"), - babybear.NewF("186040114"), - babybear.NewF("3475580573"), - babybear.NewF("2113941250"), - babybear.NewF("1779573826"), - babybear.NewF("1573808590"), - babybear.NewF("3235694804"), - babybear.NewF("2922195281"), - babybear.NewF("1119462702"), - babybear.NewF("3688305521"), - babybear.NewF("1849567013"), - babybear.NewF("667446787"), - babybear.NewF("753897224"), - babybear.NewF("1896396780"), - babybear.NewF("3143026334"), - babybear.NewF("3829603876"), - } - round += 1 - RC16[round] = [BABYBEAR_WIDTH]babybear.Variable{ - babybear.NewF("859661334"), - babybear.NewF("3898844357"), - babybear.NewF("180258337"), - babybear.NewF("2321867017"), - babybear.NewF("3599002504"), - babybear.NewF("2886782421"), - babybear.NewF("3038299378"), - babybear.NewF("1035366250"), - babybear.NewF("2038912197"), - babybear.NewF("2920174523"), - babybear.NewF("1277696101"), - babybear.NewF("2785700290"), - babybear.NewF("3806504335"), - babybear.NewF("3518858933"), - babybear.NewF("654843672"), - babybear.NewF("2127120275"), - } - round += 1 - RC16[round] = [BABYBEAR_WIDTH]babybear.Variable{ - babybear.NewF("1548195514"), - babybear.NewF("2378056027"), - babybear.NewF("390914568"), - babybear.NewF("1472049779"), - babybear.NewF("1552596765"), - babybear.NewF("1905886441"), - babybear.NewF("1611959354"), - babybear.NewF("3653263304"), - babybear.NewF("3423946386"), - babybear.NewF("340857935"), - babybear.NewF("2208879480"), - babybear.NewF("139364268"), - babybear.NewF("3447281773"), - babybear.NewF("3777813707"), - babybear.NewF("55640413"), - babybear.NewF("4101901741"), - } - round += 1 - RC16[round] = [BABYBEAR_WIDTH]babybear.Variable{ - babybear.NewF("104929687"), - babybear.NewF("1459980974"), - babybear.NewF("1831234737"), - babybear.NewF("457139004"), - babybear.NewF("2581487628"), - babybear.NewF("2112044563"), - babybear.NewF("3567013861"), - babybear.NewF("2792004347"), - babybear.NewF("576325418"), - babybear.NewF("41126132"), - babybear.NewF("2713562324"), - babybear.NewF("151213722"), - babybear.NewF("2891185935"), - babybear.NewF("546846420"), - babybear.NewF("2939794919"), - babybear.NewF("2543469905"), - } - round += 1 - RC16[round] = [BABYBEAR_WIDTH]babybear.Variable{ - babybear.NewF("2191909784"), - babybear.NewF("3315138460"), - babybear.NewF("530414574"), - babybear.NewF("1242280418"), - babybear.NewF("1211740715"), - babybear.NewF("3993672165"), - babybear.NewF("2505083323"), - babybear.NewF("3845798801"), - babybear.NewF("538768466"), - babybear.NewF("2063567560"), - babybear.NewF("3366148274"), - babybear.NewF("1449831887"), - babybear.NewF("2408012466"), - babybear.NewF("294726285"), - babybear.NewF("3943435493"), - babybear.NewF("924016661"), - } - round += 1 - RC16[round] = [BABYBEAR_WIDTH]babybear.Variable{ - babybear.NewF("3633138367"), - babybear.NewF("3222789372"), - babybear.NewF("809116305"), - babybear.NewF("30100013"), - babybear.NewF("2655172876"), - babybear.NewF("2564247117"), - babybear.NewF("2478649732"), - babybear.NewF("4113689151"), - babybear.NewF("4120146082"), - babybear.NewF("2512308515"), - babybear.NewF("650406041"), - babybear.NewF("4240012393"), - babybear.NewF("2683508708"), - babybear.NewF("951073977"), - babybear.NewF("3460081988"), - babybear.NewF("339124269"), - } - round += 1 - RC16[round] = [BABYBEAR_WIDTH]babybear.Variable{ - babybear.NewF("130182653"), - babybear.NewF("2755946749"), - babybear.NewF("542600513"), - babybear.NewF("2816103022"), - babybear.NewF("1931786340"), - babybear.NewF("2044470840"), - babybear.NewF("1709908013"), - babybear.NewF("2938369043"), - babybear.NewF("3640399693"), - babybear.NewF("1374470239"), - babybear.NewF("2191149676"), - babybear.NewF("2637495682"), - babybear.NewF("4236394040"), - babybear.NewF("2289358846"), - babybear.NewF("3833368530"), - babybear.NewF("974546524"), - } - round += 1 - RC16[round] = [BABYBEAR_WIDTH]babybear.Variable{ - babybear.NewF("3306659113"), - babybear.NewF("2234814261"), - babybear.NewF("1188782305"), - babybear.NewF("223782844"), - babybear.NewF("2248980567"), - babybear.NewF("2309786141"), - babybear.NewF("2023401627"), - babybear.NewF("3278877413"), - babybear.NewF("2022138149"), - babybear.NewF("575851471"), - babybear.NewF("1612560780"), - babybear.NewF("3926656936"), - babybear.NewF("3318548977"), - babybear.NewF("2591863678"), - babybear.NewF("188109355"), - babybear.NewF("4217723909"), - } - round += 1 - RC16[round] = [BABYBEAR_WIDTH]babybear.Variable{ - babybear.NewF("1564209905"), - babybear.NewF("2154197895"), - babybear.NewF("2459687029"), - babybear.NewF("2870634489"), - babybear.NewF("1375012945"), - babybear.NewF("1529454825"), - babybear.NewF("306140690"), - babybear.NewF("2855578299"), - babybear.NewF("1246997295"), - babybear.NewF("3024298763"), - babybear.NewF("1915270363"), - babybear.NewF("1218245412"), - babybear.NewF("2479314020"), - babybear.NewF("2989827755"), - babybear.NewF("814378556"), - babybear.NewF("4039775921"), - } - round += 1 - RC16[round] = [BABYBEAR_WIDTH]babybear.Variable{ - babybear.NewF("1165280628"), - babybear.NewF("1203983801"), - babybear.NewF("3814740033"), - babybear.NewF("1919627044"), - babybear.NewF("600240215"), - babybear.NewF("773269071"), - babybear.NewF("486685186"), - babybear.NewF("4254048810"), - babybear.NewF("1415023565"), - babybear.NewF("502840102"), - babybear.NewF("4225648358"), - babybear.NewF("510217063"), - babybear.NewF("166444818"), - babybear.NewF("1430745893"), - babybear.NewF("1376516190"), - babybear.NewF("1775891321"), - } - round += 1 - RC16[round] = [BABYBEAR_WIDTH]babybear.Variable{ - babybear.NewF("1170945922"), - babybear.NewF("1105391877"), - babybear.NewF("261536467"), - babybear.NewF("1401687994"), - babybear.NewF("1022529847"), - babybear.NewF("2476446456"), - babybear.NewF("2603844878"), - babybear.NewF("3706336043"), - babybear.NewF("3463053714"), - babybear.NewF("1509644517"), - babybear.NewF("588552318"), - babybear.NewF("65252581"), - babybear.NewF("3696502656"), - babybear.NewF("2183330763"), - babybear.NewF("3664021233"), - babybear.NewF("1643809916"), - } - round += 1 - RC16[round] = [BABYBEAR_WIDTH]babybear.Variable{ - babybear.NewF("2922875898"), - babybear.NewF("3740690643"), - babybear.NewF("3932461140"), - babybear.NewF("161156271"), - babybear.NewF("2619943483"), - babybear.NewF("4077039509"), - babybear.NewF("2921201703"), - babybear.NewF("2085619718"), - babybear.NewF("2065264646"), - babybear.NewF("2615693812"), - babybear.NewF("3116555433"), - babybear.NewF("246100007"), - babybear.NewF("4281387154"), - babybear.NewF("4046141001"), - babybear.NewF("4027749321"), - babybear.NewF("111611860"), - } - round += 1 - RC16[round] = [BABYBEAR_WIDTH]babybear.Variable{ - babybear.NewF("2066954820"), - babybear.NewF("2502099969"), - babybear.NewF("2915053115"), - babybear.NewF("2362518586"), - babybear.NewF("366091708"), - babybear.NewF("2083204932"), - babybear.NewF("4138385632"), - babybear.NewF("3195157567"), - babybear.NewF("1318086382"), - babybear.NewF("521723799"), - babybear.NewF("702443405"), - babybear.NewF("2507670985"), - babybear.NewF("1760347557"), - babybear.NewF("2631999893"), - babybear.NewF("1672737554"), - babybear.NewF("1060867760"), - } - round += 1 - RC16[round] = [BABYBEAR_WIDTH]babybear.Variable{ - babybear.NewF("2359801781"), - babybear.NewF("2800231467"), - babybear.NewF("3010357035"), - babybear.NewF("1035997899"), - babybear.NewF("1210110952"), - babybear.NewF("1018506770"), - babybear.NewF("2799468177"), - babybear.NewF("1479380761"), - babybear.NewF("1536021911"), - babybear.NewF("358993854"), - babybear.NewF("579904113"), - babybear.NewF("3432144800"), - babybear.NewF("3625515809"), - babybear.NewF("199241497"), - babybear.NewF("4058304109"), - babybear.NewF("2590164234"), - } - round += 1 - RC16[round] = [BABYBEAR_WIDTH]babybear.Variable{ - babybear.NewF("1688530738"), - babybear.NewF("1580733335"), - babybear.NewF("2443981517"), - babybear.NewF("2206270565"), - babybear.NewF("2780074229"), - babybear.NewF("2628739677"), - babybear.NewF("2940123659"), - babybear.NewF("4145206827"), - babybear.NewF("3572278009"), - babybear.NewF("2779607509"), - babybear.NewF("1098718697"), - babybear.NewF("1424913749"), - babybear.NewF("2224415875"), - babybear.NewF("1108922178"), - babybear.NewF("3646272562"), - babybear.NewF("3935186184"), - } - round += 1 - RC16[round] = [BABYBEAR_WIDTH]babybear.Variable{ - babybear.NewF("820046587"), - babybear.NewF("1393386250"), - babybear.NewF("2665818575"), - babybear.NewF("2231782019"), - babybear.NewF("672377010"), - babybear.NewF("1920315467"), - babybear.NewF("1913164407"), - babybear.NewF("2029526876"), - babybear.NewF("2629271820"), - babybear.NewF("384320012"), - babybear.NewF("4112320585"), - babybear.NewF("3131824773"), - babybear.NewF("2347818197"), - babybear.NewF("2220997386"), - babybear.NewF("1772368609"), - babybear.NewF("2579960095"), - } - round += 1 - RC16[round] = [BABYBEAR_WIDTH]babybear.Variable{ - babybear.NewF("3544930873"), - babybear.NewF("225847443"), - babybear.NewF("3070082278"), - babybear.NewF("95643305"), - babybear.NewF("3438572042"), - babybear.NewF("3312856509"), - babybear.NewF("615850007"), - babybear.NewF("1863868773"), - babybear.NewF("803582265"), - babybear.NewF("3461976859"), - babybear.NewF("2903025799"), - babybear.NewF("1482092434"), - babybear.NewF("3902972499"), - babybear.NewF("3872341868"), - babybear.NewF("1530411808"), - babybear.NewF("2214923584"), - } - round += 1 - RC16[round] = [BABYBEAR_WIDTH]babybear.Variable{ - babybear.NewF("3118792481"), - babybear.NewF("2241076515"), - babybear.NewF("3983669831"), - babybear.NewF("3180915147"), - babybear.NewF("3838626501"), - babybear.NewF("1921630011"), - babybear.NewF("3415351771"), - babybear.NewF("2249953859"), - babybear.NewF("3755081630"), - babybear.NewF("486327260"), - babybear.NewF("1227575720"), - babybear.NewF("3643869379"), - babybear.NewF("2982026073"), - babybear.NewF("2466043731"), - babybear.NewF("1982634375"), - babybear.NewF("3769609014"), - } - round += 1 - RC16[round] = [BABYBEAR_WIDTH]babybear.Variable{ - babybear.NewF("2195455495"), - babybear.NewF("2596863283"), - babybear.NewF("4244994973"), - babybear.NewF("1983609348"), - babybear.NewF("4019674395"), - babybear.NewF("3469982031"), - babybear.NewF("1458697570"), - babybear.NewF("1593516217"), - babybear.NewF("1963896497"), - babybear.NewF("3115309118"), - babybear.NewF("1659132465"), - babybear.NewF("2536770756"), - babybear.NewF("3059294171"), - babybear.NewF("2618031334"), - babybear.NewF("2040903247"), - babybear.NewF("3799795076"), - } -} diff --git a/recursion/gnark-ffi/go/sp1/poseidon2/poseidon2_babybear.go b/recursion/gnark-ffi/go/sp1/poseidon2/poseidon2_babybear.go deleted file mode 100644 index 9f83956234..0000000000 --- a/recursion/gnark-ffi/go/sp1/poseidon2/poseidon2_babybear.go +++ /dev/null @@ -1,157 +0,0 @@ -package poseidon2 - -import ( - "github.com/consensys/gnark/frontend" - "github.com/succinctlabs/sp1-recursion-gnark/sp1/babybear" -) - -const BABYBEAR_WIDTH = 16 -const BABYBEAR_NUM_EXTERNAL_ROUNDS = 8 -const BABYBEAR_NUM_INTERNAL_ROUNDS = 13 -const BABYBEAR_DEGREE = 7 - -type Poseidon2BabyBearChip struct { - api frontend.API - fieldApi *babybear.Chip -} - -func NewBabyBearChip(api frontend.API) *Poseidon2BabyBearChip { - return &Poseidon2BabyBearChip{ - api: api, - fieldApi: babybear.NewChip(api), - } -} - -func (p *Poseidon2BabyBearChip) PermuteMut(state *[BABYBEAR_WIDTH]babybear.Variable) { - // The initial linear layer. - p.externalLinearLayer(state) - - // The first half of the external rounds. - rounds := BABYBEAR_NUM_EXTERNAL_ROUNDS + BABYBEAR_NUM_INTERNAL_ROUNDS - roundsFBeggining := BABYBEAR_NUM_EXTERNAL_ROUNDS / 2 - for r := 0; r < roundsFBeggining; r++ { - p.addRc(state, RC16[r]) - p.sbox(state) - p.externalLinearLayer(state) - } - - // The internal rounds. - p_end := roundsFBeggining + BABYBEAR_NUM_INTERNAL_ROUNDS - for r := roundsFBeggining; r < p_end; r++ { - state[0] = p.fieldApi.AddF(state[0], RC16[r][0]) - state[0] = p.sboxP(state[0]) - p.diffusionPermuteMut(state) - } - - // The second half of the external rounds. - for r := p_end; r < rounds; r++ { - p.addRc(state, RC16[r]) - p.sbox(state) - p.externalLinearLayer(state) - } -} - -func (p *Poseidon2BabyBearChip) addRc(state *[BABYBEAR_WIDTH]babybear.Variable, rc [BABYBEAR_WIDTH]babybear.Variable) { - for i := 0; i < BABYBEAR_WIDTH; i++ { - state[i] = p.fieldApi.AddF(state[i], rc[i]) - } -} - -func (p *Poseidon2BabyBearChip) sboxP(input babybear.Variable) babybear.Variable { - zero := babybear.NewF("0") - inputCpy := p.fieldApi.AddF(input, zero) - inputCpy = p.fieldApi.ReduceSlow(inputCpy) - inputValue := inputCpy.Value - i2 := p.api.Mul(inputValue, inputValue) - i4 := p.api.Mul(i2, i2) - i6 := p.api.Mul(i4, i2) - i7 := p.api.Mul(i6, inputValue) - i7bb := p.fieldApi.ReduceSlow(babybear.Variable{ - Value: i7, - NbBits: 31 * 7, - }) - return i7bb -} - -func (p *Poseidon2BabyBearChip) sbox(state *[BABYBEAR_WIDTH]babybear.Variable) { - for i := 0; i < BABYBEAR_WIDTH; i++ { - state[i] = p.sboxP(state[i]) - } -} - -func (p *Poseidon2BabyBearChip) mdsLightPermutation4x4(state []babybear.Variable) { - t01 := p.fieldApi.AddF(state[0], state[1]) - t23 := p.fieldApi.AddF(state[2], state[3]) - t0123 := p.fieldApi.AddF(t01, t23) - t01123 := p.fieldApi.AddF(t0123, state[1]) - t01233 := p.fieldApi.AddF(t0123, state[3]) - state[3] = p.fieldApi.AddF(t01233, p.fieldApi.MulFConst(state[0], 2)) - state[1] = p.fieldApi.AddF(t01123, p.fieldApi.MulFConst(state[2], 2)) - state[0] = p.fieldApi.AddF(t01123, t01) - state[2] = p.fieldApi.AddF(t01233, t23) -} - -func (p *Poseidon2BabyBearChip) externalLinearLayer(state *[BABYBEAR_WIDTH]babybear.Variable) { - for i := 0; i < BABYBEAR_WIDTH; i += 4 { - p.mdsLightPermutation4x4(state[i : i+4]) - } - - sums := [4]babybear.Variable{ - state[0], - state[1], - state[2], - state[3], - } - for i := 4; i < BABYBEAR_WIDTH; i += 4 { - sums[0] = p.fieldApi.AddF(sums[0], state[i]) - sums[1] = p.fieldApi.AddF(sums[1], state[i+1]) - sums[2] = p.fieldApi.AddF(sums[2], state[i+2]) - sums[3] = p.fieldApi.AddF(sums[3], state[i+3]) - } - - for i := 0; i < BABYBEAR_WIDTH; i++ { - state[i] = p.fieldApi.AddF(state[i], sums[i%4]) - } -} - -func (p *Poseidon2BabyBearChip) diffusionPermuteMut(state *[BABYBEAR_WIDTH]babybear.Variable) { - matInternalDiagM1 := [BABYBEAR_WIDTH]babybear.Variable{ - babybear.NewF("2013265919"), - babybear.NewF("1"), - babybear.NewF("2"), - babybear.NewF("4"), - babybear.NewF("8"), - babybear.NewF("16"), - babybear.NewF("32"), - babybear.NewF("64"), - babybear.NewF("128"), - babybear.NewF("256"), - babybear.NewF("512"), - babybear.NewF("1024"), - babybear.NewF("2048"), - babybear.NewF("4096"), - babybear.NewF("8192"), - babybear.NewF("32768"), - } - montyInverse := babybear.NewF("943718400") - p.matmulInternal(state, &matInternalDiagM1) - for i := 0; i < BABYBEAR_WIDTH; i++ { - state[i] = p.fieldApi.MulF(state[i], montyInverse) - } - -} - -func (p *Poseidon2BabyBearChip) matmulInternal( - state *[BABYBEAR_WIDTH]babybear.Variable, - matInternalDiagM1 *[BABYBEAR_WIDTH]babybear.Variable, -) { - sum := babybear.NewF("0") - for i := 0; i < BABYBEAR_WIDTH; i++ { - sum = p.fieldApi.AddF(sum, state[i]) - } - - for i := 0; i < BABYBEAR_WIDTH; i++ { - state[i] = p.fieldApi.MulF(state[i], matInternalDiagM1[i]) - state[i] = p.fieldApi.AddF(state[i], sum) - } -} diff --git a/recursion/gnark-ffi/go/sp1/sp1.go b/recursion/gnark-ffi/go/sp1/sp1.go index ccde520953..f3f3b24a51 100644 --- a/recursion/gnark-ffi/go/sp1/sp1.go +++ b/recursion/gnark-ffi/go/sp1/sp1.go @@ -68,7 +68,6 @@ func (circuit *Circuit) Define(api frontend.API) error { } hashAPI := poseidon2.NewChip(api) - hashBabyBearAPI := poseidon2.NewBabyBearChip(api) fieldAPI := babybear.NewChip(api) vars := make(map[string]frontend.Variable) felts := make(map[string]babybear.Variable) @@ -133,15 +132,6 @@ func (circuit *Circuit) Define(api frontend.API) error { vars[cs.Args[0][0]] = state[0] vars[cs.Args[1][0]] = state[1] vars[cs.Args[2][0]] = state[2] - case "PermuteBabyBear": - var state [16]babybear.Variable - for i := 0; i < 16; i++ { - state[i] = felts[cs.Args[i][0]] - } - hashBabyBearAPI.PermuteMut(&state) - for i := 0; i < 16; i++ { - felts[cs.Args[i][0]] = state[i] - } case "SelectV": vars[cs.Args[0][0]] = api.Select(vars[cs.Args[1][0]], vars[cs.Args[2][0]], vars[cs.Args[3][0]]) case "SelectF": diff --git a/recursion/gnark-ffi/go/sp1/test.go b/recursion/gnark-ffi/go/sp1/test.go deleted file mode 100644 index 8d2aa8f0ae..0000000000 --- a/recursion/gnark-ffi/go/sp1/test.go +++ /dev/null @@ -1,31 +0,0 @@ -package sp1 - -import ( - "github.com/consensys/gnark/frontend" - "github.com/succinctlabs/sp1-recursion-gnark/sp1/babybear" - "github.com/succinctlabs/sp1-recursion-gnark/sp1/poseidon2" -) - -type TestPoseidon2BabyBearCircuit struct { - Input [poseidon2.BABYBEAR_WIDTH]babybear.Variable `gnark:",public"` - ExpectedOutput [poseidon2.BABYBEAR_WIDTH]babybear.Variable `gnark:",public"` -} - -func (circuit *TestPoseidon2BabyBearCircuit) Define(api frontend.API) error { - poseidon2BabyBearChip := poseidon2.NewBabyBearChip(api) - fieldApi := babybear.NewChip(api) - - zero := babybear.NewF("0") - input := [poseidon2.BABYBEAR_WIDTH]babybear.Variable{} - for i := 0; i < poseidon2.BABYBEAR_WIDTH; i++ { - input[i] = fieldApi.AddF(circuit.Input[i], zero) - } - - poseidon2BabyBearChip.PermuteMut(&input) - - for i := 0; i < poseidon2.BABYBEAR_WIDTH; i++ { - fieldApi.AssertIsEqualF(circuit.ExpectedOutput[i], input[i]) - } - - return nil -} diff --git a/recursion/gnark-ffi/src/ffi.rs b/recursion/gnark-ffi/src/ffi.rs index 35a279ff27..d7ecf9d612 100644 --- a/recursion/gnark-ffi/src/ffi.rs +++ b/recursion/gnark-ffi/src/ffi.rs @@ -110,23 +110,6 @@ pub fn test_plonk_bn254(witness_json: &str, constraints_json: &str) { } } -pub fn test_babybear_poseidon2() { - cfg_if! { - if #[cfg(feature = "plonk")] { - unsafe { - let err_ptr = bind::TestPoseidonBabyBear2(); - if !err_ptr.is_null() { - // Safety: The error message is returned from the go code and is guaranteed to be valid. - let err = CString::from_raw(err_ptr); - panic!("TestPlonkBn254 failed: {}", err.into_string().unwrap()); - } - } - } else { - panic!("plonk feature not enabled"); - } - } -} - /// Converts a C string into a Rust String. /// /// # Safety @@ -157,20 +140,3 @@ impl C_PlonkBn254Proof { } } } - -#[cfg(test)] -mod tests { - use p3_baby_bear::BabyBear; - use p3_field::AbstractField; - use p3_symmetric::Permutation; - - #[cfg(feature = "plonk")] - #[test] - pub fn test_babybear_poseidon2() { - let perm = sp1_core::utils::inner_perm(); - let zeros = [BabyBear::zero(); 16]; - let result = perm.permute(zeros); - println!("{:?}", result); - super::test_babybear_poseidon2(); - } -} diff --git a/recursion/groth16/constraints.json b/recursion/groth16/constraints.json new file mode 100644 index 0000000000..28fc1fdac3 --- /dev/null +++ b/recursion/groth16/constraints.json @@ -0,0 +1 @@ +[{"opcode":"ImmV","args":[["var0"],["100"]]},{"opcode":"Num2BitsV","args":[["var1","var2","var3","var4","var5","var6","var7","var8","var9","var10","var11","var12","var13","var14","var15","var16","var17","var18","var19","var20","var21","var22","var23","var24","var25","var26","var27","var28","var29","var30","var31","var32"],["var0"],["32"]]},{"opcode":"ImmV","args":[["backend0"],["0"]]},{"opcode":"AssertEqV","args":[["var1"],["backend0"]]},{"opcode":"ImmV","args":[["backend1"],["0"]]},{"opcode":"AssertEqV","args":[["var2"],["backend1"]]},{"opcode":"ImmV","args":[["backend2"],["1"]]},{"opcode":"AssertEqV","args":[["var3"],["backend2"]]},{"opcode":"ImmV","args":[["backend3"],["0"]]},{"opcode":"AssertEqV","args":[["var4"],["backend3"]]},{"opcode":"ImmV","args":[["backend4"],["0"]]},{"opcode":"AssertEqV","args":[["var5"],["backend4"]]},{"opcode":"ImmV","args":[["backend5"],["1"]]},{"opcode":"AssertEqV","args":[["var6"],["backend5"]]},{"opcode":"ImmV","args":[["backend6"],["1"]]},{"opcode":"AssertEqV","args":[["var7"],["backend6"]]},{"opcode":"ImmV","args":[["backend7"],["0"]]},{"opcode":"AssertEqV","args":[["var8"],["backend7"]]},{"opcode":"ImmV","args":[["backend8"],["0"]]},{"opcode":"AssertEqV","args":[["var9"],["backend8"]]},{"opcode":"ImmV","args":[["backend9"],["0"]]},{"opcode":"AssertEqV","args":[["var10"],["backend9"]]},{"opcode":"ImmV","args":[["backend10"],["0"]]},{"opcode":"AssertEqV","args":[["var11"],["backend10"]]},{"opcode":"ImmV","args":[["backend11"],["0"]]},{"opcode":"AssertEqV","args":[["var12"],["backend11"]]},{"opcode":"ImmV","args":[["backend12"],["0"]]},{"opcode":"AssertEqV","args":[["var13"],["backend12"]]},{"opcode":"ImmV","args":[["backend13"],["0"]]},{"opcode":"AssertEqV","args":[["var14"],["backend13"]]},{"opcode":"ImmV","args":[["backend14"],["0"]]},{"opcode":"AssertEqV","args":[["var15"],["backend14"]]},{"opcode":"ImmV","args":[["backend15"],["0"]]},{"opcode":"AssertEqV","args":[["var16"],["backend15"]]},{"opcode":"ImmV","args":[["backend16"],["0"]]},{"opcode":"AssertEqV","args":[["var17"],["backend16"]]},{"opcode":"ImmV","args":[["backend17"],["0"]]},{"opcode":"AssertEqV","args":[["var18"],["backend17"]]},{"opcode":"ImmV","args":[["backend18"],["0"]]},{"opcode":"AssertEqV","args":[["var19"],["backend18"]]},{"opcode":"ImmV","args":[["backend19"],["0"]]},{"opcode":"AssertEqV","args":[["var20"],["backend19"]]},{"opcode":"ImmV","args":[["backend20"],["0"]]},{"opcode":"AssertEqV","args":[["var21"],["backend20"]]},{"opcode":"ImmV","args":[["backend21"],["0"]]},{"opcode":"AssertEqV","args":[["var22"],["backend21"]]},{"opcode":"ImmV","args":[["backend22"],["0"]]},{"opcode":"AssertEqV","args":[["var23"],["backend22"]]},{"opcode":"ImmV","args":[["backend23"],["0"]]},{"opcode":"AssertEqV","args":[["var24"],["backend23"]]},{"opcode":"ImmV","args":[["backend24"],["0"]]},{"opcode":"AssertEqV","args":[["var25"],["backend24"]]},{"opcode":"ImmV","args":[["backend25"],["0"]]},{"opcode":"AssertEqV","args":[["var26"],["backend25"]]},{"opcode":"ImmV","args":[["backend26"],["0"]]},{"opcode":"AssertEqV","args":[["var27"],["backend26"]]},{"opcode":"ImmV","args":[["backend27"],["0"]]},{"opcode":"AssertEqV","args":[["var28"],["backend27"]]},{"opcode":"ImmV","args":[["backend28"],["0"]]},{"opcode":"AssertEqV","args":[["var29"],["backend28"]]},{"opcode":"ImmV","args":[["backend29"],["0"]]},{"opcode":"AssertEqV","args":[["var30"],["backend29"]]},{"opcode":"ImmV","args":[["backend30"],["0"]]},{"opcode":"AssertEqV","args":[["var31"],["backend30"]]},{"opcode":"ImmV","args":[["backend31"],["0"]]},{"opcode":"AssertEqV","args":[["var32"],["backend31"]]}] \ No newline at end of file diff --git a/recursion/groth16/lib/libbabybear.a b/recursion/groth16/lib/libbabybear.a new file mode 100644 index 0000000000..e047c94965 Binary files /dev/null and b/recursion/groth16/lib/libbabybear.a differ diff --git a/recursion/groth16/main b/recursion/groth16/main new file mode 100755 index 0000000000..126a88bb45 Binary files /dev/null and b/recursion/groth16/main differ diff --git a/recursion/groth16/witness.json b/recursion/groth16/witness.json new file mode 100644 index 0000000000..ed4386877e --- /dev/null +++ b/recursion/groth16/witness.json @@ -0,0 +1 @@ +{"vars":["999"],"felts":["999"],"exts":[["999","0","0","0"]]} \ No newline at end of file diff --git a/recursion/program/src/machine/compress.rs b/recursion/program/src/machine/compress.rs index 406a7a04cf..f8fbc857bc 100644 --- a/recursion/program/src/machine/compress.rs +++ b/recursion/program/src/machine/compress.rs @@ -236,7 +236,6 @@ where challenger.observe(builder, element); } // verify proof. - let shard_idx = builder.eval(C::N::one()); StarkVerifier::::verify_shard( builder, &vk, @@ -244,7 +243,6 @@ where machine, &mut challenger, &proof, - shard_idx, ); // Load the public values from the proof. diff --git a/recursion/program/src/machine/core.rs b/recursion/program/src/machine/core.rs index 515bb1e7b9..1b86351109 100644 --- a/recursion/program/src/machine/core.rs +++ b/recursion/program/src/machine/core.rs @@ -160,18 +160,12 @@ where let cumulative_sum: Ext<_, _> = builder.eval(C::EF::zero().cons()); let current_pc: Felt<_> = builder.uninit(); let exit_code: Felt<_> = builder.uninit(); - - // Range check that the number of proofs is sufficiently small. - let num_shard_proofs: Var<_> = shard_proofs.len().materialize(builder); - builder.range_check_v(num_shard_proofs, 16); - // Verify proofs, validate transitions, and update accumulation variables. builder.range(0, shard_proofs.len()).for_each(|i, builder| { // Load the proof. let proof = builder.get(&shard_proofs, i); // Verify the shard proof. - let shard_idx = builder.eval(i + C::N::one()); let mut challenger = leaf_challenger.copy(builder); StarkVerifier::::verify_shard( builder, @@ -180,7 +174,6 @@ where machine, &mut challenger, &proof, - shard_idx, ); // Extract public values. @@ -270,9 +263,6 @@ where // Assert that exit code is the same for all proofs. builder.assert_felt_eq(exit_code, public_values.exit_code); - // Assert that the exit code is zero (success) for all proofs. - builder.assert_felt_eq(exit_code, C::F::zero()); - // Assert that the deferred proof digest is the same for all proofs. for (digest, current_digest) in deferred_proofs_digest .iter() diff --git a/recursion/program/src/machine/deferred.rs b/recursion/program/src/machine/deferred.rs index be380516b3..2ae232ab73 100644 --- a/recursion/program/src/machine/deferred.rs +++ b/recursion/program/src/machine/deferred.rs @@ -187,9 +187,7 @@ where let element = builder.get(&proof.public_values, j); challenger.observe(builder, element); } - - // Verify the proof. - let shard_idx = builder.eval(C::N::one()); + // verify the proof. StarkVerifier::::verify_shard( builder, &compress_vk, @@ -197,7 +195,6 @@ where machine, &mut challenger, &proof, - shard_idx, ); // Load the public values from the proof. diff --git a/recursion/program/src/machine/root.rs b/recursion/program/src/machine/root.rs index 4b3cb9e885..8e8eb72c67 100644 --- a/recursion/program/src/machine/root.rs +++ b/recursion/program/src/machine/root.rs @@ -107,16 +107,7 @@ where challenger.observe(builder, element); } // verify proof. - let shard_idx = builder.eval(C::N::one()); - StarkVerifier::::verify_shard( - builder, - &vk, - pcs, - machine, - &mut challenger, - proof, - shard_idx, - ); + StarkVerifier::::verify_shard(builder, &vk, pcs, machine, &mut challenger, proof); // Get the public inputs from the proof. let public_values_elements = (0..RECURSIVE_PROOF_NUM_PV_ELTS) diff --git a/recursion/program/src/stark.rs b/recursion/program/src/stark.rs index f3040fb7f5..aec5d9453d 100644 --- a/recursion/program/src/stark.rs +++ b/recursion/program/src/stark.rs @@ -12,6 +12,7 @@ use sp1_core::stark::StarkMachine; use sp1_core::stark::StarkVerifyingKey; use sp1_recursion_compiler::ir::Array; use sp1_recursion_compiler::ir::Ext; +use sp1_recursion_compiler::ir::ExtConst; use sp1_recursion_compiler::ir::SymbolicExt; use sp1_recursion_compiler::ir::SymbolicVar; use sp1_recursion_compiler::ir::Var; @@ -93,6 +94,48 @@ impl<'a, SC: StarkGenericConfig, A: MachineAir> VerifyingKeyHint<'a, SC } } +impl StarkRecursiveVerifier for StarkMachine +where + C::F: TwoAdicField, + SC: StarkGenericConfig< + Val = C::F, + Challenge = C::EF, + Domain = TwoAdicMultiplicativeCoset, + >, + A: MachineAir + for<'a> Air>, + C::F: TwoAdicField, + C::EF: TwoAdicField, + Com: Into<[SC::Val; DIGEST_SIZE]>, +{ + fn verify_shard( + &self, + builder: &mut Builder, + vk: &VerifyingKeyVariable, + pcs: &TwoAdicFriPcsVariable, + challenger: &mut DuplexChallengerVariable, + proof: &ShardProofVariable, + is_complete: impl Into::N>>, + ) { + // Verify the shard proof. + StarkVerifier::::verify_shard(builder, vk, pcs, self, challenger, proof); + + // Verify that the cumulative sum of the chip is zero if the shard is complete. + let cumulative_sum: Ext<_, _> = builder.uninit(); + builder + .range(0, proof.opened_values.chips.len()) + .for_each(|i, builder| { + let values = builder.get(&proof.opened_values.chips, i); + builder.assign(cumulative_sum, cumulative_sum + values.cumulative_sum); + }); + + builder + .if_eq(is_complete.into(), C::N::one()) + .then(|builder| { + builder.assert_ext_eq(cumulative_sum, C::EF::zero().cons()); + }); + } +} + pub type RecursiveVerifierConstraintFolder<'a, C> = GenericVerifierConstraintFolder< 'a, ::F, @@ -118,7 +161,6 @@ where machine: &StarkMachine, challenger: &mut DuplexChallengerVariable, proof: &ShardProofVariable, - shard_idx: Var, ) where A: MachineAir + for<'a> Air>, C::F: TwoAdicField, @@ -314,39 +356,6 @@ where builder.assert_var_ne(index, C::N::from_canonical_usize(EMPTY)); } - if chip.name() == "MemoryProgram" { - builder.if_eq(shard_idx, C::N::one()).then_or_else( - |builder| { - builder.assert_var_ne(index, C::N::from_canonical_usize(EMPTY)); - }, - |builder| { - builder.assert_var_eq(index, C::N::from_canonical_usize(EMPTY)); - }, - ); - } - - if chip.name() == "MemoryInit" { - builder.if_eq(shard_idx, C::N::one()).then_or_else( - |builder| { - builder.assert_var_ne(index, C::N::from_canonical_usize(EMPTY)); - }, - |builder| { - builder.assert_var_eq(index, C::N::from_canonical_usize(EMPTY)); - }, - ); - } - - if chip.name() == "MemoryFinalize" { - builder.if_eq(shard_idx, C::N::one()).then_or_else( - |builder| { - builder.assert_var_ne(index, C::N::from_canonical_usize(EMPTY)); - }, - |builder| { - builder.assert_var_eq(index, C::N::from_canonical_usize(EMPTY)); - }, - ); - } - builder .if_ne(index, C::N::from_canonical_usize(EMPTY)) .then(|builder| { diff --git a/tests/blake3-compress/Cargo.lock b/tests/blake3-compress/Cargo.lock new file mode 100644 index 0000000000..b45f827d8d --- /dev/null +++ b/tests/blake3-compress/Cargo.lock @@ -0,0 +1,760 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 3 + +[[package]] +name = "anyhow" +version = "1.0.86" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b3d1d046238990b9cf5bcde22a3fb3584ee5cf65fb2765f454ed428c7a0063da" + +[[package]] +name = "arrayvec" +version = "0.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "96d30a06541fbafbc7f82ed10c06164cfbd2c401138f6addd8404629c4b16711" + +[[package]] +name = "autocfg" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1fdabc7756949593fe60f30ec81974b613357de856987752631dea1e3394c80" + +[[package]] +name = "base16ct" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c7f02d4ea65f2c1853089ffd8d2787bdbc63de2f0d29dedbcf8ccdfa0ccd4cf" + +[[package]] +name = "base64ct" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b" + +[[package]] +name = "bincode" +version = "1.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1f45e9417d87227c7a56d22e471c6206462cba514c7590c09aff4cf6d1ddcad" +dependencies = [ + "serde", +] + +[[package]] +name = "bitvec" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1bc2832c24239b0141d5674bb9174f9d68a8b5b3f2753311927c172ca46f7e9c" +dependencies = [ + "funty", + "radium", + "tap", + "wyz", +] + +[[package]] +name = "blake3-compress-test" +version = "0.1.0" +dependencies = [ + "sp1-zkvm", +] + +[[package]] +name = "block-buffer" +version = "0.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71" +dependencies = [ + "generic-array", +] + +[[package]] +name = "byte-slice-cast" +version = "1.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c3ac9f8b63eca6fd385229b3675f6cc0dc5c8a5c8a54a59d4f52ffd670d87b0c" + +[[package]] +name = "cfg-if" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" + +[[package]] +name = "const-oid" +version = "0.9.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c2459377285ad874054d797f3ccebf984978aa39129f6eafde5cdc8315b612f8" + +[[package]] +name = "cpufeatures" +version = "0.2.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "53fe5e26ff1b7aef8bca9c6080520cfb8d9333c7568e1829cef191a9723e5504" +dependencies = [ + "libc", +] + +[[package]] +name = "crypto-bigint" +version = "0.5.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0dc92fb57ca44df6db8059111ab3af99a63d5d0f8375d9972e319a379c6bab76" +dependencies = [ + "generic-array", + "rand_core", + "subtle", + "zeroize", +] + +[[package]] +name = "crypto-common" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3" +dependencies = [ + "generic-array", + "typenum", +] + +[[package]] +name = "der" +version = "0.7.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fffa369a668c8af7dbf8b5e56c9f744fbd399949ed171606040001947de40b1c" +dependencies = [ + "const-oid", + "zeroize", +] + +[[package]] +name = "derive_more" +version = "0.99.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4fb810d30a7c1953f91334de7244731fc3f3c10d7fe163338a35b9f640960321" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + +[[package]] +name = "digest" +version = "0.10.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" +dependencies = [ + "block-buffer", + "const-oid", + "crypto-common", + "subtle", +] + +[[package]] +name = "ecdsa" +version = "0.16.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee27f32b5c5292967d2d4a9d7f1e0b0aed2c15daded5a60300e4abb9d8020bca" +dependencies = [ + "der", + "digest", + "elliptic-curve", + "rfc6979", + "signature", + "spki", +] + +[[package]] +name = "elliptic-curve" +version = "0.13.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b5e6043086bf7973472e0c7dff2142ea0b680d30e18d9cc40f267efbf222bd47" +dependencies = [ + "base16ct", + "crypto-bigint", + "digest", + "ff", + "generic-array", + "group", + "pkcs8", + "rand_core", + "sec1", + "subtle", + "tap", + "zeroize", +] + +[[package]] +name = "equivalent" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" + +[[package]] +name = "ff" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ded41244b729663b1e574f1b4fb731469f69f79c17667b5d776b16cda0479449" +dependencies = [ + "bitvec", + "rand_core", + "subtle", +] + +[[package]] +name = "funty" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6d5a32815ae3f33302d95fdcb2ce17862f8c65363dcfd29360480ba1001fc9c" + +[[package]] +name = "generic-array" +version = "0.14.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a" +dependencies = [ + "typenum", + "version_check", + "zeroize", +] + +[[package]] +name = "getrandom" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7" +dependencies = [ + "cfg-if", + "libc", + "wasi", +] + +[[package]] +name = "group" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0f9ef7462f7c099f518d754361858f86d8a07af53ba9af0fe635bbccb151a63" +dependencies = [ + "ff", + "rand_core", + "subtle", +] + +[[package]] +name = "hashbrown" +version = "0.14.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "290f1a1d9242c78d09ce40a5e87e7554ee637af1351968159f4952f028f75604" + +[[package]] +name = "hex" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" + +[[package]] +name = "hmac" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c49c37c09c17a53d937dfbb742eb3a961d65a994e6bcdcf37e7399d0cc8ab5e" +dependencies = [ + "digest", +] + +[[package]] +name = "impl-trait-for-tuples" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "11d7a9f6330b71fea57921c9b61c47ee6e84f72d394754eff6163ae67e7395eb" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + +[[package]] +name = "indexmap" +version = "2.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "168fb715dda47215e360912c096649d23d58bf392ac62f73919e831745e40f26" +dependencies = [ + "equivalent", + "hashbrown", +] + +[[package]] +name = "k256" +version = "0.13.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "956ff9b67e26e1a6a866cb758f12c6f8746208489e3e4a4b5580802f2f0a587b" +dependencies = [ + "cfg-if", + "ecdsa", + "elliptic-curve", + "once_cell", + "sha2", + "signature", +] + +[[package]] +name = "libc" +version = "0.2.155" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97b3888a4aecf77e811145cadf6eef5901f4782c53886191b2f693f24761847c" + +[[package]] +name = "libm" +version = "0.2.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058" + +[[package]] +name = "memchr" +version = "2.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c8640c5d730cb13ebd907d8d04b52f55ac9a2eec55b440c8892f40d56c76c1d" + +[[package]] +name = "num" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "35bd024e8b2ff75562e5f34e7f4905839deb4b22955ef5e73d2fea1b9813cb23" +dependencies = [ + "num-bigint", + "num-complex", + "num-integer", + "num-iter", + "num-rational", + "num-traits", +] + +[[package]] +name = "num-bigint" +version = "0.4.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c165a9ab64cf766f73521c0dd2cfdff64f488b8f0b3e621face3462d3db536d7" +dependencies = [ + "num-integer", + "num-traits", +] + +[[package]] +name = "num-complex" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-integer" +version = "0.1.46" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-iter" +version = "0.1.45" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1429034a0490724d0075ebb2bc9e875d6503c3cf69e235a8941aa757d83ef5bf" +dependencies = [ + "autocfg", + "num-integer", + "num-traits", +] + +[[package]] +name = "num-rational" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f83d14da390562dca69fc84082e73e548e1ad308d24accdedd2720017cb37824" +dependencies = [ + "num-bigint", + "num-integer", + "num-traits", +] + +[[package]] +name = "num-traits" +version = "0.2.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" +dependencies = [ + "autocfg", +] + +[[package]] +name = "once_cell" +version = "1.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" + +[[package]] +name = "parity-scale-codec" +version = "3.6.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "881331e34fa842a2fb61cc2db9643a8fedc615e47cfcc52597d1af0db9a7e8fe" +dependencies = [ + "arrayvec", + "byte-slice-cast", + "impl-trait-for-tuples", + "parity-scale-codec-derive", +] + +[[package]] +name = "parity-scale-codec-derive" +version = "3.6.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be30eaf4b0a9fba5336683b38de57bb86d179a35862ba6bfcf57625d006bde5b" +dependencies = [ + "proc-macro-crate 2.0.2", + "proc-macro2", + "quote", + "syn 1.0.109", +] + +[[package]] +name = "pkcs8" +version = "0.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f950b2377845cebe5cf8b5165cb3cc1a5e0fa5cfa3e1f7f55707d8fd82e0a7b7" +dependencies = [ + "der", + "spki", +] + +[[package]] +name = "ppv-lite86" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" + +[[package]] +name = "proc-macro-crate" +version = "1.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f4c021e1093a56626774e81216a4ce732a735e5bad4868a03f3ed65ca0c3919" +dependencies = [ + "once_cell", + "toml_edit 0.19.15", +] + +[[package]] +name = "proc-macro-crate" +version = "2.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b00f26d3400549137f92511a46ac1cd8ce37cb5598a96d382381458b992a5d24" +dependencies = [ + "toml_datetime", + "toml_edit 0.20.2", +] + +[[package]] +name = "proc-macro2" +version = "1.0.78" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e2422ad645d89c99f8f3e6b88a9fdeca7fabeac836b1002371c4367c8f984aae" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "quote" +version = "1.0.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "291ec9ab5efd934aaf503a6466c5d5251535d108ee747472c3977cc5acc868ef" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "radium" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc33ff2d4973d518d823d61aa239014831e521c75da58e3df4840d3f47749d09" + +[[package]] +name = "rand" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +dependencies = [ + "libc", + "rand_chacha", + "rand_core", +] + +[[package]] +name = "rand_chacha" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +dependencies = [ + "ppv-lite86", + "rand_core", +] + +[[package]] +name = "rand_core" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" +dependencies = [ + "getrandom", +] + +[[package]] +name = "rfc6979" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8dd2a808d456c4a54e300a23e9f5a67e122c3024119acbfd73e3bf664491cb2" +dependencies = [ + "hmac", + "subtle", +] + +[[package]] +name = "scale-info" +version = "2.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7c453e59a955f81fb62ee5d596b450383d699f152d350e9d23a0db2adb78e4c0" +dependencies = [ + "cfg-if", + "derive_more", + "parity-scale-codec", + "scale-info-derive", +] + +[[package]] +name = "scale-info-derive" +version = "2.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "18cf6c6447f813ef19eb450e985bcce6705f9ce7660db221b59093d15c79c4b7" +dependencies = [ + "proc-macro-crate 1.3.1", + "proc-macro2", + "quote", + "syn 1.0.109", +] + +[[package]] +name = "sec1" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3e97a565f76233a6003f9f5c54be1d9c5bdfa3eccfb189469f11ec4901c47dc" +dependencies = [ + "base16ct", + "der", + "generic-array", + "pkcs8", + "subtle", + "zeroize", +] + +[[package]] +name = "serde" +version = "1.0.203" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7253ab4de971e72fb7be983802300c30b5a7f0c2e56fab8abfc6a214307c0094" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.203" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "500cbc0ebeb6f46627f50f3f5811ccf6bf00643be300b4c3eabc0ef55dc5b5ba" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.48", +] + +[[package]] +name = "sha2" +version = "0.10.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "793db75ad2bcafc3ffa7c68b215fee268f537982cd901d132f89c6343f3a3dc8" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + +[[package]] +name = "signature" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77549399552de45a898a580c1b41d445bf730df867cc44e6c0233bbc4b8329de" +dependencies = [ + "digest", + "rand_core", +] + +[[package]] +name = "snowbridge-amcl" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "460a9ed63cdf03c1b9847e8a12a5f5ba19c4efd5869e4a737e05be25d7c427e5" +dependencies = [ + "parity-scale-codec", + "scale-info", +] + +[[package]] +name = "sp1-precompiles" +version = "0.1.0" +dependencies = [ + "anyhow", + "bincode", + "cfg-if", + "getrandom", + "hex", + "k256", + "num", + "rand", + "serde", + "snowbridge-amcl", +] + +[[package]] +name = "sp1-zkvm" +version = "0.1.0" +dependencies = [ + "bincode", + "cfg-if", + "getrandom", + "k256", + "libm", + "once_cell", + "rand", + "serde", + "sha2", + "sp1-precompiles", +] + +[[package]] +name = "spki" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d91ed6c858b01f942cd56b37a94b3e0a1798290327d1236e4d9cf4eaca44d29d" +dependencies = [ + "base64ct", + "der", +] + +[[package]] +name = "subtle" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "81cdd64d312baedb58e21336b31bc043b77e01cc99033ce76ef539f78e965ebc" + +[[package]] +name = "syn" +version = "1.0.109" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "syn" +version = "2.0.48" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0f3531638e407dfc0814761abb7c00a5b54992b849452a0646b7f65c9f770f3f" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "tap" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369" + +[[package]] +name = "toml_datetime" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7cda73e2f1397b1262d6dfdcef8aafae14d1de7748d66822d3bfeeb6d03e5e4b" + +[[package]] +name = "toml_edit" +version = "0.19.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b5bb770da30e5cbfde35a2d7b9b8a2c4b8ef89548a7a6aeab5c9a576e3e7421" +dependencies = [ + "indexmap", + "toml_datetime", + "winnow", +] + +[[package]] +name = "toml_edit" +version = "0.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "396e4d48bbb2b7554c944bde63101b5ae446cff6ec4a24227428f15eb72ef338" +dependencies = [ + "indexmap", + "toml_datetime", + "winnow", +] + +[[package]] +name = "typenum" +version = "1.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" + +[[package]] +name = "unicode-ident" +version = "1.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" + +[[package]] +name = "version_check" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" + +[[package]] +name = "wasi" +version = "0.11.0+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" + +[[package]] +name = "winnow" +version = "0.5.40" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f593a95398737aeed53e489c785df13f3618e41dbcd6718c6addbf1395aa6876" +dependencies = [ + "memchr", +] + +[[package]] +name = "wyz" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05f360fc0b24296329c78fda852a1e9ae82de9cf7b27dae4b7f62f118f77b9ed" +dependencies = [ + "tap", +] + +[[package]] +name = "zeroize" +version = "1.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "525b4ec142c6b68a2d10f01f7bbf6755599ca3f81ea53b8431b7dd348f5fdb2d" diff --git a/tests/blake3-compress/Cargo.toml b/tests/blake3-compress/Cargo.toml new file mode 100644 index 0000000000..e5987407cc --- /dev/null +++ b/tests/blake3-compress/Cargo.toml @@ -0,0 +1,8 @@ +[workspace] +[package] +version = "0.1.0" +name = "blake3-compress-test" +edition = "2021" + +[dependencies] +sp1-zkvm = { path = "../../zkvm/entrypoint" } diff --git a/tests/blake3-compress/elf/riscv32im-succinct-zkvm-elf b/tests/blake3-compress/elf/riscv32im-succinct-zkvm-elf new file mode 100755 index 0000000000..4e0fee0235 Binary files /dev/null and b/tests/blake3-compress/elf/riscv32im-succinct-zkvm-elf differ diff --git a/tests/blake3-compress/src/main.rs b/tests/blake3-compress/src/main.rs new file mode 100644 index 0000000000..6bbee4916f --- /dev/null +++ b/tests/blake3-compress/src/main.rs @@ -0,0 +1,42 @@ +#![no_main] +sp1_zkvm::entrypoint!(main); + +extern "C" { + fn syscall_blake3_compress_inner(p: *mut u32, q: *const u32); +} + +pub fn main() { + // The input message and state are simply 0, 1, ..., 95 followed by some fixed constants. + for _i in 0..10 { + let input_message: [u8; 64] = [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, + 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, + 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, + ]; + + let mut input_state: [u8; 64] = [ + 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, + 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 103, 230, 9, 106, 133, 174, 103, 187, 114, 243, + 110, 60, 58, 245, 79, 165, 96, 0, 0, 0, 0, 0, 0, 0, 64, 0, 0, 0, 97, 0, 0, 0, + ]; + + unsafe { + syscall_blake3_compress_inner( + input_state.as_mut_ptr() as *mut u32, + input_message.as_ptr() as *const u32, + ); + } + + // The expected output state is the result of compress_inner. + let output_state: [u8; 64] = [ + 239, 181, 94, 129, 58, 124, 80, 104, 126, 210, 5, 157, 255, 58, 238, 89, 252, 106, 170, + 12, 233, 56, 58, 31, 215, 16, 105, 97, 11, 229, 238, 73, 6, 79, 155, 180, 197, 73, 116, + 0, 127, 22, 16, 39, 116, 174, 85, 5, 61, 94, 87, 6, 236, 10, 36, 238, 119, 171, 207, + 171, 189, 216, 43, 250, + ]; + + assert_eq!(input_state, output_state); + } + + println!("done"); +} diff --git a/tests/bls12381-add/elf/riscv32im-succinct-zkvm-elf b/tests/bls12381-add/elf/riscv32im-succinct-zkvm-elf index 6e2c7e6866..fce1bf9ffe 100755 Binary files a/tests/bls12381-add/elf/riscv32im-succinct-zkvm-elf and b/tests/bls12381-add/elf/riscv32im-succinct-zkvm-elf differ diff --git a/tests/bls12381-add/src/main.rs b/tests/bls12381-add/src/main.rs index 681cf39afe..874e9f066e 100644 --- a/tests/bls12381-add/src/main.rs +++ b/tests/bls12381-add/src/main.rs @@ -6,48 +6,44 @@ extern "C" { } pub fn main() { - for _ in 0..4 { - // generator. - // 3685416753713387016781088315183077757961620795782546409894578378688607592378376318836054947676345821548104185464507 - // 1339506544944476473020471379941921221584933875938349620426543736416511423956333506472724655353366534992391756441569 - let mut a: [u8; 96] = [ - 187, 198, 34, 219, 10, 240, 58, 251, 239, 26, 122, 249, 63, 232, 85, 108, 88, 172, 27, - 23, 63, 58, 78, 161, 5, 185, 116, 151, 79, 140, 104, 195, 15, 172, 169, 79, 140, 99, - 149, 38, 148, 215, 151, 49, 167, 211, 241, 23, 225, 231, 197, 70, 41, 35, 170, 12, 228, - 138, 136, 162, 68, 199, 60, 208, 237, 179, 4, 44, 203, 24, 219, 0, 246, 10, 208, 213, - 149, 224, 245, 252, 228, 138, 29, 116, 237, 48, 158, 160, 241, 160, 170, 227, 129, 244, - 179, 8, - ]; + // generator. + // 3685416753713387016781088315183077757961620795782546409894578378688607592378376318836054947676345821548104185464507 + // 1339506544944476473020471379941921221584933875938349620426543736416511423956333506472724655353366534992391756441569 + let mut a: [u8; 96] = [ + 187, 198, 34, 219, 10, 240, 58, 251, 239, 26, 122, 249, 63, 232, 85, 108, 88, 172, 27, 23, + 63, 58, 78, 161, 5, 185, 116, 151, 79, 140, 104, 195, 15, 172, 169, 79, 140, 99, 149, 38, + 148, 215, 151, 49, 167, 211, 241, 23, 225, 231, 197, 70, 41, 35, 170, 12, 228, 138, 136, + 162, 68, 199, 60, 208, 237, 179, 4, 44, 203, 24, 219, 0, 246, 10, 208, 213, 149, 224, 245, + 252, 228, 138, 29, 116, 237, 48, 158, 160, 241, 160, 170, 227, 129, 244, 179, 8, + ]; - // 2 * generator. - // 838589206289216005799424730305866328161735431124665289961769162861615689790485775997575391185127590486775437397838 - // 3450209970729243429733164009999191867485184320918914219895632678707687208996709678363578245114137957452475385814312 - let b: [u8; 96] = [ - 78, 15, 191, 41, 85, 140, 154, 195, 66, 124, 28, 143, 187, 117, 143, 226, 42, 166, 88, - 195, 10, 45, 144, 67, 37, 1, 40, 145, 48, 219, 33, 151, 12, 69, 169, 80, 235, 200, 8, - 136, 70, 103, 77, 144, 234, 203, 114, 5, 40, 157, 116, 121, 25, 136, 134, 186, 27, 189, - 22, 205, 212, 217, 86, 76, 106, 215, 95, 29, 2, 185, 59, 247, 97, 228, 112, 134, 203, - 62, 186, 34, 56, 142, 157, 119, 115, 166, 253, 34, 163, 115, 198, 171, 140, 157, 106, - 22, - ]; + // 2 * generator. + // 838589206289216005799424730305866328161735431124665289961769162861615689790485775997575391185127590486775437397838 + // 3450209970729243429733164009999191867485184320918914219895632678707687208996709678363578245114137957452475385814312 + let b: [u8; 96] = [ + 78, 15, 191, 41, 85, 140, 154, 195, 66, 124, 28, 143, 187, 117, 143, 226, 42, 166, 88, 195, + 10, 45, 144, 67, 37, 1, 40, 145, 48, 219, 33, 151, 12, 69, 169, 80, 235, 200, 8, 136, 70, + 103, 77, 144, 234, 203, 114, 5, 40, 157, 116, 121, 25, 136, 134, 186, 27, 189, 22, 205, + 212, 217, 86, 76, 106, 215, 95, 29, 2, 185, 59, 247, 97, 228, 112, 134, 203, 62, 186, 34, + 56, 142, 157, 119, 115, 166, 253, 34, 163, 115, 198, 171, 140, 157, 106, 22, + ]; - unsafe { - syscall_bls12381_add(a.as_mut_ptr() as *mut u32, b.as_ptr() as *const u32); - } + unsafe { + syscall_bls12381_add(a.as_mut_ptr() as *mut u32, b.as_ptr() as *const u32); + } - // 3 * generator. - // 1527649530533633684281386512094328299672026648504329745640827351945739272160755686119065091946435084697047221031460 - // 487897572011753812113448064805964756454529228648704488481988876974355015977479905373670519228592356747638779818193 - let c: [u8; 96] = [ - 36, 82, 78, 2, 201, 192, 210, 150, 155, 23, 162, 44, 11, 122, 116, 129, 249, 63, 91, - 51, 81, 10, 120, 243, 241, 165, 233, 155, 31, 214, 18, 177, 151, 150, 169, 236, 45, 33, - 101, 23, 19, 240, 209, 249, 8, 227, 236, 9, 209, 48, 174, 144, 5, 59, 71, 163, 92, 244, - 74, 99, 108, 37, 69, 231, 230, 59, 212, 15, 49, 39, 156, 157, 127, 9, 195, 171, 221, - 12, 154, 166, 12, 248, 197, 137, 51, 98, 132, 138, 159, 176, 245, 166, 211, 128, 43, 3, - ]; + // 3 * generator. + // 1527649530533633684281386512094328299672026648504329745640827351945739272160755686119065091946435084697047221031460 + // 487897572011753812113448064805964756454529228648704488481988876974355015977479905373670519228592356747638779818193 + let c: [u8; 96] = [ + 36, 82, 78, 2, 201, 192, 210, 150, 155, 23, 162, 44, 11, 122, 116, 129, 249, 63, 91, 51, + 81, 10, 120, 243, 241, 165, 233, 155, 31, 214, 18, 177, 151, 150, 169, 236, 45, 33, 101, + 23, 19, 240, 209, 249, 8, 227, 236, 9, 209, 48, 174, 144, 5, 59, 71, 163, 92, 244, 74, 99, + 108, 37, 69, 231, 230, 59, 212, 15, 49, 39, 156, 157, 127, 9, 195, 171, 221, 12, 154, 166, + 12, 248, 197, 137, 51, 98, 132, 138, 159, 176, 245, 166, 211, 128, 43, 3, + ]; - assert_eq!(a, c); - } + assert_eq!(a, c); println!("done"); } diff --git a/tests/bls12381-decompress/elf/riscv32im-succinct-zkvm-elf b/tests/bls12381-decompress/elf/riscv32im-succinct-zkvm-elf index 3a8f2e1872..818954dc49 100755 Binary files a/tests/bls12381-decompress/elf/riscv32im-succinct-zkvm-elf and b/tests/bls12381-decompress/elf/riscv32im-succinct-zkvm-elf differ diff --git a/tests/bls12381-decompress/src/main.rs b/tests/bls12381-decompress/src/main.rs index a359fa9f7c..3e93a099f2 100644 --- a/tests/bls12381-decompress/src/main.rs +++ b/tests/bls12381-decompress/src/main.rs @@ -7,22 +7,19 @@ extern "C" { pub fn main() { let compressed_key: [u8; 48] = sp1_zkvm::io::read_vec().try_into().unwrap(); + let mut decompressed_key: [u8; 96] = [0u8; 96]; - for _ in 0..4 { - let mut decompressed_key: [u8; 96] = [0u8; 96]; + decompressed_key[..48].copy_from_slice(&compressed_key); - decompressed_key[..48].copy_from_slice(&compressed_key); + println!("before: {:?}", decompressed_key); - println!("before: {:?}", decompressed_key); + let is_odd = (decompressed_key[0] & 0b_0010_0000) >> 5 == 0; + decompressed_key[0] &= 0b_0001_1111; - let is_odd = (decompressed_key[0] & 0b_0010_0000) >> 5 == 0; - decompressed_key[0] &= 0b_0001_1111; - - unsafe { - syscall_bls12381_decompress(&mut decompressed_key, is_odd); - } - - println!("after: {:?}", decompressed_key); - sp1_zkvm::io::commit_slice(&decompressed_key); + unsafe { + syscall_bls12381_decompress(&mut decompressed_key, is_odd); } + println!("after: {:?}", decompressed_key); + + sp1_zkvm::io::commit_slice(&decompressed_key); } diff --git a/tests/bls12381-double/elf/riscv32im-succinct-zkvm-elf b/tests/bls12381-double/elf/riscv32im-succinct-zkvm-elf index 50470172a8..5c4706b8fa 100755 Binary files a/tests/bls12381-double/elf/riscv32im-succinct-zkvm-elf and b/tests/bls12381-double/elf/riscv32im-succinct-zkvm-elf differ diff --git a/tests/bls12381-mul/elf/riscv32im-succinct-zkvm-elf b/tests/bls12381-mul/elf/riscv32im-succinct-zkvm-elf index d1fe6cdf6f..313a6226f3 100755 Binary files a/tests/bls12381-mul/elf/riscv32im-succinct-zkvm-elf and b/tests/bls12381-mul/elf/riscv32im-succinct-zkvm-elf differ diff --git a/tests/bls12381-mul/src/main.rs b/tests/bls12381-mul/src/main.rs index 89169660a3..9f90906e92 100644 --- a/tests/bls12381-mul/src/main.rs +++ b/tests/bls12381-mul/src/main.rs @@ -6,42 +6,39 @@ use sp1_zkvm::precompiles::utils::AffinePoint; #[sp1_derive::cycle_tracker] pub fn main() { - for _ in 0..4 { - // generator. - // 3685416753713387016781088315183077757961620795782546409894578378688607592378376318836054947676345821548104185464507 - // 1339506544944476473020471379941921221584933875938349620426543736416511423956333506472724655353366534992391756441569 - let a: [u8; 96] = [ - 187, 198, 34, 219, 10, 240, 58, 251, 239, 26, 122, 249, 63, 232, 85, 108, 88, 172, 27, - 23, 63, 58, 78, 161, 5, 185, 116, 151, 79, 140, 104, 195, 15, 172, 169, 79, 140, 99, - 149, 38, 148, 215, 151, 49, 167, 211, 241, 23, 225, 231, 197, 70, 41, 35, 170, 12, 228, - 138, 136, 162, 68, 199, 60, 208, 237, 179, 4, 44, 203, 24, 219, 0, 246, 10, 208, 213, - 149, 224, 245, 252, 228, 138, 29, 116, 237, 48, 158, 160, 241, 160, 170, 227, 129, 244, - 179, 8, - ]; - - let mut a_point = AffinePoint::::from_le_bytes(&a); - - // scalar. - // 3 - let scalar: [u32; 12] = [3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]; - - println!("cycle-tracker-start: bn254_mul"); - a_point.mul_assign(&scalar); - println!("cycle-tracker-end: bn254_mul"); - - // 3 * generator. - // 1527649530533633684281386512094328299672026648504329745640827351945739272160755686119065091946435084697047221031460 - // 487897572011753812113448064805964756454529228648704488481988876974355015977479905373670519228592356747638779818193 - let c: [u8; 96] = [ - 36, 82, 78, 2, 201, 192, 210, 150, 155, 23, 162, 44, 11, 122, 116, 129, 249, 63, 91, - 51, 81, 10, 120, 243, 241, 165, 233, 155, 31, 214, 18, 177, 151, 150, 169, 236, 45, 33, - 101, 23, 19, 240, 209, 249, 8, 227, 236, 9, 209, 48, 174, 144, 5, 59, 71, 163, 92, 244, - 74, 99, 108, 37, 69, 231, 230, 59, 212, 15, 49, 39, 156, 157, 127, 9, 195, 171, 221, - 12, 154, 166, 12, 248, 197, 137, 51, 98, 132, 138, 159, 176, 245, 166, 211, 128, 43, 3, - ]; - - assert_eq!(a_point.to_le_bytes(), c); - } + // generator. + // 3685416753713387016781088315183077757961620795782546409894578378688607592378376318836054947676345821548104185464507 + // 1339506544944476473020471379941921221584933875938349620426543736416511423956333506472724655353366534992391756441569 + let a: [u8; 96] = [ + 187, 198, 34, 219, 10, 240, 58, 251, 239, 26, 122, 249, 63, 232, 85, 108, 88, 172, 27, 23, + 63, 58, 78, 161, 5, 185, 116, 151, 79, 140, 104, 195, 15, 172, 169, 79, 140, 99, 149, 38, + 148, 215, 151, 49, 167, 211, 241, 23, 225, 231, 197, 70, 41, 35, 170, 12, 228, 138, 136, + 162, 68, 199, 60, 208, 237, 179, 4, 44, 203, 24, 219, 0, 246, 10, 208, 213, 149, 224, 245, + 252, 228, 138, 29, 116, 237, 48, 158, 160, 241, 160, 170, 227, 129, 244, 179, 8, + ]; + + let mut a_point = AffinePoint::::from_le_bytes(&a); + + // scalar. + // 3 + let scalar: [u32; 12] = [3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]; + + println!("cycle-tracker-start: bn254_mul"); + a_point.mul_assign(&scalar); + println!("cycle-tracker-end: bn254_mul"); + + // 3 * generator. + // 1527649530533633684281386512094328299672026648504329745640827351945739272160755686119065091946435084697047221031460 + // 487897572011753812113448064805964756454529228648704488481988876974355015977479905373670519228592356747638779818193 + let c: [u8; 96] = [ + 36, 82, 78, 2, 201, 192, 210, 150, 155, 23, 162, 44, 11, 122, 116, 129, 249, 63, 91, 51, + 81, 10, 120, 243, 241, 165, 233, 155, 31, 214, 18, 177, 151, 150, 169, 236, 45, 33, 101, + 23, 19, 240, 209, 249, 8, 227, 236, 9, 209, 48, 174, 144, 5, 59, 71, 163, 92, 244, 74, 99, + 108, 37, 69, 231, 230, 59, 212, 15, 49, 39, 156, 157, 127, 9, 195, 171, 221, 12, 154, 166, + 12, 248, 197, 137, 51, 98, 132, 138, 159, 176, 245, 166, 211, 128, 43, 3, + ]; + + assert_eq!(a_point.to_le_bytes(), c); println!("done"); } diff --git a/tests/bn254-add/elf/riscv32im-succinct-zkvm-elf b/tests/bn254-add/elf/riscv32im-succinct-zkvm-elf index a55b917d17..a45b52cd9d 100755 Binary files a/tests/bn254-add/elf/riscv32im-succinct-zkvm-elf and b/tests/bn254-add/elf/riscv32im-succinct-zkvm-elf differ diff --git a/tests/bn254-add/src/main.rs b/tests/bn254-add/src/main.rs index 406681d656..7e164663da 100644 --- a/tests/bn254-add/src/main.rs +++ b/tests/bn254-add/src/main.rs @@ -6,42 +6,40 @@ extern "C" { } pub fn main() { - for _ in 0..4 { - // generator. - // 1 - // 2 - let mut a: [u8; 64] = [ - 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, - ]; + // generator. + // 1 + // 2 + let mut a: [u8; 64] = [ + 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, + ]; - // 2 * generator. - // 1368015179489954701390400359078579693043519447331113978918064868415326638035 - // 9918110051302171585080402603319702774565515993150576347155970296011118125764 - let b: [u8; 64] = [ - 211, 207, 135, 109, 193, 8, 194, 211, 168, 28, 135, 22, 169, 22, 120, 217, 133, 21, 24, - 104, 91, 4, 133, 155, 2, 26, 19, 46, 231, 68, 6, 3, 196, 162, 24, 90, 122, 191, 62, - 255, 199, 143, 83, 227, 73, 164, 166, 104, 10, 156, 174, 178, 150, 95, 132, 231, 146, - 124, 10, 14, 140, 115, 237, 21, - ]; + // 2 * generator. + // 1368015179489954701390400359078579693043519447331113978918064868415326638035 + // 9918110051302171585080402603319702774565515993150576347155970296011118125764 + let b: [u8; 64] = [ + 211, 207, 135, 109, 193, 8, 194, 211, 168, 28, 135, 22, 169, 22, 120, 217, 133, 21, 24, + 104, 91, 4, 133, 155, 2, 26, 19, 46, 231, 68, 6, 3, 196, 162, 24, 90, 122, 191, 62, 255, + 199, 143, 83, 227, 73, 164, 166, 104, 10, 156, 174, 178, 150, 95, 132, 231, 146, 124, 10, + 14, 140, 115, 237, 21, + ]; - unsafe { - syscall_bn254_add(a.as_mut_ptr() as *mut u32, b.as_ptr() as *const u32); - } + unsafe { + syscall_bn254_add(a.as_mut_ptr() as *mut u32, b.as_ptr() as *const u32); + } - // 3 * generator. - // 3353031288059533942658390886683067124040920775575537747144343083137631628272 - // 19321533766552368860946552437480515441416830039777911637913418824951667761761 - let c: [u8; 64] = [ - 240, 171, 21, 25, 150, 85, 211, 242, 121, 230, 184, 21, 71, 216, 21, 147, 21, 189, 182, - 177, 188, 50, 2, 244, 63, 234, 107, 197, 154, 191, 105, 7, 97, 34, 254, 217, 61, 255, - 241, 205, 87, 91, 156, 11, 180, 99, 158, 49, 117, 100, 8, 141, 124, 219, 79, 85, 41, - 148, 72, 224, 190, 153, 183, 42, - ]; + // 3 * generator. + // 3353031288059533942658390886683067124040920775575537747144343083137631628272 + // 19321533766552368860946552437480515441416830039777911637913418824951667761761 + let c: [u8; 64] = [ + 240, 171, 21, 25, 150, 85, 211, 242, 121, 230, 184, 21, 71, 216, 21, 147, 21, 189, 182, + 177, 188, 50, 2, 244, 63, 234, 107, 197, 154, 191, 105, 7, 97, 34, 254, 217, 61, 255, 241, + 205, 87, 91, 156, 11, 180, 99, 158, 49, 117, 100, 8, 141, 124, 219, 79, 85, 41, 148, 72, + 224, 190, 153, 183, 42, + ]; - assert_eq!(a, c); - } + assert_eq!(a, c); println!("done"); } diff --git a/tests/bn254-double/elf/riscv32im-succinct-zkvm-elf b/tests/bn254-double/elf/riscv32im-succinct-zkvm-elf index b571be7344..2c7bcb6231 100755 Binary files a/tests/bn254-double/elf/riscv32im-succinct-zkvm-elf and b/tests/bn254-double/elf/riscv32im-succinct-zkvm-elf differ diff --git a/tests/bn254-mul/elf/riscv32im-succinct-zkvm-elf b/tests/bn254-mul/elf/riscv32im-succinct-zkvm-elf index dd1506ddc7..a414416de7 100755 Binary files a/tests/bn254-mul/elf/riscv32im-succinct-zkvm-elf and b/tests/bn254-mul/elf/riscv32im-succinct-zkvm-elf differ diff --git a/tests/bn254-mul/src/main.rs b/tests/bn254-mul/src/main.rs index 3086c3806f..841de5e4d0 100644 --- a/tests/bn254-mul/src/main.rs +++ b/tests/bn254-mul/src/main.rs @@ -6,38 +6,36 @@ use sp1_zkvm::precompiles::utils::AffinePoint; #[sp1_derive::cycle_tracker] pub fn main() { - for _ in 0..4 { - // generator. - // 1 - // 2 - let a: [u8; 64] = [ - 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, - ]; - - let mut a_point = AffinePoint::::from_le_bytes(&a); - - // scalar. - // 3 - let scalar: [u32; 8] = [3, 0, 0, 0, 0, 0, 0, 0]; - - println!("cycle-tracker-start: bn254_mul"); - a_point.mul_assign(&scalar); - println!("cycle-tracker-end: bn254_mul"); - - // 3 * generator. - // 3353031288059533942658390886683067124040920775575537747144343083137631628272 - // 19321533766552368860946552437480515441416830039777911637913418824951667761761 - let c: [u8; 64] = [ - 240, 171, 21, 25, 150, 85, 211, 242, 121, 230, 184, 21, 71, 216, 21, 147, 21, 189, 182, - 177, 188, 50, 2, 244, 63, 234, 107, 197, 154, 191, 105, 7, 97, 34, 254, 217, 61, 255, - 241, 205, 87, 91, 156, 11, 180, 99, 158, 49, 117, 100, 8, 141, 124, 219, 79, 85, 41, - 148, 72, 224, 190, 153, 183, 42, - ]; - - assert_eq!(a_point.to_le_bytes(), c); - } + // generator. + // 1 + // 2 + let a: [u8; 64] = [ + 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, + ]; + + let mut a_point = AffinePoint::::from_le_bytes(&a); + + // scalar. + // 3 + let scalar: [u32; 8] = [3, 0, 0, 0, 0, 0, 0, 0]; + + println!("cycle-tracker-start: bn254_mul"); + a_point.mul_assign(&scalar); + println!("cycle-tracker-end: bn254_mul"); + + // 3 * generator. + // 3353031288059533942658390886683067124040920775575537747144343083137631628272 + // 19321533766552368860946552437480515441416830039777911637913418824951667761761 + let c: [u8; 64] = [ + 240, 171, 21, 25, 150, 85, 211, 242, 121, 230, 184, 21, 71, 216, 21, 147, 21, 189, 182, + 177, 188, 50, 2, 244, 63, 234, 107, 197, 154, 191, 105, 7, 97, 34, 254, 217, 61, 255, 241, + 205, 87, 91, 156, 11, 180, 99, 158, 49, 117, 100, 8, 141, 124, 219, 79, 85, 41, 148, 72, + 224, 190, 153, 183, 42, + ]; + + assert_eq!(a_point.to_le_bytes(), c); println!("done"); } diff --git a/tests/cycle-tracker/elf/riscv32im-succinct-zkvm-elf b/tests/cycle-tracker/elf/riscv32im-succinct-zkvm-elf index 6e2531ad0b..ed3121d5d8 100755 Binary files a/tests/cycle-tracker/elf/riscv32im-succinct-zkvm-elf and b/tests/cycle-tracker/elf/riscv32im-succinct-zkvm-elf differ diff --git a/tests/ecrecover/elf/riscv32im-succinct-zkvm-elf b/tests/ecrecover/elf/riscv32im-succinct-zkvm-elf index 58e50a2590..d75d1642e9 100755 Binary files a/tests/ecrecover/elf/riscv32im-succinct-zkvm-elf and b/tests/ecrecover/elf/riscv32im-succinct-zkvm-elf differ diff --git a/tests/ed-add/elf/riscv32im-succinct-zkvm-elf b/tests/ed-add/elf/riscv32im-succinct-zkvm-elf index 5916c8a800..1f79b12f49 100755 Binary files a/tests/ed-add/elf/riscv32im-succinct-zkvm-elf and b/tests/ed-add/elf/riscv32im-succinct-zkvm-elf differ diff --git a/tests/ed-add/src/main.rs b/tests/ed-add/src/main.rs index 3deafa0c76..057aea4823 100644 --- a/tests/ed-add/src/main.rs +++ b/tests/ed-add/src/main.rs @@ -6,40 +6,37 @@ extern "C" { } pub fn main() { - for _ in 0..4 { - // 90393249858788985237231628593243673548167146579814268721945474994541877372611 - // 33321104029277118100578831462130550309254424135206412570121538923759338004303 - let mut a: [u8; 64] = [ - 195, 166, 157, 207, 218, 220, 175, 197, 111, 177, 123, 23, 73, 72, 114, 103, 28, 246, - 66, 207, 66, 146, 187, 234, 136, 238, 133, 145, 47, 196, 216, 199, 79, 31, 224, 30, - 179, 122, 51, 84, 116, 12, 4, 189, 198, 198, 190, 22, 71, 201, 143, 249, 92, 56, 147, - 133, 92, 187, 130, 33, 152, 19, 171, 73, - ]; + // 90393249858788985237231628593243673548167146579814268721945474994541877372611 + // 33321104029277118100578831462130550309254424135206412570121538923759338004303 + let mut a: [u8; 64] = [ + 195, 166, 157, 207, 218, 220, 175, 197, 111, 177, 123, 23, 73, 72, 114, 103, 28, 246, 66, + 207, 66, 146, 187, 234, 136, 238, 133, 145, 47, 196, 216, 199, 79, 31, 224, 30, 179, 122, + 51, 84, 116, 12, 4, 189, 198, 198, 190, 22, 71, 201, 143, 249, 92, 56, 147, 133, 92, 187, + 130, 33, 152, 19, 171, 73, + ]; - // 61717728572175158701898635111983295176935961585742968051419350619945173564869 - // 28137966556353620208933066709998005335145594788896528644015312259959272398451 - let b: [u8; 64] = [ - 197, 189, 200, 77, 201, 212, 57, 105, 191, 133, 123, 170, 167, 50, 114, 38, 37, 102, - 188, 29, 215, 227, 157, 142, 252, 31, 129, 67, 24, 255, 114, 136, 115, 94, 94, 55, 43, - 200, 117, 224, 139, 251, 238, 45, 80, 154, 70, 213, 219, 78, 201, 108, 73, 203, 72, 45, - 167, 131, 199, 47, 82, 134, 53, 62, - ]; + // 61717728572175158701898635111983295176935961585742968051419350619945173564869 + // 28137966556353620208933066709998005335145594788896528644015312259959272398451 + let b: [u8; 64] = [ + 197, 189, 200, 77, 201, 212, 57, 105, 191, 133, 123, 170, 167, 50, 114, 38, 37, 102, 188, + 29, 215, 227, 157, 142, 252, 31, 129, 67, 24, 255, 114, 136, 115, 94, 94, 55, 43, 200, 117, + 224, 139, 251, 238, 45, 80, 154, 70, 213, 219, 78, 201, 108, 73, 203, 72, 45, 167, 131, + 199, 47, 82, 134, 53, 62, + ]; - unsafe { - syscall_ed_add(a.as_mut_ptr() as *mut u32, b.as_ptr() as *const u32); - } - - // 36213413123116753589144482590359479011148956763279542162278577842046663495729 - // 17093345531692682197799066694073110060588941459686871373458223451938707761683 - let c: [u8; 64] = [ - 49, 144, 129, 197, 86, 163, 62, 48, 222, 208, 213, 200, 219, 90, 163, 54, 211, 248, - 178, 224, 238, 167, 235, 219, 251, 247, 189, 239, 194, 16, 16, 80, 19, 106, 20, 198, - 72, 56, 103, 111, 68, 201, 29, 107, 75, 208, 193, 232, 181, 186, 175, 22, 213, 187, - 253, 125, 44, 80, 222, 209, 159, 125, 202, 37, - ]; - - assert_eq!(a, c); + unsafe { + syscall_ed_add(a.as_mut_ptr() as *mut u32, b.as_ptr() as *const u32); } + // 36213413123116753589144482590359479011148956763279542162278577842046663495729 + // 17093345531692682197799066694073110060588941459686871373458223451938707761683 + let c: [u8; 64] = [ + 49, 144, 129, 197, 86, 163, 62, 48, 222, 208, 213, 200, 219, 90, 163, 54, 211, 248, 178, + 224, 238, 167, 235, 219, 251, 247, 189, 239, 194, 16, 16, 80, 19, 106, 20, 198, 72, 56, + 103, 111, 68, 201, 29, 107, 75, 208, 193, 232, 181, 186, 175, 22, 213, 187, 253, 125, 44, + 80, 222, 209, 159, 125, 202, 37, + ]; + + assert_eq!(a, c); println!("done"); } diff --git a/tests/ed-decompress/elf/riscv32im-succinct-zkvm-elf b/tests/ed-decompress/elf/riscv32im-succinct-zkvm-elf index 233f1ab1cb..10bbf5e06e 100755 Binary files a/tests/ed-decompress/elf/riscv32im-succinct-zkvm-elf and b/tests/ed-decompress/elf/riscv32im-succinct-zkvm-elf differ diff --git a/tests/ed-decompress/src/main.rs b/tests/ed-decompress/src/main.rs index 0b6929dde4..32f4eef659 100644 --- a/tests/ed-decompress/src/main.rs +++ b/tests/ed-decompress/src/main.rs @@ -8,28 +8,26 @@ extern "C" { } pub fn main() { - for _ in 0..4 { - let pub_bytes = hex!("ec172b93ad5e563bf4932c70e1245034c35467ef2efd4d64ebf819683467e2bf"); + let pub_bytes = hex!("ec172b93ad5e563bf4932c70e1245034c35467ef2efd4d64ebf819683467e2bf"); - let mut decompressed = [0_u8; 64]; - decompressed[32..].copy_from_slice(&pub_bytes); + let mut decompressed = [0_u8; 64]; + decompressed[32..].copy_from_slice(&pub_bytes); - println!("before: {:?}", decompressed); + println!("before: {:?}", decompressed); - unsafe { - syscall_ed_decompress(decompressed.as_mut_ptr()); - } + unsafe { + syscall_ed_decompress(decompressed.as_mut_ptr()); + } - let expected: [u8; 64] = [ - 47, 252, 114, 91, 153, 234, 110, 201, 201, 153, 152, 14, 68, 231, 90, 221, 137, 110, - 250, 67, 10, 64, 37, 70, 163, 101, 111, 223, 185, 1, 180, 88, 236, 23, 43, 147, 173, - 94, 86, 59, 244, 147, 44, 112, 225, 36, 80, 52, 195, 84, 103, 239, 46, 253, 77, 100, - 235, 248, 25, 104, 52, 103, 226, 63, - ]; + let expected: [u8; 64] = [ + 47, 252, 114, 91, 153, 234, 110, 201, 201, 153, 152, 14, 68, 231, 90, 221, 137, 110, 250, + 67, 10, 64, 37, 70, 163, 101, 111, 223, 185, 1, 180, 88, 236, 23, 43, 147, 173, 94, 86, 59, + 244, 147, 44, 112, 225, 36, 80, 52, 195, 84, 103, 239, 46, 253, 77, 100, 235, 248, 25, 104, + 52, 103, 226, 63, + ]; - assert_eq!(decompressed, expected); - println!("after: {:?}", decompressed); - } + assert_eq!(decompressed, expected); + println!("after: {:?}", decompressed); println!("done"); } diff --git a/tests/ed25519/elf/riscv32im-succinct-zkvm-elf b/tests/ed25519/elf/riscv32im-succinct-zkvm-elf index 5f149617c0..88c83e3c0a 100755 Binary files a/tests/ed25519/elf/riscv32im-succinct-zkvm-elf and b/tests/ed25519/elf/riscv32im-succinct-zkvm-elf differ diff --git a/tests/fibonacci/elf/riscv32im-succinct-zkvm-elf b/tests/fibonacci/elf/riscv32im-succinct-zkvm-elf index 7a61102c17..1c59449d83 100755 Binary files a/tests/fibonacci/elf/riscv32im-succinct-zkvm-elf and b/tests/fibonacci/elf/riscv32im-succinct-zkvm-elf differ diff --git a/tests/hint-io/elf/riscv32im-succinct-zkvm-elf b/tests/hint-io/elf/riscv32im-succinct-zkvm-elf index ac7a2fc293..69fc40b116 100755 Binary files a/tests/hint-io/elf/riscv32im-succinct-zkvm-elf and b/tests/hint-io/elf/riscv32im-succinct-zkvm-elf differ diff --git a/tests/keccak-permute/elf/riscv32im-succinct-zkvm-elf b/tests/keccak-permute/elf/riscv32im-succinct-zkvm-elf index 15dee99151..a843a07799 100755 Binary files a/tests/keccak-permute/elf/riscv32im-succinct-zkvm-elf and b/tests/keccak-permute/elf/riscv32im-succinct-zkvm-elf differ diff --git a/tests/keccak256/elf/riscv32im-succinct-zkvm-elf b/tests/keccak256/elf/riscv32im-succinct-zkvm-elf index 311da32c16..48e4965b34 100755 Binary files a/tests/keccak256/elf/riscv32im-succinct-zkvm-elf and b/tests/keccak256/elf/riscv32im-succinct-zkvm-elf differ diff --git a/tests/panic/elf/riscv32im-succinct-zkvm-elf b/tests/panic/elf/riscv32im-succinct-zkvm-elf index 8debb2189a..e68a4a4dc9 100755 Binary files a/tests/panic/elf/riscv32im-succinct-zkvm-elf and b/tests/panic/elf/riscv32im-succinct-zkvm-elf differ diff --git a/tests/secp256k1-add/elf/riscv32im-succinct-zkvm-elf b/tests/secp256k1-add/elf/riscv32im-succinct-zkvm-elf index bf7a3db101..339003c773 100755 Binary files a/tests/secp256k1-add/elf/riscv32im-succinct-zkvm-elf and b/tests/secp256k1-add/elf/riscv32im-succinct-zkvm-elf differ diff --git a/tests/secp256k1-add/src/main.rs b/tests/secp256k1-add/src/main.rs index 9640e4e8c7..c45601bcc8 100644 --- a/tests/secp256k1-add/src/main.rs +++ b/tests/secp256k1-add/src/main.rs @@ -6,43 +6,41 @@ extern "C" { } pub fn main() { - for _ in 0..4 { - // generator. - // 55066263022277343669578718895168534326250603453777594175500187360389116729240 - // 32670510020758816978083085130507043184471273380659243275938904335757337482424 - let mut a: [u8; 64] = [ - 152, 23, 248, 22, 91, 129, 242, 89, 217, 40, 206, 45, 219, 252, 155, 2, 7, 11, 135, - 206, 149, 98, 160, 85, 172, 187, 220, 249, 126, 102, 190, 121, 184, 212, 16, 251, 143, - 208, 71, 156, 25, 84, 133, 166, 72, 180, 23, 253, 168, 8, 17, 14, 252, 251, 164, 93, - 101, 196, 163, 38, 119, 218, 58, 72, - ]; + // generator. + // 55066263022277343669578718895168534326250603453777594175500187360389116729240 + // 32670510020758816978083085130507043184471273380659243275938904335757337482424 + let mut a: [u8; 64] = [ + 152, 23, 248, 22, 91, 129, 242, 89, 217, 40, 206, 45, 219, 252, 155, 2, 7, 11, 135, 206, + 149, 98, 160, 85, 172, 187, 220, 249, 126, 102, 190, 121, 184, 212, 16, 251, 143, 208, 71, + 156, 25, 84, 133, 166, 72, 180, 23, 253, 168, 8, 17, 14, 252, 251, 164, 93, 101, 196, 163, + 38, 119, 218, 58, 72, + ]; - // 2 * generator. - // 89565891926547004231252920425935692360644145829622209833684329913297188986597 - // 12158399299693830322967808612713398636155367887041628176798871954788371653930 - let b: [u8; 64] = [ - 229, 158, 112, 92, 185, 9, 172, 171, 167, 60, 239, 140, 75, 142, 119, 92, 216, 124, - 192, 149, 110, 64, 69, 48, 109, 125, 237, 65, 148, 127, 4, 198, 42, 229, 207, 80, 169, - 49, 100, 35, 225, 208, 102, 50, 101, 50, 246, 247, 238, 234, 108, 70, 25, 132, 197, - 163, 57, 195, 61, 166, 254, 104, 225, 26, - ]; + // 2 * generator. + // 89565891926547004231252920425935692360644145829622209833684329913297188986597 + // 12158399299693830322967808612713398636155367887041628176798871954788371653930 + let b: [u8; 64] = [ + 229, 158, 112, 92, 185, 9, 172, 171, 167, 60, 239, 140, 75, 142, 119, 92, 216, 124, 192, + 149, 110, 64, 69, 48, 109, 125, 237, 65, 148, 127, 4, 198, 42, 229, 207, 80, 169, 49, 100, + 35, 225, 208, 102, 50, 101, 50, 246, 247, 238, 234, 108, 70, 25, 132, 197, 163, 57, 195, + 61, 166, 254, 104, 225, 26, + ]; - unsafe { - syscall_secp256k1_add(a.as_mut_ptr() as *mut u32, b.as_ptr() as *const u32); - } + unsafe { + syscall_secp256k1_add(a.as_mut_ptr() as *mut u32, b.as_ptr() as *const u32); + } - // 3 * generator. - // 112711660439710606056748659173929673102114977341539408544630613555209775888121 - // 25583027980570883691656905877401976406448868254816295069919888960541586679410 - let c: [u8; 64] = [ - 249, 54, 224, 188, 19, 241, 1, 134, 176, 153, 111, 131, 69, 200, 49, 181, 41, 82, 157, - 248, 133, 79, 52, 73, 16, 195, 88, 146, 1, 138, 48, 249, 114, 230, 184, 132, 117, 253, - 185, 108, 27, 35, 194, 52, 153, 169, 0, 101, 86, 243, 55, 42, 230, 55, 227, 15, 20, - 232, 45, 99, 15, 123, 143, 56, - ]; + // 3 * generator. + // 112711660439710606056748659173929673102114977341539408544630613555209775888121 + // 25583027980570883691656905877401976406448868254816295069919888960541586679410 + let c: [u8; 64] = [ + 249, 54, 224, 188, 19, 241, 1, 134, 176, 153, 111, 131, 69, 200, 49, 181, 41, 82, 157, 248, + 133, 79, 52, 73, 16, 195, 88, 146, 1, 138, 48, 249, 114, 230, 184, 132, 117, 253, 185, 108, + 27, 35, 194, 52, 153, 169, 0, 101, 86, 243, 55, 42, 230, 55, 227, 15, 20, 232, 45, 99, 15, + 123, 143, 56, + ]; - assert_eq!(a, c); - } + assert_eq!(a, c); println!("done"); } diff --git a/tests/secp256k1-decompress/elf/riscv32im-succinct-zkvm-elf b/tests/secp256k1-decompress/elf/riscv32im-succinct-zkvm-elf index 2fae11204b..e06da48d78 100755 Binary files a/tests/secp256k1-decompress/elf/riscv32im-succinct-zkvm-elf and b/tests/secp256k1-decompress/elf/riscv32im-succinct-zkvm-elf differ diff --git a/tests/secp256k1-decompress/src/main.rs b/tests/secp256k1-decompress/src/main.rs index a603986e00..6dc18a25ce 100644 --- a/tests/secp256k1-decompress/src/main.rs +++ b/tests/secp256k1-decompress/src/main.rs @@ -8,22 +8,20 @@ extern "C" { pub fn main() { let compressed_key: [u8; 33] = sp1_zkvm::io::read_vec().try_into().unwrap(); - for _ in 0..4 { - let mut decompressed_key: [u8; 64] = [0; 64]; - decompressed_key[..32].copy_from_slice(&compressed_key[1..]); - let is_odd = match compressed_key[0] { - 2 => false, - 3 => true, - _ => panic!("Invalid compressed key"), - }; - unsafe { - syscall_secp256k1_decompress(&mut decompressed_key, is_odd); - } + let mut decompressed_key: [u8; 64] = [0; 64]; + decompressed_key[..32].copy_from_slice(&compressed_key[1..]); + let is_odd = match compressed_key[0] { + 2 => false, + 3 => true, + _ => panic!("Invalid compressed key"), + }; + unsafe { + syscall_secp256k1_decompress(&mut decompressed_key, is_odd); + } - let mut result: [u8; 65] = [0; 65]; - result[0] = 4; - result[1..].copy_from_slice(&decompressed_key); + let mut result: [u8; 65] = [0; 65]; + result[0] = 4; + result[1..].copy_from_slice(&decompressed_key); - sp1_zkvm::io::commit_slice(&result); - } + sp1_zkvm::io::commit_slice(&result); } diff --git a/tests/secp256k1-double/elf/riscv32im-succinct-zkvm-elf b/tests/secp256k1-double/elf/riscv32im-succinct-zkvm-elf index 79a156fcab..6ad007626d 100755 Binary files a/tests/secp256k1-double/elf/riscv32im-succinct-zkvm-elf and b/tests/secp256k1-double/elf/riscv32im-succinct-zkvm-elf differ diff --git a/tests/secp256k1-mul/elf/riscv32im-succinct-zkvm-elf b/tests/secp256k1-mul/elf/riscv32im-succinct-zkvm-elf index d3e17ead66..ec0db8bd02 100755 Binary files a/tests/secp256k1-mul/elf/riscv32im-succinct-zkvm-elf and b/tests/secp256k1-mul/elf/riscv32im-succinct-zkvm-elf differ diff --git a/tests/secp256k1-mul/src/main.rs b/tests/secp256k1-mul/src/main.rs index a2fb6a3dd3..731a81b381 100644 --- a/tests/secp256k1-mul/src/main.rs +++ b/tests/secp256k1-mul/src/main.rs @@ -6,39 +6,37 @@ use sp1_zkvm::precompiles::utils::AffinePoint; #[sp1_derive::cycle_tracker] pub fn main() { - for _ in 0..4 { - // generator. - // 55066263022277343669578718895168534326250603453777594175500187360389116729240 - // 32670510020758816978083085130507043184471273380659243275938904335757337482424 - let a: [u8; 64] = [ - 152, 23, 248, 22, 91, 129, 242, 89, 217, 40, 206, 45, 219, 252, 155, 2, 7, 11, 135, - 206, 149, 98, 160, 85, 172, 187, 220, 249, 126, 102, 190, 121, 184, 212, 16, 251, 143, - 208, 71, 156, 25, 84, 133, 166, 72, 180, 23, 253, 168, 8, 17, 14, 252, 251, 164, 93, - 101, 196, 163, 38, 119, 218, 58, 72, - ]; - - let mut a_point = AffinePoint::::from_le_bytes(&a); - - // scalar. - // 3 - let scalar: [u32; 8] = [3, 0, 0, 0, 0, 0, 0, 0]; - - println!("cycle-tracker-start: secp256k1_mul"); - a_point.mul_assign(&scalar); - println!("cycle-tracker-end: secp256k1_mul"); - - // 3 * generator. - // 112711660439710606056748659173929673102114977341539408544630613555209775888121 - // 25583027980570883691656905877401976406448868254816295069919888960541586679410 - let c: [u8; 64] = [ - 249, 54, 224, 188, 19, 241, 1, 134, 176, 153, 111, 131, 69, 200, 49, 181, 41, 82, 157, - 248, 133, 79, 52, 73, 16, 195, 88, 146, 1, 138, 48, 249, 114, 230, 184, 132, 117, 253, - 185, 108, 27, 35, 194, 52, 153, 169, 0, 101, 86, 243, 55, 42, 230, 55, 227, 15, 20, - 232, 45, 99, 15, 123, 143, 56, - ]; - - assert_eq!(a_point.to_le_bytes(), c); - } + // generator. + // 55066263022277343669578718895168534326250603453777594175500187360389116729240 + // 32670510020758816978083085130507043184471273380659243275938904335757337482424 + let a: [u8; 64] = [ + 152, 23, 248, 22, 91, 129, 242, 89, 217, 40, 206, 45, 219, 252, 155, 2, 7, 11, 135, 206, + 149, 98, 160, 85, 172, 187, 220, 249, 126, 102, 190, 121, 184, 212, 16, 251, 143, 208, 71, + 156, 25, 84, 133, 166, 72, 180, 23, 253, 168, 8, 17, 14, 252, 251, 164, 93, 101, 196, 163, + 38, 119, 218, 58, 72, + ]; + + let mut a_point = AffinePoint::::from_le_bytes(&a); + + // scalar. + // 3 + let scalar: [u32; 8] = [3, 0, 0, 0, 0, 0, 0, 0]; + + println!("cycle-tracker-start: secp256k1_mul"); + a_point.mul_assign(&scalar); + println!("cycle-tracker-end: secp256k1_mul"); + + // 3 * generator. + // 112711660439710606056748659173929673102114977341539408544630613555209775888121 + // 25583027980570883691656905877401976406448868254816295069919888960541586679410 + let c: [u8; 64] = [ + 249, 54, 224, 188, 19, 241, 1, 134, 176, 153, 111, 131, 69, 200, 49, 181, 41, 82, 157, 248, + 133, 79, 52, 73, 16, 195, 88, 146, 1, 138, 48, 249, 114, 230, 184, 132, 117, 253, 185, 108, + 27, 35, 194, 52, 153, 169, 0, 101, 86, 243, 55, 42, 230, 55, 227, 15, 20, 232, 45, 99, 15, + 123, 143, 56, + ]; + + assert_eq!(a_point.to_le_bytes(), c); println!("done"); } diff --git a/tests/sha-compress/elf/riscv32im-succinct-zkvm-elf b/tests/sha-compress/elf/riscv32im-succinct-zkvm-elf index f10443e120..97126f881c 100755 Binary files a/tests/sha-compress/elf/riscv32im-succinct-zkvm-elf and b/tests/sha-compress/elf/riscv32im-succinct-zkvm-elf differ diff --git a/tests/sha-compress/src/main.rs b/tests/sha-compress/src/main.rs index 3c306966c2..bdddab1662 100644 --- a/tests/sha-compress/src/main.rs +++ b/tests/sha-compress/src/main.rs @@ -6,10 +6,6 @@ use sp1_zkvm::syscalls::syscall_sha256_compress; pub fn main() { let mut w = [1u32; 64]; let mut state = [1u32; 8]; - - for _ in 0..4 { - syscall_sha256_compress(w.as_mut_ptr(), state.as_mut_ptr()); - } - + syscall_sha256_compress(w.as_mut_ptr(), state.as_mut_ptr()); println!("{:?}", state); } diff --git a/tests/sha-extend/elf/riscv32im-succinct-zkvm-elf b/tests/sha-extend/elf/riscv32im-succinct-zkvm-elf index d584e1c358..7b8774766b 100755 Binary files a/tests/sha-extend/elf/riscv32im-succinct-zkvm-elf and b/tests/sha-extend/elf/riscv32im-succinct-zkvm-elf differ diff --git a/tests/sha2/elf/riscv32im-succinct-zkvm-elf b/tests/sha2/elf/riscv32im-succinct-zkvm-elf index 2c63e6648b..ff4661defc 100755 Binary files a/tests/sha2/elf/riscv32im-succinct-zkvm-elf and b/tests/sha2/elf/riscv32im-succinct-zkvm-elf differ diff --git a/tests/tendermint-benchmark/elf/riscv32im-succinct-zkvm-elf b/tests/tendermint-benchmark/elf/riscv32im-succinct-zkvm-elf index 526fc2f836..d67e8be4d5 100755 Binary files a/tests/tendermint-benchmark/elf/riscv32im-succinct-zkvm-elf and b/tests/tendermint-benchmark/elf/riscv32im-succinct-zkvm-elf differ diff --git a/tests/uint256-arith/elf/riscv32im-succinct-zkvm-elf b/tests/uint256-arith/elf/riscv32im-succinct-zkvm-elf index 25b450b668..83a521f1b2 100755 Binary files a/tests/uint256-arith/elf/riscv32im-succinct-zkvm-elf and b/tests/uint256-arith/elf/riscv32im-succinct-zkvm-elf differ diff --git a/tests/uint256-arith/src/main.rs b/tests/uint256-arith/src/main.rs index 4df9947bd7..872af3bf60 100644 --- a/tests/uint256-arith/src/main.rs +++ b/tests/uint256-arith/src/main.rs @@ -36,29 +36,27 @@ pub fn main() { let a = U256::from(3u8); let b = U256::from(2u8); - for _ in 0..4 { - println!("cycle-tracker-start: uint256_add"); - let add = uint256_add(black_box(a), black_box(b)); - assert_eq!(add, U256::from(5u8)); - println!("cycle-tracker-end: uint256_add"); - println!("{:?}", add); - - println!("cycle-tracker-start: uint256_sub"); - let sub = uint256_sub(black_box(a), black_box(b)); - assert_eq!(sub, U256::from(1u8)); - println!("cycle-tracker-end: uint256_sub"); - println!("{:?}", sub); - - println!("cycle-tracker-start: uint256_div"); - let div = uint256_div(black_box(a), black_box(b)); - assert_eq!(div, U256::from(1u8)); - println!("cycle-tracker-end: uint256_div"); - println!("{:?}", div); - - println!("cycle-tracker-start: uint256_mul"); - let mul = uint256_mul(black_box(a), black_box(b)); - assert_eq!(mul, U256::from(6u8)); - println!("cycle-tracker-end: uint256_mul"); - println!("{:?}", mul); - } + println!("cycle-tracker-start: uint256_add"); + let add = uint256_add(black_box(a), black_box(b)); + assert_eq!(add, U256::from(5u8)); + println!("cycle-tracker-end: uint256_add"); + println!("{:?}", add); + + println!("cycle-tracker-start: uint256_sub"); + let sub = uint256_sub(black_box(a), black_box(b)); + assert_eq!(sub, U256::from(1u8)); + println!("cycle-tracker-end: uint256_sub"); + println!("{:?}", sub); + + println!("cycle-tracker-start: uint256_div"); + let div = uint256_div(black_box(a), black_box(b)); + assert_eq!(div, U256::from(1u8)); + println!("cycle-tracker-end: uint256_div"); + println!("{:?}", div); + + println!("cycle-tracker-start: uint256_mul"); + let mul = uint256_mul(black_box(a), black_box(b)); + assert_eq!(mul, U256::from(6u8)); + println!("cycle-tracker-end: uint256_mul"); + println!("{:?}", mul); } diff --git a/tests/uint256-div/elf/riscv32im-succinct-zkvm-elf b/tests/uint256-div/elf/riscv32im-succinct-zkvm-elf index 067b69ca73..ca4f0641f3 100755 Binary files a/tests/uint256-div/elf/riscv32im-succinct-zkvm-elf and b/tests/uint256-div/elf/riscv32im-succinct-zkvm-elf differ diff --git a/tests/uint256-mul/elf/riscv32im-succinct-zkvm-elf b/tests/uint256-mul/elf/riscv32im-succinct-zkvm-elf index 2ab53b56d0..b49313ca61 100755 Binary files a/tests/uint256-mul/elf/riscv32im-succinct-zkvm-elf and b/tests/uint256-mul/elf/riscv32im-succinct-zkvm-elf differ diff --git a/tests/verify-proof/elf/riscv32im-succinct-zkvm-elf b/tests/verify-proof/elf/riscv32im-succinct-zkvm-elf index 93c15811cb..63f478712f 100755 Binary files a/tests/verify-proof/elf/riscv32im-succinct-zkvm-elf and b/tests/verify-proof/elf/riscv32im-succinct-zkvm-elf differ diff --git a/zkvm/precompiles/src/uint256_div.rs b/zkvm/precompiles/src/uint256_div.rs index b10e0116c0..c12a07b908 100644 --- a/zkvm/precompiles/src/uint256_div.rs +++ b/zkvm/precompiles/src/uint256_div.rs @@ -11,6 +11,7 @@ use num::{BigUint, Integer}; /// represented as arrays of bytes in little-endian order. It returns the quotient /// of the division as a 256-bit unsigned integer in the same byte array format. pub fn uint256_div(x: &mut [u8; 32], y: &[u8; 32]) -> [u8; 32] { + // TODO: this will panic now. // Assert that the divisor is not zero. assert!(y != &[0; 32], "division by zero"); cfg_if::cfg_if! {