Skip to content

Commit

Permalink
wip(recursion): reduce program (#497)
Browse files Browse the repository at this point in the history
  • Loading branch information
ctian1 authored Apr 9, 2024
1 parent c159795 commit ce00cf6
Show file tree
Hide file tree
Showing 13 changed files with 526 additions and 376 deletions.
3 changes: 3 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 10 additions & 0 deletions core/src/stark/machine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,16 @@ impl<SC: StarkGenericConfig, A: MachineAir<Val<SC>>> MachineStark<SC, A> {
&self.chips
}

/// Returns the id of all chips in the machine that have preprocessed columns.
pub fn preprocessed_chip_ids(&self) -> Vec<usize> {
self.chips
.iter()
.enumerate()
.filter(|(_, chip)| chip.preprocessed_width() > 0)
.map(|(i, _)| i)
.collect()
}

pub fn shard_chips<'a, 'b>(
&'a self,
shard: &'b A::Record,
Expand Down
3 changes: 3 additions & 0 deletions prover/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,7 @@ sp1-recursion-program = { path = "../recursion/program" }
sp1-recursion-compiler = { path = "../recursion/compiler" }
sp1-recursion-core = { path = "../recursion/core" }
sp1-core = { path = "../core" }
p3-challenger = { workspace = true }
p3-baby-bear = { workspace = true }
p3-commit = { workspace = true }
bincode = "1.3.3"
172 changes: 114 additions & 58 deletions prover/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,84 +1,140 @@
#![allow(incomplete_features)]
#![feature(generic_const_exprs)]
#![allow(deprecated)]

use std::time::Instant;

use p3_baby_bear::BabyBear;
use p3_challenger::CanObserve;
use p3_commit::TwoAdicMultiplicativeCoset;
use sp1_core::{
air::{MachineAir, PublicValues, Word},
runtime::Program,
stark::{Proof, RiscvAir, StarkGenericConfig, VerifyingKey},
stark::{LocalProver, Proof, RiscvAir, ShardProof, StarkGenericConfig},
utils::{run_and_prove, BabyBearPoseidon2},
};
use sp1_recursion_core::runtime::Runtime;
use sp1_recursion_program::compress::build_compress;
use sp1_recursion_core::{runtime::Runtime, stark::RecursionAir};
use sp1_recursion_program::{hints::Hintable, reduce::build_reduce, stark::EMPTY};

type InnerSC = BabyBearPoseidon2;
type InnerF = <InnerSC as StarkGenericConfig>::Val;
type InnerEF = <InnerSC as StarkGenericConfig>::Challenge;
type InnerA = RiscvAir<InnerF>;

pub fn prove_sp1() -> (Proof<InnerSC>, VerifyingKey<InnerSC>) {
let elf = include_bytes!("../../examples/fibonacci-io/program/elf/riscv32im-succinct-zkvm-elf");

let config = InnerSC::default();
let machine = RiscvAir::machine(config.clone());
let program = Program::from(elf);
let stdin = [bincode::serialize::<u32>(&6).unwrap()];
let (_, vk) = machine.setup(&program);
let (proof, _) = run_and_prove(program, &stdin, config);
let mut challenger_ver = machine.config().challenger();
machine.verify(&vk, &proof, &mut challenger_ver).unwrap();
println!("Proof generated successfully");

(proof, vk)
}

pub fn prove_compress(sp1_proof: Proof<InnerSC>, vk: VerifyingKey<InnerSC>) {
let (program, witness_stream) = build_compress(sp1_proof, vk);

let config = InnerSC::default();
let machine = InnerA::machine(config);
let mut runtime = Runtime::<InnerF, InnerEF, _>::new(&program, machine.config().perm.clone());
runtime.witness_stream = witness_stream;

let time = Instant::now();
runtime.run();
let elapsed = time.elapsed();
runtime.print_stats();
println!("Execution took: {:?}", elapsed);

// let config = BabyBearPoseidon2::new();
// let machine = RecursionAir::machine(config);
// let (pk, vk) = machine.setup(&program);

// let mut challenger = machine.config().challenger();
// let record_clone = runtime.record.clone();
// machine.debug_constraints(&pk, record_clone, &mut challenger);

// let start = Instant::now();
// let mut challenger = machine.config().challenger();
// let proof = machine.prove::<LocalProver<_, _>>(&pk, runtime.record, &mut challenger);
// let duration = start.elapsed().as_secs();
pub struct SP1ProverImpl;

impl SP1ProverImpl {
pub fn prove(elf: &[u8], stdin: &[Vec<u8>]) -> Proof<InnerSC> {
let config = InnerSC::default();
let machine = RiscvAir::machine(config.clone());
let program = Program::from(elf);
let (_, vk) = machine.setup(&program);
let (proof, _) = run_and_prove(program, stdin, config);
let mut challenger_ver = machine.config().challenger();
machine.verify(&vk, &proof, &mut challenger_ver).unwrap();
proof
}

// let mut challenger = machine.config().challenger();
// machine.verify(&vk, &proof, &mut challenger).unwrap();
// println!("proving duration = {}", duration);
pub fn reduce(elf: &[u8], proof: Proof<InnerSC>) -> Vec<ShardProof<InnerSC>> {
let config = InnerSC::default();
let machine = RiscvAir::machine(config.clone());
let program = Program::from(elf);
let (_, vk) = machine.setup(&program);
let reduce_program = build_reduce();
println!("nb_shards {}", proof.shard_proofs.len());
let config = InnerSC::default();

let is_recursive_flags: Vec<usize> = proof.shard_proofs.iter().map(|_| 0).collect();
let sorted_indices: Vec<Vec<usize>> = proof
.shard_proofs
.iter()
.map(|p| {
machine
.chips_sorted_indices(p)
.into_iter()
.map(|x| match x {
Some(x) => x,
None => EMPTY,
})
.collect()
})
.collect();

let mut challenger = machine.config().challenger();
challenger.observe(vk.commit);
let reconstruct_challenger = challenger.clone();
for proof in proof.shard_proofs.iter() {
challenger.observe(proof.commitment.main_commit);
let public_values = PublicValues::<Word<BabyBear>, BabyBear>::new(proof.public_values);
challenger.observe_slice(&public_values.to_vec());
}

let chips = machine.chips();
let ordering = vk.chip_ordering.clone();
let (prep_sorted_indices, prep_domains): (
Vec<usize>,
Vec<TwoAdicMultiplicativeCoset<BabyBear>>,
) = machine
.preprocessed_chip_ids()
.into_iter()
.map(|chip_idx| {
let name = chips[chip_idx].name().clone();
let prep_sorted_idx = ordering[&name];
(prep_sorted_idx, vk.chip_information[prep_sorted_idx].1)
})
.unzip();

// Generate inputs.
let mut witness_stream = Vec::new();
witness_stream.extend(proof.shard_proofs.write());
witness_stream.extend(is_recursive_flags.write());
witness_stream.extend(sorted_indices.write());
witness_stream.extend(challenger.write());
witness_stream.extend(reconstruct_challenger.write());
witness_stream.extend(prep_sorted_indices.write());
witness_stream.extend(prep_domains.write());
witness_stream.extend(vk.write());
witness_stream.extend(vk.write());

// Execute runtime.
let machine = InnerA::machine(config);
let mut runtime =
Runtime::<InnerF, InnerEF, _>::new(&reduce_program, machine.config().perm.clone());
runtime.witness_stream = witness_stream;
runtime.run();
runtime.print_stats();

// Generate proof.
let config = BabyBearPoseidon2::new();
let machine = RecursionAir::machine(config);
let (pk, _) = machine.setup(&reduce_program);
// let mut challenger = machine.config().challenger();
// let record_clone = runtime.record.clone();
// machine.debug_constraints(&pk, record_clone, &mut challenger);
let start = Instant::now();
let mut challenger = machine.config().challenger();
let proof = machine.prove::<LocalProver<_, _>>(&pk, runtime.record, &mut challenger);
let duration = start.elapsed().as_secs();
println!("proving duration = {}", duration);

// let mut challenger = machine.config().challenger();
// machine.verify(&vk, &proof, &mut challenger).unwrap();

proof.shard_proofs
}
}

pub fn prove_reduce() {}

pub fn prove_snark() {}

#[cfg(test)]
mod tests {
use sp1_core::utils::setup_logger;

use super::*;

use sp1_core::utils::setup_logger;
#[test]
fn test_prove_sp1() {
setup_logger();

let (sp1_proof, vk) = prove_sp1();
prove_compress(sp1_proof, vk);
let elf =
include_bytes!("../../examples/fibonacci-io/program/elf/riscv32im-succinct-zkvm-elf");
let stdin = [bincode::serialize::<u32>(&6).unwrap()];
let proof = SP1ProverImpl::prove(elf, &stdin);
SP1ProverImpl::reduce(elf, proof);
}
}
2 changes: 2 additions & 0 deletions recursion/compiler/src/ir/var.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ pub struct MemIndex<N> {

pub trait MemVariable<C: Config>: Variable<C> {
fn size_of() -> usize;
/// Loads the variable from the heap.
fn load(&self, ptr: Ptr<C::N>, index: MemIndex<C::N>, builder: &mut Builder<C>);
/// Stores the variable to the heap.
fn store(&self, ptr: Ptr<C::N>, index: MemIndex<C::N>, builder: &mut Builder<C>);
}

Expand Down
2 changes: 1 addition & 1 deletion recursion/core/src/cpu/columns/opcode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ impl<F: Field> OpcodeSelectorCols<F> {
Opcode::PrintF => self.is_noop = F::one(),
Opcode::PrintE => self.is_noop = F::one(),
Opcode::FRIFold => self.is_fri_fold = F::one(),
_ => unreachable!(),
_ => {}
}
}
}
Expand Down
36 changes: 34 additions & 2 deletions recursion/program/src/challenger.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
use p3_field::AbstractField;

use sp1_recursion_compiler::prelude::{Array, Builder, Config, Ext, Felt, Usize, Var};
use sp1_recursion_compiler::prelude::MemIndex;
use sp1_recursion_compiler::prelude::MemVariable;
use sp1_recursion_compiler::prelude::Ptr;
use sp1_recursion_compiler::prelude::Variable;
use sp1_recursion_compiler::prelude::{Array, Builder, Config, DslVariable, Ext, Felt, Usize, Var};
use sp1_recursion_core::runtime::{DIGEST_SIZE, PERMUTATION_WIDTH};

use crate::fri::types::DigestVariable;
Expand Down Expand Up @@ -37,7 +41,7 @@ pub trait CanSampleBitsVariable<C: Config> {
}

/// Reference: https://github.com/Plonky3/Plonky3/blob/4809fa7bedd9ba8f6f5d3267b1592618e3776c57/challenger/src/duplex_challenger.rs#L10
#[derive(Clone)]
#[derive(Clone, DslVariable)]
pub struct DuplexChallengerVariable<C: Config> {
pub sponge_state: Array<C, Felt<C::F>>,
pub nb_inputs: Var<C::N>,
Expand All @@ -57,6 +61,34 @@ impl<C: Config> DuplexChallengerVariable<C> {
}
}

/// Creates a new challenger with the same state as an existing challenger.
pub fn as_clone(&self, builder: &mut Builder<C>) -> Self {
let mut sponge_state = builder.dyn_array(PERMUTATION_WIDTH);
builder.range(0, PERMUTATION_WIDTH).for_each(|i, builder| {
let element = builder.get(&self.sponge_state, i);
builder.set(&mut sponge_state, i, element);
});
let nb_inputs = builder.eval(self.nb_inputs);
let mut input_buffer = builder.dyn_array(PERMUTATION_WIDTH);
builder.range(0, PERMUTATION_WIDTH).for_each(|i, builder| {
let element = builder.get(&self.input_buffer, i);
builder.set(&mut input_buffer, i, element);
});
let nb_outputs = builder.eval(self.nb_outputs);
let mut output_buffer = builder.dyn_array(PERMUTATION_WIDTH);
builder.range(0, PERMUTATION_WIDTH).for_each(|i, builder| {
let element = builder.get(&self.output_buffer, i);
builder.set(&mut output_buffer, i, element);
});
DuplexChallengerVariable::<C> {
sponge_state,
nb_inputs,
input_buffer,
nb_outputs,
output_buffer,
}
}

/// Reference: https://github.com/Plonky3/Plonky3/blob/4809fa7bedd9ba8f6f5d3267b1592618e3776c57/challenger/src/duplex_challenger.rs#L38
pub fn duplexing(&mut self, builder: &mut Builder<C>) {
builder.range(0, self.nb_inputs).for_each(|i, builder| {
Expand Down
Loading

0 comments on commit ce00cf6

Please sign in to comment.