diff --git a/Cargo.lock b/Cargo.lock index d5baa3132c..5ffacb584b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1937,6 +1937,27 @@ dependencies = [ "zeroize", ] +[[package]] +name = "enum-map" +version = "2.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6866f3bfdf8207509a033af1a75a7b08abda06bbaaeae6669323fd5a097df2e9" +dependencies = [ + "enum-map-derive", + "serde", +] + +[[package]] +name = "enum-map-derive" +version = "0.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f282cfdfe92516eb26c2af8589c274c7c17681f5ecc03c18255fe741c6aa64eb" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.72", +] + [[package]] name = "equivalent" version = "1.0.1" @@ -5964,6 +5985,7 @@ dependencies = [ "bincode", "bytemuck", "elf", + "enum-map", "eyre", "generic-array 1.1.0", "hashbrown 0.14.5", diff --git a/crates/core/executor/Cargo.toml b/crates/core/executor/Cargo.toml index 50bd5d750d..0e48407583 100644 --- a/crates/core/executor/Cargo.toml +++ b/crates/core/executor/Cargo.toml @@ -44,6 +44,7 @@ hex = "0.4.3" bytemuck = "1.16.3" tiny-keccak = { version = "2.0.2", features = ["keccak"] } vec_map = { version = "0.8.2", features = ["serde"] } +enum-map = { version = "2.7.3", features = ["serde"] } [dev-dependencies] sp1-zkvm = { workspace = true } diff --git a/crates/core/executor/src/executor.rs b/crates/core/executor/src/executor.rs index e1c26b7111..991ba3cb0c 100644 --- a/crates/core/executor/src/executor.rs +++ b/crates/core/executor/src/executor.rs @@ -696,23 +696,19 @@ impl<'a> Executor<'a> { if self.executor_mode == ExecutorMode::Trace { self.memory_accesses = MemoryAccessRecord::default(); } - let lookup_id = if self.executor_mode == ExecutorMode::Simple { - LookupId::default() - } else { + let lookup_id = if self.executor_mode == ExecutorMode::Trace { create_alu_lookup_id() - }; - let syscall_lookup_id = if self.executor_mode == ExecutorMode::Simple { - LookupId::default() } else { + LookupId::default() + }; + let syscall_lookup_id = if self.executor_mode == ExecutorMode::Trace { create_alu_lookup_id() + } else { + LookupId::default() }; if self.print_report && !self.unconstrained { - self.report - .opcode_counts - .entry(instruction.opcode) - .and_modify(|c| *c += 1) - .or_insert(1); + self.report.opcode_counts[instruction.opcode] += 1; } match instruction.opcode { @@ -930,7 +926,7 @@ impl<'a> Executor<'a> { let syscall = SyscallCode::from_u32(syscall_id); if self.print_report && !self.unconstrained { - self.report.syscall_counts.entry(syscall).and_modify(|c| *c += 1).or_insert(1); + self.report.syscall_counts[syscall] += 1; } // `hint_slice` is allowed in unconstrained mode since it is used to write the hint. diff --git a/crates/core/executor/src/opcode.rs b/crates/core/executor/src/opcode.rs index 868f516d49..e51d3f05ad 100644 --- a/crates/core/executor/src/opcode.rs +++ b/crates/core/executor/src/opcode.rs @@ -2,6 +2,7 @@ use std::fmt::Display; +use enum_map::Enum; use p3_field::Field; use serde::{Deserialize, Serialize}; @@ -20,7 +21,9 @@ use serde::{Deserialize, Serialize}; /// Refer to the "RV32I Reference Card" [here](https://github.com/johnwinans/rvalp/releases) for /// more details. #[allow(non_camel_case_types)] -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, PartialOrd, Ord)] +#[derive( + Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, PartialOrd, Ord, Enum, +)] pub enum Opcode { /// rd ← rs1 + rs2, pc ← pc + 4 ADD = 0, diff --git a/crates/core/executor/src/report.rs b/crates/core/executor/src/report.rs index 3e4b29458a..7459d2d91b 100644 --- a/crates/core/executor/src/report.rs +++ b/crates/core/executor/src/report.rs @@ -1,19 +1,20 @@ use std::{ - collections::{hash_map::Entry, HashMap}, fmt::{Display, Formatter, Result as FmtResult}, - hash::Hash, ops::{Add, AddAssign}, }; +use enum_map::{EnumArray, EnumMap}; +use hashbrown::HashMap; + use crate::{events::sorted_table_lines, syscalls::SyscallCode, Opcode}; /// An execution report. #[derive(Default, Debug, Clone, PartialEq, Eq)] pub struct ExecutionReport { /// The opcode counts. - pub opcode_counts: HashMap, + pub opcode_counts: Box>, /// The syscall counts. - pub syscall_counts: HashMap, + pub syscall_counts: Box>, /// The cycle tracker counts. pub cycle_tracker: HashMap, /// The unique memory address counts. @@ -35,24 +36,20 @@ impl ExecutionReport { } /// Combines two `HashMap`s together. If a key is in both maps, the values are added together. -fn hashmap_add_assign(lhs: &mut HashMap, rhs: HashMap) +fn counts_add_assign(lhs: &mut EnumMap, rhs: EnumMap) where - K: Eq + Hash, + K: EnumArray, V: AddAssign, { for (k, v) in rhs { - // Can't use `.and_modify(...).or_insert(...)` because we want to use `v` in both places. - match lhs.entry(k) { - Entry::Occupied(e) => *e.into_mut() += v, - Entry::Vacant(e) => drop(e.insert(v)), - } + lhs[k] += v; } } impl AddAssign for ExecutionReport { fn add_assign(&mut self, rhs: Self) { - hashmap_add_assign(&mut self.opcode_counts, rhs.opcode_counts); - hashmap_add_assign(&mut self.syscall_counts, rhs.syscall_counts); + counts_add_assign(&mut self.opcode_counts, *rhs.opcode_counts); + counts_add_assign(&mut self.syscall_counts, *rhs.syscall_counts); self.touched_memory_addresses += rhs.touched_memory_addresses; } } @@ -69,12 +66,12 @@ impl Add for ExecutionReport { impl Display for ExecutionReport { fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { writeln!(f, "opcode counts ({} total instructions):", self.total_instruction_count())?; - for line in sorted_table_lines(&self.opcode_counts) { + for line in sorted_table_lines(self.opcode_counts.as_ref()) { writeln!(f, " {line}")?; } writeln!(f, "syscall counts ({} total syscall instructions):", self.total_syscall_count())?; - for line in sorted_table_lines(&self.syscall_counts) { + for line in sorted_table_lines(self.syscall_counts.as_ref()) { writeln!(f, " {line}")?; } diff --git a/crates/core/executor/src/syscalls/code.rs b/crates/core/executor/src/syscalls/code.rs index 929d7003c2..ece5a06e99 100644 --- a/crates/core/executor/src/syscalls/code.rs +++ b/crates/core/executor/src/syscalls/code.rs @@ -1,3 +1,4 @@ +use enum_map::Enum; use serde::{Deserialize, Serialize}; use strum_macros::EnumIter; @@ -18,7 +19,7 @@ use strum_macros::EnumIter; /// memory accesses is bounded. /// - Byte 3: Currently unused. #[derive( - Debug, Copy, Clone, PartialEq, Eq, Hash, EnumIter, Ord, PartialOrd, Serialize, Deserialize, + Debug, Copy, Clone, PartialEq, Eq, Hash, EnumIter, Ord, PartialOrd, Serialize, Deserialize, Enum, )] #[allow(non_camel_case_types)] #[allow(clippy::upper_case_acronyms)] diff --git a/crates/core/machine/src/riscv/cost.rs b/crates/core/machine/src/riscv/cost.rs index 685e683711..0da65a34cc 100644 --- a/crates/core/machine/src/riscv/cost.rs +++ b/crates/core/machine/src/riscv/cost.rs @@ -33,151 +33,133 @@ impl CostEstimator for ExecutionReport { total_area += (cpu_events as u64) * costs[&RiscvAirDiscriminants::Cpu]; total_chips += 1; - let sha_extend_events = *self.syscall_counts.get(&SyscallCode::SHA_EXTEND).unwrap_or(&0); + let sha_extend_events = self.syscall_counts[SyscallCode::SHA_EXTEND]; total_area += (sha_extend_events as u64) * costs[&RiscvAirDiscriminants::Sha256Extend]; total_chips += 1; - let sha_compress_events = - *self.syscall_counts.get(&SyscallCode::SHA_COMPRESS).unwrap_or(&0); + let sha_compress_events = self.syscall_counts[SyscallCode::SHA_COMPRESS]; total_area += (sha_compress_events as u64) * costs[&RiscvAirDiscriminants::Sha256Compress]; total_chips += 1; - let ed_add_events = *self.syscall_counts.get(&SyscallCode::ED_ADD).unwrap_or(&0); + let ed_add_events = self.syscall_counts[SyscallCode::ED_ADD]; total_area += (ed_add_events as u64) * costs[&RiscvAirDiscriminants::Ed25519Add]; total_chips += 1; - let ed_decompress_events = - *self.syscall_counts.get(&SyscallCode::ED_DECOMPRESS).unwrap_or(&0); + let ed_decompress_events = self.syscall_counts[SyscallCode::ED_DECOMPRESS]; total_area += (ed_decompress_events as u64) * costs[&RiscvAirDiscriminants::Ed25519Decompress]; total_chips += 1; - let k256_decompress_events = - *self.syscall_counts.get(&SyscallCode::SECP256K1_DECOMPRESS).unwrap_or(&0); + let k256_decompress_events = self.syscall_counts[SyscallCode::SECP256K1_DECOMPRESS]; total_area += (k256_decompress_events as u64) * costs[&RiscvAirDiscriminants::K256Decompress]; total_chips += 1; - let secp256k1_add_events = - *self.syscall_counts.get(&SyscallCode::SECP256K1_ADD).unwrap_or(&0); + let secp256k1_add_events = self.syscall_counts[SyscallCode::SECP256K1_ADD]; total_area += (secp256k1_add_events as u64) * costs[&RiscvAirDiscriminants::Secp256k1Add]; total_chips += 1; - let secp256k1_double_events = - *self.syscall_counts.get(&SyscallCode::SECP256K1_DOUBLE).unwrap_or(&0); + let secp256k1_double_events = self.syscall_counts[SyscallCode::SECP256K1_DOUBLE]; total_area += (secp256k1_double_events as u64) * costs[&RiscvAirDiscriminants::Secp256k1Double]; total_chips += 1; - let keccak256_permute_events = - *self.syscall_counts.get(&SyscallCode::KECCAK_PERMUTE).unwrap_or(&0); + let keccak256_permute_events = self.syscall_counts[SyscallCode::KECCAK_PERMUTE]; total_area += (keccak256_permute_events as u64) * costs[&RiscvAirDiscriminants::KeccakP]; total_chips += 1; - let bn254_add_events = *self.syscall_counts.get(&SyscallCode::BN254_ADD).unwrap_or(&0); + let bn254_add_events = self.syscall_counts[SyscallCode::BN254_ADD]; total_area += (bn254_add_events as u64) * costs[&RiscvAirDiscriminants::Bn254Add]; total_chips += 1; - let bn254_double_events = - *self.syscall_counts.get(&SyscallCode::BN254_DOUBLE).unwrap_or(&0); + let bn254_double_events = self.syscall_counts[SyscallCode::BN254_DOUBLE]; total_area += (bn254_double_events as u64) * costs[&RiscvAirDiscriminants::Bn254Double]; total_chips += 1; - let bls12381_add_events = - *self.syscall_counts.get(&SyscallCode::BLS12381_ADD).unwrap_or(&0); + let bls12381_add_events = self.syscall_counts[SyscallCode::BLS12381_ADD]; total_area += (bls12381_add_events as u64) * costs[&RiscvAirDiscriminants::Bls12381Add]; total_chips += 1; - let bls12381_double_events = - *self.syscall_counts.get(&SyscallCode::BLS12381_DOUBLE).unwrap_or(&0); + let bls12381_double_events = self.syscall_counts[SyscallCode::BLS12381_DOUBLE]; total_area += (bls12381_double_events as u64) * costs[&RiscvAirDiscriminants::Bls12381Double]; total_chips += 1; - let uint256_mul_events = *self.syscall_counts.get(&SyscallCode::UINT256_MUL).unwrap_or(&0); + let uint256_mul_events = self.syscall_counts[SyscallCode::UINT256_MUL]; total_area += (uint256_mul_events as u64) * costs[&RiscvAirDiscriminants::Uint256Mul]; total_chips += 1; - let bls12381_fp_events = - *self.syscall_counts.get(&SyscallCode::BLS12381_FP_ADD).unwrap_or(&0) - + *self.syscall_counts.get(&SyscallCode::BLS12381_FP_SUB).unwrap_or(&0) - + *self.syscall_counts.get(&SyscallCode::BLS12381_FP_MUL).unwrap_or(&0); + let bls12381_fp_events = self.syscall_counts[SyscallCode::BLS12381_FP_ADD] + + self.syscall_counts[SyscallCode::BLS12381_FP_SUB] + + self.syscall_counts[SyscallCode::BLS12381_FP_MUL]; total_area += (bls12381_fp_events as u64) * costs[&RiscvAirDiscriminants::Bls12381Fp]; total_chips += 1; - let bls12381_fp2_addsub_events = - *self.syscall_counts.get(&SyscallCode::BLS12381_FP2_ADD).unwrap_or(&0) - + *self.syscall_counts.get(&SyscallCode::BLS12381_FP2_SUB).unwrap_or(&0); + let bls12381_fp2_addsub_events = self.syscall_counts[SyscallCode::BLS12381_FP2_ADD] + + self.syscall_counts[SyscallCode::BLS12381_FP2_SUB]; total_area += (bls12381_fp2_addsub_events as u64) * costs[&RiscvAirDiscriminants::Bls12381Fp2AddSub]; total_chips += 1; - let bls12381_fp2_mul_events = - *self.syscall_counts.get(&SyscallCode::BLS12381_FP2_MUL).unwrap_or(&0); + let bls12381_fp2_mul_events = self.syscall_counts[SyscallCode::BLS12381_FP2_MUL]; total_area += (bls12381_fp2_mul_events as u64) * costs[&RiscvAirDiscriminants::Bls12381Fp2Mul]; total_chips += 1; - let bn254_fp_events = *self.syscall_counts.get(&SyscallCode::BN254_FP_ADD).unwrap_or(&0) - + *self.syscall_counts.get(&SyscallCode::BN254_FP_SUB).unwrap_or(&0) - + *self.syscall_counts.get(&SyscallCode::BN254_FP_MUL).unwrap_or(&0); + let bn254_fp_events = self.syscall_counts[SyscallCode::BN254_FP_ADD] + + self.syscall_counts[SyscallCode::BN254_FP_SUB] + + self.syscall_counts[SyscallCode::BN254_FP_MUL]; total_area += (bn254_fp_events as u64) * costs[&RiscvAirDiscriminants::Bn254Fp]; total_chips += 1; - let bn254_fp2_addsub_events = - *self.syscall_counts.get(&SyscallCode::BN254_FP2_ADD).unwrap_or(&0) - + *self.syscall_counts.get(&SyscallCode::BN254_FP2_SUB).unwrap_or(&0); + let bn254_fp2_addsub_events = self.syscall_counts[SyscallCode::BN254_FP2_ADD] + + self.syscall_counts[SyscallCode::BN254_FP2_SUB]; total_area += (bn254_fp2_addsub_events as u64) * costs[&RiscvAirDiscriminants::Bn254Fp2AddSub]; total_chips += 1; - let bn254_fp2_mul_events = - *self.syscall_counts.get(&SyscallCode::BN254_FP2_MUL).unwrap_or(&0); + let bn254_fp2_mul_events = self.syscall_counts[SyscallCode::BN254_FP2_MUL]; total_area += (bn254_fp2_mul_events as u64) * costs[&RiscvAirDiscriminants::Bn254Fp2Mul]; total_chips += 1; - let bls12381_decompress_events = - *self.syscall_counts.get(&SyscallCode::BLS12381_DECOMPRESS).unwrap_or(&0); + let bls12381_decompress_events = self.syscall_counts[SyscallCode::BLS12381_DECOMPRESS]; total_area += (bls12381_decompress_events as u64) * costs[&RiscvAirDiscriminants::Bls12381Decompress]; total_chips += 1; - let divrem_events = *self.opcode_counts.get(&Opcode::DIV).unwrap_or(&0) - + *self.opcode_counts.get(&Opcode::REM).unwrap_or(&0) - + *self.opcode_counts.get(&Opcode::DIVU).unwrap_or(&0) - + *self.opcode_counts.get(&Opcode::REMU).unwrap_or(&0); + let divrem_events = self.opcode_counts[Opcode::DIV] + + self.opcode_counts[Opcode::REM] + + self.opcode_counts[Opcode::DIVU] + + self.opcode_counts[Opcode::REMU]; total_area += (divrem_events as u64) * costs[&RiscvAirDiscriminants::DivRem]; total_chips += 1; - let addsub_events = *self.opcode_counts.get(&Opcode::ADD).unwrap_or(&0) - + *self.opcode_counts.get(&Opcode::SUB).unwrap_or(&0); + let addsub_events = self.opcode_counts[Opcode::ADD] + self.opcode_counts[Opcode::SUB]; total_area += (addsub_events as u64) * costs[&RiscvAirDiscriminants::Add]; total_chips += 1; - let bitwise_events = *self.opcode_counts.get(&Opcode::AND).unwrap_or(&0) - + *self.opcode_counts.get(&Opcode::OR).unwrap_or(&0) - + *self.opcode_counts.get(&Opcode::XOR).unwrap_or(&0); + let bitwise_events = self.opcode_counts[Opcode::AND] + + self.opcode_counts[Opcode::OR] + + self.opcode_counts[Opcode::XOR]; total_area += (bitwise_events as u64) * costs[&RiscvAirDiscriminants::Bitwise]; total_chips += 1; - let mul_events = *self.opcode_counts.get(&Opcode::MUL).unwrap_or(&0) - + *self.opcode_counts.get(&Opcode::MULH).unwrap_or(&0) - + *self.opcode_counts.get(&Opcode::MULHU).unwrap_or(&0) - + *self.opcode_counts.get(&Opcode::MULHSU).unwrap_or(&0); + let mul_events = self.opcode_counts[Opcode::MUL] + + self.opcode_counts[Opcode::MULH] + + self.opcode_counts[Opcode::MULHU] + + self.opcode_counts[Opcode::MULHSU]; total_area += (mul_events as u64) * costs[&RiscvAirDiscriminants::Mul]; total_chips += 1; - let shift_right_events = *self.opcode_counts.get(&Opcode::SRL).unwrap_or(&0) - + *self.opcode_counts.get(&Opcode::SRA).unwrap_or(&0); + let shift_right_events = self.opcode_counts[Opcode::SRL] + self.opcode_counts[Opcode::SRA]; total_area += (shift_right_events as u64) * costs[&RiscvAirDiscriminants::ShiftRight]; total_chips += 1; - let shift_left_events = *self.opcode_counts.get(&Opcode::SLL).unwrap_or(&0); + let shift_left_events = self.opcode_counts[Opcode::SLL]; total_area += (shift_left_events as u64) * costs[&RiscvAirDiscriminants::ShiftLeft]; total_chips += 1; - let lt_events = *self.opcode_counts.get(&Opcode::SLT).unwrap_or(&0) - + *self.opcode_counts.get(&Opcode::SLTU).unwrap_or(&0); + let lt_events = self.opcode_counts[Opcode::SLT] + self.opcode_counts[Opcode::SLTU]; total_area += (lt_events as u64) * costs[&RiscvAirDiscriminants::Lt]; total_chips += 1; diff --git a/crates/core/machine/src/utils/prove.rs b/crates/core/machine/src/utils/prove.rs index adedbd6065..5f12c7dd70 100644 --- a/crates/core/machine/src/utils/prove.rs +++ b/crates/core/machine/src/utils/prove.rs @@ -545,11 +545,11 @@ where // Print the opcode and syscall count tables like `du`: sorted by count (descending) and // with the count in the first column. tracing::info!("execution report (opcode counts):"); - for line in sorted_table_lines(&report_aggregate.opcode_counts) { + for line in sorted_table_lines(report_aggregate.opcode_counts.as_ref()) { tracing::info!(" {line}"); } tracing::info!("execution report (syscall counts):"); - for line in sorted_table_lines(&report_aggregate.syscall_counts) { + for line in sorted_table_lines(report_aggregate.syscall_counts.as_ref()) { tracing::info!(" {line}"); } diff --git a/examples/patch-testing/script/src/main.rs b/examples/patch-testing/script/src/main.rs index b3bbcb677f..45f97198ff 100644 --- a/examples/patch-testing/script/src/main.rs +++ b/examples/patch-testing/script/src/main.rs @@ -12,43 +12,25 @@ fn main() { let (_, report) = client.execute(PATCH_TEST_ELF, stdin).run().expect("executing failed"); // Confirm there was at least 1 SHA_COMPUTE syscall. - assert!( - report.syscall_counts.contains_key(&sp1_core_executor::syscalls::SyscallCode::SHA_COMPRESS) - ); - assert!( - report.syscall_counts.contains_key(&sp1_core_executor::syscalls::SyscallCode::SHA_EXTEND) - ); + assert_ne!(report.syscall_counts[sp1_core_executor::syscalls::SyscallCode::SHA_COMPRESS], 0); + assert_ne!(report.syscall_counts[sp1_core_executor::syscalls::SyscallCode::SHA_EXTEND], 0); // Confirm there was at least 1 of each ED25519 syscall. - assert!(report.syscall_counts.contains_key(&sp1_core_executor::syscalls::SyscallCode::ED_ADD)); - assert!( - report - .syscall_counts - .contains_key(&sp1_core_executor::syscalls::SyscallCode::ED_DECOMPRESS) - ); + assert_ne!(report.syscall_counts[sp1_core_executor::syscalls::SyscallCode::ED_ADD], 0); + assert_ne!(report.syscall_counts[sp1_core_executor::syscalls::SyscallCode::ED_DECOMPRESS], 0); // Confirm there was at least 1 KECCAK_PERMUTE syscall. - assert!( - report - .syscall_counts - .contains_key(&sp1_core_executor::syscalls::SyscallCode::KECCAK_PERMUTE) - ); + assert_ne!(report.syscall_counts[sp1_core_executor::syscalls::SyscallCode::KECCAK_PERMUTE], 0); // Confirm there was at least 1 SECP256K1_ADD, SECP256K1_DOUBLE and SECP256K1_DECOMPRESS syscall. - assert!( - report - .syscall_counts - .contains_key(&sp1_core_executor::syscalls::SyscallCode::SECP256K1_ADD) - ); - assert!( - report - .syscall_counts - .contains_key(&sp1_core_executor::syscalls::SyscallCode::SECP256K1_DOUBLE) + assert_ne!(report.syscall_counts[sp1_core_executor::syscalls::SyscallCode::SECP256K1_ADD], 0); + assert_ne!( + report.syscall_counts[sp1_core_executor::syscalls::SyscallCode::SECP256K1_DOUBLE], + 0 ); - assert!( - report - .syscall_counts - .contains_key(&sp1_core_executor::syscalls::SyscallCode::SECP256K1_DECOMPRESS) + assert_ne!( + report.syscall_counts[sp1_core_executor::syscalls::SyscallCode::SECP256K1_DECOMPRESS], + 0 ); println!("Total instructions: {:?}", report.total_instruction_count());