Skip to content

Commit

Permalink
Merge pull request #2813 from o1-labs/o1vm/batch-inversion
Browse files Browse the repository at this point in the history
o1vm/mips: use batch_inversion for the witness generation
  • Loading branch information
dannywillems authored Nov 25, 2024
2 parents bc8456e + a11a350 commit 1f1ceb8
Show file tree
Hide file tree
Showing 12 changed files with 202 additions and 37 deletions.
43 changes: 36 additions & 7 deletions o1vm/src/interpreters/mips/column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ use kimchi_msm::{
use std::ops::{Index, IndexMut};
use strum::EnumCount;

use super::{ITypeInstruction, JTypeInstruction, RTypeInstruction};
pub use super::{
witness::SCRATCH_SIZE_INVERSE, ITypeInstruction, JTypeInstruction, RTypeInstruction,
};

/// The number of hashes performed so far in the block
pub(crate) const MIPS_HASH_COUNTER_OFF: usize = 80;
Expand All @@ -35,7 +37,7 @@ pub(crate) const MIPS_CHUNK_BYTES_LEN: usize = 4;
pub(crate) const MIPS_PREIMAGE_KEY: usize = 97;

/// The number of columns used for relation witness in the MIPS circuit
pub const N_MIPS_REL_COLS: usize = SCRATCH_SIZE + 2;
pub const N_MIPS_REL_COLS: usize = SCRATCH_SIZE + SCRATCH_SIZE_INVERSE + 2;

/// The number of witness columns used to store the instruction selectors.
pub const N_MIPS_SEL_COLS: usize =
Expand All @@ -50,6 +52,9 @@ pub const N_MIPS_COLS: usize = N_MIPS_REL_COLS + N_MIPS_SEL_COLS;
pub enum ColumnAlias {
// Can be seen as the abstract indexed variable X_{i}
ScratchState(usize),
// A column whose value needs to be inverted in the final witness.
// We're keeping a separate column to perform a batch inversion at the end.
ScratchStateInverse(usize),
InstructionCounter,
Selector(usize),
}
Expand All @@ -66,8 +71,12 @@ impl From<ColumnAlias> for usize {
assert!(i < SCRATCH_SIZE);
i
}
ColumnAlias::InstructionCounter => SCRATCH_SIZE,
ColumnAlias::Selector(s) => SCRATCH_SIZE + 1 + s,
ColumnAlias::ScratchStateInverse(i) => {
assert!(i < SCRATCH_SIZE_INVERSE);
SCRATCH_SIZE + i
}
ColumnAlias::InstructionCounter => SCRATCH_SIZE + SCRATCH_SIZE_INVERSE,
ColumnAlias::Selector(s) => SCRATCH_SIZE + SCRATCH_SIZE_INVERSE + 1 + s,
}
}
}
Expand Down Expand Up @@ -132,16 +141,36 @@ impl<T: Clone> IndexMut<ColumnAlias> for MIPSWitness<T> {

impl ColumnIndexer for ColumnAlias {
const N_COL: usize = N_MIPS_COLS;

fn to_column(self) -> Column {
match self {
Self::ScratchState(ss) => {
assert!(ss < SCRATCH_SIZE);
assert!(
ss < SCRATCH_SIZE,
"The maximum index is {}, got {}",
SCRATCH_SIZE,
ss
);
Column::Relation(ss)
}
Self::InstructionCounter => Column::Relation(SCRATCH_SIZE),
Self::ScratchStateInverse(ss) => {
assert!(
ss < SCRATCH_SIZE_INVERSE,
"The maximum index is {}, got {}",
SCRATCH_SIZE_INVERSE,
ss
);
Column::Relation(SCRATCH_SIZE + ss)
}
Self::InstructionCounter => Column::Relation(SCRATCH_SIZE + SCRATCH_SIZE_INVERSE),
// TODO: what happens with error? It does not have a corresponding alias
Self::Selector(s) => {
assert!(s < N_MIPS_SEL_COLS);
assert!(
s < N_MIPS_SEL_COLS,
"The maximum index is {}, got {}",
N_MIPS_SEL_COLS,
s
);
Column::DynamicSelector(s)
}
}
Expand Down
11 changes: 10 additions & 1 deletion o1vm/src/interpreters/mips/constraints.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ use super::column::N_MIPS_SEL_COLS;
/// The environment keeping the constraints between the different polynomials
pub struct Env<Fp> {
scratch_state_idx: usize,
scratch_state_idx_inverse: usize,
/// A list of constraints, which are multi-variate polynomials over a field,
/// represented using the expression framework of `kimchi`.
constraints: Vec<E<Fp>>,
Expand All @@ -37,6 +38,7 @@ impl<Fp: Field> Default for Env<Fp> {
fn default() -> Self {
Self {
scratch_state_idx: 0,
scratch_state_idx_inverse: 0,
constraints: Vec::new(),
lookups: Vec::new(),
selector: None,
Expand All @@ -62,6 +64,12 @@ impl<Fp: Field> InterpreterEnv for Env<Fp> {
MIPSColumn::ScratchState(scratch_idx)
}

fn alloc_scratch_inverse(&mut self) -> Self::Position {
let scratch_idx = self.scratch_state_idx_inverse;
self.scratch_state_idx_inverse += 1;
MIPSColumn::ScratchStateInverse(scratch_idx)
}

type Variable = E<Fp>;

fn variable(&self, column: Self::Position) -> Self::Variable {
Expand Down Expand Up @@ -219,7 +227,7 @@ impl<Fp: Field> InterpreterEnv for Env<Fp> {
unsafe { self.test_zero(x, pos) }
};
let x_inv_or_zero = {
let pos = self.alloc_scratch();
let pos = self.alloc_scratch_inverse();
self.variable(pos)
};
// If x = 0, then res = 1 and x_inv_or_zero = 0
Expand Down Expand Up @@ -623,6 +631,7 @@ impl<Fp: Field> InterpreterEnv for Env<Fp> {

fn reset(&mut self) {
self.scratch_state_idx = 0;
self.scratch_state_idx_inverse = 0;
self.constraints.clear();
self.lookups.clear();
self.selector = None;
Expand Down
2 changes: 2 additions & 0 deletions o1vm/src/interpreters/mips/interpreter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,8 @@ pub trait InterpreterEnv {
/// [crate::interpreters::mips::witness::SCRATCH_SIZE]
fn alloc_scratch(&mut self) -> Self::Position;

fn alloc_scratch_inverse(&mut self) -> Self::Position;

type Variable: Clone
+ std::ops::Add<Self::Variable, Output = Self::Variable>
+ std::ops::Sub<Self::Variable, Output = Self::Variable>
Expand Down
1 change: 1 addition & 0 deletions o1vm/src/interpreters/mips/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ mod rtype {
// that condition would generate an infinite loop instead)
while dummy_env.registers.preimage_offset < total_length {
dummy_env.reset_scratch_state();
dummy_env.reset_scratch_state_inverse();

// Set maximum number of bytes to read in this call
dummy_env.registers[6] = rng.gen_range(1..=4);
Expand Down
4 changes: 4 additions & 0 deletions o1vm/src/interpreters/mips/tests_helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ use std::{fs, path::PathBuf};
// FIXME: we should parametrize the tests with different fields.
use ark_bn254::Fr as Fp;

use super::witness::SCRATCH_SIZE_INVERSE;

const PAGE_INDEX_EXECUTABLE_MEMORY: u32 = 1;

pub(crate) struct OnDiskPreImageOracle;
Expand Down Expand Up @@ -87,7 +89,9 @@ where
registers: Registers::default(),
registers_write_index: Registers::default(),
scratch_state_idx: 0,
scratch_state_idx_inverse: 0,
scratch_state: [Fp::from(0); SCRATCH_SIZE],
scratch_state_inverse: [Fp::from(0); SCRATCH_SIZE_INVERSE],
selector: crate::interpreters::mips::column::N_MIPS_SEL_COLS,
halt: false,
// Keccak related
Expand Down
54 changes: 37 additions & 17 deletions o1vm/src/interpreters/mips/witness.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,17 @@ pub const NUM_INSTRUCTION_LOOKUP_TERMS: usize = 5;
pub const NUM_LOOKUP_TERMS: usize =
NUM_GLOBAL_LOOKUP_TERMS + NUM_DECODING_LOOKUP_TERMS + NUM_INSTRUCTION_LOOKUP_TERMS;
// TODO: Delete and use a vector instead
// FIXME: since the introduction of the scratch size inverse, the value below
// can be decreased. It implies to change the offsets defined in [column]. At
// the moment, it incurs an overhead we could avoid as some columns are zeroes.
// MIPS + hash_counter + byte_counter + eof + num_bytes_read + chunk + bytes
// + length + has_n_bytes + chunk_bytes + preimage
pub const SCRATCH_SIZE: usize = 98;

/// Number of columns used by the MIPS interpreter to keep values to be
/// inverted.
pub const SCRATCH_SIZE_INVERSE: usize = 12;

#[derive(Clone, Default)]
pub struct SyscallEnv {
pub last_hint: Option<Vec<u8>>,
Expand Down Expand Up @@ -81,7 +88,9 @@ pub struct Env<Fp, PreImageOracle: PreImageOracleT> {
pub registers: Registers<u32>,
pub registers_write_index: Registers<u64>,
pub scratch_state_idx: usize,
pub scratch_state_idx_inverse: usize,
pub scratch_state: [Fp; SCRATCH_SIZE],
pub scratch_state_inverse: [Fp; SCRATCH_SIZE_INVERSE],
pub halt: bool,
pub syscall_env: SyscallEnv,
pub selector: usize,
Expand All @@ -106,6 +115,12 @@ impl<Fp: Field, PreImageOracle: PreImageOracleT> InterpreterEnv for Env<Fp, PreI
Column::ScratchState(scratch_idx)
}

fn alloc_scratch_inverse(&mut self) -> Self::Position {
let scratch_idx = self.scratch_state_idx_inverse;
self.scratch_state_idx_inverse += 1;
Column::ScratchStateInverse(scratch_idx)
}

type Variable = u64;

fn variable(&self, _column: Self::Position) -> Self::Variable {
Expand Down Expand Up @@ -314,17 +329,17 @@ impl<Fp: Field, PreImageOracle: PreImageOracleT> InterpreterEnv for Env<Fp, PreI

fn is_zero(&mut self, x: &Self::Variable) -> Self::Variable {
// write the result
let pos = self.alloc_scratch();
let res = if *x == 0 { 1 } else { 0 };
self.write_column(pos, res);
let res = {
let pos = self.alloc_scratch();
unsafe { self.test_zero(x, pos) }
};
// write the non deterministic advice inv_or_zero
let pos = self.alloc_scratch();
let inv_or_zero = if *x == 0 {
Fp::zero()
let pos = self.alloc_scratch_inverse();
if *x == 0 {
self.write_field_column(pos, Fp::zero());
} else {
Fp::inverse(&Fp::from(*x)).unwrap()
self.write_field_column(pos, Fp::from(*x));
};
self.write_field_column(pos, inv_or_zero);
// return the result
res
}
Expand All @@ -339,15 +354,11 @@ impl<Fp: Field, PreImageOracle: PreImageOracleT> InterpreterEnv for Env<Fp, PreI
self.write_column(pos, is_zero);
is_zero
};
let _to_zero_test_inv_or_zero = {
let pos = self.alloc_scratch();
let inv_or_zero = if to_zero_test == Fp::zero() {
Fp::zero()
} else {
Fp::inverse(&to_zero_test).unwrap()
};
self.write_field_column(pos, inv_or_zero);
1 // Placeholder value
let pos = self.alloc_scratch_inverse();
if to_zero_test == Fp::zero() {
self.write_field_column(pos, Fp::zero());
} else {
self.write_field_column(pos, to_zero_test);
};
res
}
Expand Down Expand Up @@ -878,7 +889,9 @@ impl<Fp: Field, PreImageOracle: PreImageOracleT> Env<Fp, PreImageOracle> {
registers: initial_registers.clone(),
registers_write_index: Registers::default(),
scratch_state_idx: 0,
scratch_state_idx_inverse: 0,
scratch_state: fresh_scratch_state(),
scratch_state_inverse: fresh_scratch_state(),
halt: state.exited,
syscall_env,
selector,
Expand All @@ -897,13 +910,19 @@ impl<Fp: Field, PreImageOracle: PreImageOracleT> Env<Fp, PreImageOracle> {
self.selector = N_MIPS_SEL_COLS;
}

pub fn reset_scratch_state_inverse(&mut self) {
self.scratch_state_idx_inverse = 0;
self.scratch_state_inverse = fresh_scratch_state();
}

pub fn write_column(&mut self, column: Column, value: u64) {
self.write_field_column(column, value.into())
}

pub fn write_field_column(&mut self, column: Column, value: Fp) {
match column {
Column::ScratchState(idx) => self.scratch_state[idx] = value,
Column::ScratchStateInverse(idx) => self.scratch_state_inverse[idx] = value,
Column::InstructionCounter => panic!("Cannot overwrite the column {:?}", column),
Column::Selector(s) => self.selector = s,
}
Expand Down Expand Up @@ -1138,6 +1157,7 @@ impl<Fp: Field, PreImageOracle: PreImageOracleT> Env<Fp, PreImageOracle> {
start: &Start,
) -> Instruction {
self.reset_scratch_state();
self.reset_scratch_state_inverse();
let (opcode, _instruction) = self.decode_instruction();

self.pp_info(&config.info_at, metadata, start);
Expand Down
28 changes: 20 additions & 8 deletions o1vm/src/pickles/column_env.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@ use ark_poly::{Evaluations, Radix2EvaluationDomain};
use kimchi_msm::columns::Column;

use crate::{
interpreters::mips::{column::N_MIPS_SEL_COLS, witness::SCRATCH_SIZE},
interpreters::mips::{
column::N_MIPS_SEL_COLS,
witness::{SCRATCH_SIZE, SCRATCH_SIZE_INVERSE},
},
pickles::proof::WitnessColumns,
};
use kimchi::circuits::{
Expand Down Expand Up @@ -36,8 +39,9 @@ pub struct ColumnEnvironment<'a, F: FftField> {
}

pub fn get_all_columns() -> Vec<Column> {
let mut cols = Vec::<Column>::with_capacity(SCRATCH_SIZE + 2 + N_MIPS_SEL_COLS);
for i in 0..SCRATCH_SIZE + 2 {
let mut cols =
Vec::<Column>::with_capacity(SCRATCH_SIZE + SCRATCH_SIZE_INVERSE + 2 + N_MIPS_SEL_COLS);
for i in 0..SCRATCH_SIZE + SCRATCH_SIZE_INVERSE + 2 {
cols.push(Column::Relation(i));
}
for i in 0..N_MIPS_SEL_COLS {
Expand All @@ -53,26 +57,34 @@ impl<G> WitnessColumns<G, [G; N_MIPS_SEL_COLS]> {
if i < SCRATCH_SIZE {
let res = &self.scratch[i];
Some(res)
} else if i == SCRATCH_SIZE {
} else if i < SCRATCH_SIZE + SCRATCH_SIZE_INVERSE {
let res = &self.scratch_inverse[i - SCRATCH_SIZE];
Some(res)
} else if i == SCRATCH_SIZE + SCRATCH_SIZE_INVERSE {
let res = &self.instruction_counter;
Some(res)
} else if i == SCRATCH_SIZE + 1 {
} else if i == SCRATCH_SIZE + SCRATCH_SIZE_INVERSE + 1 {
let res = &self.error;
Some(res)
} else {
panic!("We should not have that many relation columns");
panic!("We should not have that many relation columns. We have {} columns and index {} was given", SCRATCH_SIZE + SCRATCH_SIZE_INVERSE + 2, i);
}
}
Column::DynamicSelector(i) => {
assert!(
i < N_MIPS_SEL_COLS,
"We do not have that many dynamic selector columns"
"We do not have that many dynamic selector columns. We have {} columns and index {} was given",
N_MIPS_SEL_COLS,
i
);
let res = &self.selector[i];
Some(res)
}
_ => {
panic!("We should not have any other type of columns")
panic!(
"We should not have any other type of columns. The column {:?} was given",
col
);
}
}
}
Expand Down
7 changes: 7 additions & 0 deletions o1vm/src/pickles/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,13 @@ pub fn main() -> ExitCode {
{
scratch_chunk.push(*scratch);
}
for (scratch, scratch_chunk) in mips_wit_env
.scratch_state_inverse
.iter()
.zip(curr_proof_inputs.evaluations.scratch_inverse.iter_mut())
{
scratch_chunk.push(*scratch);
}
curr_proof_inputs
.evaluations
.instruction_counter
Expand Down
7 changes: 6 additions & 1 deletion o1vm/src/pickles/proof.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
use kimchi::{curve::KimchiCurve, proof::PointEvaluations};
use poly_commitment::{ipa::OpeningProof, PolyComm};

use crate::interpreters::mips::{column::N_MIPS_SEL_COLS, witness::SCRATCH_SIZE};
use crate::interpreters::mips::{
column::N_MIPS_SEL_COLS,
witness::{SCRATCH_SIZE, SCRATCH_SIZE_INVERSE},
};

pub struct WitnessColumns<G, S> {
pub scratch: [G; SCRATCH_SIZE],
pub scratch_inverse: [G; SCRATCH_SIZE_INVERSE],
pub instruction_counter: G,
pub error: G,
pub selector: S,
Expand All @@ -19,6 +23,7 @@ impl<G: KimchiCurve> ProofInputs<G> {
ProofInputs {
evaluations: WitnessColumns {
scratch: std::array::from_fn(|_| Vec::with_capacity(domain_size)),
scratch_inverse: std::array::from_fn(|_| Vec::with_capacity(domain_size)),
instruction_counter: Vec::with_capacity(domain_size),
error: Vec::with_capacity(domain_size),
selector: Vec::with_capacity(domain_size),
Expand Down
Loading

0 comments on commit 1f1ceb8

Please sign in to comment.