From 300762cdfdf524aec0fd52f8c77209f5d9cbd592 Mon Sep 17 00:00:00 2001 From: John Guibas Date: Tue, 4 Jun 2024 11:09:43 -0700 Subject: [PATCH 1/5] feat: shrink/wrap multi opt --- recursion/core/src/fri_fold/mod.rs | 13 ++++++++----- recursion/core/src/multi/mod.rs | 10 ++++++++-- recursion/core/src/poseidon2/external.rs | 3 +++ recursion/core/src/poseidon2/trace.rs | 12 +++++++----- recursion/core/src/stark/mod.rs | 1 + 5 files changed, 27 insertions(+), 12 deletions(-) diff --git a/recursion/core/src/fri_fold/mod.rs b/recursion/core/src/fri_fold/mod.rs index d59668d1e9..5c345d0469 100644 --- a/recursion/core/src/fri_fold/mod.rs +++ b/recursion/core/src/fri_fold/mod.rs @@ -25,6 +25,7 @@ pub const NUM_FRI_FOLD_COLS: usize = core::mem::size_of::>(); #[derive(Default)] pub struct FriFoldChip { pub fixed_log2_rows: Option, + pub pad: bool, } #[derive(Debug, Clone)] @@ -143,11 +144,13 @@ impl MachineAir for FriFoldChip .collect_vec(); // Pad the trace to a power of two. - pad_rows_fixed( - &mut rows, - || [F::zero(); NUM_FRI_FOLD_COLS], - self.fixed_log2_rows, - ); + if self.pad { + pad_rows_fixed( + &mut rows, + || [F::zero(); NUM_FRI_FOLD_COLS], + self.fixed_log2_rows, + ); + } // Convert the trace to a row major matrix. let trace = RowMajorMatrix::new(rows.into_iter().flatten().collect(), NUM_FRI_FOLD_COLS); diff --git a/recursion/core/src/multi/mod.rs b/recursion/core/src/multi/mod.rs index 10a9c3db0d..f04fc0aacf 100644 --- a/recursion/core/src/multi/mod.rs +++ b/recursion/core/src/multi/mod.rs @@ -66,8 +66,14 @@ impl MachineAir for MultiChip { input: &ExecutionRecord, output: &mut ExecutionRecord, ) -> RowMajorMatrix { - let fri_fold_chip = FriFoldChip::<3>::default(); - let poseidon2 = Poseidon2Chip::default(); + let fri_fold_chip = FriFoldChip::<3> { + fixed_log2_rows: None, + pad: false, + }; + let poseidon2 = Poseidon2Chip { + fixed_log2_rows: None, + pad: false, + }; let fri_fold_trace = fri_fold_chip.generate_trace(input, output); let mut poseidon2_trace = poseidon2.generate_trace(input, output); diff --git a/recursion/core/src/poseidon2/external.rs b/recursion/core/src/poseidon2/external.rs index d340ba2b41..21f56edb0b 100644 --- a/recursion/core/src/poseidon2/external.rs +++ b/recursion/core/src/poseidon2/external.rs @@ -24,6 +24,7 @@ pub const WIDTH: usize = 16; #[derive(Default)] pub struct Poseidon2Chip { pub fixed_log2_rows: Option, + pub pad: bool, } impl BaseAir for Poseidon2Chip { @@ -449,6 +450,7 @@ mod tests { fn generate_trace() { let chip = Poseidon2Chip { fixed_log2_rows: None, + pad: true, }; let rng = &mut rand::thread_rng(); @@ -498,6 +500,7 @@ mod tests { let chip = Poseidon2Chip { fixed_log2_rows: None, + pad: true, }; let trace: RowMajorMatrix = chip.generate_trace(&input_exec, &mut ExecutionRecord::::default()); diff --git a/recursion/core/src/poseidon2/trace.rs b/recursion/core/src/poseidon2/trace.rs index 2d5639edd2..567c09fc7d 100644 --- a/recursion/core/src/poseidon2/trace.rs +++ b/recursion/core/src/poseidon2/trace.rs @@ -174,11 +174,13 @@ impl MachineAir for Poseidon2Chip { let num_real_rows = rows.len(); // Pad the trace to a power of two. - pad_rows_fixed( - &mut rows, - || [F::zero(); NUM_POSEIDON2_COLS], - self.fixed_log2_rows, - ); + if self.pad { + pad_rows_fixed( + &mut rows, + || [F::zero(); NUM_POSEIDON2_COLS], + self.fixed_log2_rows, + ); + } let mut round_num = 0; for row in rows[num_real_rows..].iter_mut() { diff --git a/recursion/core/src/stark/mod.rs b/recursion/core/src/stark/mod.rs index 540dcb613b..88e04b9a2e 100644 --- a/recursion/core/src/stark/mod.rs +++ b/recursion/core/src/stark/mod.rs @@ -78,6 +78,7 @@ impl, const DEGREE: usize> RecursionAi }))) .chain(once(RecursionAir::FriFold(FriFoldChip:: { fixed_log2_rows: None, + pad: true, }))) .chain(once(RecursionAir::RangeCheck(RangeCheckChip::default()))) .collect() From 402fb18cb77a29f32415a1d488eeb7721a5311a6 Mon Sep 17 00:00:00 2001 From: John Guibas Date: Tue, 4 Jun 2024 11:32:33 -0700 Subject: [PATCH 2/5] fix --- recursion/core/src/stark/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/recursion/core/src/stark/mod.rs b/recursion/core/src/stark/mod.rs index 88e04b9a2e..7753c73f83 100644 --- a/recursion/core/src/stark/mod.rs +++ b/recursion/core/src/stark/mod.rs @@ -110,7 +110,7 @@ impl, const DEGREE: usize> RecursionAi fixed_log2_rows: Some(19), }))) .chain(once(RecursionAir::Multi(MultiChip { - fixed_log2_rows: Some(20), + fixed_log2_rows: Some(19), }))) .chain(once(RecursionAir::RangeCheck(RangeCheckChip::default()))) .collect() From f72833d4df16ddc4c644da1f662e8ced6df12cc8 Mon Sep 17 00:00:00 2001 From: John Guibas Date: Tue, 4 Jun 2024 11:57:22 -0700 Subject: [PATCH 3/5] remove test --- recursion/core/src/multi/mod.rs | 82 --------------------------------- 1 file changed, 82 deletions(-) diff --git a/recursion/core/src/multi/mod.rs b/recursion/core/src/multi/mod.rs index f04fc0aacf..fb7bfdd6f2 100644 --- a/recursion/core/src/multi/mod.rs +++ b/recursion/core/src/multi/mod.rs @@ -220,85 +220,3 @@ impl MultiCols { unsafe { &self.instruction.poseidon2 } } } - -#[cfg(test)] -mod tests { - use itertools::Itertools; - use std::time::Instant; - - use p3_baby_bear::BabyBear; - use p3_baby_bear::DiffusionMatrixBabyBear; - use p3_field::AbstractField; - use p3_matrix::{dense::RowMajorMatrix, Matrix}; - use p3_poseidon2::Poseidon2; - use p3_poseidon2::Poseidon2ExternalMatrixGeneral; - use sp1_core::stark::StarkGenericConfig; - use sp1_core::utils::inner_perm; - use sp1_core::{ - air::MachineAir, - utils::{uni_stark_prove, uni_stark_verify, BabyBearPoseidon2}, - }; - - use crate::multi::MultiChip; - use crate::{poseidon2::Poseidon2Event, runtime::ExecutionRecord}; - use p3_symmetric::Permutation; - - #[test] - fn prove_babybear() { - let config = BabyBearPoseidon2::compressed(); - let mut challenger = config.challenger(); - - let chip = MultiChip::<5> { - fixed_log2_rows: None, - }; - - let test_inputs = (0..16) - .map(|i| [BabyBear::from_canonical_u32(i); 16]) - .collect_vec(); - - let gt: Poseidon2< - BabyBear, - Poseidon2ExternalMatrixGeneral, - DiffusionMatrixBabyBear, - 16, - 7, - > = inner_perm(); - - let expected_outputs = test_inputs - .iter() - .map(|input| gt.permute(*input)) - .collect::>(); - - let mut input_exec = ExecutionRecord::::default(); - for (input, output) in test_inputs.into_iter().zip_eq(expected_outputs) { - input_exec - .poseidon2_events - .push(Poseidon2Event::dummy_from_input(input, output)); - } - let trace: RowMajorMatrix = - chip.generate_trace(&input_exec, &mut ExecutionRecord::::default()); - println!( - "trace dims is width: {:?}, height: {:?}", - trace.width(), - trace.height() - ); - - let start = Instant::now(); - let proof = uni_stark_prove(&config, &chip, &mut challenger, trace); - let duration = start.elapsed().as_secs_f64(); - println!("proof duration = {:?}", duration); - - let mut challenger: p3_challenger::DuplexChallenger< - BabyBear, - Poseidon2, - 16, - 8, - > = config.challenger(); - let start = Instant::now(); - uni_stark_verify(&config, &chip, &mut challenger, &proof) - .expect("expected proof to be valid"); - - let duration = start.elapsed().as_secs_f64(); - println!("verify duration = {:?}", duration); - } -} From 224db99d80c0e784bf3bd3949e27498fbfb421c7 Mon Sep 17 00:00:00 2001 From: John Guibas Date: Tue, 4 Jun 2024 12:31:27 -0700 Subject: [PATCH 4/5] hm --- recursion/core/src/multi/mod.rs | 83 ++++++++++++++++++++++++++++++++- 1 file changed, 82 insertions(+), 1 deletion(-) diff --git a/recursion/core/src/multi/mod.rs b/recursion/core/src/multi/mod.rs index fb7bfdd6f2..c3a7acb1ef 100644 --- a/recursion/core/src/multi/mod.rs +++ b/recursion/core/src/multi/mod.rs @@ -158,7 +158,6 @@ where .assert_zero(next_is_real.clone()); // Next, verify that all fri fold rows are before the poseidon2 rows within the real rows section. - builder.when_first_row().assert_one(local.is_fri_fold); builder .when_transition() .when(next_is_real) @@ -220,3 +219,85 @@ impl MultiCols { unsafe { &self.instruction.poseidon2 } } } + +#[cfg(test)] +mod tests { + use itertools::Itertools; + use std::time::Instant; + + use p3_baby_bear::BabyBear; + use p3_baby_bear::DiffusionMatrixBabyBear; + use p3_field::AbstractField; + use p3_matrix::{dense::RowMajorMatrix, Matrix}; + use p3_poseidon2::Poseidon2; + use p3_poseidon2::Poseidon2ExternalMatrixGeneral; + use sp1_core::stark::StarkGenericConfig; + use sp1_core::utils::inner_perm; + use sp1_core::{ + air::MachineAir, + utils::{uni_stark_prove, uni_stark_verify, BabyBearPoseidon2}, + }; + + use crate::multi::MultiChip; + use crate::{poseidon2::Poseidon2Event, runtime::ExecutionRecord}; + use p3_symmetric::Permutation; + + #[test] + fn prove_babybear() { + let config = BabyBearPoseidon2::compressed(); + let mut challenger = config.challenger(); + + let chip = MultiChip::<5> { + fixed_log2_rows: None, + }; + + let test_inputs = (0..16) + .map(|i| [BabyBear::from_canonical_u32(i); 16]) + .collect_vec(); + + let gt: Poseidon2< + BabyBear, + Poseidon2ExternalMatrixGeneral, + DiffusionMatrixBabyBear, + 16, + 7, + > = inner_perm(); + + let expected_outputs = test_inputs + .iter() + .map(|input| gt.permute(*input)) + .collect::>(); + + let mut input_exec = ExecutionRecord::::default(); + for (input, output) in test_inputs.into_iter().zip_eq(expected_outputs) { + input_exec + .poseidon2_events + .push(Poseidon2Event::dummy_from_input(input, output)); + } + let trace: RowMajorMatrix = + chip.generate_trace(&input_exec, &mut ExecutionRecord::::default()); + println!( + "trace dims is width: {:?}, height: {:?}", + trace.width(), + trace.height() + ); + + let start = Instant::now(); + let proof = uni_stark_prove(&config, &chip, &mut challenger, trace); + let duration = start.elapsed().as_secs_f64(); + println!("proof duration = {:?}", duration); + + let mut challenger: p3_challenger::DuplexChallenger< + BabyBear, + Poseidon2, + 16, + 8, + > = config.challenger(); + let start = Instant::now(); + uni_stark_verify(&config, &chip, &mut challenger, &proof) + .expect("expected proof to be valid"); + + let duration = start.elapsed().as_secs_f64(); + println!("verify duration = {:?}", duration); + } +} From 3afd6ff859459b445269dab4f421e8c82d41c3b3 Mon Sep 17 00:00:00 2001 From: John Guibas Date: Tue, 4 Jun 2024 13:24:17 -0700 Subject: [PATCH 5/5] fix things --- recursion/core/src/multi/mod.rs | 1 - recursion/program/Cargo.toml | 3 +++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/recursion/core/src/multi/mod.rs b/recursion/core/src/multi/mod.rs index c3a7acb1ef..0b93f52aa2 100644 --- a/recursion/core/src/multi/mod.rs +++ b/recursion/core/src/multi/mod.rs @@ -151,7 +151,6 @@ where // all the fri fold rows are first, then the posiedon2 rows, and finally any padded (non-real) rows. // First verify that all real rows are contiguous. - builder.when_first_row().assert_one(local_is_real.clone()); builder .when_transition() .when_not(local_is_real.clone()) diff --git a/recursion/program/Cargo.toml b/recursion/program/Cargo.toml index 47dd52380c..7bfb8a3792 100644 --- a/recursion/program/Cargo.toml +++ b/recursion/program/Cargo.toml @@ -24,3 +24,6 @@ itertools = "0.12.1" serde = { version = "1.0.201", features = ["derive"] } rand = "0.8.5" tracing = "0.1.40" + +[features] +debug = ["sp1-core/debug"] \ No newline at end of file