diff --git a/crates/brainfuck_prover/src/components/instruction/table.rs b/crates/brainfuck_prover/src/components/instruction/table.rs index 6b05180..dd38fc5 100644 --- a/crates/brainfuck_prover/src/components/instruction/table.rs +++ b/crates/brainfuck_prover/src/components/instruction/table.rs @@ -1,9 +1,9 @@ -use brainfuck_vm::{machine::ProgramMemory, registers::Registers}; +use brainfuck_vm::{ + instruction::VALID_INSTRUCTIONS_BF, machine::ProgramMemory, registers::Registers, +}; use num_traits::Zero; use stwo_prover::core::fields::m31::BaseField; -use crate::utils::VALID_INSTRUCTIONS; - /// Represents a single row in the Instruction Table. /// /// The Instruction Table stores: @@ -94,7 +94,7 @@ impl From<(Vec, &ProgramMemory)> for InstructionTable { let code = program_memory.code(); for (index, ci) in code.iter().enumerate() { - if !VALID_INSTRUCTIONS.contains(ci) { + if !VALID_INSTRUCTIONS_BF.contains(ci) { continue; } @@ -128,12 +128,9 @@ impl From<(Vec, &ProgramMemory)> for InstructionTable { #[cfg(test)] mod tests { use super::*; - use crate::utils::{ - DECREMENT_INSTRUCTION_BF, INCREMENT_INSTRUCTION_BF, INPUT_INSTRUCTION_BF, - JUMP_IF_NON_ZERO_INSTRUCTION_BF, JUMP_IF_ZERO_INSTRUCTION_BF, OUTPUT_INSTRUCTION_BF, - SHIFT_LEFT_INSTRUCTION_BF, SHIFT_RIGHT_INSTRUCTION_BF, + use brainfuck_vm::{ + compiler::Compiler, instruction::InstructionType, test_helper::create_test_machine, }; - use brainfuck_vm::{compiler::Compiler, test_helper::create_test_machine}; use num_traits::{One, Zero}; #[test] @@ -218,60 +215,60 @@ mod tests { // Create the expected `InstructionTable` let ins_0 = InstructionTableRow { ip: BaseField::zero(), - ci: INCREMENT_INSTRUCTION_BF, - ni: SHIFT_RIGHT_INSTRUCTION_BF, + ci: InstructionType::Plus.to_base_field(), + ni: InstructionType::Right.to_base_field(), }; let ins_1 = InstructionTableRow { ip: BaseField::one(), - ci: SHIFT_RIGHT_INSTRUCTION_BF, - ni: INPUT_INSTRUCTION_BF, + ci: InstructionType::Right.to_base_field(), + ni: InstructionType::ReadChar.to_base_field(), }; let ins_2 = InstructionTableRow { ip: BaseField::from(2), - ci: INPUT_INSTRUCTION_BF, - ni: SHIFT_LEFT_INSTRUCTION_BF, + ci: InstructionType::ReadChar.to_base_field(), + ni: InstructionType::Left.to_base_field(), }; let ins_3 = InstructionTableRow { ip: BaseField::from(3), - ci: SHIFT_LEFT_INSTRUCTION_BF, - ni: JUMP_IF_ZERO_INSTRUCTION_BF, + ci: InstructionType::Left.to_base_field(), + ni: InstructionType::JumpIfZero.to_base_field(), }; let ins_4 = InstructionTableRow { ip: BaseField::from(4), - ci: JUMP_IF_ZERO_INSTRUCTION_BF, + ci: InstructionType::JumpIfZero.to_base_field(), ni: BaseField::from(12), }; let ins_6 = InstructionTableRow { ip: BaseField::from(6), - ci: SHIFT_RIGHT_INSTRUCTION_BF, - ni: INCREMENT_INSTRUCTION_BF, + ci: InstructionType::Right.to_base_field(), + ni: InstructionType::Plus.to_base_field(), }; let ins_7 = InstructionTableRow { ip: BaseField::from(7), - ci: INCREMENT_INSTRUCTION_BF, - ni: OUTPUT_INSTRUCTION_BF, + ci: InstructionType::Plus.to_base_field(), + ni: InstructionType::PutChar.to_base_field(), }; let ins_8 = InstructionTableRow { ip: BaseField::from(8), - ci: OUTPUT_INSTRUCTION_BF, - ni: SHIFT_LEFT_INSTRUCTION_BF, + ci: InstructionType::PutChar.to_base_field(), + ni: InstructionType::Left.to_base_field(), }; let ins_9 = InstructionTableRow { ip: BaseField::from(9), - ci: SHIFT_LEFT_INSTRUCTION_BF, - ni: DECREMENT_INSTRUCTION_BF, + ci: InstructionType::Left.to_base_field(), + ni: InstructionType::Minus.to_base_field(), }; let inst_10 = InstructionTableRow { ip: BaseField::from(10), - ci: DECREMENT_INSTRUCTION_BF, - ni: JUMP_IF_NON_ZERO_INSTRUCTION_BF, + ci: InstructionType::Minus.to_base_field(), + ni: InstructionType::JumpIfNotZero.to_base_field(), }; let ins_11 = InstructionTableRow { ip: BaseField::from(11), - ci: JUMP_IF_NON_ZERO_INSTRUCTION_BF, + ci: InstructionType::JumpIfNotZero.to_base_field(), ni: BaseField::from(6), }; @@ -336,19 +333,19 @@ mod tests { let ins_0 = InstructionTableRow { ip: BaseField::zero(), - ci: JUMP_IF_ZERO_INSTRUCTION_BF, + ci: InstructionType::JumpIfZero.to_base_field(), ni: BaseField::from(4), }; let ins_2 = InstructionTableRow { ip: BaseField::from(2), - ci: DECREMENT_INSTRUCTION_BF, - ni: JUMP_IF_NON_ZERO_INSTRUCTION_BF, + ci: InstructionType::Minus.to_base_field(), + ni: InstructionType::JumpIfNotZero.to_base_field(), }; let ins_3 = InstructionTableRow { ip: BaseField::from(3), - ci: JUMP_IF_NON_ZERO_INSTRUCTION_BF, + ci: InstructionType::JumpIfNotZero.to_base_field(), ni: BaseField::from(2), }; diff --git a/crates/brainfuck_prover/src/components/io/table.rs b/crates/brainfuck_prover/src/components/io/table.rs index a152b0a..f680e05 100644 --- a/crates/brainfuck_prover/src/components/io/table.rs +++ b/crates/brainfuck_prover/src/components/io/table.rs @@ -1,5 +1,4 @@ -use crate::utils::{INPUT_INSTRUCTION, OUTPUT_INSTRUCTION}; -use brainfuck_vm::registers::Registers; +use brainfuck_vm::{instruction::InstructionType, registers::Registers}; use stwo_prover::core::fields::m31::BaseField; /// Represents a single row in the I/O Table. @@ -112,13 +111,13 @@ impl From> for IOTable { /// /// This table is made of the memory values (`mv` register) corresponding to /// inputs (when the current instruction `ci` equals ','). -pub type InputTable = IOTable; +pub type InputTable = IOTable<{ InstructionType::ReadChar.to_u32() }>; /// Output table (trace) for the Output component. /// /// This table is made of the memory values (`mv` register) corresponding to /// outputs (when the current instruction `ci` equals '.'). -pub type OutputTable = IOTable; +pub type OutputTable = IOTable<{ InstructionType::PutChar.to_u32() }>; #[cfg(test)] mod tests { @@ -206,12 +205,12 @@ mod tests { let reg1 = Registers::default(); let reg2 = Registers { mv: BaseField::one(), - ci: BaseField::from(INPUT_INSTRUCTION), + ci: BaseField::from(InstructionType::ReadChar.to_base_field()), ..Default::default() }; let reg3 = Registers { mv: BaseField::from(5), - ci: BaseField::from(OUTPUT_INSTRUCTION), + ci: BaseField::from(InstructionType::PutChar.to_base_field()), ..Default::default() }; let registers: Vec = vec![reg3, reg1, reg2]; @@ -230,12 +229,12 @@ mod tests { let reg1 = Registers::default(); let reg2 = Registers { mv: BaseField::one(), - ci: BaseField::from(INPUT_INSTRUCTION), + ci: BaseField::from(InstructionType::ReadChar.to_base_field()), ..Default::default() }; let reg3 = Registers { mv: BaseField::from(5), - ci: BaseField::from(OUTPUT_INSTRUCTION), + ci: BaseField::from(InstructionType::PutChar.to_base_field()), ..Default::default() }; let registers: Vec = vec![reg3, reg1, reg2]; diff --git a/crates/brainfuck_prover/src/lib.rs b/crates/brainfuck_prover/src/lib.rs index d75d437..8913003 100644 --- a/crates/brainfuck_prover/src/lib.rs +++ b/crates/brainfuck_prover/src/lib.rs @@ -1,3 +1,2 @@ pub mod brainfuck_air; pub mod components; -pub mod utils; diff --git a/crates/brainfuck_prover/src/utils/mod.rs b/crates/brainfuck_prover/src/utils/mod.rs deleted file mode 100644 index 3c333b6..0000000 --- a/crates/brainfuck_prover/src/utils/mod.rs +++ /dev/null @@ -1,36 +0,0 @@ -use stwo_prover::core::fields::m31::BaseField; - -pub const SHIFT_RIGHT_INSTRUCTION: u32 = b'>' as u32; -pub const SHIFT_LEFT_INSTRUCTION: u32 = b'<' as u32; -pub const INCREMENT_INSTRUCTION: u32 = b'+' as u32; -pub const DECREMENT_INSTRUCTION: u32 = b'-' as u32; -pub const INPUT_INSTRUCTION: u32 = b',' as u32; -pub const OUTPUT_INSTRUCTION: u32 = b'.' as u32; -pub const JUMP_IF_ZERO_INSTRUCTION: u32 = b'[' as u32; -pub const JUMP_IF_NON_ZERO_INSTRUCTION: u32 = b']' as u32; - -pub const SHIFT_RIGHT_INSTRUCTION_BF: BaseField = - BaseField::from_u32_unchecked(SHIFT_RIGHT_INSTRUCTION); -pub const SHIFT_LEFT_INSTRUCTION_BF: BaseField = - BaseField::from_u32_unchecked(SHIFT_LEFT_INSTRUCTION); -pub const INCREMENT_INSTRUCTION_BF: BaseField = - BaseField::from_u32_unchecked(INCREMENT_INSTRUCTION); -pub const DECREMENT_INSTRUCTION_BF: BaseField = - BaseField::from_u32_unchecked(DECREMENT_INSTRUCTION); -pub const INPUT_INSTRUCTION_BF: BaseField = BaseField::from_u32_unchecked(INPUT_INSTRUCTION); -pub const OUTPUT_INSTRUCTION_BF: BaseField = BaseField::from_u32_unchecked(OUTPUT_INSTRUCTION); -pub const JUMP_IF_ZERO_INSTRUCTION_BF: BaseField = - BaseField::from_u32_unchecked(JUMP_IF_ZERO_INSTRUCTION); -pub const JUMP_IF_NON_ZERO_INSTRUCTION_BF: BaseField = - BaseField::from_u32_unchecked(JUMP_IF_NON_ZERO_INSTRUCTION); - -pub const VALID_INSTRUCTIONS: [BaseField; 8] = [ - SHIFT_RIGHT_INSTRUCTION_BF, - SHIFT_LEFT_INSTRUCTION_BF, - INCREMENT_INSTRUCTION_BF, - DECREMENT_INSTRUCTION_BF, - INPUT_INSTRUCTION_BF, - OUTPUT_INSTRUCTION_BF, - JUMP_IF_ZERO_INSTRUCTION_BF, - JUMP_IF_NON_ZERO_INSTRUCTION_BF, -]; diff --git a/crates/brainfuck_vm/src/instruction.rs b/crates/brainfuck_vm/src/instruction.rs index b1457f9..54a1baa 100644 --- a/crates/brainfuck_vm/src/instruction.rs +++ b/crates/brainfuck_vm/src/instruction.rs @@ -1,6 +1,7 @@ // Taken from rkdud007 brainfuck-zkvm https://github.com/rkdud007/brainfuck-zkvm/blob/main/src/instruction.rs use std::{fmt::Display, str::FromStr}; +use stwo_prover::core::fields::m31::BaseField; use thiserror::Error; /// Custom error type for instructions @@ -59,6 +60,39 @@ impl FromStr for InstructionType { } } +impl InstructionType { + /// Convert an [`InstructionType`] to its corresponding u32 representation + pub const fn to_u32(&self) -> u32 { + match self { + Self::Right => b'>' as u32, + Self::Left => b'<' as u32, + Self::Plus => b'+' as u32, + Self::Minus => b'-' as u32, + Self::PutChar => b'.' as u32, + Self::ReadChar => b',' as u32, + Self::JumpIfZero => b'[' as u32, + Self::JumpIfNotZero => b']' as u32, + } + } + + /// Convert an [`InstructionType`] to a [`BaseField`] + pub const fn to_base_field(&self) -> BaseField { + BaseField::from_u32_unchecked(self.to_u32()) + } +} + +/// Define all valid instructions as [`BaseField`] values +pub const VALID_INSTRUCTIONS_BF: [BaseField; 8] = [ + InstructionType::Right.to_base_field(), + InstructionType::Left.to_base_field(), + InstructionType::Plus.to_base_field(), + InstructionType::Minus.to_base_field(), + InstructionType::PutChar.to_base_field(), + InstructionType::ReadChar.to_base_field(), + InstructionType::JumpIfZero.to_base_field(), + InstructionType::JumpIfNotZero.to_base_field(), +]; + impl Display for InstructionType { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let symbol = match self { @@ -143,6 +177,54 @@ mod tests { assert_eq!(result, Err(InstructionError::Conversion('x'))); } + #[test] + fn test_instruction_type_to_u32() { + assert_eq!(InstructionType::Right.to_u32(), b'>'.into()); + assert_eq!(InstructionType::Left.to_u32(), b'<'.into()); + assert_eq!(InstructionType::Plus.to_u32(), b'+'.into()); + assert_eq!(InstructionType::Minus.to_u32(), b'-'.into()); + assert_eq!(InstructionType::PutChar.to_u32(), b'.'.into()); + assert_eq!(InstructionType::ReadChar.to_u32(), b','.into()); + assert_eq!(InstructionType::JumpIfZero.to_u32(), b'['.into()); + assert_eq!(InstructionType::JumpIfNotZero.to_u32(), b']'.into()); + } + + #[test] + fn test_instruction_type_to_base_field() { + assert_eq!( + InstructionType::Right.to_base_field(), + BaseField::from_u32_unchecked(b'>'.into()) + ); + assert_eq!( + InstructionType::Left.to_base_field(), + BaseField::from_u32_unchecked(b'<'.into()) + ); + assert_eq!( + InstructionType::Plus.to_base_field(), + BaseField::from_u32_unchecked(b'+'.into()) + ); + assert_eq!( + InstructionType::Minus.to_base_field(), + BaseField::from_u32_unchecked(b'-'.into()) + ); + assert_eq!( + InstructionType::PutChar.to_base_field(), + BaseField::from_u32_unchecked(b'.'.into()) + ); + assert_eq!( + InstructionType::ReadChar.to_base_field(), + BaseField::from_u32_unchecked(b','.into()) + ); + assert_eq!( + InstructionType::JumpIfZero.to_base_field(), + BaseField::from_u32_unchecked(b'['.into()) + ); + assert_eq!( + InstructionType::JumpIfNotZero.to_base_field(), + BaseField::from_u32_unchecked(b']'.into()) + ); + } + // Test Instruction struct creation #[test] fn test_instruction_creation() {