diff --git a/Cargo.lock b/Cargo.lock index 6cbc50c5b5..3a27e31e30 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3483,6 +3483,9 @@ name = "sp1-prover" version = "0.1.0" dependencies = [ "bincode", + "p3-baby-bear", + "p3-challenger", + "p3-commit", "sp1-core", "sp1-recursion-compiler", "sp1-recursion-core", diff --git a/core/src/stark/machine.rs b/core/src/stark/machine.rs index eda62809b9..92caafe312 100644 --- a/core/src/stark/machine.rs +++ b/core/src/stark/machine.rs @@ -72,6 +72,16 @@ impl>> MachineStark { &self.chips } + /// Returns the id of all chips in the machine that have preprocessed columns. + pub fn preprocessed_chip_ids(&self) -> Vec { + self.chips + .iter() + .enumerate() + .filter(|(_, chip)| chip.preprocessed_width() > 0) + .map(|(i, _)| i) + .collect() + } + pub fn shard_chips<'a, 'b>( &'a self, shard: &'b A::Record, diff --git a/prover/Cargo.toml b/prover/Cargo.toml index 407be5070e..6098a78acc 100644 --- a/prover/Cargo.toml +++ b/prover/Cargo.toml @@ -10,4 +10,7 @@ sp1-recursion-program = { path = "../recursion/program" } sp1-recursion-compiler = { path = "../recursion/compiler" } sp1-recursion-core = { path = "../recursion/core" } sp1-core = { path = "../core" } +p3-challenger = { workspace = true } +p3-baby-bear = { workspace = true } +p3-commit = { workspace = true } bincode = "1.3.3" diff --git a/prover/src/lib.rs b/prover/src/lib.rs index 267667d2d7..28736028d6 100644 --- a/prover/src/lib.rs +++ b/prover/src/lib.rs @@ -1,84 +1,140 @@ #![allow(incomplete_features)] #![feature(generic_const_exprs)] +#![allow(deprecated)] use std::time::Instant; +use p3_baby_bear::BabyBear; +use p3_challenger::CanObserve; +use p3_commit::TwoAdicMultiplicativeCoset; use sp1_core::{ + air::{MachineAir, PublicValues, Word}, runtime::Program, - stark::{Proof, RiscvAir, StarkGenericConfig, VerifyingKey}, + stark::{LocalProver, Proof, RiscvAir, ShardProof, StarkGenericConfig}, utils::{run_and_prove, BabyBearPoseidon2}, }; -use sp1_recursion_core::runtime::Runtime; -use sp1_recursion_program::compress::build_compress; +use sp1_recursion_core::{runtime::Runtime, stark::RecursionAir}; +use sp1_recursion_program::{hints::Hintable, reduce::build_reduce, stark::EMPTY}; type InnerSC = BabyBearPoseidon2; type InnerF = ::Val; type InnerEF = ::Challenge; type InnerA = RiscvAir; -pub fn prove_sp1() -> (Proof, VerifyingKey) { - let elf = include_bytes!("../../examples/fibonacci-io/program/elf/riscv32im-succinct-zkvm-elf"); - - let config = InnerSC::default(); - let machine = RiscvAir::machine(config.clone()); - let program = Program::from(elf); - let stdin = [bincode::serialize::(&6).unwrap()]; - let (_, vk) = machine.setup(&program); - let (proof, _) = run_and_prove(program, &stdin, config); - let mut challenger_ver = machine.config().challenger(); - machine.verify(&vk, &proof, &mut challenger_ver).unwrap(); - println!("Proof generated successfully"); - - (proof, vk) -} - -pub fn prove_compress(sp1_proof: Proof, vk: VerifyingKey) { - let (program, witness_stream) = build_compress(sp1_proof, vk); - - let config = InnerSC::default(); - let machine = InnerA::machine(config); - let mut runtime = Runtime::::new(&program, machine.config().perm.clone()); - runtime.witness_stream = witness_stream; - - let time = Instant::now(); - runtime.run(); - let elapsed = time.elapsed(); - runtime.print_stats(); - println!("Execution took: {:?}", elapsed); - - // let config = BabyBearPoseidon2::new(); - // let machine = RecursionAir::machine(config); - // let (pk, vk) = machine.setup(&program); - - // let mut challenger = machine.config().challenger(); - // let record_clone = runtime.record.clone(); - // machine.debug_constraints(&pk, record_clone, &mut challenger); - - // let start = Instant::now(); - // let mut challenger = machine.config().challenger(); - // let proof = machine.prove::>(&pk, runtime.record, &mut challenger); - // let duration = start.elapsed().as_secs(); +pub struct SP1ProverImpl; + +impl SP1ProverImpl { + pub fn prove(elf: &[u8], stdin: &[Vec]) -> Proof { + let config = InnerSC::default(); + let machine = RiscvAir::machine(config.clone()); + let program = Program::from(elf); + let (_, vk) = machine.setup(&program); + let (proof, _) = run_and_prove(program, stdin, config); + let mut challenger_ver = machine.config().challenger(); + machine.verify(&vk, &proof, &mut challenger_ver).unwrap(); + proof + } - // let mut challenger = machine.config().challenger(); - // machine.verify(&vk, &proof, &mut challenger).unwrap(); - // println!("proving duration = {}", duration); + pub fn reduce(elf: &[u8], proof: Proof) -> Vec> { + let config = InnerSC::default(); + let machine = RiscvAir::machine(config.clone()); + let program = Program::from(elf); + let (_, vk) = machine.setup(&program); + let reduce_program = build_reduce(); + println!("nb_shards {}", proof.shard_proofs.len()); + let config = InnerSC::default(); + + let is_recursive_flags: Vec = proof.shard_proofs.iter().map(|_| 0).collect(); + let sorted_indices: Vec> = proof + .shard_proofs + .iter() + .map(|p| { + machine + .chips_sorted_indices(p) + .into_iter() + .map(|x| match x { + Some(x) => x, + None => EMPTY, + }) + .collect() + }) + .collect(); + + let mut challenger = machine.config().challenger(); + challenger.observe(vk.commit); + let reconstruct_challenger = challenger.clone(); + for proof in proof.shard_proofs.iter() { + challenger.observe(proof.commitment.main_commit); + let public_values = PublicValues::, BabyBear>::new(proof.public_values); + challenger.observe_slice(&public_values.to_vec()); + } + + let chips = machine.chips(); + let ordering = vk.chip_ordering.clone(); + let (prep_sorted_indices, prep_domains): ( + Vec, + Vec>, + ) = machine + .preprocessed_chip_ids() + .into_iter() + .map(|chip_idx| { + let name = chips[chip_idx].name().clone(); + let prep_sorted_idx = ordering[&name]; + (prep_sorted_idx, vk.chip_information[prep_sorted_idx].1) + }) + .unzip(); + + // Generate inputs. + let mut witness_stream = Vec::new(); + witness_stream.extend(proof.shard_proofs.write()); + witness_stream.extend(is_recursive_flags.write()); + witness_stream.extend(sorted_indices.write()); + witness_stream.extend(challenger.write()); + witness_stream.extend(reconstruct_challenger.write()); + witness_stream.extend(prep_sorted_indices.write()); + witness_stream.extend(prep_domains.write()); + witness_stream.extend(vk.write()); + witness_stream.extend(vk.write()); + + // Execute runtime. + let machine = InnerA::machine(config); + let mut runtime = + Runtime::::new(&reduce_program, machine.config().perm.clone()); + runtime.witness_stream = witness_stream; + runtime.run(); + runtime.print_stats(); + + // Generate proof. + let config = BabyBearPoseidon2::new(); + let machine = RecursionAir::machine(config); + let (pk, _) = machine.setup(&reduce_program); + // let mut challenger = machine.config().challenger(); + // let record_clone = runtime.record.clone(); + // machine.debug_constraints(&pk, record_clone, &mut challenger); + let start = Instant::now(); + let mut challenger = machine.config().challenger(); + let proof = machine.prove::>(&pk, runtime.record, &mut challenger); + let duration = start.elapsed().as_secs(); + println!("proving duration = {}", duration); + + // let mut challenger = machine.config().challenger(); + // machine.verify(&vk, &proof, &mut challenger).unwrap(); + + proof.shard_proofs + } } -pub fn prove_reduce() {} - -pub fn prove_snark() {} - #[cfg(test)] mod tests { - use sp1_core::utils::setup_logger; - use super::*; - + use sp1_core::utils::setup_logger; #[test] fn test_prove_sp1() { setup_logger(); - - let (sp1_proof, vk) = prove_sp1(); - prove_compress(sp1_proof, vk); + let elf = + include_bytes!("../../examples/fibonacci-io/program/elf/riscv32im-succinct-zkvm-elf"); + let stdin = [bincode::serialize::(&6).unwrap()]; + let proof = SP1ProverImpl::prove(elf, &stdin); + SP1ProverImpl::reduce(elf, proof); } } diff --git a/recursion/compiler/src/ir/var.rs b/recursion/compiler/src/ir/var.rs index 0c332cb24e..3cceb76fba 100644 --- a/recursion/compiler/src/ir/var.rs +++ b/recursion/compiler/src/ir/var.rs @@ -29,7 +29,9 @@ pub struct MemIndex { pub trait MemVariable: Variable { fn size_of() -> usize; + /// Loads the variable from the heap. fn load(&self, ptr: Ptr, index: MemIndex, builder: &mut Builder); + /// Stores the variable to the heap. fn store(&self, ptr: Ptr, index: MemIndex, builder: &mut Builder); } diff --git a/recursion/core/src/cpu/columns/opcode.rs b/recursion/core/src/cpu/columns/opcode.rs index 5c8041e48f..28cc959205 100644 --- a/recursion/core/src/cpu/columns/opcode.rs +++ b/recursion/core/src/cpu/columns/opcode.rs @@ -95,7 +95,7 @@ impl OpcodeSelectorCols { Opcode::PrintF => self.is_noop = F::one(), Opcode::PrintE => self.is_noop = F::one(), Opcode::FRIFold => self.is_fri_fold = F::one(), - _ => unreachable!(), + _ => {} } } } diff --git a/recursion/program/src/challenger.rs b/recursion/program/src/challenger.rs index da09decfa7..b57f313e1f 100644 --- a/recursion/program/src/challenger.rs +++ b/recursion/program/src/challenger.rs @@ -1,6 +1,10 @@ use p3_field::AbstractField; -use sp1_recursion_compiler::prelude::{Array, Builder, Config, Ext, Felt, Usize, Var}; +use sp1_recursion_compiler::prelude::MemIndex; +use sp1_recursion_compiler::prelude::MemVariable; +use sp1_recursion_compiler::prelude::Ptr; +use sp1_recursion_compiler::prelude::Variable; +use sp1_recursion_compiler::prelude::{Array, Builder, Config, DslVariable, Ext, Felt, Usize, Var}; use sp1_recursion_core::runtime::{DIGEST_SIZE, PERMUTATION_WIDTH}; use crate::fri::types::DigestVariable; @@ -37,7 +41,7 @@ pub trait CanSampleBitsVariable { } /// Reference: https://github.com/Plonky3/Plonky3/blob/4809fa7bedd9ba8f6f5d3267b1592618e3776c57/challenger/src/duplex_challenger.rs#L10 -#[derive(Clone)] +#[derive(Clone, DslVariable)] pub struct DuplexChallengerVariable { pub sponge_state: Array>, pub nb_inputs: Var, @@ -57,6 +61,34 @@ impl DuplexChallengerVariable { } } + /// Creates a new challenger with the same state as an existing challenger. + pub fn as_clone(&self, builder: &mut Builder) -> Self { + let mut sponge_state = builder.dyn_array(PERMUTATION_WIDTH); + builder.range(0, PERMUTATION_WIDTH).for_each(|i, builder| { + let element = builder.get(&self.sponge_state, i); + builder.set(&mut sponge_state, i, element); + }); + let nb_inputs = builder.eval(self.nb_inputs); + let mut input_buffer = builder.dyn_array(PERMUTATION_WIDTH); + builder.range(0, PERMUTATION_WIDTH).for_each(|i, builder| { + let element = builder.get(&self.input_buffer, i); + builder.set(&mut input_buffer, i, element); + }); + let nb_outputs = builder.eval(self.nb_outputs); + let mut output_buffer = builder.dyn_array(PERMUTATION_WIDTH); + builder.range(0, PERMUTATION_WIDTH).for_each(|i, builder| { + let element = builder.get(&self.output_buffer, i); + builder.set(&mut output_buffer, i, element); + }); + DuplexChallengerVariable:: { + sponge_state, + nb_inputs, + input_buffer, + nb_outputs, + output_buffer, + } + } + /// Reference: https://github.com/Plonky3/Plonky3/blob/4809fa7bedd9ba8f6f5d3267b1592618e3776c57/challenger/src/duplex_challenger.rs#L38 pub fn duplexing(&mut self, builder: &mut Builder) { builder.range(0, self.nb_inputs).for_each(|i, builder| { diff --git a/recursion/program/src/compress.rs b/recursion/program/src/compress.rs deleted file mode 100644 index 7262a91649..0000000000 --- a/recursion/program/src/compress.rs +++ /dev/null @@ -1,164 +0,0 @@ -use std::time::Instant; - -use crate::challenger::CanObserveVariable; -use crate::challenger::DuplexChallengerVariable; -use crate::fri::types::FriConfigVariable; -use crate::fri::TwoAdicFriPcsVariable; -use crate::fri::TwoAdicMultiplicativeCosetVariable; -use crate::hints::Hintable; -use crate::stark::StarkVerifier; -use crate::stark::EMPTY; -use crate::types::ShardCommitmentVariable; -use p3_baby_bear::BabyBear; -use p3_baby_bear::DiffusionMatrixBabybear; -use p3_challenger::CanObserve; -use p3_commit::ExtensionMmcs; -use p3_commit::TwoAdicMultiplicativeCoset; -use p3_field::extension::BinomialExtensionField; -use p3_field::AbstractField; -use p3_field::Field; -use p3_field::TwoAdicField; -use p3_fri::FriConfig; -use p3_merkle_tree::FieldMerkleTreeMmcs; -use p3_poseidon2::Poseidon2; -use p3_poseidon2::Poseidon2ExternalMatrixGeneral; -use p3_symmetric::PaddingFreeSponge; -use p3_symmetric::TruncatedPermutation; -use sp1_core::air::PublicValues; -use sp1_core::air::Word; -use sp1_core::stark::Proof; -use sp1_core::stark::ShardProof; -use sp1_core::stark::VerifyingKey; -use sp1_core::stark::{RiscvAir, StarkGenericConfig}; -use sp1_recursion_compiler::asm::AsmConfig; -use sp1_recursion_compiler::asm::VmBuilder; -use sp1_recursion_compiler::ir::Array; -use sp1_recursion_compiler::ir::Builder; -use sp1_recursion_compiler::ir::Felt; -use sp1_recursion_core::air::Block; -use sp1_recursion_core::runtime::Program as RecursionProgram; -use sp1_recursion_core::runtime::DIGEST_SIZE; -use sp1_recursion_core::stark::config::inner_fri_config; -use sp1_sdk::utils::BabyBearPoseidon2; - -type SC = BabyBearPoseidon2; -type F = ::Val; -type EF = ::Challenge; -type C = AsmConfig; - -type Val = BabyBear; -type Challenge = BinomialExtensionField; -type Perm = Poseidon2; -type Hash = PaddingFreeSponge; -type Compress = TruncatedPermutation; -type ValMmcs = - FieldMerkleTreeMmcs<::Packing, ::Packing, Hash, Compress, 8>; -type ChallengeMmcs = ExtensionMmcs; -type RecursionConfig = AsmConfig; -type RecursionBuilder = Builder; - -pub fn const_fri_config( - builder: &mut RecursionBuilder, - config: FriConfig, -) -> FriConfigVariable { - let two_addicity = Val::TWO_ADICITY; - let mut generators = builder.dyn_array(two_addicity); - let mut subgroups = builder.dyn_array(two_addicity); - for i in 0..two_addicity { - let constant_generator = Val::two_adic_generator(i); - builder.set(&mut generators, i, constant_generator); - - let constant_domain = TwoAdicMultiplicativeCoset { - log_n: i, - shift: Val::one(), - }; - let domain_value: TwoAdicMultiplicativeCosetVariable<_> = builder.constant(constant_domain); - builder.set(&mut subgroups, i, domain_value); - } - FriConfigVariable { - log_blowup: config.log_blowup, - num_queries: config.num_queries, - proof_of_work_bits: config.proof_of_work_bits, - subgroups, - generators, - } -} - -// TODO: proof is only necessary now because it's a constant, it should be I/O soon -pub fn build_compress( - proof: Proof, - vk: VerifyingKey, -) -> (RecursionProgram, Vec>>) { - let machine = RiscvAir::machine(SC::default()); - - let mut challenger_val = machine.config().challenger(); - challenger_val.observe(vk.commit); - proof.shard_proofs.iter().for_each(|proof| { - challenger_val.observe(proof.commitment.main_commit); - let public_values_field = PublicValues::, F>::new(proof.public_values); - challenger_val.observe_slice(&public_values_field.to_vec()); - }); - - let time = Instant::now(); - let mut builder = VmBuilder::::default(); - let config = const_fri_config(&mut builder, inner_fri_config()); - let pcs = TwoAdicFriPcsVariable { config }; - - let mut challenger = DuplexChallengerVariable::new(&mut builder); - - let preprocessed_commit_val: [F; DIGEST_SIZE] = vk.commit.into(); - let preprocessed_commit: Array = builder.constant(preprocessed_commit_val.to_vec()); - challenger.observe(&mut builder, preprocessed_commit); - - let mut witness_stream = Vec::new(); - let mut shard_proofs = vec![]; - let mut sorted_indices = vec![]; - for proof_val in proof.shard_proofs { - witness_stream.extend(proof_val.write()); - let sorted_indices_raw: Vec = machine - .chips_sorted_indices(&proof_val) - .into_iter() - .map(|x| match x { - Some(x) => x, - None => EMPTY, - }) - .collect(); - witness_stream.extend(sorted_indices_raw.write()); - let proof = ShardProof::<_>::read(&mut builder); - let sorted_indices_arr = Vec::::read(&mut builder); - builder - .range(0, sorted_indices_arr.len()) - .for_each(|i, builder| { - let el = builder.get(&sorted_indices_arr, i); - builder.print_v(el); - }); - let ShardCommitmentVariable { main_commit, .. } = &proof.commitment; - challenger.observe(&mut builder, main_commit.clone()); - let public_values_field = PublicValues::, F>::new(proof_val.public_values); - let public_values_felt: Vec> = public_values_field - .to_vec() - .iter() - .map(|x| builder.eval(*x)) - .collect(); - challenger.observe_slice(&mut builder, &public_values_felt); - shard_proofs.push(proof); - sorted_indices.push(sorted_indices_arr); - } - - for (proof, sorted_indices) in shard_proofs.iter().zip(sorted_indices) { - StarkVerifier::::verify_shard( - &mut builder, - &vk, - &pcs, - &machine, - &mut challenger.clone(), - proof, - sorted_indices, - ); - } - - let program = builder.compile(); - let elapsed = time.elapsed(); - println!("Building took: {:?}", elapsed); - (program, witness_stream) -} diff --git a/recursion/program/src/hints.rs b/recursion/program/src/hints.rs index de243ccae5..74925e0312 100644 --- a/recursion/program/src/hints.rs +++ b/recursion/program/src/hints.rs @@ -1,4 +1,14 @@ +use crate::challenger::DuplexChallengerVariable; +use crate::fri::TwoAdicMultiplicativeCosetVariable; +use crate::types::{ + AirOpenedValuesVariable, ChipOpenedValuesVariable, PublicValuesVariable, + ShardCommitmentVariable, ShardOpenedValuesVariable, ShardProofVariable, VerifyingKeyVariable, +}; +use p3_challenger::DuplexChallenger; +use p3_commit::TwoAdicMultiplicativeCoset; +use p3_field::TwoAdicField; use p3_field::{AbstractExtensionField, AbstractField}; +use sp1_core::stark::VerifyingKey; use sp1_core::{ air::{PublicValues, Word}, stark::{AirOpenedValues, ChipOpenedValues, ShardCommitment, ShardOpenedValues, ShardProof}, @@ -9,15 +19,12 @@ use sp1_recursion_compiler::{ }; use sp1_recursion_core::{ air::Block, - stark::config::{InnerChallenge, InnerDigest, InnerDigestHash, InnerPcsProof, InnerVal}, + stark::config::{ + InnerChallenge, InnerDigest, InnerDigestHash, InnerPcsProof, InnerPerm, InnerVal, + }, }; use sp1_sdk::utils::BabyBearPoseidon2; -use crate::types::{ - AirOpenedValuesVariable, ChipOpenedValuesVariable, PublicValuesVariable, - ShardCommitmentVariable, ShardOpenedValuesVariable, ShardProofVariable, -}; - pub trait Hintable { type HintVariable: MemVariable; @@ -64,6 +71,68 @@ impl Hintable for InnerChallenge { } } +impl Hintable for TwoAdicMultiplicativeCoset { + type HintVariable = TwoAdicMultiplicativeCosetVariable; + + fn read(builder: &mut Builder) -> Self::HintVariable { + let log_n = usize::read(builder); + let shift = InnerVal::read(builder); + let g_val = InnerVal::read(builder); + let size = usize::read(builder); + + // Initialize a domain. + TwoAdicMultiplicativeCosetVariable:: { + log_n, + size, + shift, + g: g_val, + } + } + + fn write(&self) -> Vec::F>>> { + let mut vec = Vec::new(); + vec.extend(usize::write(&self.log_n)); + vec.extend(InnerVal::write(&self.shift)); + vec.extend(InnerVal::write(&InnerVal::two_adic_generator(self.log_n))); + vec.extend(usize::write(&(1usize << (self.log_n)))); + vec + } +} + +trait VecAutoHintable: Hintable {} + +impl VecAutoHintable for ShardProof {} +impl VecAutoHintable for TwoAdicMultiplicativeCoset {} +impl VecAutoHintable for Vec {} + +impl> Hintable for Vec { + type HintVariable = Array; + + fn read(builder: &mut Builder) -> Self::HintVariable { + let len = builder.hint_var(); + let mut arr = builder.dyn_array(len); + builder.range(0, len).for_each(|i, builder| { + let hint = I::read(builder); + builder.set(&mut arr, i, hint); + }); + arr + } + + fn write(&self) -> Vec::F>>> { + let mut stream = Vec::new(); + + let len = InnerVal::from_canonical_usize(self.len()); + stream.push(vec![len.into()]); + + self.iter().for_each(|i| { + let comm = I::write(i); + stream.extend(comm); + }); + + stream + } +} + impl Hintable for Vec { type HintVariable = Array>; @@ -289,6 +358,51 @@ impl Hintable for PublicValues { } } +impl Hintable for DuplexChallenger { + type HintVariable = DuplexChallengerVariable; + + fn read(builder: &mut Builder) -> Self::HintVariable { + let sponge_state = builder.hint_felts(); + let nb_inputs = builder.hint_var(); + let input_buffer = builder.hint_felts(); + let nb_outputs = builder.hint_var(); + let output_buffer = builder.hint_felts(); + DuplexChallengerVariable { + sponge_state, + nb_inputs, + input_buffer, + nb_outputs, + output_buffer, + } + } + + fn write(&self) -> Vec::F>>> { + let mut stream = Vec::new(); + stream.extend(self.sponge_state.to_vec().write()); + stream.extend(self.input_buffer.len().write()); + stream.extend(self.input_buffer.write()); + stream.extend(self.output_buffer.len().write()); + stream.extend(self.output_buffer.write()); + stream + } +} + +impl Hintable for VerifyingKey { + type HintVariable = VerifyingKeyVariable; + + fn read(builder: &mut Builder) -> Self::HintVariable { + let commitment = InnerDigest::read(builder); + VerifyingKeyVariable { commitment } + } + + fn write(&self) -> Vec::F>>> { + let mut stream = Vec::new(); + let h: InnerDigest = self.commit.into(); + stream.extend(h.write()); + stream + } +} + impl Hintable for ShardProof { type HintVariable = ShardProofVariable; diff --git a/recursion/program/src/lib.rs b/recursion/program/src/lib.rs index df2daa54e6..6a5366f755 100644 --- a/recursion/program/src/lib.rs +++ b/recursion/program/src/lib.rs @@ -5,10 +5,10 @@ #![allow(clippy::too_many_arguments)] pub mod challenger; pub mod commit; -pub mod compress; pub mod constraints; pub mod folder; pub mod fri; pub mod hints; +pub mod reduce; pub mod stark; pub mod types; diff --git a/recursion/program/src/reduce.rs b/recursion/program/src/reduce.rs new file mode 100644 index 0000000000..e135d2456d --- /dev/null +++ b/recursion/program/src/reduce.rs @@ -0,0 +1,192 @@ +use std::time::Instant; + +use crate::challenger::CanObserveVariable; +use crate::challenger::DuplexChallengerVariable; +use crate::fri::types::FriConfigVariable; +use crate::fri::TwoAdicFriPcsVariable; +use crate::fri::TwoAdicMultiplicativeCosetVariable; +use crate::hints::Hintable; +use crate::stark::StarkVerifier; +use p3_baby_bear::BabyBear; +use p3_baby_bear::DiffusionMatrixBabybear; +use p3_challenger::DuplexChallenger; +use p3_commit::ExtensionMmcs; +use p3_commit::TwoAdicMultiplicativeCoset; +use p3_field::extension::BinomialExtensionField; +use p3_field::AbstractField; +use p3_field::Field; +use p3_field::TwoAdicField; +use p3_fri::FriConfig; +use p3_merkle_tree::FieldMerkleTreeMmcs; +use p3_poseidon2::Poseidon2; +use p3_poseidon2::Poseidon2ExternalMatrixGeneral; +use p3_symmetric::PaddingFreeSponge; +use p3_symmetric::TruncatedPermutation; +use sp1_core::stark::ShardProof; +use sp1_core::stark::VerifyingKey; +use sp1_core::stark::{RiscvAir, StarkGenericConfig}; +use sp1_recursion_compiler::asm::AsmConfig; +use sp1_recursion_compiler::asm::VmBuilder; +use sp1_recursion_compiler::ir::Builder; +use sp1_recursion_compiler::ir::Felt; +use sp1_recursion_compiler::ir::MemVariable; +use sp1_recursion_compiler::ir::Usize; +use sp1_recursion_compiler::ir::Var; +use sp1_recursion_core::runtime::Program as RecursionProgram; +use sp1_recursion_core::runtime::DIGEST_SIZE; +use sp1_recursion_core::stark::config::inner_fri_config; +use sp1_recursion_core::stark::RecursionAir; +use sp1_sdk::utils::BabyBearPoseidon2; + +type SC = BabyBearPoseidon2; +type F = ::Val; +type EF = ::Challenge; +type C = AsmConfig; + +type Val = BabyBear; +type Challenge = BinomialExtensionField; +type Perm = Poseidon2; +type Hash = PaddingFreeSponge; +type Compress = TruncatedPermutation; +type ValMmcs = + FieldMerkleTreeMmcs<::Packing, ::Packing, Hash, Compress, 8>; +type ChallengeMmcs = ExtensionMmcs; +type RecursionConfig = AsmConfig; +type RecursionBuilder = Builder; + +pub fn const_fri_config( + builder: &mut RecursionBuilder, + config: FriConfig, +) -> FriConfigVariable { + let two_addicity = Val::TWO_ADICITY; + let mut generators = builder.dyn_array(two_addicity); + let mut subgroups = builder.dyn_array(two_addicity); + for i in 0..two_addicity { + let constant_generator = Val::two_adic_generator(i); + builder.set(&mut generators, i, constant_generator); + + let constant_domain = TwoAdicMultiplicativeCoset { + log_n: i, + shift: Val::one(), + }; + let domain_value: TwoAdicMultiplicativeCosetVariable<_> = builder.constant(constant_domain); + builder.set(&mut subgroups, i, domain_value); + } + FriConfigVariable { + log_blowup: config.log_blowup, + num_queries: config.num_queries, + proof_of_work_bits: config.proof_of_work_bits, + subgroups, + generators, + } +} + +fn clone>(builder: &mut RecursionBuilder, var: &T) -> T { + let mut arr = builder.dyn_array(1); + builder.set(&mut arr, 0, var.clone()); + builder.get(&arr, 0) +} + +pub fn build_reduce() -> RecursionProgram { + let sp1_machine = RiscvAir::machine(SC::default()); + let _recursion_machine = RecursionAir::machine(SC::default()); + + let time = Instant::now(); + let mut builder = VmBuilder::::default(); + let config = const_fri_config(&mut builder, inner_fri_config()); + let pcs = TwoAdicFriPcsVariable { config }; + + // Read witness inputs + let proofs = Vec::>::read(&mut builder); + let is_recursive_flags = Vec::::read(&mut builder); + let sorted_indices = Vec::>::read(&mut builder); + let sp1_challenger = DuplexChallenger::read(&mut builder); + let mut reconstruct_challenger = DuplexChallenger::read(&mut builder); + // let recursion_challenger = DuplexChallenger::read(&mut builder); + let prep_sorted_indices = Vec::::read(&mut builder); + let prep_domains = Vec::>::read(&mut builder); + // let recursion_prep_sorted_indices = Vec::::read(&mut builder); + // let recursion_prep_domains = Vec::>::read(&mut builder); + let sp1_vk = VerifyingKey::::read(&mut builder); + let _recursion_vk = VerifyingKey::::read(&mut builder); + let num_proofs = proofs.len(); + + let _pre_start_challenger = clone(&mut builder, &sp1_challenger); + let _pre_reconstruct_challenger = clone(&mut builder, &reconstruct_challenger); + let zero: Var<_> = builder.constant(F::zero()); + let one: Var<_> = builder.constant(F::one()); + let _one_felt: Felt<_> = builder.constant(F::one()); + builder + .range(Usize::Const(0), num_proofs) + .for_each(|i, builder| { + let proof = builder.get(&proofs, i); + let sorted_indices = builder.get(&sorted_indices, i); + let is_recursive = builder.get(&is_recursive_flags, i); + builder.if_eq(is_recursive, zero).then_or_else( + // Non-recursive proof + |builder| { + let shard_bits = builder.num2bits_f(proof.public_values.shard); + let shard = builder.bits2num_v(&shard_bits); + builder.if_eq(shard, one).then(|builder| { + // Initialize the current challenger + // let h: [BabyBear; DIGEST_SIZE] = sp1_vk.commit.into(); + // let const_commit: DigestVariable = builder.eval_const(h.to_vec()); + reconstruct_challenger = DuplexChallengerVariable::new(builder); + reconstruct_challenger.observe(builder, sp1_vk.commitment.clone()); + }); + for j in 0..DIGEST_SIZE { + let element = builder.get(&proof.commitment.main_commit, j); + reconstruct_challenger.observe(builder, element); + // TODO: observe public values + // challenger.observe_slice(&public_values.to_vec()); + } + // reconstruct_challenger + // .observe_slice(builder, &proof.commitment.main_commit.vec()); + let mut current_challenger = sp1_challenger.as_clone(builder); + StarkVerifier::::verify_shard( + builder, + &sp1_vk.clone(), + &pcs, + &sp1_machine, + &mut current_challenger, + &proof, + sorted_indices.clone(), + prep_sorted_indices.clone(), + prep_domains.clone(), + ); + }, + // Recursive proof + |_builder| { + // let mut current_challenger = recursion_challenger.as_clone(builder); + // StarkVerifier::::verify_shard( + // builder, + // &recursion_vk.clone(), + // &pcs, + // &recursion_machine, + // &mut current_challenger, + // &proof, + // sorted_indices.clone(), + // prep_sorted_indices.clone(), + // prep_domains.clone(), + // ); + }, + ); + }); + + // Public values: + // ( + // final current_challenger, + // reconstruct_challenger, + // pre_challenger, + // pre_reconstruct_challenger, + // verify_start_challenger, + // recursion_vk, + // ) + // Note we still need to check that verify_start_challenger matches final reconstruct_challenger + // after observing pv_digest at the end. + + let program = builder.compile(); + let elapsed = time.elapsed(); + println!("Building took: {:?}", elapsed); + program +} diff --git a/recursion/program/src/stark.rs b/recursion/program/src/stark.rs index a015c681bd..28e23ca397 100644 --- a/recursion/program/src/stark.rs +++ b/recursion/program/src/stark.rs @@ -1,3 +1,14 @@ +use crate::challenger::CanObserveVariable; +use crate::challenger::DuplexChallengerVariable; +use crate::challenger::FeltChallenger; +use crate::commit::PolynomialSpaceVariable; +use crate::folder::RecursiveVerifierConstraintFolder; +use crate::fri::types::TwoAdicPcsMatsVariable; +use crate::fri::types::TwoAdicPcsRoundVariable; +use crate::fri::TwoAdicMultiplicativeCosetVariable; +use crate::types::ShardCommitmentVariable; +use crate::types::VerifyingKeyVariable; +use crate::{commit::PcsVariable, fri::TwoAdicFriPcsVariable, types::ShardProofVariable}; use p3_air::Air; use p3_commit::TwoAdicMultiplicativeCoset; use p3_field::AbstractField; @@ -6,24 +17,12 @@ use sp1_core::air::MachineAir; use sp1_core::stark::Com; use sp1_core::stark::MachineStark; use sp1_core::stark::StarkGenericConfig; -use sp1_core::stark::VerifyingKey; use sp1_recursion_compiler::ir::Array; use sp1_recursion_compiler::ir::Ext; use sp1_recursion_compiler::ir::Var; use sp1_recursion_compiler::ir::{Builder, Config, Usize}; use sp1_recursion_core::runtime::DIGEST_SIZE; -use crate::challenger::CanObserveVariable; -use crate::challenger::DuplexChallengerVariable; -use crate::challenger::FeltChallenger; -use crate::commit::PolynomialSpaceVariable; -use crate::folder::RecursiveVerifierConstraintFolder; -use crate::fri::types::TwoAdicPcsMatsVariable; -use crate::fri::types::TwoAdicPcsRoundVariable; -use crate::fri::TwoAdicMultiplicativeCosetVariable; -use crate::types::ShardCommitmentVariable; -use crate::{commit::PcsVariable, fri::TwoAdicFriPcsVariable, types::ShardProofVariable}; - pub const EMPTY: usize = 0x_1111_1111; #[derive(Debug, Clone, Copy)] @@ -42,12 +41,14 @@ where { pub fn verify_shard( builder: &mut Builder, - vk: &VerifyingKey, + vk: &VerifyingKeyVariable, pcs: &TwoAdicFriPcsVariable, machine: &MachineStark, challenger: &mut DuplexChallengerVariable, proof: &ShardProofVariable, - sorted_indices: Array>, + chip_sorted_idxs: Array>, + preprocessed_sorted_idxs: Array>, + prep_domains: Array>, ) where A: MachineAir + for<'a> Air>, C::F: TwoAdicField, @@ -92,7 +93,7 @@ where let log_quotient_degree = C::N::from_canonical_usize(log_quotient_degree_val); let num_quotient_chunks_val = 1 << log_quotient_degree_val; - let num_preprocessed_chips = vk.chip_information.len(); + let num_preprocessed_chips = machine.preprocessed_chip_ids().len(); let mut prep_mats: Array<_, TwoAdicPcsMatsVariable<_>> = builder.dyn_array(num_preprocessed_chips); @@ -106,18 +107,17 @@ where let mut qc_points = builder.dyn_array::>(1); builder.set(&mut qc_points, 0, zeta); - // TODO FIX: There is something weird going on here because the number of chips may not match - // the number of chips in a shard. - for (i, (name, domain, _)) in vk.chip_information.iter().enumerate() { - let chip_idx = machine - .chips() - .iter() - .rposition(|chip| &chip.name() == name) - .unwrap(); - let index = builder.get(&sorted_indices, chip_idx); - let opening = builder.get(&opened_values.chips, index); + // Iterate through machine.chips filtered for preprocessed chips. + for (preprocessed_id, chip_id) in machine.preprocessed_chip_ids().into_iter().enumerate() { + // Get index within sorted preprocessed chips. + let preprocessed_sorted_id = builder.get(&preprocessed_sorted_idxs, preprocessed_id); + // Get domain from witnessed domains. Array is ordered by machine.chips ordering. + let domain = builder.get(&prep_domains, preprocessed_id); - let domain: TwoAdicMultiplicativeCosetVariable<_> = builder.constant(*domain); + // Get index within all sorted chips. + let chip_sorted_id = builder.get(&chip_sorted_idxs, chip_id); + // Get opening from proof. + let opening = builder.get(&opened_values.chips, chip_sorted_id); let mut trace_points = builder.dyn_array::>(2); let zeta_next = domain.next_point(builder, zeta); @@ -133,7 +133,7 @@ where values: prep_values, points: trace_points.clone(), }; - builder.set(&mut prep_mats, i, main_mat); + builder.set(&mut prep_mats, preprocessed_sorted_id, main_mat); } builder.range(0, num_shard_chips).for_each(|i, builder| { @@ -197,8 +197,7 @@ where // Create the pcs rounds. let mut rounds = builder.dyn_array::>(4); - let prep_commit_val: [SC::Val; DIGEST_SIZE] = vk.commit.clone().into(); - let prep_commit = builder.constant(prep_commit_val.to_vec()); + let prep_commit = vk.commitment.clone(); let prep_round = TwoAdicPcsRoundVariable { batch_commit: prep_commit, mats: prep_mats, @@ -223,8 +222,14 @@ where // Verify the pcs proof pcs.verify(builder, rounds, opening_proof.clone(), challenger); + // TODO CONSTRAIN: that the preprocessed chips get called with verify_constraints. for (i, chip) in machine.chips().iter().enumerate() { - let index = builder.get(&sorted_indices, i); + let index = builder.get(&chip_sorted_idxs, i); + + if chip.preprocessed_width() > 0 { + builder.assert_var_ne(index, C::N::from_canonical_usize(EMPTY)); + } + builder .if_ne(index, C::N::from_canonical_usize(EMPTY)) .then(|builder| { @@ -252,19 +257,19 @@ where #[cfg(test)] pub(crate) mod tests { - use std::time::Instant; - use crate::challenger::CanObserveVariable; + use crate::challenger::DuplexChallengerVariable; use crate::challenger::FeltChallenger; use crate::hints::Hintable; use crate::stark::Ext; - use crate::stark::EMPTY; use crate::types::ShardCommitmentVariable; use p3_challenger::{CanObserve, FieldChallenger}; use p3_field::AbstractField; use rand::Rng; use sp1_core::air::PublicValues; + use sp1_core::air::Word; use sp1_core::runtime::Program; + use sp1_core::stark::LocalProver; use sp1_core::{ stark::{RiscvAir, ShardProof, StarkGenericConfig}, utils::BabyBearPoseidon2, @@ -277,22 +282,12 @@ pub(crate) mod tests { ir::{Builder, ExtConst}, }; use sp1_recursion_core::runtime::{Runtime, DIGEST_SIZE}; - use sp1_recursion_core::stark::config::inner_fri_config; use sp1_recursion_core::stark::config::InnerChallenge; use sp1_recursion_core::stark::config::InnerVal; - use sp1_sdk::{SP1Prover, SP1Stdin}; - - use sp1_core::air::Word; - - use crate::{ - challenger::DuplexChallengerVariable, - fri::{const_fri_config, TwoAdicFriPcsVariable}, - stark::StarkVerifier, - }; - - use sp1_core::stark::LocalProver; use sp1_recursion_core::stark::RecursionAir; use sp1_sdk::utils::setup_logger; + use sp1_sdk::{SP1Prover, SP1Stdin}; + use std::time::Instant; type SC = BabyBearPoseidon2; type F = InnerVal; @@ -371,105 +366,6 @@ pub(crate) mod tests { ); } - #[test] - fn test_recursive_verify_shard() { - // Generate a dummy proof. - sp1_core::utils::setup_logger(); - - let elf = - include_bytes!("../../../examples/fibonacci/program/elf/riscv32im-succinct-zkvm-elf"); - - let machine = A::machine(SC::default()); - - let (_, vk) = machine.setup(&Program::from(elf)); - let proof = SP1Prover::prove_with_config(elf, SP1Stdin::new(), machine.config().clone()) - .unwrap() - .proof; - let mut challenger_ver = machine.config().challenger(); - machine.verify(&vk, &proof, &mut challenger_ver).unwrap(); - println!("Proof generated successfully"); - - let mut challenger_val = machine.config().challenger(); - challenger_val.observe(vk.commit); - proof.shard_proofs.iter().for_each(|proof| { - challenger_val.observe(proof.commitment.main_commit); - let public_values_field = PublicValues::, F>::new(proof.public_values); - challenger_val.observe_slice(&public_values_field.to_vec()); - }); - - let time = Instant::now(); - let mut builder = Builder::::default(); - let config = const_fri_config(&mut builder, inner_fri_config()); - let pcs = TwoAdicFriPcsVariable { config }; - - let mut challenger = DuplexChallengerVariable::new(&mut builder); - - let preprocessed_commit_val: [F; DIGEST_SIZE] = vk.commit.into(); - let preprocessed_commit: Array = builder.constant(preprocessed_commit_val.to_vec()); - challenger.observe(&mut builder, preprocessed_commit); - - let mut witness_stream = Vec::new(); - let mut shard_proofs = vec![]; - let mut sorted_indices = vec![]; - for proof_val in proof.shard_proofs { - witness_stream.extend(proof_val.write()); - let sorted_indices_raw: Vec = machine - .chips_sorted_indices(&proof_val) - .into_iter() - .map(|x| match x { - Some(x) => x, - None => EMPTY, - }) - .collect(); - witness_stream.extend(sorted_indices_raw.write()); - let proof = ShardProof::<_>::read(&mut builder); - let sorted_indices_arr = Vec::::read(&mut builder); - builder - .range(0, sorted_indices_arr.len()) - .for_each(|i, builder| { - let el = builder.get(&sorted_indices_arr, i); - builder.print_v(el); - }); - let ShardCommitmentVariable { main_commit, .. } = &proof.commitment; - challenger.observe(&mut builder, main_commit.clone()); - let public_values_field = PublicValues::, F>::new(proof_val.public_values); - let public_values_felt: Vec> = public_values_field - .to_vec() - .iter() - .map(|x| builder.eval(*x)) - .collect(); - challenger.observe_slice(&mut builder, &public_values_felt); - shard_proofs.push(proof); - sorted_indices.push(sorted_indices_arr); - } - - let code = builder.eval(InnerVal::two()); - builder.print_v(code); - for (proof, sorted_indices) in shard_proofs.iter().zip(sorted_indices) { - StarkVerifier::::verify_shard( - &mut builder, - &vk, - &pcs, - &machine, - &mut challenger.clone(), - proof, - sorted_indices, - ); - } - - let program = builder.compile(); - let elapsed = time.elapsed(); - println!("Building took: {:?}", elapsed); - - let time = Instant::now(); - let mut runtime = Runtime::::new(&program, machine.config().perm.clone()); - runtime.witness_stream = witness_stream; - runtime.run(); - let elapsed = time.elapsed(); - runtime.print_stats(); - println!("Execution took: {:?}", elapsed); - } - #[test] #[ignore] fn test_kitchen_sink() { diff --git a/recursion/program/src/types.rs b/recursion/program/src/types.rs index 0b6fcaa9f5..c8211ced72 100644 --- a/recursion/program/src/types.rs +++ b/recursion/program/src/types.rs @@ -76,6 +76,12 @@ pub struct ShardProofVariable { pub public_values: PublicValuesVariable, } +/// Reference: https://github.com/succinctlabs/sp1/blob/b5d5473c010ab0630102652146e16c014a1eddf6/core/src/stark/machine.rs#L63 +#[derive(DslVariable, Clone)] +pub struct VerifyingKeyVariable { + pub commitment: DigestVariable, +} + #[derive(DslVariable, Clone)] pub struct ShardCommitmentVariable { pub main_commit: DigestVariable,