Skip to content

Commit

Permalink
for loop optimizations
Browse files Browse the repository at this point in the history
  • Loading branch information
jtguibas committed Apr 9, 2024
1 parent b5d5473 commit 2671e46
Show file tree
Hide file tree
Showing 12 changed files with 161 additions and 36 deletions.
2 changes: 1 addition & 1 deletion examples/fibonacci-io/script/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ fn main() {
println!("b: {}", b);

// Verify proof and public values
SP1Verifier::verify(ELF, &proof).expect("verification failed");
SP1Verifier::verify(ELF, &proof).expect("verification failed"s);

let mut pv_hasher = Sha256::new();
pv_hasher.update(n.to_le_bytes());
Expand Down
14 changes: 2 additions & 12 deletions examples/fibonacci/program/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Binary file modified examples/fibonacci/program/elf/riscv32im-succinct-zkvm-elf
Binary file not shown.
34 changes: 28 additions & 6 deletions recursion/compiler/src/asm/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -744,12 +744,17 @@ impl<'a, F: PrimeField32 + TwoAdicField, EF: ExtensionField<F> + TwoAdicField>
let loop_label = self.compiler.block_label();
// The loop body.
f(self.loop_var, self.compiler);
// Increment the loop variable.
self.compiler.push(AsmInstruction::ADDI(
self.loop_var.fp(),
self.loop_var.fp(),
self.step_size,
));

if self.step_size == F::one() {
self.jump_to_loop_body_inc(loop_label);
} else {
// Increment the loop variable.
self.compiler.push(AsmInstruction::ADDI(
self.loop_var.fp(),
self.loop_var.fp(),
self.step_size,
));
}

// Add a basic block for the loop condition.
self.compiler.basic_block();
Expand Down Expand Up @@ -815,4 +820,21 @@ impl<'a, F: PrimeField32 + TwoAdicField, EF: ExtensionField<F> + TwoAdicField>
}
}
}

