Skip to content

Commit

Permalink
feat: shrink/wrap multi opt (#886)
Browse files Browse the repository at this point in the history
  • Loading branch information
jtguibas authored Jun 4, 2024
2 parents b31b27c + 3afd6ff commit 667ff1b
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 15 deletions.
13 changes: 8 additions & 5 deletions recursion/core/src/fri_fold/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ pub const NUM_FRI_FOLD_COLS: usize = core::mem::size_of::<FriFoldCols<u8>>();
#[derive(Default)]
pub struct FriFoldChip<const DEGREE: usize> {
pub fixed_log2_rows: Option<usize>,
pub pad: bool,
}

#[derive(Debug, Clone)]
Expand Down Expand Up @@ -143,11 +144,13 @@ impl<F: PrimeField32, const DEGREE: usize> MachineAir<F> for FriFoldChip<DEGREE>
.collect_vec();

// Pad the trace to a power of two.
pad_rows_fixed(
&mut rows,
|| [F::zero(); NUM_FRI_FOLD_COLS],
self.fixed_log2_rows,
);
if self.pad {
pad_rows_fixed(
&mut rows,
|| [F::zero(); NUM_FRI_FOLD_COLS],
self.fixed_log2_rows,
);
}

// Convert the trace to a row major matrix.
let trace = RowMajorMatrix::new(rows.into_iter().flatten().collect(), NUM_FRI_FOLD_COLS);
Expand Down
12 changes: 8 additions & 4 deletions recursion/core/src/multi/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,14 @@ impl<F: PrimeField32, const DEGREE: usize> MachineAir<F> for MultiChip<DEGREE> {
input: &ExecutionRecord<F>,
output: &mut ExecutionRecord<F>,
) -> RowMajorMatrix<F> {
let fri_fold_chip = FriFoldChip::<3>::default();
let poseidon2 = Poseidon2Chip::default();
let fri_fold_chip = FriFoldChip::<3> {
fixed_log2_rows: None,
pad: false,
};
let poseidon2 = Poseidon2Chip {
fixed_log2_rows: None,
pad: false,
};
let fri_fold_trace = fri_fold_chip.generate_trace(input, output);
let mut poseidon2_trace = poseidon2.generate_trace(input, output);

Expand Down Expand Up @@ -145,14 +151,12 @@ where
// all the fri fold rows are first, then the posiedon2 rows, and finally any padded (non-real) rows.

// First verify that all real rows are contiguous.
builder.when_first_row().assert_one(local_is_real.clone());
builder
.when_transition()
.when_not(local_is_real.clone())
.assert_zero(next_is_real.clone());

// Next, verify that all fri fold rows are before the poseidon2 rows within the real rows section.
builder.when_first_row().assert_one(local.is_fri_fold);
builder
.when_transition()
.when(next_is_real)
Expand Down
3 changes: 3 additions & 0 deletions recursion/core/src/poseidon2/external.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ pub const WIDTH: usize = 16;
#[derive(Default)]
pub struct Poseidon2Chip {
pub fixed_log2_rows: Option<usize>,
pub pad: bool,
}

impl<F> BaseAir<F> for Poseidon2Chip {
Expand Down Expand Up @@ -449,6 +450,7 @@ mod tests {
fn generate_trace() {
let chip = Poseidon2Chip {
fixed_log2_rows: None,
pad: true,
};

let rng = &mut rand::thread_rng();
Expand Down Expand Up @@ -498,6 +500,7 @@ mod tests {

let chip = Poseidon2Chip {
fixed_log2_rows: None,
pad: true,
};
let trace: RowMajorMatrix<BabyBear> =
chip.generate_trace(&input_exec, &mut ExecutionRecord::<BabyBear>::default());
Expand Down
12 changes: 7 additions & 5 deletions recursion/core/src/poseidon2/trace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -174,11 +174,13 @@ impl<F: PrimeField32> MachineAir<F> for Poseidon2Chip {
let num_real_rows = rows.len();

// Pad the trace to a power of two.
pad_rows_fixed(
&mut rows,
|| [F::zero(); NUM_POSEIDON2_COLS],
self.fixed_log2_rows,
);
if self.pad {
pad_rows_fixed(
&mut rows,
|| [F::zero(); NUM_POSEIDON2_COLS],
self.fixed_log2_rows,
);
}

let mut round_num = 0;
for row in rows[num_real_rows..].iter_mut() {
Expand Down
3 changes: 2 additions & 1 deletion recursion/core/src/stark/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ impl<F: PrimeField32 + BinomiallyExtendable<D>, const DEGREE: usize> RecursionAi
})))
.chain(once(RecursionAir::FriFold(FriFoldChip::<DEGREE> {
fixed_log2_rows: None,
pad: true,
})))
.chain(once(RecursionAir::RangeCheck(RangeCheckChip::default())))
.collect()
Expand Down Expand Up @@ -109,7 +110,7 @@ impl<F: PrimeField32 + BinomiallyExtendable<D>, const DEGREE: usize> RecursionAi
fixed_log2_rows: Some(19),
})))
.chain(once(RecursionAir::Multi(MultiChip {
fixed_log2_rows: Some(20),
fixed_log2_rows: Some(19),
})))
.chain(once(RecursionAir::RangeCheck(RangeCheckChip::default())))
.collect()
Expand Down
3 changes: 3 additions & 0 deletions recursion/program/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,6 @@ itertools = "0.12.1"
serde = { version = "1.0.201", features = ["derive"] }
rand = "0.8.5"
tracing = "0.1.40"

[features]
debug = ["sp1-core/debug"]

0 comments on commit 667ff1b

Please sign in to comment.