diff --git a/Cargo.lock b/Cargo.lock index 21b1074ae0..226de59286 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1873,6 +1873,15 @@ dependencies = [ "rustversion", ] +[[package]] +name = "debugid" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bef552e6f588e446098f6ba40d89ac146c8c7b64aade83c051ee00bb5d2bc18d" +dependencies = [ + "uuid", +] + [[package]] name = "der" version = "0.7.9" @@ -2430,6 +2439,17 @@ version = "2.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1d758ba1b47b00caf47f24925c0074ecb20d6dfcffe7f6d53395c0465674841a" +[[package]] +name = "gecko_profile" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "890852c7e1e02bc6758e325d6b1e0236e4fbf21b492f585ce4d4715be54b4c6a" +dependencies = [ + "debugid", + "serde", + "serde_json", +] + [[package]] name = "generic-array" version = "0.14.7" @@ -2541,6 +2561,17 @@ dependencies = [ "scroll", ] +[[package]] +name = "goblin" +version = "0.9.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "53ab3f32d1d77146981dea5d6b1e8fe31eedcb7013e5e00d6ccd1259a4b4d923" +dependencies = [ + "log", + "plain", + "scroll", +] + [[package]] name = "group" version = "0.12.1" @@ -5529,7 +5560,7 @@ dependencies = [ "clap", "ctrlc", "dirs", - "goblin", + "goblin 0.8.2", "hex", "indicatif", "prettytable-rs", @@ -5557,8 +5588,11 @@ dependencies = [ "elf", "enum-map", "eyre", + "gecko_profile", + "goblin 0.9.2", "hashbrown 0.14.5", "hex", + "indicatif", "itertools 0.13.0", "log", "nohash-hasher", @@ -5567,7 +5601,9 @@ dependencies = [ "p3-maybe-rayon", "rand 0.8.5", "rrs-succinct", + "rustc-demangle", "serde", + "serde_json", "sp1-curves", "sp1-primitives", "sp1-stark", @@ -7019,6 +7055,12 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" +[[package]] +name = "uuid" +version = "1.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8c5f0a0af699448548ad1a2fbf920fb4bee257eae39953ba95cb84891a0446a" + [[package]] name = "valuable" version = "0.1.0" diff --git a/book/writing-programs/cycle-tracking.md b/book/writing-programs/cycle-tracking.md index e7168c5723..44c29af4e0 100644 --- a/book/writing-programs/cycle-tracking.md +++ b/book/writing-programs/cycle-tracking.md @@ -56,69 +56,29 @@ fn main() { This will log the cycle count for `block name` and include it in the `ExecutionReport` in the `cycle_tracker` map. -## Tracking Cycles with Tracing +### Profiling the ZKVM -The `cycle-tracker` annotation is a convenient way to track cycles for specific sections of code. However, sometimes it can also be useful to track what functions are taking the most cycles across the entire program, without having to annotate every function individually. +Profiling the VM is a good way to get an understanding of what is bottlenecking your program, note only one program may be profiled at a time. -First, we need to generate a trace file of the program counter at each cycle while the program is executing. This can be done by simply setting the `TRACE_FILE` environment variable with the path of the file you want to write the trace to. For example, you can run the following command in the `script` directory for any example program: - -```bash -TRACE_FILE=trace.log RUST_LOG=info cargo run --release +To profile a program, you have to setup a script to execute the program, many examples can be found in the repo, such as this ['fibonacci'](https://github.com/succinctlabs/sp1/blob/12f212e386ae4c2da30cf6a61a7d87615d56bdac/examples/fibonacci/script/src/main.rs#L22) script. +Once you have your script it should contain the following code: +```rs + // Execute the program using the `ProverClient.execute` method, without generating a proof. + let (_, report) = client.execute(ELF, stdin.clone()).run().unwrap(); ``` -When the `TRACE_FILE` environment variable is set, as SP1's RISC-V runtime is executing, it will write a log of the program counter to the file specified by `TRACE_FILE`. +The data captured by the profiler can be quite large, you can set the sample rate using the `TRACE_SAMPLE_RATE` env var. +To enable profiling, set the `TRACE_FILE` env var to the path where you want the profile to be saved. -Next, we can use the `cargo prove` CLI with the `trace` command to analyze the trace file and generate a table of instruction counts. This can be done with the following command: +A larger sample rate will give you a smaller profile, it is the number of instructions in between each sample. -```bash -cargo prove trace --elf --trace +The full command to profile should look something like this +```sh + TRACE_FILE=output.json TRACE_SAMPLE_RATE=100 cargo run ... ``` -The `trace` command will generate a table of instruction counts, sorted by the number of cycles spent in each function. The output will look something like this: - -``` - [00:00:00] [########################################] 17053/17053 (0s) - -Total instructions in trace: 17053 - - - Instruction counts considering call graph -+----------------------------------------+-------------------+ -| Function Name | Instruction Count | -| __start | 17045 | -| main | 12492 | -| sp1_zkvm::syscalls::halt::syscall_halt | 4445 | -| sha2::sha256::compress256 | 4072 | -| sp1_lib::io::commit | 258 | -| sp1_lib::io::SyscallWriter::write | 255 | -| syscall_write | 195 | -| memcpy | 176 | -| memset | 109 | -| sp1_lib::io::read_vec | 71 | -| __rust_alloc | 29 | -| sp1_zkvm::heap::SimpleAlloc::alloc | 22 | -| syscall_hint_len | 3 | -| syscall_hint_read | 2 | -+----------------------------------------+-------------------+ - - - Instruction counts ignoring call graph -+----------------------------------------+-------------------+ -| Function Name | Instruction Count | -| main | 12075 | -| sha2::sha256::compress256 | 4073 | -| sp1_zkvm::syscalls::halt::syscall_halt | 219 | -| memcpy | 180 | -| syscall_write | 123 | -| memset | 111 | -| sp1_lib::io::commit | 88 | -| sp1_lib::io::SyscallWriter::write | 60 | -| __start | 45 | -| sp1_lib::io::read_vec | 35 | -| sp1_zkvm::heap::SimpleAlloc::alloc | 23 | -| anonymous | 7 | -| __rust_alloc | 7 | -| syscall_hint_len | 4 | -| syscall_hint_read | 3 | -+----------------------------------------+-------------------+ +To view these profiles, we recommend [Samply](https://github.com/mstange/samply). +```sh + cargo install --locked samply + samply load output.json ``` diff --git a/crates/cli/Cargo.toml b/crates/cli/Cargo.toml index cf2e8cb4b2..d291ab46ad 100644 --- a/crates/cli/Cargo.toml +++ b/crates/cli/Cargo.toml @@ -43,4 +43,4 @@ regex = "1.5.4" prettytable-rs = "0.10" textwrap = "0.16.0" ctrlc = "3.4.2" -cargo_metadata = "0.18.1" \ No newline at end of file +cargo_metadata = "0.18.1" diff --git a/crates/cli/src/bin/cargo-prove.rs b/crates/cli/src/bin/cargo-prove.rs index e2b87d44bd..bb84ea02ca 100644 --- a/crates/cli/src/bin/cargo-prove.rs +++ b/crates/cli/src/bin/cargo-prove.rs @@ -3,7 +3,7 @@ use clap::{Parser, Subcommand}; use sp1_cli::{ commands::{ build::BuildCmd, build_toolchain::BuildToolchainCmd, - install_toolchain::InstallToolchainCmd, new::NewCmd, trace::TraceCmd, vkey::VkeyCmd, + install_toolchain::InstallToolchainCmd, new::NewCmd, vkey::VkeyCmd, }, SP1_VERSION_MESSAGE, }; @@ -27,7 +27,6 @@ pub enum ProveCliCommands { Build(BuildCmd), BuildToolchain(BuildToolchainCmd), InstallToolchain(InstallToolchainCmd), - Trace(TraceCmd), Vkey(VkeyCmd), } @@ -39,7 +38,6 @@ fn main() -> Result<()> { ProveCliCommands::Build(cmd) => cmd.run(), ProveCliCommands::BuildToolchain(cmd) => cmd.run(), ProveCliCommands::InstallToolchain(cmd) => cmd.run(), - ProveCliCommands::Trace(cmd) => cmd.run(), ProveCliCommands::Vkey(cmd) => cmd.run(), } } diff --git a/crates/cli/src/commands/mod.rs b/crates/cli/src/commands/mod.rs index fc6eb6a5ac..e17d443d05 100644 --- a/crates/cli/src/commands/mod.rs +++ b/crates/cli/src/commands/mod.rs @@ -2,5 +2,4 @@ pub mod build; pub mod build_toolchain; pub mod install_toolchain; pub mod new; -pub mod trace; pub mod vkey; diff --git a/crates/cli/src/commands/trace.rs b/crates/cli/src/commands/trace.rs deleted file mode 100644 index 16cd2e219c..0000000000 --- a/crates/cli/src/commands/trace.rs +++ /dev/null @@ -1,428 +0,0 @@ -//! RISC-V tracer for SP1 traces. This tool can be used to analyze function call graphs and -//! instruction counts from a trace file from SP1 execution by setting the `TRACE_FILE` env -//! variable. -// -// Adapted from Sovereign's RISC-V tracer tool: https://github.com/Sovereign-Labs/riscv-cycle-tracer. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -// Modified by Succinct Labs on July 25, 2024. - -use anyhow::Result; -use clap::Parser; -use goblin::elf::{sym::STT_FUNC, Elf}; -use indicatif::{ProgressBar, ProgressStyle}; -use prettytable::{format, Cell, Row, Table}; -use regex::Regex; -use rustc_demangle::demangle; -use std::{ - cmp::Ordering, - collections::HashMap, - io::Read, - process::Command, - str, - sync::{atomic::AtomicBool, Arc}, -}; -use textwrap::wrap; - -#[derive(Parser, Debug)] -#[command(name = "trace", about = "Trace a program execution and analyze cycle counts.")] -pub struct TraceCmd { - /// Include the "top" number of functions. - #[arg(short, long, default_value_t = 30)] - top: usize, - - /// Don't print stack aware instruction counts - #[arg(long)] - no_stack_counts: bool, - - /// Don't print raw (stack un-aware) instruction counts. - #[arg(long)] - no_raw_counts: bool, - - /// Path to the ELF. - #[arg(long, required = true)] - elf: String, - - /// Path to the trace file. Simply run the program with `TRACE_FILE=trace.log` environment - /// variable. File must be one u64 program counter per line - #[arg(long, required = true)] - trace: String, - - /// Strip the hashes from the function name while printing. - #[arg(short, long)] - keep_hashes: bool, - - /// Function name to target for getting stack counts. - #[arg(short, long)] - function_name: Option, - - /// Exclude functions matching these patterns from display. - /// - /// Usage: `-e func1 -e func2 -e func3`. - #[arg(short, long)] - exclude_view: Vec, -} - -fn strip_hash(name_with_hash: &str) -> String { - let re = Regex::new(r"::h[0-9a-fA-F]{16}").unwrap(); - let mut result = re.replace(name_with_hash, "").to_string(); - let re2 = Regex::new(r"^<(.+) as .+>").unwrap(); - result = re2.replace(&result, "$1").to_string(); - let re2 = Regex::new(r"^<(.+) as .+>").unwrap(); - result = re2.replace(&result, "$1").to_string(); - let re2 = Regex::new(r"([^\:])<.+>::").unwrap(); - result = re2.replace_all(&result, "$1::").to_string(); - result -} - -fn print_instruction_counts( - first_header: &str, - count_vec: Vec<(String, usize)>, - top_n: usize, - strip_hashes: bool, - exclude_list: Option<&[String]>, -) { - let mut table = Table::new(); - table.set_format(*format::consts::FORMAT_NO_LINESEP); - table.set_titles(Row::new(vec![Cell::new(first_header), Cell::new("Instruction Count")])); - - let wrap_width = 120; - let mut row_count = 0; - for (key, value) in count_vec { - let mut cont = false; - if let Some(ev) = exclude_list { - for e in ev { - if key.contains(e) { - cont = true; - break; - } - } - if cont { - continue; - } - } - let mut stripped_key = key.clone(); - if strip_hashes { - stripped_key = strip_hash(&key); - } - row_count += 1; - if row_count > top_n { - break; - } - let wrapped_key = wrap(&stripped_key, wrap_width); - let key_cell_content = wrapped_key.join("\n"); - table.add_row(Row::new(vec![Cell::new(&key_cell_content), Cell::new(&value.to_string())])); - } - - table.printstd(); -} - -fn focused_stack_counts( - function_stack: &[String], - filtered_stack_counts: &mut HashMap, usize>, - function_name: &str, - num_instructions: usize, -) { - if let Some(index) = function_stack.iter().position(|s| s == function_name) { - let truncated_stack = &function_stack[0..=index]; - let count = filtered_stack_counts.entry(truncated_stack.to_vec()).or_insert(0); - *count += num_instructions; - } -} - -fn _build_radare2_lookups( - start_lookup: &mut HashMap, - end_lookup: &mut HashMap, - func_range_lookup: &mut HashMap, - elf_name: &str, -) -> std::io::Result<()> { - let output = Command::new("r2").arg("-q").arg("-c").arg("aa;afl").arg(elf_name).output()?; - - if output.status.success() { - let result_str = str::from_utf8(&output.stdout).unwrap(); - for line in result_str.lines() { - let parts: Vec<&str> = line.split_whitespace().collect(); - let address = u64::from_str_radix(&parts[0][2..], 16).unwrap(); - let size = parts[2].parse::().unwrap(); - let end_address = address + size - 4; - let function_name = parts[3]; - start_lookup.insert(address, function_name.to_string()); - end_lookup.insert(end_address, function_name.to_string()); - func_range_lookup.insert(function_name.to_string(), (address, end_address)); - } - } else { - eprintln!("Error executing command: {}", str::from_utf8(&output.stderr).unwrap()); - } - Ok(()) -} - -fn build_goblin_lookups( - start_lookup: &mut HashMap, - end_lookup: &mut HashMap, - func_range_lookup: &mut HashMap, - elf_name: &str, -) -> std::io::Result<()> { - let buffer = std::fs::read(elf_name).unwrap(); - let elf = Elf::parse(&buffer).unwrap(); - - for sym in &elf.syms { - if sym.st_type() == STT_FUNC { - let name = elf.strtab.get_at(sym.st_name).unwrap_or(""); - let demangled_name = demangle(name); - let size = sym.st_size; - let start_address = sym.st_value; - let end_address = start_address + size - 4; - start_lookup.insert(start_address, demangled_name.to_string()); - end_lookup.insert(end_address, demangled_name.to_string()); - func_range_lookup.insert(demangled_name.to_string(), (start_address, end_address)); - } - } - Ok(()) -} - -fn increment_stack_counts( - instruction_counts: &mut HashMap, - function_stack: &[String], - filtered_stack_counts: &mut HashMap, usize>, - function_name: &Option, - num_instructions: usize, -) { - for f in function_stack { - *instruction_counts.entry(f.clone()).or_insert(0) += num_instructions; - } - if let Some(f) = function_name { - focused_stack_counts(function_stack, filtered_stack_counts, f, num_instructions) - } -} - -impl TraceCmd { - pub fn run(&self) -> Result<()> { - let top_n = self.top; - let elf_path = self.elf.clone(); - let trace_path = self.trace.clone(); - let no_stack_counts = self.no_stack_counts; - let no_raw_counts = self.no_raw_counts; - let strip_hashes = !self.keep_hashes; - let function_name = self.function_name.clone(); - let exclude_view = self.exclude_view.clone(); - - let mut start_lookup = HashMap::new(); - let mut end_lookup = HashMap::new(); - let mut func_range_lookup = HashMap::new(); - build_goblin_lookups(&mut start_lookup, &mut end_lookup, &mut func_range_lookup, &elf_path) - .unwrap(); - - let mut function_ranges: Vec<(u64, u64, String)> = - func_range_lookup.iter().map(|(f, &(start, end))| (start, end, f.clone())).collect(); - - function_ranges.sort_by_key(|&(start, _, _)| start); - - let file = std::fs::File::open(trace_path).unwrap(); - let file_size = file.metadata().unwrap().len(); - let mut buf = std::io::BufReader::new(file); - let mut function_stack: Vec = Vec::new(); - let mut instruction_counts: HashMap = HashMap::new(); - let mut counts_without_callgraph: HashMap = HashMap::new(); - let mut filtered_stack_counts: HashMap, usize> = HashMap::new(); - let total_lines = file_size / 4; - let mut current_function_range: (u64, u64) = (0, 0); - - let update_interval = 1000usize; - let pb = ProgressBar::new(total_lines); - pb.set_style( - ProgressStyle::default_bar() - .template( - "{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} ({eta})", - ) - .unwrap() - .progress_chars("#>-"), - ); - - let running = Arc::new(AtomicBool::new(true)); - let r = running.clone(); - - ctrlc::set_handler(move || { - r.store(false, std::sync::atomic::Ordering::SeqCst); - }) - .expect("Error setting Ctrl-C handler"); - - for c in 0..total_lines { - if (c as usize) % update_interval == 0 { - pb.inc(update_interval as u64); - if !running.load(std::sync::atomic::Ordering::SeqCst) { - pb.finish_with_message("Interrupted"); - break; - } - } - - // Parse pc from hex. - let mut pc_bytes = [0u8; 4]; - buf.read_exact(&mut pc_bytes).unwrap(); - let pc = u32::from_be_bytes(pc_bytes) as u64; - - // Only 1 instruction per opcode. - let num_instructions = 1; - - // Raw counts without considering the callgraph at all we're just checking if the PC - // belongs to a function if so we're incrementing. This would ignore the call stack - // so for example "main" would only have a hundred instructions or so. - if let Ok(index) = function_ranges.binary_search_by(|&(start, end, _)| { - if pc < start { - Ordering::Greater - } else if pc > end { - Ordering::Less - } else { - Ordering::Equal - } - }) { - let (_, _, fname) = &function_ranges[index]; - *counts_without_callgraph.entry(fname.clone()).or_insert(0) += num_instructions - } else { - *counts_without_callgraph.entry("anonymous".to_string()).or_insert(0) += - num_instructions; - } - - // The next section considers the callstack. We build a callstack and maintain it based - // on some rules. Functions lower in the stack get their counts incremented. - - // We are still in the current function. - if pc > current_function_range.0 && pc <= current_function_range.1 { - increment_stack_counts( - &mut instruction_counts, - &function_stack, - &mut filtered_stack_counts, - &function_name, - num_instructions, - ); - continue; - } - - // Jump to a new function (or the same one). - if let Some(f) = start_lookup.get(&pc) { - increment_stack_counts( - &mut instruction_counts, - &function_stack, - &mut filtered_stack_counts, - &function_name, - num_instructions, - ); - - // Jump to a new function (not recursive). - if !function_stack.contains(f) { - function_stack.push(f.clone()); - current_function_range = *func_range_lookup.get(f).unwrap(); - } - } else { - // This means pc now points to an instruction that is - // - // 1. not in the current function's range - // 2. not a new function call - // - // We now account for a new possibility where we're returning to a function in the - // stack this need not be the immediate parent and can be any of the existing - // functions in the stack due to some optimizations that the compiler can make. - let mut unwind_point = 0; - let mut unwind_found = false; - for (c, f) in function_stack.iter().enumerate() { - let (s, e) = func_range_lookup.get(f).unwrap(); - if pc > *s && pc <= *e { - unwind_point = c; - unwind_found = true; - break; - } - } - - // Unwinding until the parent. - if unwind_found { - function_stack.truncate(unwind_point + 1); - increment_stack_counts( - &mut instruction_counts, - &function_stack, - &mut filtered_stack_counts, - &function_name, - num_instructions, - ); - continue; - } - - // If no unwind point has been found, that means we jumped to some random location - // so we'll just increment the counts for everything in the stack. - increment_stack_counts( - &mut instruction_counts, - &function_stack, - &mut filtered_stack_counts, - &function_name, - num_instructions, - ); - } - } - - pb.finish_with_message("done"); - - let mut raw_counts: Vec<(String, usize)> = - instruction_counts.iter().map(|(key, value)| (key.clone(), *value)).collect(); - raw_counts.sort_by(|a, b| b.1.cmp(&a.1)); - - println!("\n\nTotal instructions in trace: {}", total_lines); - if !no_stack_counts { - println!("\n\n Instruction counts considering call graph"); - print_instruction_counts( - "Function Name", - raw_counts, - top_n, - strip_hashes, - Some(&exclude_view), - ); - } - - let mut raw_counts: Vec<(String, usize)> = - counts_without_callgraph.iter().map(|(key, value)| (key.clone(), *value)).collect(); - raw_counts.sort_by(|a, b| b.1.cmp(&a.1)); - if !no_raw_counts { - println!("\n\n Instruction counts ignoring call graph"); - print_instruction_counts( - "Function Name", - raw_counts, - top_n, - strip_hashes, - Some(&exclude_view), - ); - } - - let mut raw_counts: Vec<(String, usize)> = filtered_stack_counts - .iter() - .map(|(stack, count)| { - let numbered_stack = stack - .iter() - .rev() - .enumerate() - .map(|(index, line)| { - let modified_line = - if strip_hashes { strip_hash(line) } else { line.clone() }; - format!("({}) {}", index + 1, modified_line) - }) - .collect::>() - .join("\n"); - (numbered_stack, *count) - }) - .collect(); - - raw_counts.sort_by(|a, b| b.1.cmp(&a.1)); - if let Some(f) = function_name { - println!("\n\n Stack patterns for function '{f}' "); - print_instruction_counts("Function Stack", raw_counts, top_n, strip_hashes, None); - } - Ok(()) - } -} diff --git a/crates/core/executor/Cargo.toml b/crates/core/executor/Cargo.toml index 3c09170d6f..e3bf88b8f7 100644 --- a/crates/core/executor/Cargo.toml +++ b/crates/core/executor/Cargo.toml @@ -43,9 +43,23 @@ vec_map = { version = "0.8.2", features = ["serde"] } enum-map = { version = "2.7.3", features = ["serde"] } test-artifacts = { workspace = true, optional = true } +# profiling +goblin = { version = "0.9", optional = true } +rustc-demangle = { version = "0.1.18", optional = true } +gecko_profile = { version = "0.4.0", optional = true } +indicatif = { version = "0.17.8", optional = true } +serde_json = { version = "1.0.121", optional = true } + [dev-dependencies] sp1-zkvm = { workspace = true } [features] programs = ["dep:test-artifacts"] bigint-rug = ["sp1-curves/bigint-rug"] +profiling = [ + "dep:goblin", + "dep:rustc-demangle", + "dep:gecko_profile", + "dep:indicatif", + "dep:serde_json", +] diff --git a/crates/core/executor/src/executor.rs b/crates/core/executor/src/executor.rs index a9e0834a45..fb088799a7 100644 --- a/crates/core/executor/src/executor.rs +++ b/crates/core/executor/src/executor.rs @@ -1,9 +1,3 @@ -use std::{ - fs::File, - io::{BufWriter, Write}, - sync::Arc, -}; - use hashbrown::HashMap; use serde::{Deserialize, Serialize}; use sp1_stark::SP1CoreOpts; @@ -26,6 +20,13 @@ use crate::{ Instruction, Opcode, Program, Register, }; +#[cfg(feature = "profiling")] +use crate::profiler::Profiler; +#[cfg(feature = "profiling")] +use std::{fs::File, io::BufWriter}; + +use std::sync::Arc; + #[derive(Debug, Clone, Copy, PartialEq, Eq)] /// Whether to verify deferred proofs during execution. pub enum DeferredProofVerification { @@ -110,8 +111,9 @@ pub struct Executor<'a> { /// A buffer for stdout and stderr IO. pub io_buf: HashMap, - /// A buffer for writing trace events to a file. - pub trace_buf: Option>, + /// The ZKVM profiler. + #[cfg(feature = "profiling")] + pub profiler: Option<(Profiler, BufWriter)>, /// The state of the runtime when in unconstrained mode. pub unconstrained_state: ForkState, @@ -190,11 +192,52 @@ impl<'a> Executor<'a> { Self::with_context(program, opts, SP1Context::default()) } + /// Crete a new runtime for the program, and setup the profiler if `TRACE_FILE` env var is set + /// and the feature flag `profiling` is enabled. + #[must_use] + pub fn with_context_and_elf( + opts: SP1CoreOpts, + context: SP1Context<'a>, + elf_bytes: &[u8], + ) -> Self { + let program = Program::from(elf_bytes).expect("Failed to create program from ELF bytes"); + + #[cfg(not(feature = "profiling"))] + return Self::with_context(program, opts, context); + + #[cfg(feature = "profiling")] + { + let mut this = Self::with_context(program, opts, context); + + let trace_buf = std::env::var("TRACE_FILE").ok().map(|file| { + let file = File::create(file).unwrap(); + BufWriter::new(file) + }); + + if let Some(trace_buf) = trace_buf { + println!("Profiling enabled"); + + let sample_rate = std::env::var("TRACE_SAMPLE_RATE") + .ok() + .and_then(|rate| { + println!("Profiling sample rate: {rate}"); + rate.parse::().ok() + }) + .unwrap_or(1); + + this.profiler = Some(( + Profiler::new(elf_bytes, sample_rate as u64).expect("Failed to create profiler"), + trace_buf, + )); + } + + this + } + } + /// Create a new runtime from a program, options, and a context. /// - /// # Panics - /// - /// This function may panic if it fails to create the trace file if `TRACE_FILE` is set. + /// Note: This function *will not* set up the profiler. #[must_use] pub fn with_context(program: Program, opts: SP1CoreOpts, context: SP1Context<'a>) -> Self { // Create a shared reference to the program. @@ -203,14 +246,6 @@ impl<'a> Executor<'a> { // Create a default record with the program. let record = ExecutionRecord::new(program.clone()); - // If `TRACE_FILE`` is set, initialize the trace buffer. - let trace_buf = if let Ok(trace_file) = std::env::var("TRACE_FILE") { - let file = File::create(trace_file).unwrap(); - Some(BufWriter::new(file)) - } else { - None - }; - // Determine the maximum number of cycles for any syscall. let syscall_map = default_syscall_map(); let max_syscall_cycles = @@ -230,7 +265,8 @@ impl<'a> Executor<'a> { shard_batch_size: opts.shard_batch_size as u32, cycle_tracker: HashMap::new(), io_buf: HashMap::new(), - trace_buf, + #[cfg(feature = "profiling")] + profiler: None, unconstrained: false, unconstrained_state: ForkState::default(), syscall_map, @@ -1190,7 +1226,6 @@ impl<'a> Executor<'a> { let instruction = self.fetch(); // Log the current state of the runtime. - #[cfg(debug_assertions)] self.log(&instruction); // Execute the instruction. @@ -1466,6 +1501,12 @@ impl<'a> Executor<'a> { self.executor_mode = ExecutorMode::Simple; self.print_report = true; while !self.execute()? {} + + #[cfg(feature = "profiling")] + if let Some((profiler, writer)) = self.profiler.take() { + profiler.write(writer).expect("Failed to write profile to output file"); + } + Ok(()) } @@ -1478,6 +1519,12 @@ impl<'a> Executor<'a> { self.executor_mode = ExecutorMode::Trace; self.print_report = true; while !self.execute()? {} + + #[cfg(feature = "profiling")] + if let Some((profiler, writer)) = self.profiler.take() { + profiler.write(writer).expect("Failed to write profile to output file"); + } + Ok(()) } @@ -1576,11 +1623,6 @@ impl<'a> Executor<'a> { } } - // Flush trace buf - if let Some(ref mut buf) = self.trace_buf { - buf.flush().unwrap(); - } - // Ensure that all proofs and input bytes were read, otherwise warn the user. if self.state.proof_stream_ptr != self.state.proof_stream.len() { tracing::warn!( @@ -1648,12 +1690,11 @@ impl<'a> Executor<'a> { } #[inline] - #[cfg(debug_assertions)] fn log(&mut self, _: &Instruction) { - // Write the current program counter to the trace buffer for the cycle tracer. - if let Some(ref mut buf) = self.trace_buf { + #[cfg(feature = "profiling")] + if let Some((ref mut profiler, _)) = self.profiler { if !self.unconstrained { - buf.write_all(&u32::to_be_bytes(self.state.pc)).unwrap(); + profiler.record(self.state.global_clk, self.state.pc as u64); } } diff --git a/crates/core/executor/src/lib.rs b/crates/core/executor/src/lib.rs index a4b6a06ced..a1bccb45cf 100644 --- a/crates/core/executor/src/lib.rs +++ b/crates/core/executor/src/lib.rs @@ -29,6 +29,8 @@ mod instruction; mod io; mod memory; mod opcode; +#[cfg(feature = "profiling")] +mod profiler; mod program; #[cfg(any(test, feature = "programs"))] pub mod programs; diff --git a/crates/core/executor/src/profiler.rs b/crates/core/executor/src/profiler.rs new file mode 100644 index 0000000000..52abc8393f --- /dev/null +++ b/crates/core/executor/src/profiler.rs @@ -0,0 +1,229 @@ +use gecko_profile::{Frame, ProfileBuilder, StringIndex, ThreadBuilder}; +use goblin::elf::{sym::STT_FUNC, Elf}; +use indicatif::{ProgressBar, ProgressStyle}; +use rustc_demangle::demangle; +use std::collections::HashMap; + +#[derive(Debug, thiserror::Error)] +pub enum ProfilerError { + #[error("Failed to read ELF file {}", .0)] + Io(#[from] std::io::Error), + #[error("Failed to parse ELF file {}", .0)] + Elf(#[from] goblin::error::Error), + #[error("Failed to serialize samples {}", .0)] + Serde(#[from] serde_json::Error), +} + +/// The ZKVM Profiler. +/// +/// During execution, the profiler always keeps track of the callstack +/// and will occasionally save the stack according to the sample rate. +pub struct Profiler { + sample_rate: u64, + /// `start_address`-> index in `function_ranges` + start_lookup: HashMap, + /// the start and end of the function + function_ranges: Vec<(u64, u64, Frame)>, + + /// the current known call stack + function_stack: Vec, + /// useful for quick search as to not count recursive calls + function_stack_indices: Vec, + /// The call stacks code ranges, useful for keeping track of unwinds + function_stack_ranges: Vec<(u64, u64)>, + /// The deepest function code range + current_function_range: (u64, u64), + + main_idx: Option, + builder: ThreadBuilder, + samples: Vec, +} + +struct Sample { + stack: Vec, +} + +impl Profiler { + pub(super) fn new(elf_bytes: &[u8], sample_rate: u64) -> Result { + let elf = Elf::parse(elf_bytes)?; + + let mut start_lookup = HashMap::new(); + let mut function_ranges = Vec::new(); + let mut builder = ThreadBuilder::new(1, 0, std::time::Instant::now(), false, false); + + // We need to extract all the functions from the elf file + // and thier corresponding PC ranges. + let mut main_idx = None; + for sym in &elf.syms { + // check if its a function + if sym.st_type() == STT_FUNC { + let name = elf.strtab.get_at(sym.st_name).unwrap_or(""); + let demangled_name = demangle(name); + let size = sym.st_size; + let start_address = sym.st_value; + let end_address = start_address + size - 4; + + // Now that we have the name lets immediately intern it so we only need to copy + // around a usize + let demangled_name = demangled_name.to_string(); + let string_idx = builder.intern_string(&demangled_name); + if main_idx.is_none() && demangled_name == "main" { + main_idx = Some(string_idx); + } + + let start_idx = function_ranges.len(); + function_ranges.push((start_address, end_address, Frame::Label(string_idx))); + start_lookup.insert(start_address, start_idx); + } + } + + Ok(Self { + builder, + main_idx, + sample_rate, + samples: Vec::new(), + start_lookup, + function_ranges, + function_stack: Vec::new(), + function_stack_indices: Vec::new(), + function_stack_ranges: Vec::new(), + current_function_range: (0, 0), + }) + } + + pub(super) fn record(&mut self, clk: u64, pc: u64) { + // We are still in the current function. + if pc > self.current_function_range.0 && pc <= self.current_function_range.1 { + if clk % self.sample_rate == 0 { + self.samples.push(Sample { stack: self.function_stack.clone() }); + } + + return; + } + + // Jump to a new function (or the same one). + if let Some(f) = self.start_lookup.get(&pc) { + // Jump to a new function (not recursive). + if !self.function_stack_indices.contains(f) { + self.function_stack_indices.push(*f); + let (start, end, name) = self.function_ranges.get(*f).unwrap(); + self.current_function_range = (*start, *end); + self.function_stack_ranges.push((*start, *end)); + self.function_stack.push(name.clone()); + } + } else { + // This means pc now points to an instruction that is + // + // 1. not in the current function's range + // 2. not a new function call + // + // We now account for a new possibility where we're returning to a function in the + // stack this need not be the immediate parent and can be any of the existing + // functions in the stack due to some optimizations that the compiler can make. + let mut unwind_point = 0; + let mut unwind_found = false; + for (c, (s, e)) in self.function_stack_ranges.iter().enumerate() { + if pc > *s && pc <= *e { + unwind_point = c; + unwind_found = true; + break; + } + } + + // Unwinding until the parent. + if unwind_found { + self.function_stack.truncate(unwind_point + 1); + self.function_stack_ranges.truncate(unwind_point + 1); + self.function_stack_indices.truncate(unwind_point + 1); + } + + // If no unwind point has been found, that means we jumped to some random location + // so we'll just increment the counts for everything in the stack. + } + + if clk % self.sample_rate == 0 { + self.samples.push(Sample { stack: self.function_stack.clone() }); + } + } + + /// Write the captured samples so far to the `std::io::Write`. This will output a JSON gecko + /// profile. + pub(super) fn write(mut self, writer: impl std::io::Write) -> Result<(), ProfilerError> { + self.check_samples(); + + let start_time = std::time::Instant::now(); + let mut profile_builder = ProfileBuilder::new( + start_time, + std::time::SystemTime::now(), + "SP1 ZKVM", + 0, + std::time::Duration::from_micros(1), + ); + + let pb = ProgressBar::new(self.samples.len() as u64); + pb.set_style( + ProgressStyle::default_bar() + .template( + "{msg} \n {spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} ({eta})", + ) + .unwrap() + .progress_chars("#>-"), + ); + + pb.set_message("Creating profile"); + + let mut last_known_time = std::time::Instant::now(); + for sample in self.samples.drain(..) { + pb.inc(1); + + self.builder.add_sample( + last_known_time, + sample.stack.into_iter(), + // We don't have a way to know the duration of each sample, so we just use 1us for + // all instructions + std::time::Duration::from_micros(self.sample_rate), + ); + + last_known_time += std::time::Duration::from_micros(self.sample_rate); + } + + profile_builder.add_thread(self.builder); + + pb.finish(); + + println!("Writing profile, this can take awhile"); + serde_json::to_writer(writer, &profile_builder.to_serializable())?; + println!("Profile written successfully"); + + Ok(()) + } + + /// Simple check to makes sure we have valid main function that lasts for most of the exeuction + /// time + fn check_samples(&self) { + let Some(main_idx) = self.main_idx else { + eprintln!("Warning: The `main` function is not present in the Elf file, this is likely caused by using the wrong Elf file"); + return; + }; + + let main_count = + self.samples + .iter() + .filter(|s| { + s.stack.iter().any(|f| { + if let Frame::Label(idx) = f { + *idx == main_idx + } else { + false + } + }) + }) + .count(); + + #[allow(clippy::cast_precision_loss)] + let main_ratio = main_count as f64 / self.samples.len() as f64; + if main_ratio < 0.9 { + eprintln!("Warning: This trace appears to be invalid. The `main` function is present in only {:.2}% of the samples, this is likely caused by the using the wrong Elf file", main_ratio * 100.0); + } + } +} diff --git a/crates/prover/src/lib.rs b/crates/prover/src/lib.rs index f72f7c9df7..9a8d6d6017 100644 --- a/crates/prover/src/lib.rs +++ b/crates/prover/src/lib.rs @@ -273,9 +273,9 @@ impl SP1Prover { mut context: SP1Context<'a>, ) -> Result<(SP1PublicValues, ExecutionReport), ExecutionError> { context.subproof_verifier.replace(Arc::new(self)); - let program = self.get_program(elf).unwrap(); let opts = SP1CoreOpts::default(); - let mut runtime = Executor::with_context(program, opts, context); + let mut runtime = Executor::with_context_and_elf(opts, context, elf); + runtime.write_vecs(&stdin.buffer); for (proof, vkey) in stdin.proofs.iter() { runtime.write_proof(proof.clone(), vkey.clone()); diff --git a/crates/sdk/Cargo.toml b/crates/sdk/Cargo.toml index a55b6caf39..3a8293efc4 100644 --- a/crates/sdk/Cargo.toml +++ b/crates/sdk/Cargo.toml @@ -87,6 +87,8 @@ network-v2 = [ ] cuda = ["sp1-cuda"] +profiling = ["sp1-core-executor/profiling"] + [build-dependencies] vergen = { version = "8", default-features = false, features = [ "build", diff --git a/examples/Cargo.lock b/examples/Cargo.lock index 53dd96cfd5..05d23d23a0 100644 --- a/examples/Cargo.lock +++ b/examples/Cargo.lock @@ -219,7 +219,7 @@ dependencies = [ "alloy-sol-types", "serde", "serde_json", - "thiserror", + "thiserror 1.0.68", "tracing", ] @@ -241,7 +241,7 @@ dependencies = [ "async-trait", "auto_impl", "futures-utils-wasm", - "thiserror", + "thiserror 1.0.68", ] [[package]] @@ -426,7 +426,7 @@ dependencies = [ "auto_impl", "elliptic-curve", "k256", - "thiserror", + "thiserror 1.0.68", ] [[package]] @@ -442,7 +442,7 @@ dependencies = [ "async-trait", "k256", "rand 0.8.5", - "thiserror", + "thiserror 1.0.68", ] [[package]] @@ -1155,7 +1155,7 @@ dependencies = [ "semver 1.0.23", "serde", "serde_json", - "thiserror", + "thiserror 1.0.68", ] [[package]] @@ -1636,6 +1636,15 @@ dependencies = [ "rustversion", ] +[[package]] +name = "debugid" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bef552e6f588e446098f6ba40d89ac146c8c7b64aade83c051ee00bb5d2bc18d" +dependencies = [ + "uuid", +] + [[package]] name = "der" version = "0.5.1" @@ -1834,7 +1843,7 @@ dependencies = [ "rand_core 0.6.4", "serde", "sha2 0.9.9", - "thiserror", + "thiserror 1.0.68", "zeroize", ] @@ -2214,6 +2223,17 @@ version = "2.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1d758ba1b47b00caf47f24925c0074ecb20d6dfcffe7f6d53395c0465674841a" +[[package]] +name = "gecko_profile" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "890852c7e1e02bc6758e325d6b1e0236e4fbf21b492f585ce4d4715be54b4c6a" +dependencies = [ + "debugid", + "serde", + "serde_json", +] + [[package]] name = "generic-array" version = "0.14.7" @@ -2271,6 +2291,17 @@ version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" +[[package]] +name = "goblin" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b363a30c165f666402fe6a3024d3bec7ebc898f96a4a23bd1c99f8dbf3f4f47" +dependencies = [ + "log", + "plain", + "scroll", +] + [[package]] name = "groth16-verifier-program" version = "1.1.0" @@ -3846,7 +3877,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "879952a81a83930934cbf1786752d6dedc3b1f29e8f8fb2ad1d0a36f377cf442" dependencies = [ "memchr", - "thiserror", + "thiserror 1.0.68", "ucd-trie", ] @@ -3911,6 +3942,12 @@ version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "953ec861398dccce10c670dfeaf3ec4911ca479e9c02154b3a215178c5f566f2" +[[package]] +name = "plain" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b4596b6d070b27117e987119b4dac604f3c58cfb0b191112e24771b2faeac1a6" + [[package]] name = "portable-atomic" version = "1.9.0" @@ -4106,7 +4143,7 @@ dependencies = [ "rustc-hash 2.0.0", "rustls", "socket2", - "thiserror", + "thiserror 1.0.68", "tokio", "tracing", ] @@ -4123,7 +4160,7 @@ dependencies = [ "rustc-hash 2.0.0", "rustls", "slab", - "thiserror", + "thiserror 1.0.68", "tinyvec", "tracing", ] @@ -4289,7 +4326,7 @@ checksum = "ba009ff324d1fc1b900bd1fdb31564febe58a8ccc8a6fdbb93b543d33b13ca43" dependencies = [ "getrandom", "libredox", - "thiserror", + "thiserror 1.0.68", ] [[package]] @@ -4407,7 +4444,7 @@ dependencies = [ "http", "reqwest", "serde", - "thiserror", + "thiserror 1.0.68", "tower-service", ] @@ -4420,7 +4457,7 @@ dependencies = [ "reth-execution-errors", "reth-primitives", "reth-storage-errors", - "thiserror", + "thiserror 1.0.68", ] [[package]] @@ -4514,7 +4551,7 @@ dependencies = [ "reth-execution-errors", "reth-fs-util", "reth-storage-errors", - "thiserror", + "thiserror 1.0.68", ] [[package]] @@ -4598,7 +4635,7 @@ dependencies = [ "reth-revm", "revm", "revm-primitives", - "thiserror", + "thiserror 1.0.68", "tracing", ] @@ -4637,7 +4674,7 @@ source = "git+https://github.com/sp1-patches/reth?tag=rsp-20240830#260c7ed2c9374 dependencies = [ "serde", "serde_json", - "thiserror", + "thiserror 1.0.68", ] [[package]] @@ -4649,7 +4686,7 @@ dependencies = [ "alloy-rlp", "enr", "serde_with", - "thiserror", + "thiserror 1.0.68", "url", ] @@ -4706,7 +4743,7 @@ dependencies = [ "reth-trie-common", "revm-primitives", "serde", - "thiserror", + "thiserror 1.0.68", ] [[package]] @@ -4741,7 +4778,7 @@ dependencies = [ "modular-bitfield", "reth-codecs", "serde", - "thiserror", + "thiserror 1.0.68", ] [[package]] @@ -5090,7 +5127,7 @@ dependencies = [ "rlp", "rsp-primitives", "serde", - "thiserror", + "thiserror 1.0.68", ] [[package]] @@ -5337,6 +5374,26 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" +[[package]] +name = "scroll" +version = "0.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ab8598aa408498679922eff7fa985c25d58a90771bd6be794434c5277eab1a6" +dependencies = [ + "scroll_derive", +] + +[[package]] +name = "scroll_derive" +version = "0.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f81c2fde025af7e69b1d1420531c8a8811ca898919db177141a85313b1cb932" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.87", +] + [[package]] name = "sdd" version = "3.0.4" @@ -5697,8 +5754,11 @@ dependencies = [ "elf", "enum-map", "eyre", + "gecko_profile", + "goblin", "hashbrown 0.14.5", "hex", + "indicatif", "itertools 0.13.0", "log", "nohash-hasher", @@ -5707,13 +5767,15 @@ dependencies = [ "p3-maybe-rayon", "rand 0.8.5", "rrs-succinct", + "rustc-demangle", "serde", + "serde_json", "sp1-curves", "sp1-primitives", "sp1-stark", "strum", "strum_macros", - "thiserror", + "thiserror 1.0.68", "tiny-keccak", "tracing", "typenum", @@ -5758,7 +5820,7 @@ dependencies = [ "strum", "strum_macros", "tempfile", - "thiserror", + "thiserror 1.0.68", "tracing", "tracing-forest", "tracing-subscriber", @@ -5888,7 +5950,7 @@ dependencies = [ "sp1-recursion-core", "sp1-recursion-gnark-ffi", "sp1-stark", - "thiserror", + "thiserror 1.0.68", "tracing", "tracing-subscriber", ] @@ -5973,7 +6035,7 @@ dependencies = [ "sp1-primitives", "sp1-stark", "static_assertions", - "thiserror", + "thiserror 1.0.68", "tracing", "vec_map", "zkhash", @@ -6046,7 +6108,7 @@ dependencies = [ "strum", "strum_macros", "tempfile", - "thiserror", + "thiserror 1.0.68", "tokio", "tracing", "twirp-rs", @@ -6488,7 +6550,16 @@ version = "1.0.68" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "02dd99dc800bbb97186339685293e1cc5d9df1f8fae2d0aecd9ff1c77efea892" dependencies = [ - "thiserror-impl", + "thiserror-impl 1.0.68", +] + +[[package]] +name = "thiserror" +version = "2.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c006c85c7651b3cf2ada4584faa36773bd07bac24acfb39f3c431b36d7e667aa" +dependencies = [ + "thiserror-impl 2.0.3", ] [[package]] @@ -6502,6 +6573,17 @@ dependencies = [ "syn 2.0.87", ] +[[package]] +name = "thiserror-impl" +version = "2.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f077553d607adc1caf65430528a576c757a71ed73944b66ebb58ef2bbd243568" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.87", +] + [[package]] name = "thiserror-impl-no-std" version = "2.0.2" @@ -6758,7 +6840,7 @@ checksum = "ee40835db14ddd1e3ba414292272eddde9dad04d3d4b65509656414d1c42592f" dependencies = [ "ansi_term", "smallvec", - "thiserror", + "thiserror 1.0.68", "tracing", "tracing-subscriber", ] @@ -6814,7 +6896,7 @@ dependencies = [ "reqwest", "serde", "serde_json", - "thiserror", + "thiserror 1.0.68", "tokio", "tower", "url", @@ -6909,6 +6991,12 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" +[[package]] +name = "uuid" +version = "1.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8c5f0a0af699448548ad1a2fbf920fb4bee257eae39953ba95cb84891a0446a" + [[package]] name = "valuable" version = "0.1.0"