From 5a57be728165fdb5954a86dd326a8f752b83e3c4 Mon Sep 17 00:00:00 2001 From: "Chris T." Date: Wed, 30 Oct 2024 10:16:36 -0700 Subject: [PATCH] feat: executor optimizations (#1712) --- Cargo.lock | 1 + crates/core/executor/src/dependencies.rs | 119 +++-- crates/core/executor/src/disassembler/rrs.rs | 57 +- crates/core/executor/src/events/alu.rs | 4 +- crates/core/executor/src/events/byte.rs | 2 +- crates/core/executor/src/events/cpu.rs | 8 - crates/core/executor/src/events/memory.rs | 8 +- crates/core/executor/src/events/utils.rs | 37 +- crates/core/executor/src/executor.rs | 486 ++++++++---------- crates/core/executor/src/instruction.rs | 4 +- crates/core/executor/src/memory.rs | 27 +- crates/core/executor/src/opcode.rs | 2 +- crates/core/executor/src/program.rs | 7 + crates/core/executor/src/record.rs | 71 ++- crates/core/executor/src/register.rs | 4 +- .../src/syscalls/precompiles/sha256/extend.rs | 10 +- crates/core/machine/Cargo.toml | 1 + crates/core/machine/src/alu/add_sub/mod.rs | 66 +-- crates/core/machine/src/alu/divrem/mod.rs | 28 +- crates/core/machine/src/alu/lt/mod.rs | 51 +- crates/core/machine/src/alu/mul/mod.rs | 283 +++++----- crates/core/machine/src/alu/sr/mod.rs | 73 +-- .../machine/src/cpu/columns/instruction.rs | 6 +- crates/core/machine/src/cpu/columns/opcode.rs | 2 +- crates/core/machine/src/cpu/trace.rs | 213 ++++---- crates/core/machine/src/memory/local.rs | 76 +-- crates/core/machine/src/memory/program.rs | 61 ++- crates/core/machine/src/program/mod.rs | 58 ++- crates/core/machine/src/runtime/utils.rs | 8 +- .../precompiles/edwards/ed_decompress.rs | 2 +- crates/prover/src/lib.rs | 8 +- crates/prover/src/shapes.rs | 21 +- .../recursion/circuit/src/machine/deferred.rs | 2 +- .../compiler/src/circuit/compiler.rs | 5 +- .../core/src/chips/poseidon2_skinny/trace.rs | 4 + .../core/src/chips/poseidon2_wide/trace.rs | 4 + crates/recursion/core/src/machine.rs | 16 +- crates/recursion/core/src/runtime/memory.rs | 2 +- crates/recursion/core/src/runtime/mod.rs | 16 + 39 files changed, 924 insertions(+), 929 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 22a73f10b5..748f845ec7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6215,6 +6215,7 @@ dependencies = [ "tracing-forest", "tracing-subscriber", "typenum", + "vec_map", "web-time", ] diff --git a/crates/core/executor/src/dependencies.rs b/crates/core/executor/src/dependencies.rs index 194d8d0eb2..6b0bbbe33c 100644 --- a/crates/core/executor/src/dependencies.rs +++ b/crates/core/executor/src/dependencies.rs @@ -1,5 +1,5 @@ use crate::{ - events::{create_alu_lookups, AluEvent, CpuEvent}, + events::AluEvent, utils::{get_msb, get_quotient_and_remainder, is_signed_operation}, Executor, Opcode, }; @@ -7,6 +7,7 @@ use crate::{ /// Emits the dependencies for division and remainder operations. #[allow(clippy::too_many_lines)] pub fn emit_divrem_dependencies(executor: &mut Executor, event: AluEvent) { + let shard = executor.shard(); let (quotient, remainder) = get_quotient_and_remainder(event.b, event.c, event.opcode); let c_msb = get_msb(event.c); let rem_msb = get_msb(remainder); @@ -19,27 +20,29 @@ pub fn emit_divrem_dependencies(executor: &mut Executor, event: AluEvent) { } if c_neg == 1 { + let ids = executor.record.create_lookup_ids(); executor.record.add_events.push(AluEvent { lookup_id: event.sub_lookups[4], - shard: event.shard, + shard, clk: event.clk, opcode: Opcode::ADD, a: 0, b: event.c, c: (event.c as i32).unsigned_abs(), - sub_lookups: create_alu_lookups(), + sub_lookups: ids, }); } if rem_neg == 1 { + let ids = executor.record.create_lookup_ids(); executor.record.add_events.push(AluEvent { lookup_id: event.sub_lookups[5], - shard: event.shard, + shard, clk: event.clk, opcode: Opcode::ADD, a: 0, b: remainder, c: (remainder as i32).unsigned_abs(), - sub_lookups: create_alu_lookups(), + sub_lookups: ids, }); } @@ -55,19 +58,19 @@ pub fn emit_divrem_dependencies(executor: &mut Executor, event: AluEvent) { let lower_multiplication = AluEvent { lookup_id: event.sub_lookups[0], - shard: event.shard, + shard, clk: event.clk, opcode: Opcode::MUL, a: lower_word, c: event.c, b: quotient, - sub_lookups: create_alu_lookups(), + sub_lookups: executor.record.create_lookup_ids(), }; executor.record.mul_events.push(lower_multiplication); let upper_multiplication = AluEvent { lookup_id: event.sub_lookups[1], - shard: event.shard, + shard, clk: event.clk, opcode: { if is_signed_operation { @@ -79,31 +82,31 @@ pub fn emit_divrem_dependencies(executor: &mut Executor, event: AluEvent) { a: upper_word, c: event.c, b: quotient, - sub_lookups: create_alu_lookups(), + sub_lookups: executor.record.create_lookup_ids(), }; executor.record.mul_events.push(upper_multiplication); let lt_event = if is_signed_operation { AluEvent { lookup_id: event.sub_lookups[2], - shard: event.shard, + shard, opcode: Opcode::SLTU, a: 1, b: (remainder as i32).unsigned_abs(), c: u32::max(1, (event.c as i32).unsigned_abs()), clk: event.clk, - sub_lookups: create_alu_lookups(), + sub_lookups: executor.record.create_lookup_ids(), } } else { AluEvent { lookup_id: event.sub_lookups[3], - shard: event.shard, + shard, opcode: Opcode::SLTU, a: 1, b: remainder, c: u32::max(1, event.c), clk: event.clk, - sub_lookups: create_alu_lookups(), + sub_lookups: executor.record.create_lookup_ids(), } }; @@ -114,9 +117,12 @@ pub fn emit_divrem_dependencies(executor: &mut Executor, event: AluEvent) { /// Emit the dependencies for CPU events. #[allow(clippy::too_many_lines)] -pub fn emit_cpu_dependencies(executor: &mut Executor, event: &CpuEvent) { +pub fn emit_cpu_dependencies(executor: &mut Executor, index: usize) { + let event = executor.record.cpu_events[index]; + let shard = executor.shard(); + let instruction = &executor.program.fetch(event.pc); if matches!( - event.instruction.opcode, + instruction.opcode, Opcode::LB | Opcode::LH | Opcode::LW @@ -130,58 +136,57 @@ pub fn emit_cpu_dependencies(executor: &mut Executor, event: &CpuEvent) { // Add event to ALU check to check that addr == b + c let add_event = AluEvent { lookup_id: event.memory_add_lookup_id, - shard: event.shard, + shard, clk: event.clk, opcode: Opcode::ADD, a: memory_addr, b: event.b, c: event.c, - sub_lookups: create_alu_lookups(), + sub_lookups: executor.record.create_lookup_ids(), }; executor.record.add_events.push(add_event); let addr_offset = (memory_addr % 4_u32) as u8; let mem_value = event.memory_record.unwrap().value(); - if matches!(event.instruction.opcode, Opcode::LB | Opcode::LH) { - let (unsigned_mem_val, most_sig_mem_value_byte, sign_value) = - match event.instruction.opcode { - Opcode::LB => { - let most_sig_mem_value_byte = mem_value.to_le_bytes()[addr_offset as usize]; - let sign_value = 256; - (most_sig_mem_value_byte as u32, most_sig_mem_value_byte, sign_value) - } - Opcode::LH => { - let sign_value = 65536; - let unsigned_mem_val = match (addr_offset >> 1) % 2 { - 0 => mem_value & 0x0000FFFF, - 1 => (mem_value & 0xFFFF0000) >> 16, - _ => unreachable!(), - }; - let most_sig_mem_value_byte = unsigned_mem_val.to_le_bytes()[1]; - (unsigned_mem_val, most_sig_mem_value_byte, sign_value) - } - _ => unreachable!(), - }; + if matches!(instruction.opcode, Opcode::LB | Opcode::LH) { + let (unsigned_mem_val, most_sig_mem_value_byte, sign_value) = match instruction.opcode { + Opcode::LB => { + let most_sig_mem_value_byte = mem_value.to_le_bytes()[addr_offset as usize]; + let sign_value = 256; + (most_sig_mem_value_byte as u32, most_sig_mem_value_byte, sign_value) + } + Opcode::LH => { + let sign_value = 65536; + let unsigned_mem_val = match (addr_offset >> 1) % 2 { + 0 => mem_value & 0x0000FFFF, + 1 => (mem_value & 0xFFFF0000) >> 16, + _ => unreachable!(), + }; + let most_sig_mem_value_byte = unsigned_mem_val.to_le_bytes()[1]; + (unsigned_mem_val, most_sig_mem_value_byte, sign_value) + } + _ => unreachable!(), + }; if most_sig_mem_value_byte >> 7 & 0x01 == 1 { let sub_event = AluEvent { lookup_id: event.memory_sub_lookup_id, - shard: event.shard, + shard, clk: event.clk, opcode: Opcode::SUB, a: event.a, b: unsigned_mem_val, c: sign_value, - sub_lookups: create_alu_lookups(), + sub_lookups: executor.record.create_lookup_ids(), }; executor.record.add_events.push(sub_event); } } } - if event.instruction.is_branch_instruction() { + if instruction.is_branch_instruction() { let a_eq_b = event.a == event.b; - let use_signed_comparison = matches!(event.instruction.opcode, Opcode::BLT | Opcode::BGE); + let use_signed_comparison = matches!(instruction.opcode, Opcode::BLT | Opcode::BGE); let a_lt_b = if use_signed_comparison { (event.a as i32) < (event.b as i32) } else { @@ -197,27 +202,27 @@ pub fn emit_cpu_dependencies(executor: &mut Executor, event: &CpuEvent) { // Add the ALU events for the comparisons let lt_comp_event = AluEvent { lookup_id: event.branch_lt_lookup_id, - shard: event.shard, + shard, clk: event.clk, opcode: alu_op_code, a: a_lt_b as u32, b: event.a, c: event.b, - sub_lookups: create_alu_lookups(), + sub_lookups: executor.record.create_lookup_ids(), }; let gt_comp_event = AluEvent { lookup_id: event.branch_gt_lookup_id, - shard: event.shard, + shard, clk: event.clk, opcode: alu_op_code, a: a_gt_b as u32, b: event.b, c: event.a, - sub_lookups: create_alu_lookups(), + sub_lookups: executor.record.create_lookup_ids(), }; executor.record.lt_events.push(lt_comp_event); executor.record.lt_events.push(gt_comp_event); - let branching = match event.instruction.opcode { + let branching = match instruction.opcode { Opcode::BEQ => a_eq_b, Opcode::BNE => !a_eq_b, Opcode::BLT | Opcode::BLTU => a_lt_b, @@ -228,31 +233,31 @@ pub fn emit_cpu_dependencies(executor: &mut Executor, event: &CpuEvent) { let next_pc = event.pc.wrapping_add(event.c); let add_event = AluEvent { lookup_id: event.branch_add_lookup_id, - shard: event.shard, + shard, clk: event.clk, opcode: Opcode::ADD, a: next_pc, b: event.pc, c: event.c, - sub_lookups: create_alu_lookups(), + sub_lookups: executor.record.create_lookup_ids(), }; executor.record.add_events.push(add_event); } } - if event.instruction.is_jump_instruction() { - match event.instruction.opcode { + if instruction.is_jump_instruction() { + match instruction.opcode { Opcode::JAL => { let next_pc = event.pc.wrapping_add(event.b); let add_event = AluEvent { lookup_id: event.jump_jal_lookup_id, - shard: event.shard, + shard, clk: event.clk, opcode: Opcode::ADD, a: next_pc, b: event.pc, c: event.b, - sub_lookups: create_alu_lookups(), + sub_lookups: executor.record.create_lookup_ids(), }; executor.record.add_events.push(add_event); } @@ -260,13 +265,13 @@ pub fn emit_cpu_dependencies(executor: &mut Executor, event: &CpuEvent) { let next_pc = event.b.wrapping_add(event.c); let add_event = AluEvent { lookup_id: event.jump_jalr_lookup_id, - shard: event.shard, + shard, clk: event.clk, opcode: Opcode::ADD, a: next_pc, b: event.b, c: event.c, - sub_lookups: create_alu_lookups(), + sub_lookups: executor.record.create_lookup_ids(), }; executor.record.add_events.push(add_event); } @@ -274,16 +279,16 @@ pub fn emit_cpu_dependencies(executor: &mut Executor, event: &CpuEvent) { } } - if matches!(event.instruction.opcode, Opcode::AUIPC) { + if matches!(instruction.opcode, Opcode::AUIPC) { let add_event = AluEvent { lookup_id: event.auipc_lookup_id, - shard: event.shard, + shard, clk: event.clk, opcode: Opcode::ADD, a: event.a, b: event.pc, c: event.b, - sub_lookups: create_alu_lookups(), + sub_lookups: executor.record.create_lookup_ids(), }; executor.record.add_events.push(add_event); } diff --git a/crates/core/executor/src/disassembler/rrs.rs b/crates/core/executor/src/disassembler/rrs.rs index 711dbe4e4a..a105e10a81 100644 --- a/crates/core/executor/src/disassembler/rrs.rs +++ b/crates/core/executor/src/disassembler/rrs.rs @@ -9,52 +9,31 @@ impl Instruction { /// Create a new [`Instruction`] from an R-type instruction. #[must_use] pub const fn from_r_type(opcode: Opcode, dec_insn: &RType) -> Self { - Self::new( - opcode, - dec_insn.rd as u32, - dec_insn.rs1 as u32, - dec_insn.rs2 as u32, - false, - false, - ) + Self::new(opcode, dec_insn.rd as u8, dec_insn.rs1 as u32, dec_insn.rs2 as u32, false, false) } /// Create a new [`Instruction`] from an I-type instruction. #[must_use] pub const fn from_i_type(opcode: Opcode, dec_insn: &IType) -> Self { - Self::new(opcode, dec_insn.rd as u32, dec_insn.rs1 as u32, dec_insn.imm as u32, false, true) + Self::new(opcode, dec_insn.rd as u8, dec_insn.rs1 as u32, dec_insn.imm as u32, false, true) } /// Create a new [`Instruction`] from an I-type instruction with a shamt. #[must_use] pub const fn from_i_type_shamt(opcode: Opcode, dec_insn: &ITypeShamt) -> Self { - Self::new(opcode, dec_insn.rd as u32, dec_insn.rs1 as u32, dec_insn.shamt, false, true) + Self::new(opcode, dec_insn.rd as u8, dec_insn.rs1 as u32, dec_insn.shamt, false, true) } /// Create a new [`Instruction`] from an S-type instruction. #[must_use] pub const fn from_s_type(opcode: Opcode, dec_insn: &SType) -> Self { - Self::new( - opcode, - dec_insn.rs2 as u32, - dec_insn.rs1 as u32, - dec_insn.imm as u32, - false, - true, - ) + Self::new(opcode, dec_insn.rs2 as u8, dec_insn.rs1 as u32, dec_insn.imm as u32, false, true) } /// Create a new [`Instruction`] from a B-type instruction. #[must_use] pub const fn from_b_type(opcode: Opcode, dec_insn: &BType) -> Self { - Self::new( - opcode, - dec_insn.rs1 as u32, - dec_insn.rs2 as u32, - dec_insn.imm as u32, - false, - true, - ) + Self::new(opcode, dec_insn.rs1 as u8, dec_insn.rs2 as u32, dec_insn.imm as u32, false, true) } /// Create a new [`Instruction`] that is not implemented. @@ -82,9 +61,9 @@ impl Instruction { #[must_use] pub fn r_type(&self) -> (Register, Register, Register) { ( - Register::from_u32(self.op_a), - Register::from_u32(self.op_b), - Register::from_u32(self.op_c), + Register::from_u8(self.op_a), + Register::from_u8(self.op_b as u8), + Register::from_u8(self.op_c as u8), ) } @@ -92,35 +71,35 @@ impl Instruction { #[inline] #[must_use] pub fn i_type(&self) -> (Register, Register, u32) { - (Register::from_u32(self.op_a), Register::from_u32(self.op_b), self.op_c) + (Register::from_u8(self.op_a), Register::from_u8(self.op_b as u8), self.op_c) } /// Decode the [`Instruction`] in the S-type format. #[inline] #[must_use] pub fn s_type(&self) -> (Register, Register, u32) { - (Register::from_u32(self.op_a), Register::from_u32(self.op_b), self.op_c) + (Register::from_u8(self.op_a), Register::from_u8(self.op_b as u8), self.op_c) } /// Decode the [`Instruction`] in the B-type format. #[inline] #[must_use] pub fn b_type(&self) -> (Register, Register, u32) { - (Register::from_u32(self.op_a), Register::from_u32(self.op_b), self.op_c) + (Register::from_u8(self.op_a), Register::from_u8(self.op_b as u8), self.op_c) } /// Decode the [`Instruction`] in the J-type format. #[inline] #[must_use] pub fn j_type(&self) -> (Register, u32) { - (Register::from_u32(self.op_a), self.op_b) + (Register::from_u8(self.op_a), self.op_b) } /// Decode the [`Instruction`] in the U-type format. #[inline] #[must_use] pub fn u_type(&self) -> (Register, u32) { - (Register::from_u32(self.op_a), self.op_b) + (Register::from_u8(self.op_a), self.op_b) } } @@ -263,13 +242,13 @@ impl InstructionProcessor for InstructionTranspiler { } fn process_jal(&mut self, dec_insn: JType) -> Self::InstructionResult { - Instruction::new(Opcode::JAL, dec_insn.rd as u32, dec_insn.imm as u32, 0, true, true) + Instruction::new(Opcode::JAL, dec_insn.rd as u8, dec_insn.imm as u32, 0, true, true) } fn process_jalr(&mut self, dec_insn: IType) -> Self::InstructionResult { Instruction::new( Opcode::JALR, - dec_insn.rd as u32, + dec_insn.rd as u8, dec_insn.rs1 as u32, dec_insn.imm as u32, false, @@ -282,14 +261,14 @@ impl InstructionProcessor for InstructionTranspiler { // // Notably, LUI instructions are converted to an SLL instruction with `imm_b` and `imm_c` // turned on. Additionally the `op_c` should be set to 12. - Instruction::new(Opcode::ADD, dec_insn.rd as u32, 0, dec_insn.imm as u32, true, true) + Instruction::new(Opcode::ADD, dec_insn.rd as u8, 0, dec_insn.imm as u32, true, true) } /// AUIPC instructions have the third operand set to imm << 12. fn process_auipc(&mut self, dec_insn: UType) -> Self::InstructionResult { Instruction::new( Opcode::AUIPC, - dec_insn.rd as u32, + dec_insn.rd as u8, dec_insn.imm as u32, dec_insn.imm as u32, true, @@ -300,7 +279,7 @@ impl InstructionProcessor for InstructionTranspiler { fn process_ecall(&mut self) -> Self::InstructionResult { Instruction::new( Opcode::ECALL, - Register::X5 as u32, + Register::X5 as u8, Register::X10 as u32, Register::X11 as u32, false, diff --git a/crates/core/executor/src/events/alu.rs b/crates/core/executor/src/events/alu.rs index bf79a65e4c..2d2b14fe03 100644 --- a/crates/core/executor/src/events/alu.rs +++ b/crates/core/executor/src/events/alu.rs @@ -2,7 +2,7 @@ use serde::{Deserialize, Serialize}; use crate::Opcode; -use super::{create_alu_lookups, LookupId}; +use super::{create_random_lookup_ids, LookupId}; /// Arithmetic Logic Unit (ALU) Event. /// @@ -40,7 +40,7 @@ impl AluEvent { a, b, c, - sub_lookups: create_alu_lookups(), + sub_lookups: create_random_lookup_ids(), } } } diff --git a/crates/core/executor/src/events/byte.rs b/crates/core/executor/src/events/byte.rs index 4e5f254373..3db5c7647b 100644 --- a/crates/core/executor/src/events/byte.rs +++ b/crates/core/executor/src/events/byte.rs @@ -233,7 +233,7 @@ impl ByteOpcode { ByteOpcode::MSB, ByteOpcode::U16Range, ]; - assert_eq!(opcodes.len(), NUM_BYTE_OPS); + debug_assert_eq!(opcodes.len(), NUM_BYTE_OPS); opcodes } diff --git a/crates/core/executor/src/events/cpu.rs b/crates/core/executor/src/events/cpu.rs index b2d775cf12..f609e941c0 100644 --- a/crates/core/executor/src/events/cpu.rs +++ b/crates/core/executor/src/events/cpu.rs @@ -1,7 +1,5 @@ use serde::{Deserialize, Serialize}; -use crate::Instruction; - use super::{memory::MemoryRecordEnum, LookupId}; /// CPU Event. @@ -10,16 +8,12 @@ use super::{memory::MemoryRecordEnum, LookupId}; /// shard, opcode, operands, and other relevant information. #[derive(Debug, Copy, Clone, Serialize, Deserialize)] pub struct CpuEvent { - /// The shard number. - pub shard: u32, /// The clock cycle. pub clk: u32, /// The program counter. pub pc: u32, /// The next program counter. pub next_pc: u32, - /// The instruction. - pub instruction: Instruction, /// The first operand. pub a: u32, /// The first operand memory record. @@ -32,8 +26,6 @@ pub struct CpuEvent { pub c: u32, /// The third operand memory record. pub c_record: Option, - /// The memory value. - pub memory: Option, /// The memory record. pub memory_record: Option, /// The exit code. diff --git a/crates/core/executor/src/events/memory.rs b/crates/core/executor/src/events/memory.rs index 4372f21267..655e0fc21d 100644 --- a/crates/core/executor/src/events/memory.rs +++ b/crates/core/executor/src/events/memory.rs @@ -150,7 +150,9 @@ impl MemoryReadRecord { prev_shard: u32, prev_timestamp: u32, ) -> Self { - assert!(shard > prev_shard || ((shard == prev_shard) && (timestamp > prev_timestamp))); + debug_assert!( + shard > prev_shard || ((shard == prev_shard) && (timestamp > prev_timestamp)) + ); Self { value, shard, timestamp, prev_shard, prev_timestamp } } } @@ -166,7 +168,9 @@ impl MemoryWriteRecord { prev_shard: u32, prev_timestamp: u32, ) -> Self { - assert!(shard > prev_shard || ((shard == prev_shard) && (timestamp > prev_timestamp)),); + debug_assert!( + shard > prev_shard || ((shard == prev_shard) && (timestamp > prev_timestamp)), + ); Self { value, shard, timestamp, prev_value, prev_shard, prev_timestamp } } } diff --git a/crates/core/executor/src/events/utils.rs b/crates/core/executor/src/events/utils.rs index 681bc6cc78..d4b38df745 100644 --- a/crates/core/executor/src/events/utils.rs +++ b/crates/core/executor/src/events/utils.rs @@ -5,43 +5,16 @@ use std::{ iter::{Map, Peekable}, }; -use rand::{thread_rng, Rng}; - /// A unique identifier for lookups. -/// -/// We use 4 u32s instead of a u128 to make it compatible with C. #[derive(Deserialize, Serialize, Debug, Clone, Copy, Default, Eq, Hash, PartialEq)] -pub struct LookupId { - /// First part of the id. - pub a: u32, - /// Second part of the id. - pub b: u32, - /// Third part of the id. - pub c: u32, - /// Fourth part of the id. - pub d: u32, -} - -/// Creates a new ALU lookup id with ``LookupId`` -#[must_use] -pub fn create_alu_lookup_id() -> LookupId { - let mut rng = thread_rng(); - LookupId { a: rng.gen(), b: rng.gen(), c: rng.gen(), d: rng.gen() } -} +pub struct LookupId(pub u64); -/// Creates a new ALU lookup id with ``LookupId`` +/// Create a random lookup id. This is slower than `record.create_lookup_id()` but is useful for +/// testing. #[must_use] -pub fn create_alu_lookups() -> [LookupId; 6] { - let mut rng = thread_rng(); - [ - LookupId { a: rng.gen(), b: rng.gen(), c: rng.gen(), d: rng.gen() }, - LookupId { a: rng.gen(), b: rng.gen(), c: rng.gen(), d: rng.gen() }, - LookupId { a: rng.gen(), b: rng.gen(), c: rng.gen(), d: rng.gen() }, - LookupId { a: rng.gen(), b: rng.gen(), c: rng.gen(), d: rng.gen() }, - LookupId { a: rng.gen(), b: rng.gen(), c: rng.gen(), d: rng.gen() }, - LookupId { a: rng.gen(), b: rng.gen(), c: rng.gen(), d: rng.gen() }, - ] +pub(crate) fn create_random_lookup_ids() -> [LookupId; 6] { + std::array::from_fn(|_| LookupId(rand::random())) } /// Returns sorted and formatted rows of a table of counts (e.g. `opcode_counts`). diff --git a/crates/core/executor/src/executor.rs b/crates/core/executor/src/executor.rs index cfbc8cacd5..7244955a37 100644 --- a/crates/core/executor/src/executor.rs +++ b/crates/core/executor/src/executor.rs @@ -13,9 +13,8 @@ use crate::{ context::SP1Context, dependencies::{emit_cpu_dependencies, emit_divrem_dependencies}, events::{ - create_alu_lookup_id, create_alu_lookups, AluEvent, CpuEvent, LookupId, - MemoryAccessPosition, MemoryInitializeFinalizeEvent, MemoryLocalEvent, MemoryReadRecord, - MemoryRecord, MemoryWriteRecord, SyscallEvent, + AluEvent, CpuEvent, LookupId, MemoryAccessPosition, MemoryInitializeFinalizeEvent, + MemoryLocalEvent, MemoryReadRecord, MemoryRecord, MemoryWriteRecord, SyscallEvent, }, hook::{HookEnv, HookRegistry}, memory::{Entry, PagedMemory}, @@ -269,7 +268,7 @@ impl<'a> Executor<'a> { pub fn registers(&mut self) -> [u32; 32] { let mut registers = [0; 32]; for i in 0..32 { - let addr = Register::from_u32(i as u32) as u32; + let addr = Register::from_u8(i as u8) as u32; let record = self.state.memory.get(addr); // Only add the previous memory state to checkpoint map if we're in checkpoint mode, @@ -407,7 +406,7 @@ impl<'a> Executor<'a> { record.shard = shard; record.timestamp = timestamp; - if !self.unconstrained { + if !self.unconstrained && self.executor_mode == ExecutorMode::Trace { let local_memory_access = if let Some(local_memory_access) = local_memory_access { local_memory_access } else { @@ -486,7 +485,7 @@ impl<'a> Executor<'a> { record.shard = shard; record.timestamp = timestamp; - if !self.unconstrained { + if !self.unconstrained && self.executor_mode == ExecutorMode::Trace { let local_memory_access = if let Some(local_memory_access) = local_memory_access { local_memory_access } else { @@ -553,19 +552,19 @@ impl<'a> Executor<'a> { if !self.unconstrained && self.executor_mode == ExecutorMode::Trace { match position { MemoryAccessPosition::A => { - assert!(self.memory_accesses.a.is_none()); + debug_assert!(self.memory_accesses.a.is_none()); self.memory_accesses.a = Some(record.into()); } MemoryAccessPosition::B => { - assert!(self.memory_accesses.b.is_none()); + debug_assert!(self.memory_accesses.b.is_none()); self.memory_accesses.b = Some(record.into()); } MemoryAccessPosition::C => { - assert!(self.memory_accesses.c.is_none()); + debug_assert!(self.memory_accesses.c.is_none()); self.memory_accesses.c = Some(record.into()); } MemoryAccessPosition::Memory => { - assert!(self.memory_accesses.memory.is_none()); + debug_assert!(self.memory_accesses.memory.is_none()); self.memory_accesses.memory = Some(record.into()); } } @@ -593,49 +592,50 @@ impl<'a> Executor<'a> { #[allow(clippy::too_many_arguments)] fn emit_cpu( &mut self, - shard: u32, clk: u32, pc: u32, next_pc: u32, - instruction: Instruction, a: u32, b: u32, c: u32, - memory_store_value: Option, record: MemoryAccessRecord, exit_code: u32, lookup_id: LookupId, syscall_lookup_id: LookupId, ) { - let cpu_event = CpuEvent { - shard, + let memory_add_lookup_id = self.record.create_lookup_id(); + let memory_sub_lookup_id = self.record.create_lookup_id(); + let branch_lt_lookup_id = self.record.create_lookup_id(); + let branch_gt_lookup_id = self.record.create_lookup_id(); + let branch_add_lookup_id = self.record.create_lookup_id(); + let jump_jal_lookup_id = self.record.create_lookup_id(); + let jump_jalr_lookup_id = self.record.create_lookup_id(); + let auipc_lookup_id = self.record.create_lookup_id(); + self.record.cpu_events.push(CpuEvent { clk, pc, next_pc, - instruction, a, a_record: record.a, b, b_record: record.b, c, c_record: record.c, - memory: memory_store_value, memory_record: record.memory, exit_code, alu_lookup_id: lookup_id, syscall_lookup_id, - memory_add_lookup_id: create_alu_lookup_id(), - memory_sub_lookup_id: create_alu_lookup_id(), - branch_lt_lookup_id: create_alu_lookup_id(), - branch_gt_lookup_id: create_alu_lookup_id(), - branch_add_lookup_id: create_alu_lookup_id(), - jump_jal_lookup_id: create_alu_lookup_id(), - jump_jalr_lookup_id: create_alu_lookup_id(), - auipc_lookup_id: create_alu_lookup_id(), - }; + memory_add_lookup_id, + memory_sub_lookup_id, + branch_lt_lookup_id, + branch_gt_lookup_id, + branch_add_lookup_id, + jump_jal_lookup_id, + jump_jalr_lookup_id, + auipc_lookup_id, + }); - self.record.cpu_events.push(cpu_event); - emit_cpu_dependencies(self, &cpu_event); + emit_cpu_dependencies(self, self.record.cpu_events.len() - 1); } /// Emit an ALU event. @@ -648,7 +648,7 @@ impl<'a> Executor<'a> { a, b, c, - sub_lookups: create_alu_lookups(), + sub_lookups: self.record.create_lookup_ids(), }; match opcode { Opcode::ADD => { @@ -696,7 +696,7 @@ impl<'a> Executor<'a> { arg1, arg2, lookup_id, - nonce: self.record.nonce_lookup[&lookup_id], + nonce: self.record.nonce_lookup[lookup_id.0 as usize], } } @@ -725,9 +725,9 @@ impl<'a> Executor<'a> { let (rd, b, c) = (rd, self.rr(rs1, MemoryAccessPosition::B), imm); (rd, b, c) } else { - assert!(instruction.imm_b && instruction.imm_c); + debug_assert!(instruction.imm_b && instruction.imm_c); let (rd, b, c) = - (Register::from_u32(instruction.op_a), instruction.op_b, instruction.op_c); + (Register::from_u8(instruction.op_a), instruction.op_b, instruction.op_c); (rd, b, c) } } @@ -780,8 +780,7 @@ impl<'a> Executor<'a> { /// Fetch the instruction at the current program counter. #[inline] fn fetch(&self) -> Instruction { - let idx = ((self.state.pc - self.program.pc_base) / 4) as usize; - self.program.instructions[idx] + *self.program.fetch(self.state.pc) } /// Execute the given instruction over the current state of the runtime. @@ -793,21 +792,18 @@ impl<'a> Executor<'a> { let mut next_pc = self.state.pc.wrapping_add(4); - let rd: Register; let (a, b, c): (u32, u32, u32); - let (addr, memory_read_value): (u32, u32); - let mut memory_store_value: Option = None; if self.executor_mode == ExecutorMode::Trace { self.memory_accesses = MemoryAccessRecord::default(); } let lookup_id = if self.executor_mode == ExecutorMode::Trace { - create_alu_lookup_id() + self.record.create_lookup_id() } else { LookupId::default() }; let syscall_lookup_id = if self.executor_mode == ExecutorMode::Trace { - create_alu_lookup_id() + self.record.create_lookup_id() } else { LookupId::default() }; @@ -842,182 +838,40 @@ impl<'a> Executor<'a> { match instruction.opcode { // Arithmetic instructions. - Opcode::ADD => { - (rd, b, c) = self.alu_rr(instruction); - a = b.wrapping_add(c); - self.alu_rw(instruction, rd, a, b, c, lookup_id); - } - Opcode::SUB => { - (rd, b, c) = self.alu_rr(instruction); - a = b.wrapping_sub(c); - self.alu_rw(instruction, rd, a, b, c, lookup_id); - } - Opcode::XOR => { - (rd, b, c) = self.alu_rr(instruction); - a = b ^ c; - self.alu_rw(instruction, rd, a, b, c, lookup_id); - } - Opcode::OR => { - (rd, b, c) = self.alu_rr(instruction); - a = b | c; - self.alu_rw(instruction, rd, a, b, c, lookup_id); - } - Opcode::AND => { - (rd, b, c) = self.alu_rr(instruction); - a = b & c; - self.alu_rw(instruction, rd, a, b, c, lookup_id); - } - Opcode::SLL => { - (rd, b, c) = self.alu_rr(instruction); - a = b.wrapping_shl(c); - self.alu_rw(instruction, rd, a, b, c, lookup_id); - } - Opcode::SRL => { - (rd, b, c) = self.alu_rr(instruction); - a = b.wrapping_shr(c); - self.alu_rw(instruction, rd, a, b, c, lookup_id); - } - Opcode::SRA => { - (rd, b, c) = self.alu_rr(instruction); - a = (b as i32).wrapping_shr(c) as u32; - self.alu_rw(instruction, rd, a, b, c, lookup_id); - } - Opcode::SLT => { - (rd, b, c) = self.alu_rr(instruction); - a = if (b as i32) < (c as i32) { 1 } else { 0 }; - self.alu_rw(instruction, rd, a, b, c, lookup_id); - } - Opcode::SLTU => { - (rd, b, c) = self.alu_rr(instruction); - a = if b < c { 1 } else { 0 }; - self.alu_rw(instruction, rd, a, b, c, lookup_id); + Opcode::ADD + | Opcode::SUB + | Opcode::XOR + | Opcode::OR + | Opcode::AND + | Opcode::SLL + | Opcode::SRL + | Opcode::SRA + | Opcode::SLT + | Opcode::SLTU + | Opcode::MUL + | Opcode::MULH + | Opcode::MULHU + | Opcode::MULHSU + | Opcode::DIV + | Opcode::DIVU + | Opcode::REM + | Opcode::REMU => { + (a, b, c) = self.execute_alu(instruction, lookup_id); } // Load instructions. - Opcode::LB => { - (rd, b, c, addr, memory_read_value) = self.load_rr(instruction); - let value = (memory_read_value).to_le_bytes()[(addr % 4) as usize]; - a = ((value as i8) as i32) as u32; - memory_store_value = Some(memory_read_value); - self.rw(rd, a); - } - Opcode::LH => { - (rd, b, c, addr, memory_read_value) = self.load_rr(instruction); - if addr % 2 != 0 { - return Err(ExecutionError::InvalidMemoryAccess(Opcode::LH, addr)); - } - let value = match (addr >> 1) % 2 { - 0 => memory_read_value & 0x0000_FFFF, - 1 => (memory_read_value & 0xFFFF_0000) >> 16, - _ => unreachable!(), - }; - a = ((value as i16) as i32) as u32; - memory_store_value = Some(memory_read_value); - self.rw(rd, a); - } - Opcode::LW => { - (rd, b, c, addr, memory_read_value) = self.load_rr(instruction); - if addr % 4 != 0 { - return Err(ExecutionError::InvalidMemoryAccess(Opcode::LW, addr)); - } - a = memory_read_value; - memory_store_value = Some(memory_read_value); - self.rw(rd, a); - } - Opcode::LBU => { - (rd, b, c, addr, memory_read_value) = self.load_rr(instruction); - let value = (memory_read_value).to_le_bytes()[(addr % 4) as usize]; - a = value as u32; - memory_store_value = Some(memory_read_value); - self.rw(rd, a); - } - Opcode::LHU => { - (rd, b, c, addr, memory_read_value) = self.load_rr(instruction); - if addr % 2 != 0 { - return Err(ExecutionError::InvalidMemoryAccess(Opcode::LHU, addr)); - } - let value = match (addr >> 1) % 2 { - 0 => memory_read_value & 0x0000_FFFF, - 1 => (memory_read_value & 0xFFFF_0000) >> 16, - _ => unreachable!(), - }; - a = (value as u16) as u32; - memory_store_value = Some(memory_read_value); - self.rw(rd, a); + Opcode::LB | Opcode::LH | Opcode::LW | Opcode::LBU | Opcode::LHU => { + (a, b, c) = self.execute_load(instruction)?; } // Store instructions. - Opcode::SB => { - (a, b, c, addr, memory_read_value) = self.store_rr(instruction); - let value = match addr % 4 { - 0 => (a & 0x0000_00FF) + (memory_read_value & 0xFFFF_FF00), - 1 => ((a & 0x0000_00FF) << 8) + (memory_read_value & 0xFFFF_00FF), - 2 => ((a & 0x0000_00FF) << 16) + (memory_read_value & 0xFF00_FFFF), - 3 => ((a & 0x0000_00FF) << 24) + (memory_read_value & 0x00FF_FFFF), - _ => unreachable!(), - }; - memory_store_value = Some(value); - self.mw_cpu(align(addr), value, MemoryAccessPosition::Memory); - } - Opcode::SH => { - (a, b, c, addr, memory_read_value) = self.store_rr(instruction); - if addr % 2 != 0 { - return Err(ExecutionError::InvalidMemoryAccess(Opcode::SH, addr)); - } - let value = match (addr >> 1) % 2 { - 0 => (a & 0x0000_FFFF) + (memory_read_value & 0xFFFF_0000), - 1 => ((a & 0x0000_FFFF) << 16) + (memory_read_value & 0x0000_FFFF), - _ => unreachable!(), - }; - memory_store_value = Some(value); - self.mw_cpu(align(addr), value, MemoryAccessPosition::Memory); - } - Opcode::SW => { - (a, b, c, addr, _) = self.store_rr(instruction); - if addr % 4 != 0 { - return Err(ExecutionError::InvalidMemoryAccess(Opcode::SW, addr)); - } - let value = a; - memory_store_value = Some(value); - self.mw_cpu(align(addr), value, MemoryAccessPosition::Memory); + Opcode::SB | Opcode::SH | Opcode::SW => { + (a, b, c) = self.execute_store(instruction)?; } - // B-type instructions. - Opcode::BEQ => { - (a, b, c) = self.branch_rr(instruction); - if a == b { - next_pc = self.state.pc.wrapping_add(c); - } - } - Opcode::BNE => { - (a, b, c) = self.branch_rr(instruction); - if a != b { - next_pc = self.state.pc.wrapping_add(c); - } - } - Opcode::BLT => { - (a, b, c) = self.branch_rr(instruction); - if (a as i32) < (b as i32) { - next_pc = self.state.pc.wrapping_add(c); - } - } - Opcode::BGE => { - (a, b, c) = self.branch_rr(instruction); - if (a as i32) >= (b as i32) { - next_pc = self.state.pc.wrapping_add(c); - } - } - Opcode::BLTU => { - (a, b, c) = self.branch_rr(instruction); - if a < b { - next_pc = self.state.pc.wrapping_add(c); - } - } - Opcode::BGEU => { - (a, b, c) = self.branch_rr(instruction); - if a >= b { - next_pc = self.state.pc.wrapping_add(c); - } + // Branch instructions. + Opcode::BEQ | Opcode::BNE | Opcode::BLT | Opcode::BGE | Opcode::BLTU | Opcode::BGEU => { + (a, b, c, next_pc) = self.execute_branch(instruction, next_pc); } // Jump instructions. @@ -1080,7 +934,7 @@ impl<'a> Executor<'a> { _ => (self.opts.split_opts.deferred, 1), }; let nonce = (((*syscall_count as usize) % threshold) * multiplier) as u32; - self.record.nonce_lookup.insert(syscall_lookup_id, nonce); + self.record.nonce_lookup[syscall_lookup_id.0 as usize] = nonce; *syscall_count += 1; let syscall_impl = self.get_syscall(syscall).cloned(); @@ -1130,64 +984,6 @@ impl<'a> Executor<'a> { return Err(ExecutionError::Breakpoint()); } - // Multiply instructions. - Opcode::MUL => { - (rd, b, c) = self.alu_rr(instruction); - a = b.wrapping_mul(c); - self.alu_rw(instruction, rd, a, b, c, lookup_id); - } - Opcode::MULH => { - (rd, b, c) = self.alu_rr(instruction); - a = (((b as i32) as i64).wrapping_mul((c as i32) as i64) >> 32) as u32; - self.alu_rw(instruction, rd, a, b, c, lookup_id); - } - Opcode::MULHU => { - (rd, b, c) = self.alu_rr(instruction); - a = ((b as u64).wrapping_mul(c as u64) >> 32) as u32; - self.alu_rw(instruction, rd, a, b, c, lookup_id); - } - Opcode::MULHSU => { - (rd, b, c) = self.alu_rr(instruction); - a = (((b as i32) as i64).wrapping_mul(c as i64) >> 32) as u32; - self.alu_rw(instruction, rd, a, b, c, lookup_id); - } - Opcode::DIV => { - (rd, b, c) = self.alu_rr(instruction); - if c == 0 { - a = u32::MAX; - } else { - a = (b as i32).wrapping_div(c as i32) as u32; - } - self.alu_rw(instruction, rd, a, b, c, lookup_id); - } - Opcode::DIVU => { - (rd, b, c) = self.alu_rr(instruction); - if c == 0 { - a = u32::MAX; - } else { - a = b.wrapping_div(c); - } - self.alu_rw(instruction, rd, a, b, c, lookup_id); - } - Opcode::REM => { - (rd, b, c) = self.alu_rr(instruction); - if c == 0 { - a = b; - } else { - a = (b as i32).wrapping_rem(c as i32) as u32; - } - self.alu_rw(instruction, rd, a, b, c, lookup_id); - } - Opcode::REMU => { - (rd, b, c) = self.alu_rr(instruction); - if c == 0 { - a = b; - } else { - a = b.wrapping_rem(c); - } - self.alu_rw(instruction, rd, a, b, c, lookup_id); - } - // See https://github.com/riscv-non-isa/riscv-asm-manual/blob/master/riscv-asm.md#instruction-aliases Opcode::UNIMP => { return Err(ExecutionError::Unimplemented()); @@ -1203,15 +999,12 @@ impl<'a> Executor<'a> { // Emit the CPU event for this cycle. if self.executor_mode == ExecutorMode::Trace { self.emit_cpu( - self.shard(), clk, pc, next_pc, - *instruction, a, b, c, - memory_store_value, self.memory_accesses, exit_code, lookup_id, @@ -1221,6 +1014,153 @@ impl<'a> Executor<'a> { Ok(()) } + fn execute_alu(&mut self, instruction: &Instruction, lookup_id: LookupId) -> (u32, u32, u32) { + let (rd, b, c) = self.alu_rr(instruction); + let a = match instruction.opcode { + Opcode::ADD => b.wrapping_add(c), + Opcode::SUB => b.wrapping_sub(c), + Opcode::XOR => b ^ c, + Opcode::OR => b | c, + Opcode::AND => b & c, + Opcode::SLL => b.wrapping_shl(c), + Opcode::SRL => b.wrapping_shr(c), + Opcode::SRA => (b as i32).wrapping_shr(c) as u32, + Opcode::SLT => { + if (b as i32) < (c as i32) { + 1 + } else { + 0 + } + } + Opcode::SLTU => { + if b < c { + 1 + } else { + 0 + } + } + Opcode::MUL => b.wrapping_mul(c), + Opcode::MULH => (((b as i32) as i64).wrapping_mul((c as i32) as i64) >> 32) as u32, + Opcode::MULHU => ((b as u64).wrapping_mul(c as u64) >> 32) as u32, + Opcode::MULHSU => (((b as i32) as i64).wrapping_mul(c as i64) >> 32) as u32, + Opcode::DIV => { + if c == 0 { + u32::MAX + } else { + (b as i32).wrapping_div(c as i32) as u32 + } + } + Opcode::DIVU => { + if c == 0 { + u32::MAX + } else { + b.wrapping_div(c) + } + } + Opcode::REM => { + if c == 0 { + b + } else { + (b as i32).wrapping_rem(c as i32) as u32 + } + } + Opcode::REMU => { + if c == 0 { + b + } else { + b.wrapping_rem(c) + } + } + _ => unreachable!(), + }; + self.alu_rw(instruction, rd, a, b, c, lookup_id); + (a, b, c) + } + + fn execute_load( + &mut self, + instruction: &Instruction, + ) -> Result<(u32, u32, u32), ExecutionError> { + let (rd, b, c, addr, memory_read_value) = self.load_rr(instruction); + let a = match instruction.opcode { + Opcode::LB => ((memory_read_value >> ((addr % 4) * 8)) & 0xFF) as i8 as i32 as u32, + Opcode::LH => { + if addr % 2 != 0 { + return Err(ExecutionError::InvalidMemoryAccess(Opcode::LH, addr)); + } + ((memory_read_value >> (((addr / 2) % 2) * 16)) & 0xFFFF) as i16 as i32 as u32 + } + Opcode::LW => { + if addr % 4 != 0 { + return Err(ExecutionError::InvalidMemoryAccess(Opcode::LW, addr)); + } + memory_read_value + } + Opcode::LBU => (memory_read_value >> ((addr % 4) * 8)) & 0xFF, + Opcode::LHU => { + if addr % 2 != 0 { + return Err(ExecutionError::InvalidMemoryAccess(Opcode::LHU, addr)); + } + (memory_read_value >> (((addr / 2) % 2) * 16)) & 0xFFFF + } + _ => unreachable!(), + }; + self.rw(rd, a); + Ok((a, b, c)) + } + + fn execute_store( + &mut self, + instruction: &Instruction, + ) -> Result<(u32, u32, u32), ExecutionError> { + let (a, b, c, addr, memory_read_value) = self.store_rr(instruction); + let memory_store_value = match instruction.opcode { + Opcode::SB => { + let shift = (addr % 4) * 8; + ((a & 0xFF) << shift) | (memory_read_value & !(0xFF << shift)) + } + Opcode::SH => { + if addr % 2 != 0 { + return Err(ExecutionError::InvalidMemoryAccess(Opcode::SH, addr)); + } + let shift = ((addr / 2) % 2) * 16; + ((a & 0xFFFF) << shift) | (memory_read_value & !(0xFFFF << shift)) + } + Opcode::SW => { + if addr % 4 != 0 { + return Err(ExecutionError::InvalidMemoryAccess(Opcode::SW, addr)); + } + a + } + _ => unreachable!(), + }; + self.mw_cpu(align(addr), memory_store_value, MemoryAccessPosition::Memory); + Ok((a, b, c)) + } + + fn execute_branch( + &mut self, + instruction: &Instruction, + mut next_pc: u32, + ) -> (u32, u32, u32, u32) { + let (a, b, c) = self.branch_rr(instruction); + let branch = match instruction.opcode { + Opcode::BEQ => a == b, + Opcode::BNE => a != b, + Opcode::BLT => (a as i32) < (b as i32), + Opcode::BGE => (a as i32) >= (b as i32), + Opcode::BLTU => a < b, + Opcode::BGEU => a >= b, + _ => { + unreachable!() + } + }; + if branch { + next_pc = self.state.pc.wrapping_add(c); + } + (a, b, c, next_pc) + } + /// Executes one cycle of the program, returning whether the program has finished. #[inline] #[allow(clippy::too_many_lines)] @@ -1397,6 +1337,7 @@ impl<'a> Executor<'a> { std::mem::replace(&mut self.record, ExecutionRecord::new(self.program.clone())); let public_values = removed_record.public_values; self.record.public_values = public_values; + self.record.nonce_lookup = vec![0; self.opts.shard_size * 32]; self.records.push(removed_record); } @@ -1471,6 +1412,8 @@ impl<'a> Executor<'a> { } fn initialize(&mut self) { + self.record.nonce_lookup = vec![0; self.opts.shard_size * 32]; + self.state.clk = 0; tracing::debug!("loading memory image"); @@ -1506,6 +1449,11 @@ impl<'a> Executor<'a> { /// Executes up to `self.shard_batch_size` cycles of the program, returning whether the program /// has finished. pub fn execute(&mut self) -> Result { + // Initialize the nonce lookup table if it's uninitialized. + if self.record.nonce_lookup.len() <= 2 { + self.record.nonce_lookup = vec![0; self.opts.shard_size * 32]; + } + // Get the program. let program = self.program.clone(); diff --git a/crates/core/executor/src/instruction.rs b/crates/core/executor/src/instruction.rs index bc1df27ba0..10dfa5476d 100644 --- a/crates/core/executor/src/instruction.rs +++ b/crates/core/executor/src/instruction.rs @@ -15,7 +15,7 @@ pub struct Instruction { /// The operation to execute. pub opcode: Opcode, /// The first operand. - pub op_a: u32, + pub op_a: u8, /// The second operand. pub op_b: u32, /// The third operand. @@ -31,7 +31,7 @@ impl Instruction { #[must_use] pub const fn new( opcode: Opcode, - op_a: u32, + op_a: u8, op_b: u32, op_c: u32, imm_b: bool, diff --git a/crates/core/executor/src/memory.rs b/crates/core/executor/src/memory.rs index 6e375753d4..a036bbf5ca 100644 --- a/crates/core/executor/src/memory.rs +++ b/crates/core/executor/src/memory.rs @@ -11,10 +11,10 @@ impl Default for Page { } } -const LOG_PAGE_LEN: usize = 15; +const LOG_PAGE_LEN: usize = 14; const PAGE_LEN: usize = 1 << LOG_PAGE_LEN; const MAX_PAGE_COUNT: usize = ((1 << 31) - (1 << 27)) / 4 / PAGE_LEN + 1; -const NO_PAGE: usize = usize::MAX; +const NO_PAGE: u16 = u16::MAX; const PAGE_MASK: usize = PAGE_LEN - 1; #[derive(Debug, Clone, Serialize, Deserialize)] @@ -41,7 +41,7 @@ impl Default for NewPage { pub struct PagedMemory { /// The internal page table. pub page_table: Vec>, - pub index: Vec, + pub index: Vec, } impl PagedMemory { @@ -50,8 +50,7 @@ impl PagedMemory { /// The number of registers in the virtual machine. const NUM_REGISTERS: usize = 32; /// The offset subtracted from the main address space to make it contiguous. - const ADDR_COMPRESS_OFFSET: usize = - Self::NUM_REGISTERS - (Self::NUM_REGISTERS >> Self::NUM_IGNORED_LOWER_BITS); + const ADDR_COMPRESS_OFFSET: usize = Self::NUM_REGISTERS; /// Create a `PagedMemory` with capacity `MAX_PAGE_COUNT`. pub fn new_preallocated() -> Self { @@ -65,7 +64,7 @@ impl PagedMemory { if index == NO_PAGE { None } else { - self.page_table[index].0[lower].as_ref() + self.page_table[index as usize].0[lower].as_ref() } } @@ -76,7 +75,7 @@ impl PagedMemory { if index == NO_PAGE { None } else { - self.page_table[index].0[lower].as_mut() + self.page_table[index as usize].0[lower].as_mut() } } @@ -85,11 +84,11 @@ impl PagedMemory { let (upper, lower) = Self::indices(addr); let mut index = self.index[upper]; if index == NO_PAGE { - index = self.page_table.len(); + index = self.page_table.len() as u16; self.index[upper] = index; self.page_table.push(NewPage::new()); } - self.page_table[index].0[lower].replace(value) + self.page_table[index as usize].0[lower].replace(value) } /// Remove the value at the given address if it exists, returning it. @@ -99,7 +98,7 @@ impl PagedMemory { if index == NO_PAGE { None } else { - self.page_table[index].0[lower].take() + self.page_table[index as usize].0[lower].take() } } @@ -109,11 +108,11 @@ impl PagedMemory { let index = self.index[upper]; if index == NO_PAGE { let index = self.page_table.len(); - self.index[upper] = index; + self.index[upper] = index as u16; self.page_table.push(NewPage::new()); Entry::Vacant(VacantEntry { entry: &mut self.page_table[index].0[lower] }) } else { - let option = &mut self.page_table[index].0[lower]; + let option = &mut self.page_table[index as usize].0[lower]; match option { Some(_) => Entry::Occupied(OccupiedEntry { entry: option }), None => Entry::Vacant(VacantEntry { entry: option }), @@ -125,7 +124,7 @@ impl PagedMemory { pub fn keys(&self) -> impl Iterator + '_ { self.index.iter().enumerate().filter(|(_, &i)| i != NO_PAGE).flat_map(|(i, index)| { let upper = i << LOG_PAGE_LEN; - self.page_table[*index] + self.page_table[*index as usize] .0 .iter() .enumerate() @@ -275,7 +274,7 @@ impl IntoIterator for PagedMemory { move |(i, index)| { let upper = i << LOG_PAGE_LEN; let replacement = NewPage::new(); - std::mem::replace(&mut self.page_table[index], replacement) + std::mem::replace(&mut self.page_table[index as usize], replacement) .0 .into_iter() .enumerate() diff --git a/crates/core/executor/src/opcode.rs b/crates/core/executor/src/opcode.rs index 6d0589ca91..818b5b1f2b 100644 --- a/crates/core/executor/src/opcode.rs +++ b/crates/core/executor/src/opcode.rs @@ -100,7 +100,7 @@ pub enum Opcode { /// rd ← rs1 % rs2 (unsigned), pc ← pc + 4 REMU = 37, /// Unimplemented instruction. - UNIMP = 39, + UNIMP = 38, } /// Byte Opcode. diff --git a/crates/core/executor/src/program.rs b/crates/core/executor/src/program.rs index 29743a4c8f..09bb70cac4 100644 --- a/crates/core/executor/src/program.rs +++ b/crates/core/executor/src/program.rs @@ -89,6 +89,13 @@ impl Program { }) .copied() } + + #[must_use] + /// Fetch the instruction at the given program counter. + pub fn fetch(&self, pc: u32) -> &Instruction { + let idx = ((pc - self.pc_base) / 4) as usize; + &self.instructions[idx] + } } impl MachineProgram for Program { diff --git a/crates/core/executor/src/record.rs b/crates/core/executor/src/record.rs index b6c23c45f8..f9e89acb4c 100644 --- a/crates/core/executor/src/record.rs +++ b/crates/core/executor/src/record.rs @@ -23,7 +23,7 @@ use crate::{ /// A record of the execution of a program. /// /// The trace of the execution is represented as a list of "events" that occur every cycle. -#[derive(Default, Clone, Debug, Serialize, Deserialize)] +#[derive(Clone, Debug, Serialize, Deserialize)] pub struct ExecutionRecord { /// The program. pub program: Arc, @@ -60,16 +60,63 @@ pub struct ExecutionRecord { /// The public values. pub public_values: PublicValues, /// The nonce lookup. - pub nonce_lookup: HashMap, + pub nonce_lookup: Vec, + /// The next nonce to use for a new lookup. + pub next_nonce: u64, /// The shape of the proof. pub shape: Option, } +impl Default for ExecutionRecord { + fn default() -> Self { + let mut res = Self { + program: Arc::default(), + cpu_events: Vec::default(), + add_events: Vec::default(), + mul_events: Vec::default(), + sub_events: Vec::default(), + bitwise_events: Vec::default(), + shift_left_events: Vec::default(), + shift_right_events: Vec::default(), + divrem_events: Vec::default(), + lt_events: Vec::default(), + byte_lookups: HashMap::default(), + precompile_events: PrecompileEvents::default(), + global_memory_initialize_events: Vec::default(), + global_memory_finalize_events: Vec::default(), + cpu_local_memory_access: Vec::default(), + syscall_events: Vec::default(), + public_values: PublicValues::default(), + nonce_lookup: Vec::default(), + next_nonce: 0, + shape: None, + }; + res.nonce_lookup.insert(0, 0); + res + } +} + impl ExecutionRecord { /// Create a new [`ExecutionRecord`]. #[must_use] pub fn new(program: Arc) -> Self { - Self { program, ..Default::default() } + let mut res = Self { program, ..Default::default() }; + res.nonce_lookup.insert(0, 0); + res + } + + /// Create a lookup id for an event. + pub fn create_lookup_id(&mut self) -> LookupId { + // let id = self.nonce_lookup.len() as u64; + let id = self.next_nonce; + self.next_nonce += 1; + // self.nonce_lookup.insert(id as usize, 0); + LookupId(id) + } + + /// Create 6 lookup ids for an ALU event. + pub fn create_lookup_ids(&mut self) -> [LookupId; 6] { + std::array::from_fn(|_| self.create_lookup_id()) } /// Add a mul event to the execution record. @@ -299,7 +346,7 @@ impl MachineRecord for ExecutionRecord { ); stats.insert("local_memory_access_events".to_string(), self.cpu_local_memory_access.len()); if !self.cpu_events.is_empty() { - let shard = self.cpu_events[0].shard; + let shard = self.public_values.shard; stats.insert( "byte_lookups".to_string(), self.byte_lookups.get(&shard).map_or(0, hashbrown::HashMap::len), @@ -337,35 +384,35 @@ impl MachineRecord for ExecutionRecord { fn register_nonces(&mut self, _opts: &Self::Config) { self.add_events.iter().enumerate().for_each(|(i, event)| { - self.nonce_lookup.insert(event.lookup_id, i as u32); + self.nonce_lookup[event.lookup_id.0 as usize] = i as u32; }); self.sub_events.iter().enumerate().for_each(|(i, event)| { - self.nonce_lookup.insert(event.lookup_id, (self.add_events.len() + i) as u32); + self.nonce_lookup[event.lookup_id.0 as usize] = (self.add_events.len() + i) as u32; }); self.mul_events.iter().enumerate().for_each(|(i, event)| { - self.nonce_lookup.insert(event.lookup_id, i as u32); + self.nonce_lookup[event.lookup_id.0 as usize] = i as u32; }); self.bitwise_events.iter().enumerate().for_each(|(i, event)| { - self.nonce_lookup.insert(event.lookup_id, i as u32); + self.nonce_lookup[event.lookup_id.0 as usize] = i as u32; }); self.shift_left_events.iter().enumerate().for_each(|(i, event)| { - self.nonce_lookup.insert(event.lookup_id, i as u32); + self.nonce_lookup[event.lookup_id.0 as usize] = i as u32; }); self.shift_right_events.iter().enumerate().for_each(|(i, event)| { - self.nonce_lookup.insert(event.lookup_id, i as u32); + self.nonce_lookup[event.lookup_id.0 as usize] = i as u32; }); self.divrem_events.iter().enumerate().for_each(|(i, event)| { - self.nonce_lookup.insert(event.lookup_id, i as u32); + self.nonce_lookup[event.lookup_id.0 as usize] = i as u32; }); self.lt_events.iter().enumerate().for_each(|(i, event)| { - self.nonce_lookup.insert(event.lookup_id, i as u32); + self.nonce_lookup[event.lookup_id.0 as usize] = i as u32; }); } diff --git a/crates/core/executor/src/register.rs b/crates/core/executor/src/register.rs index 176ef1c951..c24b75decc 100644 --- a/crates/core/executor/src/register.rs +++ b/crates/core/executor/src/register.rs @@ -70,14 +70,14 @@ pub enum Register { } impl Register { - /// Create a new register from a u32. + /// Create a new register from a u8. /// /// # Panics /// /// This function will panic if the register is invalid. #[inline] #[must_use] - pub fn from_u32(value: u32) -> Self { + pub fn from_u8(value: u8) -> Self { match value { 0 => Register::X0, 1 => Register::X1, diff --git a/crates/core/executor/src/syscalls/precompiles/sha256/extend.rs b/crates/core/executor/src/syscalls/precompiles/sha256/extend.rs index 36f1b56542..1d4a2a769e 100644 --- a/crates/core/executor/src/syscalls/precompiles/sha256/extend.rs +++ b/crates/core/executor/src/syscalls/precompiles/sha256/extend.rs @@ -22,11 +22,11 @@ impl Syscall for Sha256ExtendSyscall { assert!(arg2 == 0, "arg2 must be 0"); let w_ptr_init = w_ptr; - let mut w_i_minus_15_reads = Vec::new(); - let mut w_i_minus_2_reads = Vec::new(); - let mut w_i_minus_16_reads = Vec::new(); - let mut w_i_minus_7_reads = Vec::new(); - let mut w_i_writes = Vec::new(); + let mut w_i_minus_15_reads = Vec::with_capacity(48); + let mut w_i_minus_2_reads = Vec::with_capacity(48); + let mut w_i_minus_16_reads = Vec::with_capacity(48); + let mut w_i_minus_7_reads = Vec::with_capacity(48); + let mut w_i_writes = Vec::with_capacity(48); for i in 16..64 { // Read w[i-15]. let (record, w_i_minus_15) = rt.mr(w_ptr + (i - 15) * 4); diff --git a/crates/core/machine/Cargo.toml b/crates/core/machine/Cargo.toml index 7e2bd5a50a..3b0b0f34f9 100644 --- a/crates/core/machine/Cargo.toml +++ b/crates/core/machine/Cargo.toml @@ -54,6 +54,7 @@ static_assertions = "1.1.0" sp1-stark = { workspace = true } sp1-core-executor = { workspace = true } sp1-curves = { workspace = true } +vec_map = "0.8.2" [dev-dependencies] tiny-keccak = { version = "2.0.2", features = ["keccak"] } diff --git a/crates/core/machine/src/alu/add_sub/mod.rs b/crates/core/machine/src/alu/add_sub/mod.rs index bf6dabef42..d276820755 100644 --- a/crates/core/machine/src/alu/add_sub/mod.rs +++ b/crates/core/machine/src/alu/add_sub/mod.rs @@ -8,7 +8,7 @@ use itertools::Itertools; use p3_air::{Air, AirBuilder, BaseAir}; use p3_field::{AbstractField, PrimeField}; use p3_matrix::{dense::RowMajorMatrix, Matrix}; -use p3_maybe_rayon::prelude::{ParallelBridge, ParallelIterator, ParallelSlice}; +use p3_maybe_rayon::prelude::{ParallelBridge, ParallelIterator}; use sp1_core_executor::{ events::{AluEvent, ByteLookupEvent, ByteRecord}, ExecutionRecord, Opcode, Program, @@ -19,7 +19,10 @@ use sp1_stark::{ Word, }; -use crate::{operations::AddOperation, utils::pad_rows_fixed}; +use crate::{ + operations::AddOperation, + utils::{next_power_of_two, zeroed_f_vec}, +}; /// The number of main trace columns for `AddSubChip`. pub const NUM_ADD_SUB_COLS: usize = size_of::>(); @@ -79,46 +82,29 @@ impl MachineAir for AddSubChip { std::cmp::max((input.add_events.len() + input.sub_events.len()) / num_cpus::get(), 1); let merged_events = input.add_events.iter().chain(input.sub_events.iter()).collect::>(); - - let row_batches = merged_events - .par_chunks(chunk_size) - .map(|events| { - let rows = events - .iter() - .map(|event| { - let mut row = [F::zero(); NUM_ADD_SUB_COLS]; - let cols: &mut AddSubCols = row.as_mut_slice().borrow_mut(); - let mut blu = Vec::new(); - self.event_to_row(event, cols, &mut blu); - row - }) - .collect::>(); - rows - }) - .collect::>(); - - let mut rows: Vec<[F; NUM_ADD_SUB_COLS]> = vec![]; - for row_batch in row_batches { - rows.extend(row_batch); - } - - pad_rows_fixed( - &mut rows, - || [F::zero(); NUM_ADD_SUB_COLS], - input.fixed_log2_rows::(self), + let nb_rows = merged_events.len(); + let size_log2 = input.fixed_log2_rows::(self); + let padded_nb_rows = next_power_of_two(nb_rows, size_log2); + let mut values = zeroed_f_vec(padded_nb_rows * NUM_ADD_SUB_COLS); + + values.chunks_mut(chunk_size * NUM_ADD_SUB_COLS).enumerate().par_bridge().for_each( + |(i, rows)| { + rows.chunks_mut(NUM_ADD_SUB_COLS).enumerate().for_each(|(j, row)| { + let idx = i * chunk_size + j; + let cols: &mut AddSubCols = row.borrow_mut(); + + if idx < merged_events.len() { + let mut byte_lookup_events = Vec::new(); + let event = &merged_events[idx]; + self.event_to_row(event, cols, &mut byte_lookup_events); + } + cols.nonce = F::from_canonical_usize(idx); + }); + }, ); - // Convert the trace to a row major matrix. - let mut trace = - RowMajorMatrix::new(rows.into_iter().flatten().collect::>(), NUM_ADD_SUB_COLS); - - // Write the nonces to the trace. - for i in 0..trace.height() { - let cols: &mut AddSubCols = - trace.values[i * NUM_ADD_SUB_COLS..(i + 1) * NUM_ADD_SUB_COLS].borrow_mut(); - cols.nonce = F::from_canonical_usize(i); - } - trace + // Convert the trace to a row major matrix. + RowMajorMatrix::new(values, NUM_ADD_SUB_COLS) } fn generate_dependencies(&self, input: &Self::Record, output: &mut Self::Record) { diff --git a/crates/core/machine/src/alu/divrem/mod.rs b/crates/core/machine/src/alu/divrem/mod.rs index 9170b87b5f..1d4d539fc4 100644 --- a/crates/core/machine/src/alu/divrem/mod.rs +++ b/crates/core/machine/src/alu/divrem/mod.rs @@ -279,11 +279,19 @@ impl MachineAir for DivRemChip { // Set the `alu_event` flags. cols.abs_c_alu_event = cols.c_neg * cols.is_real; cols.abs_c_alu_event_nonce = F::from_canonical_u32( - input.nonce_lookup.get(&event.sub_lookups[4]).copied().unwrap_or_default(), + input + .nonce_lookup + .get(event.sub_lookups[4].0 as usize) + .copied() + .unwrap_or_default(), ); cols.abs_rem_alu_event = cols.rem_neg * cols.is_real; cols.abs_rem_alu_event_nonce = F::from_canonical_u32( - input.nonce_lookup.get(&event.sub_lookups[5]).copied().unwrap_or_default(), + input + .nonce_lookup + .get(event.sub_lookups[5].0 as usize) + .copied() + .unwrap_or_default(), ); // Insert the MSB lookup events. @@ -344,16 +352,24 @@ impl MachineAir for DivRemChip { // Insert the necessary multiplication & LT events. { cols.lower_nonce = F::from_canonical_u32( - input.nonce_lookup.get(&event.sub_lookups[0]).copied().unwrap_or_default(), + input + .nonce_lookup + .get(event.sub_lookups[0].0 as usize) + .copied() + .unwrap_or_default(), ); cols.upper_nonce = F::from_canonical_u32( - input.nonce_lookup.get(&event.sub_lookups[1]).copied().unwrap_or_default(), + input + .nonce_lookup + .get(event.sub_lookups[1].0 as usize) + .copied() + .unwrap_or_default(), ); if is_signed_operation(event.opcode) { cols.abs_nonce = F::from_canonical_u32( input .nonce_lookup - .get(&event.sub_lookups[2]) + .get(event.sub_lookups[2].0 as usize) .copied() .unwrap_or_default(), ); @@ -361,7 +377,7 @@ impl MachineAir for DivRemChip { cols.abs_nonce = F::from_canonical_u32( input .nonce_lookup - .get(&event.sub_lookups[3]) + .get(event.sub_lookups[3].0 as usize) .copied() .unwrap_or_default(), ); diff --git a/crates/core/machine/src/alu/lt/mod.rs b/crates/core/machine/src/alu/lt/mod.rs index 211cf5d912..876fdaaf8f 100644 --- a/crates/core/machine/src/alu/lt/mod.rs +++ b/crates/core/machine/src/alu/lt/mod.rs @@ -19,7 +19,7 @@ use sp1_stark::{ Word, }; -use crate::utils::pad_rows_fixed; +use crate::utils::{next_power_of_two, zeroed_f_vec}; /// The number of main trace columns for `LtChip`. pub const NUM_LT_COLS: usize = size_of::>(); @@ -107,38 +107,31 @@ impl MachineAir for LtChip { _: &mut ExecutionRecord, ) -> RowMajorMatrix { // Generate the trace rows for each event. - let mut rows = input - .lt_events - .par_iter() - .map(|event| { - let mut row = [F::zero(); NUM_LT_COLS]; - let mut new_byte_lookup_events: Vec = Vec::new(); - let cols: &mut LtCols = row.as_mut_slice().borrow_mut(); - self.event_to_row(event, cols, &mut new_byte_lookup_events); - - row - }) - .collect::>(); - - // Pad the trace to a power of two depending on the proof shape in `input`. - pad_rows_fixed( - &mut rows, - || [F::zero(); NUM_LT_COLS], - input.fixed_log2_rows::(self), + let nb_rows = input.lt_events.len(); + let size_log2 = input.fixed_log2_rows::(self); + let padded_nb_rows = next_power_of_two(nb_rows, size_log2); + let mut values = zeroed_f_vec(padded_nb_rows * NUM_LT_COLS); + let chunk_size = std::cmp::max((nb_rows + 1) / num_cpus::get(), 1); + + values.chunks_mut(chunk_size * NUM_LT_COLS).enumerate().par_bridge().for_each( + |(i, rows)| { + rows.chunks_mut(NUM_LT_COLS).enumerate().for_each(|(j, row)| { + let idx = i * chunk_size + j; + let cols: &mut LtCols = row.borrow_mut(); + + if idx < nb_rows { + let mut byte_lookup_events = Vec::new(); + let event = &input.lt_events[idx]; + self.event_to_row(event, cols, &mut byte_lookup_events); + } + cols.nonce = F::from_canonical_usize(idx); + }); + }, ); // Convert the trace to a row major matrix. - let mut trace = - RowMajorMatrix::new(rows.into_iter().flatten().collect::>(), NUM_LT_COLS); - - // Write the nonces to the trace. - for i in 0..trace.height() { - let cols: &mut LtCols = - trace.values[i * NUM_LT_COLS..(i + 1) * NUM_LT_COLS].borrow_mut(); - cols.nonce = F::from_canonical_usize(i); - } - trace + RowMajorMatrix::new(values, NUM_LT_COLS) } fn generate_dependencies(&self, input: &Self::Record, output: &mut Self::Record) { diff --git a/crates/core/machine/src/alu/mul/mod.rs b/crates/core/machine/src/alu/mul/mod.rs index 0453cb5f87..6a1ce272fe 100644 --- a/crates/core/machine/src/alu/mul/mod.rs +++ b/crates/core/machine/src/alu/mul/mod.rs @@ -35,19 +35,24 @@ use core::{ mem::size_of, }; +use hashbrown::HashMap; use p3_air::{Air, AirBuilder, BaseAir}; use p3_field::{AbstractField, PrimeField}; use p3_matrix::{dense::RowMajorMatrix, Matrix}; -use p3_maybe_rayon::prelude::{ParallelIterator, ParallelSlice}; +use p3_maybe_rayon::prelude::{ParallelBridge, ParallelIterator, ParallelSlice}; use sp1_core_executor::{ - events::{ByteLookupEvent, ByteRecord}, + events::{AluEvent, ByteLookupEvent, ByteRecord}, ByteOpcode, ExecutionRecord, Opcode, Program, }; use sp1_derive::AlignedBorrow; use sp1_primitives::consts::WORD_SIZE; -use sp1_stark::{air::MachineAir, MachineRecord, Word}; +use sp1_stark::{air::MachineAir, Word}; -use crate::{air::SP1CoreAirBuilder, alu::mul::utils::get_msb, utils::pad_rows_fixed}; +use crate::{ + air::SP1CoreAirBuilder, + alu::mul::utils::get_msb, + utils::{next_power_of_two, zeroed_f_vec}, +}; /// The number of main trace columns for `MulChip`. pub const NUM_MUL_COLS: usize = size_of::>(); @@ -131,148 +136,54 @@ impl MachineAir for MulChip { fn generate_trace( &self, input: &ExecutionRecord, - output: &mut ExecutionRecord, + _: &mut ExecutionRecord, ) -> RowMajorMatrix { - let mul_events = input.mul_events.clone(); - // Compute the chunk size based on the number of events and the number of CPUs. - let chunk_size = std::cmp::max(mul_events.len() / num_cpus::get(), 1); + // Generate the trace rows for each event. + let nb_rows = input.mul_events.len(); + let size_log2 = input.fixed_log2_rows::(self); + let padded_nb_rows = next_power_of_two(nb_rows, size_log2); + let mut values = zeroed_f_vec(padded_nb_rows * NUM_MUL_COLS); + let chunk_size = std::cmp::max((nb_rows + 1) / num_cpus::get(), 1); + + values.chunks_mut(chunk_size * NUM_MUL_COLS).enumerate().par_bridge().for_each( + |(i, rows)| { + rows.chunks_mut(NUM_MUL_COLS).enumerate().for_each(|(j, row)| { + let idx = i * chunk_size + j; + let cols: &mut MulCols = row.borrow_mut(); + + if idx < nb_rows { + let mut byte_lookup_events = Vec::new(); + let event = &input.mul_events[idx]; + self.event_to_row(event, cols, &mut byte_lookup_events); + } + cols.nonce = F::from_canonical_usize(idx); + }); + }, + ); - // Generate the trace rows & corresponding records for each chunk of events in parallel. - let rows_and_records = mul_events + // Convert the trace to a row major matrix. + + RowMajorMatrix::new(values, NUM_MUL_COLS) + } + + fn generate_dependencies(&self, input: &Self::Record, output: &mut Self::Record) { + let chunk_size = std::cmp::max(input.mul_events.len() / num_cpus::get(), 1); + + let blu_batches = input + .mul_events .par_chunks(chunk_size) .map(|events| { - let mut record = ExecutionRecord::default(); - let rows = events - .iter() - .map(|event| { - // Ensure that the opcode is MUL, MULHU, MULH, or MULHSU. - assert!( - event.opcode == Opcode::MUL - || event.opcode == Opcode::MULHU - || event.opcode == Opcode::MULH - || event.opcode == Opcode::MULHSU - ); - let mut row = [F::zero(); NUM_MUL_COLS]; - let cols: &mut MulCols = row.as_mut_slice().borrow_mut(); - - let a_word = event.a.to_le_bytes(); - let b_word = event.b.to_le_bytes(); - let c_word = event.c.to_le_bytes(); - - let mut b = b_word.to_vec(); - let mut c = c_word.to_vec(); - - // Handle b and c's signs. - { - let b_msb = get_msb(b_word); - cols.b_msb = F::from_canonical_u8(b_msb); - let c_msb = get_msb(c_word); - cols.c_msb = F::from_canonical_u8(c_msb); - - // If b is signed and it is negative, sign extend b. - if (event.opcode == Opcode::MULH || event.opcode == Opcode::MULHSU) - && b_msb == 1 - { - cols.b_sign_extend = F::one(); - b.resize(PRODUCT_SIZE, BYTE_MASK); - } - - // If c is signed and it is negative, sign extend c. - if event.opcode == Opcode::MULH && c_msb == 1 { - cols.c_sign_extend = F::one(); - c.resize(PRODUCT_SIZE, BYTE_MASK); - } - - // Insert the MSB lookup events. - { - let words = [b_word, c_word]; - let mut blu_events: Vec = vec![]; - for word in words.iter() { - let most_significant_byte = word[WORD_SIZE - 1]; - blu_events.push(ByteLookupEvent { - shard: event.shard, - opcode: ByteOpcode::MSB, - a1: get_msb(*word) as u16, - a2: 0, - b: most_significant_byte, - c: 0, - }); - } - record.add_byte_lookup_events(blu_events); - } - } - - let mut product = [0u32; PRODUCT_SIZE]; - for i in 0..b.len() { - for j in 0..c.len() { - if i + j < PRODUCT_SIZE { - product[i + j] += (b[i] as u32) * (c[j] as u32); - } - } - } - - // Calculate the correct product using the `product` array. We store the - // correct carry value for verification. - let base = (1 << BYTE_SIZE) as u32; - let mut carry = [0u32; PRODUCT_SIZE]; - for i in 0..PRODUCT_SIZE { - carry[i] = product[i] / base; - product[i] %= base; - if i + 1 < PRODUCT_SIZE { - product[i + 1] += carry[i]; - } - cols.carry[i] = F::from_canonical_u32(carry[i]); - } - - cols.product = product.map(F::from_canonical_u32); - cols.a = Word(a_word.map(F::from_canonical_u8)); - cols.b = Word(b_word.map(F::from_canonical_u8)); - cols.c = Word(c_word.map(F::from_canonical_u8)); - cols.is_real = F::one(); - cols.is_mul = F::from_bool(event.opcode == Opcode::MUL); - cols.is_mulh = F::from_bool(event.opcode == Opcode::MULH); - cols.is_mulhu = F::from_bool(event.opcode == Opcode::MULHU); - cols.is_mulhsu = F::from_bool(event.opcode == Opcode::MULHSU); - cols.shard = F::from_canonical_u32(event.shard); - - // Range check. - { - record.add_u16_range_checks(event.shard, &carry.map(|x| x as u16)); - record.add_u8_range_checks(event.shard, &product.map(|x| x as u8)); - } - row - }) - .collect::>(); - (rows, record) + let mut blu: HashMap> = HashMap::new(); + events.iter().for_each(|event| { + let mut row = [F::zero(); NUM_MUL_COLS]; + let cols: &mut MulCols = row.as_mut_slice().borrow_mut(); + self.event_to_row(event, cols, &mut blu); + }); + blu }) .collect::>(); - // Generate the trace rows for each event. - let mut rows: Vec<[F; NUM_MUL_COLS]> = vec![]; - for mut row_and_record in rows_and_records { - rows.extend(row_and_record.0); - output.append(&mut row_and_record.1); - } - - // Pad the trace to a power of two depending on the proof shape in `input`. - pad_rows_fixed( - &mut rows, - || [F::zero(); NUM_MUL_COLS], - input.fixed_log2_rows::(self), - ); - - // Convert the trace to a row major matrix. - let mut trace = - RowMajorMatrix::new(rows.into_iter().flatten().collect::>(), NUM_MUL_COLS); - - // Write the nonces to the trace. - for i in 0..trace.height() { - let cols: &mut MulCols = - trace.values[i * NUM_MUL_COLS..(i + 1) * NUM_MUL_COLS].borrow_mut(); - cols.nonce = F::from_canonical_usize(i); - } - - trace + output.add_sharded_byte_lookup_events(blu_batches.iter().collect::>()); } fn included(&self, shard: &Self::Record) -> bool { @@ -284,6 +195,100 @@ impl MachineAir for MulChip { } } +impl MulChip { + /// Create a row from an event. + fn event_to_row( + &self, + event: &AluEvent, + cols: &mut MulCols, + blu: &mut impl ByteRecord, + ) { + let a_word = event.a.to_le_bytes(); + let b_word = event.b.to_le_bytes(); + let c_word = event.c.to_le_bytes(); + + let mut b = b_word.to_vec(); + let mut c = c_word.to_vec(); + + // Handle b and c's signs. + { + let b_msb = get_msb(b_word); + cols.b_msb = F::from_canonical_u8(b_msb); + let c_msb = get_msb(c_word); + cols.c_msb = F::from_canonical_u8(c_msb); + + // If b is signed and it is negative, sign extend b. + if (event.opcode == Opcode::MULH || event.opcode == Opcode::MULHSU) && b_msb == 1 { + cols.b_sign_extend = F::one(); + b.resize(PRODUCT_SIZE, BYTE_MASK); + } + + // If c is signed and it is negative, sign extend c. + if event.opcode == Opcode::MULH && c_msb == 1 { + cols.c_sign_extend = F::one(); + c.resize(PRODUCT_SIZE, BYTE_MASK); + } + + // Insert the MSB lookup events. + { + let words = [b_word, c_word]; + let mut blu_events: Vec = vec![]; + for word in words.iter() { + let most_significant_byte = word[WORD_SIZE - 1]; + blu_events.push(ByteLookupEvent { + shard: event.shard, + opcode: ByteOpcode::MSB, + a1: get_msb(*word) as u16, + a2: 0, + b: most_significant_byte, + c: 0, + }); + } + blu.add_byte_lookup_events(blu_events); + } + } + + let mut product = [0u32; PRODUCT_SIZE]; + for i in 0..b.len() { + for j in 0..c.len() { + if i + j < PRODUCT_SIZE { + product[i + j] += (b[i] as u32) * (c[j] as u32); + } + } + } + + // Calculate the correct product using the `product` array. We store the + // correct carry value for verification. + let base = (1 << BYTE_SIZE) as u32; + let mut carry = [0u32; PRODUCT_SIZE]; + for i in 0..PRODUCT_SIZE { + carry[i] = product[i] / base; + product[i] %= base; + if i + 1 < PRODUCT_SIZE { + product[i + 1] += carry[i]; + } + cols.carry[i] = F::from_canonical_u32(carry[i]); + } + + cols.product = product.map(F::from_canonical_u32); + cols.a = Word(a_word.map(F::from_canonical_u8)); + cols.b = Word(b_word.map(F::from_canonical_u8)); + cols.c = Word(c_word.map(F::from_canonical_u8)); + cols.is_real = F::one(); + cols.is_mul = F::from_bool(event.opcode == Opcode::MUL); + cols.is_mulh = F::from_bool(event.opcode == Opcode::MULH); + cols.is_mulhu = F::from_bool(event.opcode == Opcode::MULHU); + cols.is_mulhsu = F::from_bool(event.opcode == Opcode::MULHSU); + cols.shard = F::from_canonical_u32(event.shard); + + // Range check. + { + blu.add_u16_range_checks(event.shard, &carry.map(|x| x as u16)); + blu.add_u8_range_checks(event.shard, &product.map(|x| x as u8)); + } + } +} + impl BaseAir for MulChip { fn width(&self) -> usize { NUM_MUL_COLS diff --git a/crates/core/machine/src/alu/sr/mod.rs b/crates/core/machine/src/alu/sr/mod.rs index 9c19b4491c..b26c949945 100644 --- a/crates/core/machine/src/alu/sr/mod.rs +++ b/crates/core/machine/src/alu/sr/mod.rs @@ -52,7 +52,7 @@ use itertools::Itertools; use p3_air::{Air, AirBuilder, BaseAir}; use p3_field::{AbstractField, PrimeField}; use p3_matrix::{dense::RowMajorMatrix, Matrix}; -use p3_maybe_rayon::prelude::{ParallelIterator, ParallelSlice}; +use p3_maybe_rayon::prelude::{ParallelBridge, ParallelIterator, ParallelSlice}; use sp1_core_executor::{ events::{AluEvent, ByteLookupEvent, ByteRecord}, ByteOpcode, ExecutionRecord, Opcode, Program, @@ -65,7 +65,7 @@ use crate::{ air::SP1CoreAirBuilder, alu::sr::utils::{nb_bits_to_shift, nb_bytes_to_shift}, bytes::utils::shr_carry, - utils::pad_rows_fixed, + utils::{next_power_of_two, zeroed_f_vec}, }; /// The number of main trace columns for `ShiftRightChip`. @@ -149,54 +149,33 @@ impl MachineAir for ShiftRightChip { _: &mut ExecutionRecord, ) -> RowMajorMatrix { // Generate the trace rows for each event. - let mut rows: Vec<[F; NUM_SHIFT_RIGHT_COLS]> = Vec::new(); - let sr_events = input.shift_right_events.clone(); - for event in sr_events.iter() { - assert!(event.opcode == Opcode::SRL || event.opcode == Opcode::SRA); - let mut row = [F::zero(); NUM_SHIFT_RIGHT_COLS]; - let cols: &mut ShiftRightCols = row.as_mut_slice().borrow_mut(); - let mut blu = Vec::new(); - self.event_to_row(event, cols, &mut blu); - rows.push(row); - } - - // Pad the trace to a power of two depending on the proof shape in `input`. - pad_rows_fixed( - &mut rows, - || [F::zero(); NUM_SHIFT_RIGHT_COLS], - input.fixed_log2_rows::(self), + let nb_rows = input.shift_right_events.len(); + let size_log2 = input.fixed_log2_rows::(self); + let padded_nb_rows = next_power_of_two(nb_rows, size_log2); + let mut values = zeroed_f_vec(padded_nb_rows * NUM_SHIFT_RIGHT_COLS); + let chunk_size = std::cmp::max((nb_rows + 1) / num_cpus::get(), 1); + + values.chunks_mut(chunk_size * NUM_SHIFT_RIGHT_COLS).enumerate().par_bridge().for_each( + |(i, rows)| { + rows.chunks_mut(NUM_SHIFT_RIGHT_COLS).enumerate().for_each(|(j, row)| { + let idx = i * chunk_size + j; + let cols: &mut ShiftRightCols = row.borrow_mut(); + + if idx < nb_rows { + let mut byte_lookup_events = Vec::new(); + let event = &input.shift_right_events[idx]; + self.event_to_row(event, cols, &mut byte_lookup_events); + } else { + cols.shift_by_n_bits[0] = F::one(); + cols.shift_by_n_bytes[0] = F::one(); + } + cols.nonce = F::from_canonical_usize(idx); + }); + }, ); // Convert the trace to a row major matrix. - let mut trace = RowMajorMatrix::new( - rows.into_iter().flatten().collect::>(), - NUM_SHIFT_RIGHT_COLS, - ); - - // Create the template for the padded rows. These are fake rows that don't fail on some - // sanity checks. - let padded_row_template = { - let mut row = [F::zero(); NUM_SHIFT_RIGHT_COLS]; - let cols: &mut ShiftRightCols = row.as_mut_slice().borrow_mut(); - // Shift 0 by 0 bits and 0 bytes. - // cols.is_srl = F::one(); - cols.shift_by_n_bits[0] = F::one(); - cols.shift_by_n_bytes[0] = F::one(); - row - }; - debug_assert!(padded_row_template.len() == NUM_SHIFT_RIGHT_COLS); - for i in input.shift_right_events.len() * NUM_SHIFT_RIGHT_COLS..trace.values.len() { - trace.values[i] = padded_row_template[i % NUM_SHIFT_RIGHT_COLS]; - } - - // Write the nonces to the trace. - for i in 0..trace.height() { - let cols: &mut ShiftRightCols = - trace.values[i * NUM_SHIFT_RIGHT_COLS..(i + 1) * NUM_SHIFT_RIGHT_COLS].borrow_mut(); - cols.nonce = F::from_canonical_usize(i); - } - - trace + RowMajorMatrix::new(values, NUM_SHIFT_RIGHT_COLS) } fn generate_dependencies(&self, input: &Self::Record, output: &mut Self::Record) { diff --git a/crates/core/machine/src/cpu/columns/instruction.rs b/crates/core/machine/src/cpu/columns/instruction.rs index a16de4fb08..27dd1f6a91 100644 --- a/crates/core/machine/src/cpu/columns/instruction.rs +++ b/crates/core/machine/src/cpu/columns/instruction.rs @@ -27,13 +27,13 @@ pub struct InstructionCols { } impl InstructionCols { - pub fn populate(&mut self, instruction: Instruction) { + pub fn populate(&mut self, instruction: &Instruction) { self.opcode = instruction.opcode.as_field::(); - self.op_a = instruction.op_a.into(); + self.op_a = (instruction.op_a as u32).into(); self.op_b = instruction.op_b.into(); self.op_c = instruction.op_c.into(); - self.op_a_0 = F::from_bool(instruction.op_a == Register::X0 as u32); + self.op_a_0 = F::from_bool(instruction.op_a == Register::X0 as u8); } } diff --git a/crates/core/machine/src/cpu/columns/opcode.rs b/crates/core/machine/src/cpu/columns/opcode.rs index 9b4344d036..4de8f11ba7 100644 --- a/crates/core/machine/src/cpu/columns/opcode.rs +++ b/crates/core/machine/src/cpu/columns/opcode.rs @@ -63,7 +63,7 @@ pub struct OpcodeSelectorCols { } impl OpcodeSelectorCols { - pub fn populate(&mut self, instruction: Instruction) { + pub fn populate(&mut self, instruction: &Instruction) { self.imm_b = F::from_bool(instruction.imm_b); self.imm_c = F::from_bool(instruction.imm_c); diff --git a/crates/core/machine/src/cpu/trace.rs b/crates/core/machine/src/cpu/trace.rs index 01b1489127..c9831c66c0 100644 --- a/crates/core/machine/src/cpu/trace.rs +++ b/crates/core/machine/src/cpu/trace.rs @@ -1,10 +1,10 @@ use hashbrown::HashMap; use itertools::Itertools; use sp1_core_executor::{ - events::{ByteLookupEvent, ByteRecord, CpuEvent, LookupId, MemoryRecordEnum}, + events::{ByteLookupEvent, ByteRecord, CpuEvent, MemoryRecordEnum}, syscalls::SyscallCode, ByteOpcode::{self, U16Range}, - CoreShape, ExecutionRecord, Opcode, Program, + ExecutionRecord, Instruction, Opcode, Program, Register::X0, }; use sp1_primitives::consts::WORD_SIZE; @@ -13,15 +13,10 @@ use std::{array, borrow::BorrowMut}; use p3_field::{PrimeField, PrimeField32}; use p3_matrix::dense::RowMajorMatrix; -use p3_maybe_rayon::prelude::{ - IntoParallelRefMutIterator, ParallelBridge, ParallelIterator, ParallelSlice, -}; +use p3_maybe_rayon::prelude::{ParallelBridge, ParallelIterator, ParallelSlice}; use tracing::instrument; -use super::{ - columns::{CPU_COL_MAP, NUM_CPU_COLS}, - CpuChip, -}; +use super::{columns::NUM_CPU_COLS, CpuChip}; use crate::{cpu::columns::CpuCols, memory::MemoryCols, utils::zeroed_f_vec}; impl MachineAir for CpuChip { @@ -38,7 +33,16 @@ impl MachineAir for CpuChip { input: &ExecutionRecord, _: &mut ExecutionRecord, ) -> RowMajorMatrix { - let mut values = zeroed_f_vec(input.cpu_events.len() * NUM_CPU_COLS); + let n_real_rows = input.cpu_events.len(); + let padded_nb_rows = if let Some(shape) = &input.shape { + 1 << shape.inner[&MachineAir::::name(self)] + } else if n_real_rows < 16 { + 16 + } else { + n_real_rows.next_power_of_two() + }; + let mut values = zeroed_f_vec(padded_nb_rows * NUM_CPU_COLS); + let shard = input.public_values.shard; let chunk_size = std::cmp::max(input.cpu_events.len() / num_cpus::get(), 1); values.chunks_mut(chunk_size * NUM_CPU_COLS).enumerate().par_bridge().for_each( @@ -46,30 +50,36 @@ impl MachineAir for CpuChip { rows.chunks_mut(NUM_CPU_COLS).enumerate().for_each(|(j, row)| { let idx = i * chunk_size + j; let cols: &mut CpuCols = row.borrow_mut(); - let mut byte_lookup_events = Vec::new(); - self.event_to_row( - &input.cpu_events[idx], - &input.nonce_lookup, - cols, - &mut byte_lookup_events, - ); + + if idx >= input.cpu_events.len() { + cols.selectors.imm_b = F::one(); + cols.selectors.imm_c = F::one(); + } else { + let mut byte_lookup_events = Vec::new(); + let event = &input.cpu_events[idx]; + let instruction = &input.program.fetch(event.pc); + self.event_to_row( + event, + &input.nonce_lookup, + cols, + &mut byte_lookup_events, + shard, + instruction, + ); + } }); }, ); // Convert the trace to a row major matrix. - let mut trace = RowMajorMatrix::new(values, NUM_CPU_COLS); - - // Pad the trace to a power of two. - Self::pad_to_power_of_two::(self, &input.shape, &mut trace.values); - - trace + RowMajorMatrix::new(values, NUM_CPU_COLS) } #[instrument(name = "generate cpu dependencies", level = "debug", skip_all)] fn generate_dependencies(&self, input: &ExecutionRecord, output: &mut ExecutionRecord) { // Generate the trace rows for each event. let chunk_size = std::cmp::max(input.cpu_events.len() / num_cpus::get(), 1); + let shard = input.public_values.shard; let blu_events: Vec<_> = input .cpu_events @@ -80,7 +90,15 @@ impl MachineAir for CpuChip { ops.iter().for_each(|op| { let mut row = [F::zero(); NUM_CPU_COLS]; let cols: &mut CpuCols = row.as_mut_slice().borrow_mut(); - self.event_to_row::(op, &HashMap::new(), cols, &mut blu); + let instruction = &input.program.fetch(op.pc); + self.event_to_row::( + op, + &input.nonce_lookup, + cols, + &mut blu, + shard, + instruction, + ); }); blu }) @@ -103,23 +121,25 @@ impl CpuChip { fn event_to_row( &self, event: &CpuEvent, - nonce_lookup: &HashMap, + nonce_lookup: &[u32], cols: &mut CpuCols, blu_events: &mut impl ByteRecord, + shard: u32, + instruction: &Instruction, ) { // Populate shard and clk columns. - self.populate_shard_clk(cols, event, blu_events); + self.populate_shard_clk(cols, event, blu_events, shard); // Populate the nonce. cols.nonce = F::from_canonical_u32( - nonce_lookup.get(&event.alu_lookup_id).copied().unwrap_or_default(), + nonce_lookup.get(event.alu_lookup_id.0 as usize).copied().unwrap_or_default(), ); // Populate basic fields. cols.pc = F::from_canonical_u32(event.pc); cols.next_pc = F::from_canonical_u32(event.next_pc); - cols.instruction.populate(event.instruction); - cols.selectors.populate(event.instruction); + cols.instruction.populate(instruction); + cols.selectors.populate(instruction); *cols.op_a_access.value_mut() = event.a.into(); *cols.op_b_access.value_mut() = event.b.into(); *cols.op_c_access.value_mut() = event.c.into(); @@ -145,7 +165,7 @@ impl CpuChip { .map(|x| x.as_canonical_u32()) .collect::>(); blu_events.add_byte_lookup_event(ByteLookupEvent { - shard: event.shard, + shard, opcode: ByteOpcode::U8Range, a1: 0, a2: 0, @@ -153,7 +173,7 @@ impl CpuChip { c: a_bytes[1] as u8, }); blu_events.add_byte_lookup_event(ByteLookupEvent { - shard: event.shard, + shard, opcode: ByteOpcode::U8Range, a1: 0, a2: 0, @@ -162,23 +182,20 @@ impl CpuChip { }); // Populate memory accesses for reading from memory. - assert_eq!(event.memory_record.is_some(), event.memory.is_some()); let memory_columns = cols.opcode_specific_columns.memory_mut(); if let Some(record) = event.memory_record { memory_columns.memory_access.populate(record, blu_events) } // Populate memory, branch, jump, and auipc specific fields. - self.populate_memory(cols, event, blu_events, nonce_lookup); - self.populate_branch(cols, event, nonce_lookup); - self.populate_jump(cols, event, nonce_lookup); - self.populate_auipc(cols, event, nonce_lookup); + self.populate_memory(cols, event, blu_events, nonce_lookup, shard, instruction); + self.populate_branch(cols, event, nonce_lookup, instruction); + self.populate_jump(cols, event, nonce_lookup, instruction); + self.populate_auipc(cols, event, nonce_lookup, instruction); let is_halt = self.populate_ecall(cols, event, nonce_lookup); cols.is_sequential_instr = F::from_bool( - !event.instruction.is_branch_instruction() - && !event.instruction.is_jump_instruction() - && !is_halt, + !instruction.is_branch_instruction() && !instruction.is_jump_instruction() && !is_halt, ); // Assert that the instruction is not a no-op. @@ -191,8 +208,9 @@ impl CpuChip { cols: &mut CpuCols, event: &CpuEvent, blu_events: &mut impl ByteRecord, + shard: u32, ) { - cols.shard = F::from_canonical_u32(event.shard); + cols.shard = F::from_canonical_u32(shard); cols.clk = F::from_canonical_u32(event.clk); let clk_16bit_limb = (event.clk & 0xffff) as u16; @@ -201,15 +219,15 @@ impl CpuChip { cols.clk_8bit_limb = F::from_canonical_u8(clk_8bit_limb); blu_events.add_byte_lookup_event(ByteLookupEvent::new( - event.shard, + shard, U16Range, - event.shard as u16, + shard as u16, 0, 0, 0, )); blu_events.add_byte_lookup_event(ByteLookupEvent::new( - event.shard, + shard, U16Range, clk_16bit_limb, 0, @@ -217,7 +235,7 @@ impl CpuChip { 0, )); blu_events.add_byte_lookup_event(ByteLookupEvent::new( - event.shard, + shard, ByteOpcode::U8Range, 0, 0, @@ -232,10 +250,12 @@ impl CpuChip { cols: &mut CpuCols, event: &CpuEvent, blu_events: &mut impl ByteRecord, - nonce_lookup: &HashMap, + nonce_lookup: &[u32], + shard: u32, + instruction: &Instruction, ) { if !matches!( - event.instruction.opcode, + instruction.opcode, Opcode::LB | Opcode::LH | Opcode::LW @@ -262,7 +282,7 @@ impl CpuChip { let bits: [bool; 8] = array::from_fn(|i| aligned_addr_ls_byte & (1 << i) != 0); memory_columns.aa_least_sig_byte_decomp = array::from_fn(|i| F::from_bool(bits[i + 2])); memory_columns.addr_word_nonce = F::from_canonical_u32( - nonce_lookup.get(&event.memory_add_lookup_id).copied().unwrap_or_default(), + nonce_lookup.get(event.memory_add_lookup_id.0 as usize).copied().unwrap_or_default(), ); // Populate memory offsets. @@ -275,10 +295,10 @@ impl CpuChip { // If it is a load instruction, set the unsigned_mem_val column. let mem_value = event.memory_record.unwrap().value(); if matches!( - event.instruction.opcode, + instruction.opcode, Opcode::LB | Opcode::LBU | Opcode::LH | Opcode::LHU | Opcode::LW ) { - match event.instruction.opcode { + match instruction.opcode { Opcode::LB | Opcode::LBU => { cols.unsigned_mem_val = (mem_value.to_le_bytes()[addr_offset as usize] as u32).into(); @@ -298,8 +318,8 @@ impl CpuChip { } // For the signed load instructions, we need to check if the loaded value is negative. - if matches!(event.instruction.opcode, Opcode::LB | Opcode::LH) { - let most_sig_mem_value_byte = if matches!(event.instruction.opcode, Opcode::LB) { + if matches!(instruction.opcode, Opcode::LB | Opcode::LH) { + let most_sig_mem_value_byte = if matches!(instruction.opcode, Opcode::LB) { cols.unsigned_mem_val.to_u32().to_le_bytes()[0] } else { cols.unsigned_mem_val.to_u32().to_le_bytes()[1] @@ -310,20 +330,22 @@ impl CpuChip { F::from_canonical_u8(most_sig_mem_value_byte >> i & 0x01); } if memory_columns.most_sig_byte_decomp[7] == F::one() { - cols.mem_value_is_neg_not_x0 = - F::from_bool(event.instruction.op_a != (X0 as u32)); + cols.mem_value_is_neg_not_x0 = F::from_bool(instruction.op_a != (X0 as u8)); cols.unsigned_mem_val_nonce = F::from_canonical_u32( - nonce_lookup.get(&event.memory_sub_lookup_id).copied().unwrap_or_default(), + nonce_lookup + .get(event.memory_sub_lookup_id.0 as usize) + .copied() + .unwrap_or_default(), ); } } // Set the `mem_value_is_pos_not_x0` composite flag. cols.mem_value_is_pos_not_x0 = F::from_bool( - ((matches!(event.instruction.opcode, Opcode::LB | Opcode::LH) + ((matches!(instruction.opcode, Opcode::LB | Opcode::LH) && (memory_columns.most_sig_byte_decomp[7] == F::zero())) - || matches!(event.instruction.opcode, Opcode::LBU | Opcode::LHU | Opcode::LW)) - && event.instruction.op_a != (X0 as u32), + || matches!(instruction.opcode, Opcode::LBU | Opcode::LHU | Opcode::LW)) + && instruction.op_a != (X0 as u8), ); } @@ -331,7 +353,7 @@ impl CpuChip { let addr_bytes = memory_addr.to_le_bytes(); for byte_pair in addr_bytes.chunks_exact(2) { blu_events.add_byte_lookup_event(ByteLookupEvent { - shard: event.shard, + shard, opcode: ByteOpcode::U8Range, a1: 0, a2: 0, @@ -346,15 +368,15 @@ impl CpuChip { &self, cols: &mut CpuCols, event: &CpuEvent, - nonce_lookup: &HashMap, + nonce_lookup: &[u32], + instruction: &Instruction, ) { - if event.instruction.is_branch_instruction() { + if instruction.is_branch_instruction() { let branch_columns = cols.opcode_specific_columns.branch_mut(); let a_eq_b = event.a == event.b; - let use_signed_comparison = - matches!(event.instruction.opcode, Opcode::BLT | Opcode::BGE); + let use_signed_comparison = matches!(instruction.opcode, Opcode::BLT | Opcode::BGE); let a_lt_b = if use_signed_comparison { (event.a as i32) < (event.b as i32) @@ -368,18 +390,18 @@ impl CpuChip { }; branch_columns.a_lt_b_nonce = F::from_canonical_u32( - nonce_lookup.get(&event.branch_lt_lookup_id).copied().unwrap_or_default(), + nonce_lookup.get(event.branch_lt_lookup_id.0 as usize).copied().unwrap_or_default(), ); branch_columns.a_gt_b_nonce = F::from_canonical_u32( - nonce_lookup.get(&event.branch_gt_lookup_id).copied().unwrap_or_default(), + nonce_lookup.get(event.branch_gt_lookup_id.0 as usize).copied().unwrap_or_default(), ); branch_columns.a_eq_b = F::from_bool(a_eq_b); branch_columns.a_lt_b = F::from_bool(a_lt_b); branch_columns.a_gt_b = F::from_bool(a_gt_b); - let branching = match event.instruction.opcode { + let branching = match instruction.opcode { Opcode::BEQ => a_eq_b, Opcode::BNE => !a_eq_b, Opcode::BLT | Opcode::BLTU => a_lt_b, @@ -396,7 +418,10 @@ impl CpuChip { if branching { cols.branching = F::one(); branch_columns.next_pc_nonce = F::from_canonical_u32( - nonce_lookup.get(&event.branch_add_lookup_id).copied().unwrap_or_default(), + nonce_lookup + .get(event.branch_add_lookup_id.0 as usize) + .copied() + .unwrap_or_default(), ); } else { cols.not_branching = F::one(); @@ -409,12 +434,13 @@ impl CpuChip { &self, cols: &mut CpuCols, event: &CpuEvent, - nonce_lookup: &HashMap, + nonce_lookup: &[u32], + instruction: &Instruction, ) { - if event.instruction.is_jump_instruction() { + if instruction.is_jump_instruction() { let jump_columns = cols.opcode_specific_columns.jump_mut(); - match event.instruction.opcode { + match instruction.opcode { Opcode::JAL => { let next_pc = event.pc.wrapping_add(event.b); jump_columns.op_a_range_checker.populate(event.a); @@ -423,7 +449,10 @@ impl CpuChip { jump_columns.next_pc = Word::from(next_pc); jump_columns.next_pc_range_checker.populate(next_pc); jump_columns.jal_nonce = F::from_canonical_u32( - nonce_lookup.get(&event.jump_jal_lookup_id).copied().unwrap_or_default(), + nonce_lookup + .get(event.jump_jal_lookup_id.0 as usize) + .copied() + .unwrap_or_default(), ); } Opcode::JALR => { @@ -432,7 +461,10 @@ impl CpuChip { jump_columns.next_pc = Word::from(next_pc); jump_columns.next_pc_range_checker.populate(next_pc); jump_columns.jalr_nonce = F::from_canonical_u32( - nonce_lookup.get(&event.jump_jalr_lookup_id).copied().unwrap_or_default(), + nonce_lookup + .get(event.jump_jalr_lookup_id.0 as usize) + .copied() + .unwrap_or_default(), ); } _ => unreachable!(), @@ -445,15 +477,16 @@ impl CpuChip { &self, cols: &mut CpuCols, event: &CpuEvent, - nonce_lookup: &HashMap, + nonce_lookup: &[u32], + instruction: &Instruction, ) { - if matches!(event.instruction.opcode, Opcode::AUIPC) { + if matches!(instruction.opcode, Opcode::AUIPC) { let auipc_columns = cols.opcode_specific_columns.auipc_mut(); auipc_columns.pc = Word::from(event.pc); auipc_columns.pc_range_checker.populate(event.pc); auipc_columns.auipc_nonce = F::from_canonical_u32( - nonce_lookup.get(&event.auipc_lookup_id).copied().unwrap_or_default(), + nonce_lookup.get(event.auipc_lookup_id.0 as usize).copied().unwrap_or_default(), ); } } @@ -463,7 +496,7 @@ impl CpuChip { &self, cols: &mut CpuCols, event: &CpuEvent, - nonce_lookup: &HashMap, + nonce_lookup: &[u32], ) -> bool { let mut is_halt = false; @@ -516,9 +549,8 @@ impl CpuChip { } // Write the syscall nonce. - ecall_cols.syscall_nonce = F::from_canonical_u32( - nonce_lookup.get(&event.syscall_lookup_id).copied().unwrap_or_default(), - ); + ecall_cols.syscall_nonce = + F::from_canonical_u32(nonce_lookup[event.syscall_lookup_id.0 as usize]); is_halt = syscall_id == F::from_canonical_u32(SyscallCode::HALT.syscall_id()); @@ -540,29 +572,4 @@ impl CpuChip { is_halt } - - fn pad_to_power_of_two(&self, shape: &Option, values: &mut Vec) { - let n_real_rows = values.len() / NUM_CPU_COLS; - let padded_nb_rows = if let Some(shape) = shape { - 1 << shape.inner[&MachineAir::::name(self)] - } else if n_real_rows < 16 { - 16 - } else { - n_real_rows.next_power_of_two() - }; - values.resize(padded_nb_rows * NUM_CPU_COLS, F::zero()); - - // Interpret values as a slice of arrays of length `NUM_CPU_COLS` - let rows = unsafe { - core::slice::from_raw_parts_mut( - values.as_mut_ptr() as *mut [F; NUM_CPU_COLS], - values.len() / NUM_CPU_COLS, - ) - }; - - rows[n_real_rows..].par_iter_mut().for_each(|padded_row| { - padded_row[CPU_COL_MAP.selectors.imm_b] = F::one(); - padded_row[CPU_COL_MAP.selectors.imm_c] = F::one(); - }); - } } diff --git a/crates/core/machine/src/memory/local.rs b/crates/core/machine/src/memory/local.rs index bf109d028e..8be4377031 100644 --- a/crates/core/machine/src/memory/local.rs +++ b/crates/core/machine/src/memory/local.rs @@ -3,11 +3,11 @@ use std::{ mem::size_of, }; -use crate::utils::pad_rows_fixed; -use itertools::Itertools; +use crate::utils::{next_power_of_two, zeroed_f_vec}; use p3_air::{Air, BaseAir}; use p3_field::PrimeField32; use p3_matrix::{dense::RowMajorMatrix, Matrix}; +use p3_maybe_rayon::prelude::{ParallelBridge, ParallelIterator}; use sp1_core_executor::{ExecutionRecord, Program}; use sp1_derive::AlignedBorrow; use sp1_stark::{ @@ -86,39 +86,45 @@ impl MachineAir for MemoryLocalChip { input: &ExecutionRecord, _output: &mut ExecutionRecord, ) -> RowMajorMatrix { - let mut rows = Vec::<[F; NUM_MEMORY_LOCAL_INIT_COLS]>::new(); - - for local_mem_events in - &input.get_local_mem_events().chunks(NUM_LOCAL_MEMORY_ENTRIES_PER_ROW) - { - let mut row = [F::zero(); NUM_MEMORY_LOCAL_INIT_COLS]; - let cols: &mut MemoryLocalCols = row.as_mut_slice().borrow_mut(); - - for (cols, event) in cols.memory_local_entries.iter_mut().zip(local_mem_events) { - cols.addr = F::from_canonical_u32(event.addr); - cols.initial_shard = F::from_canonical_u32(event.initial_mem_access.shard); - cols.final_shard = F::from_canonical_u32(event.final_mem_access.shard); - cols.initial_clk = F::from_canonical_u32(event.initial_mem_access.timestamp); - cols.final_clk = F::from_canonical_u32(event.final_mem_access.timestamp); - cols.initial_value = event.initial_mem_access.value.into(); - cols.final_value = event.final_mem_access.value.into(); - cols.is_real = F::one(); - } - - rows.push(row); - } - - // Pad the trace to a power of two depending on the proof shape in `input`. - pad_rows_fixed( - &mut rows, - || [F::zero(); NUM_MEMORY_LOCAL_INIT_COLS], - input.fixed_log2_rows::(self), - ); - - RowMajorMatrix::new( - rows.into_iter().flatten().collect::>(), - NUM_MEMORY_LOCAL_INIT_COLS, - ) + // Generate the trace rows for each event. + let events = input.get_local_mem_events().collect::>(); + let nb_rows = (events.len() + 3) / 4; + let size_log2 = input.fixed_log2_rows::(self); + let padded_nb_rows = next_power_of_two(nb_rows, size_log2); + let mut values = zeroed_f_vec(padded_nb_rows * NUM_MEMORY_LOCAL_INIT_COLS); + let chunk_size = std::cmp::max((nb_rows + 1) / num_cpus::get(), 1); + + values + .chunks_mut(chunk_size * NUM_MEMORY_LOCAL_INIT_COLS) + .enumerate() + .par_bridge() + .for_each(|(i, rows)| { + rows.chunks_mut(NUM_MEMORY_LOCAL_INIT_COLS).enumerate().for_each(|(j, row)| { + let idx = (i * chunk_size + j) * NUM_LOCAL_MEMORY_ENTRIES_PER_ROW; + + let cols: &mut MemoryLocalCols = row.borrow_mut(); + for k in 0..NUM_LOCAL_MEMORY_ENTRIES_PER_ROW { + let cols = &mut cols.memory_local_entries[k]; + if idx + k < events.len() { + let event = &events[idx + k]; + cols.addr = F::from_canonical_u32(event.addr); + cols.initial_shard = + F::from_canonical_u32(event.initial_mem_access.shard); + cols.final_shard = F::from_canonical_u32(event.final_mem_access.shard); + cols.initial_clk = + F::from_canonical_u32(event.initial_mem_access.timestamp); + cols.final_clk = + F::from_canonical_u32(event.final_mem_access.timestamp); + cols.initial_value = event.initial_mem_access.value.into(); + cols.final_value = event.final_mem_access.value.into(); + cols.is_real = F::one(); + } + } + }); + }); + + // Convert the trace to a row major matrix. + RowMajorMatrix::new(values, NUM_MEMORY_LOCAL_INIT_COLS) } fn included(&self, shard: &Self::Record) -> bool { diff --git a/crates/core/machine/src/memory/program.rs b/crates/core/machine/src/memory/program.rs index 7b68661624..699e052c0d 100644 --- a/crates/core/machine/src/memory/program.rs +++ b/crates/core/machine/src/memory/program.rs @@ -7,6 +7,7 @@ use p3_air::{Air, AirBuilder, AirBuilderWithPublicValues, BaseAir, PairBuilder}; use p3_field::{AbstractField, PrimeField}; use p3_matrix::{dense::RowMajorMatrix, Matrix}; +use p3_maybe_rayon::prelude::{ParallelBridge, ParallelIterator}; use sp1_core_executor::{ExecutionRecord, Program}; use sp1_derive::AlignedBorrow; use sp1_stark::{ @@ -17,7 +18,10 @@ use sp1_stark::{ InteractionKind, Word, }; -use crate::{operations::IsZeroOperation, utils::pad_rows_fixed}; +use crate::{ + operations::IsZeroOperation, + utils::{next_power_of_two, pad_rows_fixed, zeroed_f_vec}, +}; pub const NUM_MEMORY_PROGRAM_PREPROCESSED_COLS: usize = size_of::>(); @@ -71,35 +75,36 @@ impl MachineAir for MemoryProgramChip { } fn generate_preprocessed_trace(&self, program: &Self::Program) -> Option> { - let program_memory = &program.memory_image; - // Note that BTreeMap is guaranteed to be sorted by key. This makes the row order - // deterministic. - let mut rows = program_memory - .iter() - .sorted() - .map(|(&addr, &word)| { - let mut row = [F::zero(); NUM_MEMORY_PROGRAM_PREPROCESSED_COLS]; - let cols: &mut MemoryProgramPreprocessedCols = row.as_mut_slice().borrow_mut(); - cols.addr = F::from_canonical_u32(addr); - cols.value = Word::from(word); - cols.is_real = F::one(); - row - }) - .collect::>(); - - // Pad the trace to a power of two depending on the proof shape in `input`. - pad_rows_fixed( - &mut rows, - || [F::zero(); NUM_MEMORY_PROGRAM_PREPROCESSED_COLS], - program.fixed_log2_rows::(self), - ); + // Generate the trace rows for each event. + let nb_rows = program.memory_image.len(); + let size_log2 = program.fixed_log2_rows::(self); + let padded_nb_rows = next_power_of_two(nb_rows, size_log2); + let mut values = zeroed_f_vec(padded_nb_rows * NUM_MEMORY_PROGRAM_PREPROCESSED_COLS); + let chunk_size = std::cmp::max((nb_rows + 1) / num_cpus::get(), 1); + + let memory = program.memory_image.iter().collect::>(); + values + .chunks_mut(chunk_size * NUM_MEMORY_PROGRAM_PREPROCESSED_COLS) + .enumerate() + .par_bridge() + .for_each(|(i, rows)| { + rows.chunks_mut(NUM_MEMORY_PROGRAM_PREPROCESSED_COLS).enumerate().for_each( + |(j, row)| { + let idx = i * chunk_size + j; + + if idx < nb_rows { + let (addr, word) = memory[idx]; + let cols: &mut MemoryProgramPreprocessedCols = row.borrow_mut(); + cols.addr = F::from_canonical_u32(*addr); + cols.value = Word::from(*word); + cols.is_real = F::one(); + } + }, + ); + }); // Convert the trace to a row major matrix. - let trace = RowMajorMatrix::new( - rows.into_iter().flatten().collect::>(), - NUM_MEMORY_PROGRAM_PREPROCESSED_COLS, - ); - Some(trace) + Some(RowMajorMatrix::new(values, NUM_MEMORY_PROGRAM_PREPROCESSED_COLS)) } fn generate_dependencies(&self, _input: &ExecutionRecord, _output: &mut ExecutionRecord) { diff --git a/crates/core/machine/src/program/mod.rs b/crates/core/machine/src/program/mod.rs index 39dfb6a765..f6c1f3bc0c 100644 --- a/crates/core/machine/src/program/mod.rs +++ b/crates/core/machine/src/program/mod.rs @@ -4,10 +4,14 @@ use core::{ }; use std::collections::HashMap; -use crate::{air::ProgramAirBuilder, utils::pad_rows_fixed}; +use crate::{ + air::ProgramAirBuilder, + utils::{next_power_of_two, pad_rows_fixed, zeroed_f_vec}, +}; use p3_air::{Air, BaseAir, PairBuilder}; use p3_field::PrimeField; use p3_matrix::{dense::RowMajorMatrix, Matrix}; +use p3_maybe_rayon::prelude::{ParallelBridge, ParallelIterator}; use sp1_core_executor::{ExecutionRecord, Program}; use sp1_derive::AlignedBorrow; use sp1_stark::air::{MachineAir, SP1AirBuilder}; @@ -65,36 +69,34 @@ impl MachineAir for ProgramChip { !program.instructions.is_empty() || program.preprocessed_shape.is_some(), "empty program" ); - let mut rows = program - .instructions - .iter() + // Generate the trace rows for each event. + let nb_rows = program.instructions.len(); + let size_log2 = program.fixed_log2_rows::(self); + let padded_nb_rows = next_power_of_two(nb_rows, size_log2); + let mut values = zeroed_f_vec(padded_nb_rows * NUM_PROGRAM_PREPROCESSED_COLS); + let chunk_size = std::cmp::max((nb_rows + 1) / num_cpus::get(), 1); + + values + .chunks_mut(chunk_size * NUM_PROGRAM_PREPROCESSED_COLS) .enumerate() - .map(|(i, &instruction)| { - let pc = program.pc_base + (i as u32 * 4); - let mut row = [F::zero(); NUM_PROGRAM_PREPROCESSED_COLS]; - let cols: &mut ProgramPreprocessedCols = row.as_mut_slice().borrow_mut(); - cols.pc = F::from_canonical_u32(pc); - cols.instruction.populate(instruction); - cols.selectors.populate(instruction); - - row - }) - .collect::>(); - - // Pad the trace to a power of two depending on the proof shape in `input`. - pad_rows_fixed( - &mut rows, - || [F::zero(); NUM_PROGRAM_PREPROCESSED_COLS], - program.fixed_log2_rows::(self), - ); + .par_bridge() + .for_each(|(i, rows)| { + rows.chunks_mut(NUM_PROGRAM_PREPROCESSED_COLS).enumerate().for_each(|(j, row)| { + let idx = i * chunk_size + j; + + if idx < nb_rows { + let cols: &mut ProgramPreprocessedCols = row.borrow_mut(); + let instruction = &program.instructions[idx]; + let pc = program.pc_base + (idx as u32 * 4); + cols.pc = F::from_canonical_u32(pc); + cols.instruction.populate(instruction); + cols.selectors.populate(instruction); + } + }); + }); // Convert the trace to a row major matrix. - let trace = RowMajorMatrix::new( - rows.into_iter().flatten().collect::>(), - NUM_PROGRAM_PREPROCESSED_COLS, - ); - - Some(trace) + Some(RowMajorMatrix::new(values, NUM_PROGRAM_PREPROCESSED_COLS)) } fn generate_dependencies(&self, _input: &ExecutionRecord, _output: &mut ExecutionRecord) { diff --git a/crates/core/machine/src/runtime/utils.rs b/crates/core/machine/src/runtime/utils.rs index 7c0ad541e7..483400e294 100644 --- a/crates/core/machine/src/runtime/utils.rs +++ b/crates/core/machine/src/runtime/utils.rs @@ -19,7 +19,7 @@ macro_rules! assert_valid_memory_access { assert!($addr > 40); } _ => { - Register::from_u32($addr); + Register::from_u8($addr); } }; } @@ -69,11 +69,7 @@ impl<'a> Runtime<'a> { ); if !self.unconstrained && self.state.global_clk % 10_000_000 == 0 { - log::info!( - "clk = {} pc = 0x{:x?}", - self.state.global_clk, - self.state.pc - ); + log::info!("clk = {} pc = 0x{:x?}", self.state.global_clk, self.state.pc); } } } diff --git a/crates/core/machine/src/syscall/precompiles/edwards/ed_decompress.rs b/crates/core/machine/src/syscall/precompiles/edwards/ed_decompress.rs index c9c5ae93c8..97d0e526fe 100644 --- a/crates/core/machine/src/syscall/precompiles/edwards/ed_decompress.rs +++ b/crates/core/machine/src/syscall/precompiles/edwards/ed_decompress.rs @@ -72,7 +72,7 @@ impl EdDecompressCols { self.clk = F::from_canonical_u32(event.clk); self.ptr = F::from_canonical_u32(event.ptr); self.nonce = F::from_canonical_u32( - record.nonce_lookup.get(&event.lookup_id).copied().unwrap_or_default(), + record.nonce_lookup.get(event.lookup_id.0 as usize).copied().unwrap_or_default(), ); self.sign = F::from_bool(event.sign); for i in 0..8 { diff --git a/crates/prover/src/lib.rs b/crates/prover/src/lib.rs index be4d816264..fb46afd4ed 100644 --- a/crates/prover/src/lib.rs +++ b/crates/prover/src/lib.rs @@ -33,11 +33,7 @@ use std::{ }; use lru::LruCache; - -use tracing::instrument; - use p3_baby_bear::BabyBear; - use p3_challenger::CanObserve; use p3_field::{AbstractField, PrimeField, PrimeField32}; use p3_matrix::dense::RowMajorMatrix; @@ -80,6 +76,7 @@ use sp1_stark::{ MachineProver, SP1CoreOpts, SP1ProverOpts, ShardProof, StarkGenericConfig, StarkVerifyingKey, Val, Word, DIGEST_SIZE, }; +use tracing::instrument; pub use types::*; use utils::{sp1_committed_values_digest_bn254, sp1_vkey_digest_bn254, words_to_bytes}; @@ -356,8 +353,9 @@ impl SP1Prover { input: &SP1CompressWithVKeyWitnessValues, ) -> Arc> { let mut cache = self.compress_programs.lock().unwrap_or_else(|e| e.into_inner()); + let shape = input.shape(); cache - .get_or_insert(input.shape(), || { + .get_or_insert(shape.clone(), || { let misses = self.compress_cache_misses.fetch_add(1, Ordering::Relaxed); tracing::debug!("compress cache miss, misses: {}", misses); // Get the operations. diff --git a/crates/prover/src/shapes.rs b/crates/prover/src/shapes.rs index b7adddc0e5..74f7ba177e 100644 --- a/crates/prover/src/shapes.rs +++ b/crates/prover/src/shapes.rs @@ -1,11 +1,13 @@ use std::{ collections::{BTreeMap, BTreeSet, HashSet}, fs::File, + hash::{DefaultHasher, Hash, Hasher}, panic::{catch_unwind, AssertUnwindSafe}, path::PathBuf, sync::{Arc, Mutex}, }; +use eyre::Result; use thiserror::Error; use p3_baby_bear::BabyBear; @@ -29,7 +31,7 @@ pub enum SP1ProofShape { Shrink(ProofShape), } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Hash)] pub enum SP1CompressProgramShape { Recursion(SP1RecursionShape), Compress(SP1CompressWithVkeyShape), @@ -37,6 +39,14 @@ pub enum SP1CompressProgramShape { Shrink(SP1CompressWithVkeyShape), } +impl SP1CompressProgramShape { + pub fn hash_u64(&self) -> u64 { + let mut hasher = DefaultHasher::new(); + Hash::hash(&self, &mut hasher); + hasher.finish() + } +} + #[derive(Debug, Error)] pub enum VkBuildError { #[error("IO error: {0}")] @@ -231,6 +241,15 @@ impl SP1ProofShape { ) } + pub fn generate_compress_shapes( + recursion_shape_config: &'_ RecursionShapeConfig>, + reduce_batch_size: usize, + ) -> impl Iterator + '_ { + (1..=reduce_batch_size).flat_map(|batch_size| { + recursion_shape_config.get_all_shape_combinations(batch_size).map(Self::Compress) + }) + } + pub fn dummy_vk_map<'a>( core_shape_config: &'a CoreShapeConfig, recursion_shape_config: &'a RecursionShapeConfig>, diff --git a/crates/recursion/circuit/src/machine/deferred.rs b/crates/recursion/circuit/src/machine/deferred.rs index 0f38620a7d..d5ab720973 100644 --- a/crates/recursion/circuit/src/machine/deferred.rs +++ b/crates/recursion/circuit/src/machine/deferred.rs @@ -43,7 +43,7 @@ pub struct SP1DeferredVerifier { _phantom: std::marker::PhantomData<(C, SC, A)>, } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Hash)] pub struct SP1DeferredShape { inner: SP1CompressShape, height: usize, diff --git a/crates/recursion/compiler/src/circuit/compiler.rs b/crates/recursion/compiler/src/circuit/compiler.rs index d75c0d3d6e..b44b38a1dd 100644 --- a/crates/recursion/compiler/src/circuit/compiler.rs +++ b/crates/recursion/compiler/src/circuit/compiler.rs @@ -17,6 +17,9 @@ use sp1_recursion_core::*; use crate::prelude::*; +/// The number of instructions to preallocate in a recursion program +const PREALLOC_INSTRUCTIONS: usize = 10000000; + /// The backend for the circuit compiler. #[derive(Debug, Clone, Default)] pub struct AsmCompiler { @@ -511,7 +514,7 @@ where // Compile each IR instruction into a list of ASM instructions, then combine them. // This step also counts the number of times each address is read from. let (mut instrs, traces) = tracing::debug_span!("compile_one loop").in_scope(|| { - let mut instrs = Vec::with_capacity(operations.vec.len()); + let mut instrs = Vec::with_capacity(PREALLOC_INSTRUCTIONS); let mut traces = vec![]; if debug_mode { let mut span_builder = diff --git a/crates/recursion/core/src/chips/poseidon2_skinny/trace.rs b/crates/recursion/core/src/chips/poseidon2_skinny/trace.rs index ecd9c57550..7e67f54362 100644 --- a/crates/recursion/core/src/chips/poseidon2_skinny/trace.rs +++ b/crates/recursion/core/src/chips/poseidon2_skinny/trace.rs @@ -41,6 +41,10 @@ impl MachineAir for Poseidon2SkinnyChip format!("Poseidon2SkinnyDeg{}", DEGREE) } + fn generate_dependencies(&self, _: &Self::Record, _: &mut Self::Record) { + // This is a no-op. + } + #[instrument(name = "generate poseidon2 skinny trace", level = "debug", skip_all, fields(rows = input.poseidon2_events.len()))] fn generate_trace( &self, diff --git a/crates/recursion/core/src/chips/poseidon2_wide/trace.rs b/crates/recursion/core/src/chips/poseidon2_wide/trace.rs index 0d5c666265..19a3b6f287 100644 --- a/crates/recursion/core/src/chips/poseidon2_wide/trace.rs +++ b/crates/recursion/core/src/chips/poseidon2_wide/trace.rs @@ -37,6 +37,10 @@ impl MachineAir for Poseidon2WideChip, const DEGREE: u #[derive(Debug, Clone, Copy, Default)] pub struct RecursionAirEventCount { - mem_const_events: usize, - mem_var_events: usize, - base_alu_events: usize, - ext_alu_events: usize, - poseidon2_wide_events: usize, - fri_fold_events: usize, - select_events: usize, - exp_reverse_bits_len_events: usize, + pub mem_const_events: usize, + pub mem_var_events: usize, + pub base_alu_events: usize, + pub ext_alu_events: usize, + pub poseidon2_wide_events: usize, + pub fri_fold_events: usize, + pub select_events: usize, + pub exp_reverse_bits_len_events: usize, } impl, const DEGREE: usize> RecursionAir { diff --git a/crates/recursion/core/src/runtime/memory.rs b/crates/recursion/core/src/runtime/memory.rs index a82337b684..d0b7f0f229 100644 --- a/crates/recursion/core/src/runtime/memory.rs +++ b/crates/recursion/core/src/runtime/memory.rs @@ -5,7 +5,7 @@ use vec_map::{Entry, VecMap}; use crate::{air::Block, Address}; -#[derive(Debug, Clone, Default)] +#[derive(Debug, Clone, Default, Copy)] pub struct MemoryEntry { pub val: Block, pub mult: F, diff --git a/crates/recursion/core/src/runtime/mod.rs b/crates/recursion/core/src/runtime/mod.rs index f95ef87a14..2478a45ddb 100644 --- a/crates/recursion/core/src/runtime/mod.rs +++ b/crates/recursion/core/src/runtime/mod.rs @@ -8,6 +8,7 @@ mod record; use backtrace::Backtrace as Trace; pub use instruction::Instruction; use instruction::{FieldEltType, HintBitsInstr, HintExt2FeltsInstr, HintInstr, PrintInstr}; +use machine::RecursionAirEventCount; use memory::*; pub use opcode::*; pub use program::*; @@ -249,6 +250,7 @@ where pub fn run(&mut self) -> Result<(), RuntimeError> { let early_exit_ts = std::env::var("RECURSION_EARLY_EXIT_TS") .map_or(usize::MAX, |ts: String| ts.parse().unwrap()); + self.preallocate_record(); while self.pc < F::from_canonical_u32(self.program.instructions.len() as u32) { let idx = self.pc.as_canonical_u32() as usize; let instruction = self.program.instructions[idx].clone(); @@ -544,4 +546,18 @@ where } Ok(()) } + + pub fn preallocate_record(&mut self) { + let event_counts = self + .program + .instructions + .iter() + .fold(RecursionAirEventCount::default(), |heights, instruction| heights + instruction); + self.record.poseidon2_events.reserve(event_counts.poseidon2_wide_events); + self.record.mem_var_events.reserve(event_counts.mem_var_events); + self.record.base_alu_events.reserve(event_counts.base_alu_events); + self.record.ext_alu_events.reserve(event_counts.ext_alu_events); + self.record.exp_reverse_bits_len_events.reserve(event_counts.exp_reverse_bits_len_events); + self.record.select_events.reserve(event_counts.select_events); + } }