fn jump_to_loop_body_inc(&mut self, loop_label: F) {
match self.end {
Usize::Const(end) => {
let instr = AsmInstruction::BNEIINC(
loop_label,
self.loop_var.fp(),
F::from_canonical_usize(end),
);
self.compiler.push(instr);
}
Usize::Var(end) => {
let instr = AsmInstruction::BNEINC(loop_label, self.loop_var.fp(), end.fp());
self.compiler.push(instr);
}
}
}
}
49 changes: 49 additions & 0 deletions recursion/compiler/src/asm/instruction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,11 @@ pub enum AsmInstruction<F, EF> {
JALR(i32, i32, i32),
/// Branch not equal
BNE(F, i32, i32),
/// Branch not equal increment c by 1.
BNEINC(F, i32, i32),
/// Branch not equal immediate
BNEI(F, i32, F),
BNEIINC(F, i32, F),
/// Branch equal
BEQ(F, i32, i32),
/// Branch equal immediate
Expand Down Expand Up @@ -683,6 +686,20 @@ impl<F: PrimeField32, EF: ExtensionField<F>> AsmInstruction<F, EF> {
true,
)
}
AsmInstruction::BNEINC(label, lhs, rhs) => {
let offset =
F::from_canonical_usize(label_to_pc[&label]) - F::from_canonical_usize(pc);
Instruction::new(
Opcode::BNEINC,
i32_f(lhs),
i32_f_arr(rhs),
f_u32(offset),
F::zero(),
F::zero(),
false,
true,
)
}
AsmInstruction::BNEI(label, lhs, rhs) => {
let offset =
F::from_canonical_usize(label_to_pc[&label]) - F::from_canonical_usize(pc);
Expand All @@ -697,6 +714,20 @@ impl<F: PrimeField32, EF: ExtensionField<F>> AsmInstruction<F, EF> {
true,
)
}
AsmInstruction::BNEIINC(label, lhs, rhs) => {
let offset =
F::from_canonical_usize(label_to_pc[&label]) - F::from_canonical_usize(pc);
Instruction::new(
Opcode::BNEINC,
i32_f(lhs),
f_u32(rhs),
f_u32(offset),
F::zero(),
F::zero(),
true,
true,
)
}
AsmInstruction::EBNE(label, lhs, rhs) => {
let offset =
F::from_canonical_usize(label_to_pc[&label]) - F::from_canonical_usize(pc);
Expand Down Expand Up @@ -1100,6 +1131,24 @@ impl<F: PrimeField32, EF: ExtensionField<F>> AsmInstruction<F, EF> {
rhs
)
}
AsmInstruction::BNEINC(label, lhs, rhs) => {
write!(
f,
"bneinc {}, ({})fp, {}",
labels.get(label).unwrap_or(&format!(".L{}", label)),
lhs,
rhs
)
}
AsmInstruction::BNEIINC(label, lhs, rhs) => {
write!(
f,
"bneiinc {}, ({})fp, {}",
labels.get(label).unwrap_or(&format!(".L{}", label)),
lhs,
rhs
)
}
AsmInstruction::BEQ(label, lhs, rhs) => {
write!(
f,
Expand Down
7 changes: 4 additions & 3 deletions recursion/compiler/src/ir/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ impl<C: Config> Builder<C> {
builder.range(0, HASH_RATE).for_each(|j, builder| {
let index: Var<_> = builder.eval(i + j);
let element = builder.get(array, index);
builder.set(&mut state, j, element);
builder.set_value(&mut state, j, element);
builder.if_eq(index, last_index).then(|builder| {
builder.assign(break_flag, C::N::one());
builder.break_loop();
Expand Down Expand Up @@ -481,11 +481,12 @@ impl<C: Config> Builder<C> {
self.range(0, bit_len).for_each(|i, builder| {
let index: Var<C::N> = builder.eval(bit_len - i - C::N::one());
let entry = builder.get(index_bits, index);
builder.set(&mut result_bits, i, entry);
builder.set_value(&mut result_bits, i, entry);
});

let zero = self.eval(C::N::zero());
self.range(bit_len, NUM_BITS).for_each(|i, builder| {
builder.set(&mut result_bits, i, C::N::zero());
builder.set_value(&mut result_bits, i, zero);
});

result_bits
Expand Down
23 changes: 23 additions & 0 deletions recursion/compiler/src/ir/collections.rs
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,29 @@ impl<C: Config> Builder<C> {
}
}
}

pub fn set_value<V: MemVariable<C>, I: Into<Usize<C::N>>>(
&mut self,
slice: &mut Array<C, V>,
index: I,
value: V,
) {
let index = index.into();

match slice {
Array::Fixed(_) => {
todo!()
}
Array::Dyn(ptr, _) => {
let index = MemIndex {
index,
offset: 0,
size: V::size_of(),
};
self.store(*ptr, index, value);
}
}
}
}

impl<C: Config, T: MemVariable<C>> Variable<C> for Array<C, T> {
Expand Down
28 changes: 28 additions & 0 deletions recursion/compiler/tests/for_loops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -197,3 +197,31 @@ fn test_compiler_step_by() {
let mut runtime = Runtime::<F, EF, _>::new(&program, config.perm.clone());
runtime.run();
}

#[test]
fn test_compiler_bneinc() {
type SC = BabyBearPoseidon2;
type F = <SC as StarkGenericConfig>::Val;
type EF = <SC as StarkGenericConfig>::Challenge;
let mut builder = VmBuilder::<F, EF>::default();

let n_val = BabyBear::from_canonical_u32(20);

let zero: Var<_> = builder.eval(F::zero());
let n: Var<_> = builder.eval(n_val);

let i_counter: Var<_> = builder.eval(F::zero());
builder.range(zero, n).step_by(1).for_each(|_, builder| {
builder.assign(i_counter, i_counter + F::one());
});

let code = builder.clone().compile_to_asm();

println!("{}", code);

let program = builder.compile();

let config = SC::default();
let mut runtime = Runtime::<F, EF, _>::new(&program, config.perm.clone());
runtime.run();
}
9 changes: 9 additions & 0 deletions recursion/core/src/runtime/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,15 @@ where
next_pc = self.pc + c_offset;
}
}
Opcode::BNEINC => {
let (mut a_val, b_val, c_offset) = self.branch_rr(&instruction);
a_val.0[0] += F::one();
if a_val.0[0] != b_val.0[0] {
next_pc = self.pc + c_offset;
}
self.mw(self.fp + instruction.op_a, a_val, MemoryAccessPosition::A);
(a, b, c) = (a_val, b_val, Block::from(c_offset));
}
Opcode::EBEQ => {
let (a_val, b_val, c_offset) = self.branch_rr(&instruction);
(a, b, c) = (a_val, b_val, Block::from(c_offset));
Expand Down
1 change: 1 addition & 0 deletions recursion/core/src/runtime/opcode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ pub enum Opcode {
HintLen = 37,
Hint = 38,
Poseidon2Compress = 39,
BNEINC = 40,
}

impl Opcode {
Expand Down
16 changes: 8 additions & 8 deletions recursion/program/src/fri/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,19 +168,19 @@ where
let index_pair = index_bits.shift(builder, i_plus_one);

let mut evals: Array<C, Ext<C::F, C::EF>> = builder.array(2);
builder.set(&mut evals, 0, folded_eval);
builder.set(&mut evals, 1, folded_eval);
builder.set(&mut evals, index_sibling_mod_2, step.sibling_value);
builder.set_value(&mut evals, 0, folded_eval);
builder.set_value(&mut evals, 1, folded_eval);
builder.set_value(&mut evals, index_sibling_mod_2, step.sibling_value);

let two: Var<C::N> = builder.eval(C::N::from_canonical_u32(2));
let dims = DimensionsVariable::<C> {
height: builder.exp(two, log_folded_height),
};
let mut dims_slice: Array<C, DimensionsVariable<C>> = builder.array(1);
builder.set(&mut dims_slice, 0, dims);
builder.set_value(&mut dims_slice, 0, dims);

let mut opened_values = builder.array(1);
builder.set(&mut opened_values, 0, evals.clone());
builder.set_value(&mut opened_values, 0, evals.clone());
verify_batch::<C, 4>(
builder,
&commit,
Expand All @@ -192,8 +192,8 @@ where

let mut xs: Array<C, Ext<C::F, C::EF>> = builder.array(2);
let two_adic_generator_one = config.get_two_adic_generator(builder, Usize::Const(1));
builder.set(&mut xs, 0, x);
builder.set(&mut xs, 1, x);
builder.set_value(&mut xs, 0, x);
builder.set_value(&mut xs, 1, x);
builder.set(&mut xs, index_sibling_mod_2, x * two_adic_generator_one);

let xs_0 = builder.get(&xs, 0);
Expand Down Expand Up @@ -306,7 +306,7 @@ pub fn reduce<C: Config, const D: usize>(
let opened_value_flat = builder.ext2felt(opened_value);
for k in 0..D {
let base = builder.get(&opened_value_flat, k);
builder.set(&mut flattened_opened_values, nb_opened_values, base);
builder.set_value(&mut flattened_opened_values, nb_opened_values, base);
builder.assign(nb_opened_values, nb_opened_values + C::N::one());
}
});
Expand Down
14 changes: 8 additions & 6 deletions recursion/program/src/fri/two_adic_pcs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,13 @@ pub fn verify_two_adic_pcs<C: Config>(

let mut ro: Array<C, Ext<C::F, C::EF>> = builder.array(32);
let mut alpha_pow: Array<C, Ext<C::F, C::EF>> = builder.array(32);
let zero_ef = builder.eval(C::EF::zero().cons());
for j in 0..32 {
builder.set(&mut ro, j, C::EF::zero().cons());
builder.set_value(&mut ro, j, zero_ef);
}
let one_ef = builder.eval(C::EF::one().cons());
for j in 0..32 {
builder.set(&mut alpha_pow, j, C::EF::one().cons());
builder.set_value(&mut alpha_pow, j, one_ef);
}

builder.range(0, rounds.len()).for_each(|j, builder| {
Expand All @@ -72,15 +74,15 @@ pub fn verify_two_adic_pcs<C: Config>(
builder.range(0, mats.len()).for_each(|k, builder| {
let mat = builder.get(&mats, k);
let height_log2: Var<_> = builder.eval(mat.domain.log_n + log_blowup);
builder.set(&mut batch_heights_log2, k, height_log2);
builder.set_value(&mut batch_heights_log2, k, height_log2);
});
let mut batch_dims: Array<C, DimensionsVariable<C>> = builder.array(mats.len());
builder.range(0, mats.len()).for_each(|k, builder| {
let mat = builder.get(&mats, k);
let dim = DimensionsVariable::<C> {
height: builder.eval(mat.domain.size() * blowup),
};
builder.set(&mut batch_dims, k, dim);
builder.set_value(&mut batch_dims, k, dim);
});

let log_batch_max_height = builder.get(&batch_heights_log2, 0);
Expand Down Expand Up @@ -136,7 +138,7 @@ pub fn verify_two_adic_pcs<C: Config>(
};

let mut input_ptr = builder.array::<FriFoldInput<_>>(1);
builder.set(&mut input_ptr, 0, input);
builder.set_value(&mut input_ptr, 0, input);

builder.range(0, ps_at_z.len()).for_each(|m, builder| {
builder.push(DslIR::FriFold(m, input_ptr.clone()));
Expand All @@ -145,7 +147,7 @@ pub fn verify_two_adic_pcs<C: Config>(
});
});

builder.set(&mut reduced_openings, i, ro);
builder.set_value(&mut reduced_openings, i, ro);
});

verify_challenges(
Expand Down

0 comments on commit 2671e46

Please sign in to comment.