From 4d4dd3de6097f21cca6d1c0e2dc12d298af5b255 Mon Sep 17 00:00:00 2001 From: mcalancea Date: Fri, 7 Feb 2025 12:37:40 +0200 Subject: [PATCH] `secp256k1` and `sha256_extend` mocks + refactor (#826) - all precompiles have the same interface as their counter-parts in `sp1`; the `secp256k1` curve ops are backed by the `secp` crate. - integrates appropriate dummy circuits for precompiles into the zkvm - added `utils.rs` with some utilities for safely manipulating memory segments of the VM - tested the compatibility of the syscalls (+ keccak, done previously) with sp1 (see the `syscalls.rs` example file). - extended the docs pasted from sp1 Aurel has reviewed an earlier version of this [here](https://github.com/Inversed-Tech/ceno/pull/4). --- Cargo.lock | 64 +++++ Cargo.toml | 1 + ceno_emul/Cargo.toml | 1 + ceno_emul/src/lib.rs | 12 +- ceno_emul/src/syscalls.rs | 36 ++- ceno_emul/src/syscalls/keccak_permute.rs | 106 ++++--- ceno_emul/src/syscalls/secp256k1.rs | 271 ++++++++++++++++++ ceno_emul/src/syscalls/sha256.rs | 78 +++++ ceno_emul/src/utils.rs | 76 +++++ ceno_host/tests/test_elf.rs | 213 +++++++++++++- ceno_rt/src/syscalls.rs | 120 +++++++- .../instructions/riscv/dummy/dummy_ecall.rs | 20 +- .../src/instructions/riscv/dummy/test.rs | 3 +- ceno_zkvm/src/instructions/riscv/rv32im.rs | 89 +++++- examples/examples/ceno_rt_keccak.rs | 4 +- examples/examples/secp256k1_add_syscall.rs | 49 ++++ .../examples/secp256k1_decompress_syscall.rs | 32 +++ examples/examples/secp256k1_double_syscall.rs | 43 +++ examples/examples/sha_extend_syscall.rs | 23 ++ examples/examples/syscalls.rs | 160 +++++++++++ 20 files changed, 1316 insertions(+), 85 deletions(-) create mode 100644 ceno_emul/src/syscalls/secp256k1.rs create mode 100644 ceno_emul/src/syscalls/sha256.rs create mode 100644 ceno_emul/src/utils.rs create mode 100644 examples/examples/secp256k1_add_syscall.rs create mode 100644 examples/examples/secp256k1_decompress_syscall.rs create mode 100644 examples/examples/secp256k1_double_syscall.rs create mode 100644 examples/examples/sha_extend_syscall.rs create mode 100644 examples/examples/syscalls.rs diff --git a/Cargo.lock b/Cargo.lock index 822faba4a..76fac0dfe 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -156,6 +156,12 @@ dependencies = [ "windows-targets", ] +[[package]] +name = "base16ct" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c7f02d4ea65f2c1853089ffd8d2787bdbc63de2f0d29dedbcf8ccdfa0ccd4cf" + [[package]] name = "base64" version = "0.22.1" @@ -186,6 +192,22 @@ version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5e764a1d40d510daf35e07be9eb06e75770908c27d411ee6c92109c9840eaaf7" +[[package]] +name = "bitcoin-io" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b47c4ab7a93edb0c7198c5535ed9b52b63095f4e9b45279c6736cec4b856baf" + +[[package]] +name = "bitcoin_hashes" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bb18c03d0db0247e147a21a6faafd5a7eb851c743db062de72018b6b7e8e4d16" +dependencies = [ + "bitcoin-io", + "hex-conservative", +] + [[package]] name = "bitflags" version = "2.6.0" @@ -295,6 +317,7 @@ dependencies = [ "num-derive", "num-traits", "rrs-succinct", + "secp", "strum", "strum_macros", "tiny-keccak", @@ -917,6 +940,15 @@ version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" +[[package]] +name = "hex-conservative" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5313b072ce3c597065a808dbf612c4c8e8590bdbf8b579508bf7a762c5eae6cd" +dependencies = [ + "arrayvec", +] + [[package]] name = "humantime" version = "2.1.0" @@ -1893,6 +1925,38 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" +[[package]] +name = "secp" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85ed54b1141d8cec428d8a4abf01282755ba4e4c8a621dd23fa2e0ed761814c2" +dependencies = [ + "base16ct", + "once_cell", + "secp256k1", + "subtle", +] + +[[package]] +name = "secp256k1" +version = "0.30.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b50c5943d326858130af85e049f2661ba3c78b26589b8ab98e65e80ae44a1252" +dependencies = [ + "bitcoin_hashes", + "rand", + "secp256k1-sys", +] + +[[package]] +name = "secp256k1-sys" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d4387882333d3aa8cb20530a17c69a3752e97837832f34f6dccc760e715001d9" +dependencies = [ + "cc", +] + [[package]] name = "serde" version = "1.0.217" diff --git a/Cargo.toml b/Cargo.toml index 2fb6dfde7..2d0b62b5d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -46,6 +46,7 @@ rand_chacha = { version = "0.3", features = ["serde1"] } rand_core = "0.6" rand_xorshift = "0.3" rayon = "1.10" +secp = "0.4.1" serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" strum = "0.26" diff --git a/ceno_emul/Cargo.toml b/ceno_emul/Cargo.toml index 2bc4c830f..6478a1597 100644 --- a/ceno_emul/Cargo.toml +++ b/ceno_emul/Cargo.toml @@ -17,6 +17,7 @@ itertools.workspace = true num-derive.workspace = true num-traits.workspace = true rrs_lib = { package = "rrs-succinct", version = "0.1.0" } +secp.workspace = true strum.workspace = true strum_macros.workspace = true tiny-keccak.workspace = true diff --git a/ceno_emul/src/lib.rs b/ceno_emul/src/lib.rs index 5e9d871ef..60a4a3037 100644 --- a/ceno_emul/src/lib.rs +++ b/ceno_emul/src/lib.rs @@ -23,7 +23,17 @@ pub use elf::Program; pub mod disassemble; mod syscalls; -pub use syscalls::{KECCAK_PERMUTE, keccak_permute::KECCAK_WORDS}; +pub use syscalls::{ + KECCAK_PERMUTE, SECP256K1_ADD, SECP256K1_DECOMPRESS, SECP256K1_DOUBLE, SHA_EXTEND, SyscallSpec, + keccak_permute::{KECCAK_WORDS, KeccakSpec}, + secp256k1::{ + COORDINATE_WORDS, SECP256K1_ARG_WORDS, Secp256k1AddSpec, Secp256k1DecompressSpec, + Secp256k1DoubleSpec, + }, + sha256::{SHA_EXTEND_WORDS, Sha256ExtendSpec}, +}; + +pub mod utils; pub mod test_utils; diff --git a/ceno_emul/src/syscalls.rs b/ceno_emul/src/syscalls.rs index d5ca85402..d99bd9025 100644 --- a/ceno_emul/src/syscalls.rs +++ b/ceno_emul/src/syscalls.rs @@ -2,16 +2,32 @@ use crate::{RegIdx, Tracer, VMState, Word, WordAddr, WriteOp}; use anyhow::Result; pub mod keccak_permute; +pub mod secp256k1; +pub mod sha256; // Using the same function codes as sp1: // https://github.com/succinctlabs/sp1/blob/013c24ea2fa15a0e7ed94f7d11a7ada4baa39ab9/crates/core/executor/src/syscalls/code.rs -pub use ceno_rt::syscalls::KECCAK_PERMUTE; +pub use ceno_rt::syscalls::{ + KECCAK_PERMUTE, SECP256K1_ADD, SECP256K1_DECOMPRESS, SECP256K1_DOUBLE, SHA_EXTEND, +}; + +pub trait SyscallSpec { + const NAME: &'static str; + + const REG_OPS_COUNT: usize; + const MEM_OPS_COUNT: usize; + const CODE: u32; +} /// Trace the inputs and effects of a syscall. pub fn handle_syscall(vm: &VMState, function_code: u32) -> Result { match function_code { KECCAK_PERMUTE => Ok(keccak_permute::keccak_permute(vm)), + SECP256K1_ADD => Ok(secp256k1::secp256k1_add(vm)), + SECP256K1_DOUBLE => Ok(secp256k1::secp256k1_double(vm)), + SECP256K1_DECOMPRESS => Ok(secp256k1::secp256k1_decompress(vm)), + SHA_EXTEND => Ok(sha256::extend(vm)), // TODO: introduce error types. _ => Err(anyhow::anyhow!("Unknown syscall: {}", function_code)), } @@ -22,6 +38,24 @@ pub fn handle_syscall(vm: &VMState, function_code: u32) -> Result, pub reg_ops: Vec, + _marker: (), +} + +impl SyscallWitness { + fn new(mem_ops: Vec, reg_ops: Vec) -> SyscallWitness { + for (i, op) in mem_ops.iter().enumerate() { + assert_eq!( + op.addr, + mem_ops[0].addr + i, + "Dummy circuit expects that mem_ops addresses are consecutive." + ); + } + SyscallWitness { + mem_ops, + reg_ops, + _marker: (), + } + } } /// The effects of a syscall to apply on the VM. diff --git a/ceno_emul/src/syscalls/keccak_permute.rs b/ceno_emul/src/syscalls/keccak_permute.rs index 63aba9201..3748fc3ea 100644 --- a/ceno_emul/src/syscalls/keccak_permute.rs +++ b/ceno_emul/src/syscalls/keccak_permute.rs @@ -1,65 +1,87 @@ -use itertools::{Itertools, izip}; +use itertools::Itertools; use tiny_keccak::keccakf; -use crate::{Change, EmuContext, Platform, VMState, WORD_SIZE, WordAddr, WriteOp}; +use crate::{Change, EmuContext, Platform, VMState, Word, WriteOp, utils::MemoryView}; -use super::{SyscallEffects, SyscallWitness}; +use super::{SyscallEffects, SyscallSpec, SyscallWitness}; const KECCAK_CELLS: usize = 25; // u64 cells pub const KECCAK_WORDS: usize = KECCAK_CELLS * 2; // u32 words +pub struct KeccakSpec; + +impl SyscallSpec for KeccakSpec { + const NAME: &'static str = "KECCAK"; + + const REG_OPS_COUNT: usize = 2; + const MEM_OPS_COUNT: usize = KECCAK_WORDS; + const CODE: u32 = ceno_rt::syscalls::KECCAK_PERMUTE; +} + +/// Wrapper type for the keccak_permute argument that implements conversions +/// from and to VM word-representations according to the syscall spec +pub struct KeccakState(pub [u64; KECCAK_CELLS]); + +impl From<[Word; KECCAK_WORDS]> for KeccakState { + fn from(words: [Word; KECCAK_WORDS]) -> Self { + KeccakState( + words + .chunks_exact(2) + .map(|chunk| (chunk[0] as u64 | ((chunk[1] as u64) << 32))) + .collect_vec() + .try_into() + .expect("failed to parse words into [u64; 25]"), + ) + } +} + +impl From for [Word; KECCAK_WORDS] { + fn from(state: KeccakState) -> [Word; KECCAK_WORDS] { + state + .0 + .iter() + .flat_map(|&elem| [elem as u32, (elem >> 32) as u32]) + .collect_vec() + .try_into() + .unwrap() + } +} + /// Trace the execution of a Keccak permutation. /// /// Compatible with: /// https://github.com/succinctlabs/sp1/blob/013c24ea2fa15a0e7ed94f7d11a7ada4baa39ab9/crates/core/executor/src/syscalls/precompiles/keccak256/permute.rs -/// -/// TODO: test compatibility. pub fn keccak_permute(vm: &VMState) -> SyscallEffects { let state_ptr = vm.peek_register(Platform::reg_arg0()); - // Read the argument `state_ptr`. - let reg_ops = vec![WriteOp::new_register_op( - Platform::reg_arg0(), - Change::new(state_ptr, state_ptr), - 0, // Cycle set later in finalize(). - )]; - - let addrs = (state_ptr..) - .step_by(WORD_SIZE) - .take(KECCAK_WORDS) - .map(WordAddr::from) - .collect_vec(); - - // Read Keccak state. - let input = addrs - .iter() - .map(|&addr| vm.peek_memory(addr)) - .collect::>(); + // for compatibility with sp1 spec + assert_eq!(vm.peek_register(Platform::reg_arg1()), 0); - // Compute Keccak permutation. - let output = { - let mut state = [0_u64; KECCAK_CELLS]; - for (cell, (&lo, &hi)) in izip!(&mut state, input.iter().tuples()) { - *cell = lo as u64 | ((hi as u64) << 32); - } - - keccakf(&mut state); + // Read the argument `state_ptr`. + let reg_ops = vec![ + WriteOp::new_register_op( + Platform::reg_arg0(), + Change::new(state_ptr, state_ptr), + 0, // Cycle set later in finalize(). + ), + WriteOp::new_register_op( + Platform::reg_arg1(), + Change::new(0, 0), + 0, // Cycle set later in finalize(). + ), + ]; - state.into_iter().flat_map(|c| [c as u32, (c >> 32) as u32]) - }; + let mut state_view = MemoryView::::new(vm, state_ptr); + let mut state = KeccakState::from(state_view.words()); + keccakf(&mut state.0); + let output_words: [Word; KECCAK_WORDS] = state.into(); - // Write permuted state. - let mem_ops = izip!(addrs, input, output) - .map(|(addr, before, after)| WriteOp { - addr, - value: Change { before, after }, - previous_cycle: 0, // Cycle set later in finalize(). - }) - .collect_vec(); + state_view.write(output_words); + let mem_ops: Vec = state_view.mem_ops().to_vec(); assert_eq!(mem_ops.len(), KECCAK_WORDS); SyscallEffects { - witness: SyscallWitness { mem_ops, reg_ops }, + witness: SyscallWitness::new(mem_ops, reg_ops), next_pc: None, } } diff --git a/ceno_emul/src/syscalls/secp256k1.rs b/ceno_emul/src/syscalls/secp256k1.rs new file mode 100644 index 000000000..cb307827b --- /dev/null +++ b/ceno_emul/src/syscalls/secp256k1.rs @@ -0,0 +1,271 @@ +use crate::{Change, EmuContext, Platform, VMState, WORD_SIZE, Word, WriteOp, utils::MemoryView}; +use itertools::Itertools; +use secp::{self}; +use std::iter; + +use super::{SyscallEffects, SyscallSpec, SyscallWitness}; + +pub struct Secp256k1AddSpec; +pub struct Secp256k1DoubleSpec; +pub struct Secp256k1DecompressSpec; + +impl SyscallSpec for Secp256k1AddSpec { + const NAME: &'static str = "SECP256K1_ADD"; + + const REG_OPS_COUNT: usize = 2; + const MEM_OPS_COUNT: usize = 2 * SECP256K1_ARG_WORDS; + const CODE: u32 = ceno_rt::syscalls::SECP256K1_ADD; +} + +impl SyscallSpec for Secp256k1DoubleSpec { + const NAME: &'static str = "SECP256K1_DOUBLE"; + + const REG_OPS_COUNT: usize = 2; + const MEM_OPS_COUNT: usize = SECP256K1_ARG_WORDS; + const CODE: u32 = ceno_rt::syscalls::SECP256K1_DOUBLE; +} + +impl SyscallSpec for Secp256k1DecompressSpec { + const NAME: &'static str = "SECP256K1_DECOMPRESS"; + + const REG_OPS_COUNT: usize = 2; + const MEM_OPS_COUNT: usize = 2 * COORDINATE_WORDS; + const CODE: u32 = ceno_rt::syscalls::SECP256K1_DECOMPRESS; +} + +// A secp256k1 point in uncompressed form takes 64 bytes +pub const SECP256K1_ARG_WORDS: usize = 16; + +/// Wrapper type for a point on the secp256k1 curve that implements conversions +/// from and to VM word-representations according to the syscall spec +pub struct SecpPoint(pub secp::Point); + +impl From<[Word; SECP256K1_ARG_WORDS]> for SecpPoint { + fn from(words: [Word; SECP256K1_ARG_WORDS]) -> Self { + // Prepend the "tag" byte as expected by secp + let mut bytes = iter::once(4u8) + .chain(words.iter().flat_map(|word| word.to_le_bytes())) + .collect_vec(); + + // The call-site uses "little endian", while secp uses "big endian" + // We need to reverse the coordinate representations + + // Reverse X coordinate + bytes[1..33].reverse(); + // Reverse Y coordinate + bytes[33..].reverse(); + SecpPoint(secp::Point::from_slice(&bytes).unwrap()) + } +} + +impl From for [Word; SECP256K1_ARG_WORDS] { + fn from(point: SecpPoint) -> [Word; SECP256K1_ARG_WORDS] { + // reuse MaybePoint implementation + SecpMaybePoint(point.0.into()).into() + } +} + +/// Wrapper type for a maybe-point on the secp256k1 curve that implements conversions +/// from and to VM word-representations according to the syscall spec +pub struct SecpMaybePoint(pub secp::MaybePoint); + +impl From for [Word; SECP256K1_ARG_WORDS] { + fn from(maybe_point: SecpMaybePoint) -> [Word; SECP256K1_ARG_WORDS] { + let mut bytes: [u8; 64] = maybe_point.0.serialize_uncompressed()[1..] + .try_into() + .unwrap(); + // The call-site expects "little endian", while secp uses "big endian" + // We need to reverse the coordinate representations + + // Reverse X coordinate + bytes[..32].reverse(); + // Reverse Y coordinate + bytes[32..].reverse(); + bytes + .chunks_exact(4) + .map(|chunk| Word::from_le_bytes(chunk.try_into().unwrap())) + .collect_vec() + .try_into() + .unwrap() + } +} + +/// Trace the execution of a secp256k1_add call +pub fn secp256k1_add(vm: &VMState) -> SyscallEffects { + let p_ptr = vm.peek_register(Platform::reg_arg0()); + let q_ptr = vm.peek_register(Platform::reg_arg1()); + + // Read the argument pointers + let reg_ops = vec![ + WriteOp::new_register_op( + Platform::reg_arg0(), + Change::new(p_ptr, p_ptr), + 0, // Cycle set later in finalize(). + ), + WriteOp::new_register_op( + Platform::reg_arg1(), + Change::new(q_ptr, q_ptr), + 0, // Cycle set later in finalize(). + ), + ]; + + // Memory segments of P and Q + let [mut p_view, q_view] = + [p_ptr, q_ptr].map(|start| MemoryView::::new(vm, start)); + + // Read P and Q from words via wrapper type + let [p, q] = [&p_view, &q_view].map(|view| SecpPoint::from(view.words())); + + // Compute the sum and convert back to words + let sum = SecpMaybePoint(p.0 + q.0); + let output_words: [Word; SECP256K1_ARG_WORDS] = sum.into(); + + p_view.write(output_words); + + let mem_ops = p_view + .mem_ops() + .into_iter() + .chain(q_view.mem_ops()) + .collect_vec(); + + assert_eq!(mem_ops.len(), 2 * SECP256K1_ARG_WORDS); + SyscallEffects { + witness: SyscallWitness::new(mem_ops, reg_ops), + next_pc: None, + } +} + +/// Trace the execution of a secp256k1_double call +pub fn secp256k1_double(vm: &VMState) -> SyscallEffects { + let p_ptr = vm.peek_register(Platform::reg_arg0()); + + // for compatibility with sp1 spec + assert_eq!(vm.peek_register(Platform::reg_arg1()), 0); + + // Read the argument pointers + let reg_ops = vec![ + WriteOp::new_register_op( + Platform::reg_arg0(), + Change::new(p_ptr, p_ptr), + 0, // Cycle set later in finalize(). + ), + WriteOp::new_register_op( + Platform::reg_arg1(), + Change::new(0, 0), + 0, // Cycle set later in finalize(). + ), + ]; + + // P's memory segment + let mut p_view = MemoryView::::new(vm, p_ptr); + // Create point from words via wrapper type + let p = SecpPoint::from(p_view.words()); + + // Compute result and convert back into words + let result = SecpPoint(secp::Scalar::two() * p.0); + let output_words: [Word; SECP256K1_ARG_WORDS] = result.into(); + + p_view.write(output_words); + + let mem_ops = p_view.mem_ops().to_vec(); + + assert_eq!(mem_ops.len(), SECP256K1_ARG_WORDS); + SyscallEffects { + witness: SyscallWitness::new(mem_ops, reg_ops), + next_pc: None, + } +} + +pub const COORDINATE_WORDS: usize = SECP256K1_ARG_WORDS / 2; + +/// Wrapper type for a single coordinate of a point on the secp256k1 curve. +/// It implements conversions from and to VM word-representations according +/// to the spec of syscall +pub struct SecpCoordinate(pub [u8; COORDINATE_WORDS * WORD_SIZE]); + +impl From<[Word; COORDINATE_WORDS]> for SecpCoordinate { + fn from(words: [Word; COORDINATE_WORDS]) -> Self { + let bytes = (words.iter().flat_map(|word| word.to_le_bytes())) + .collect_vec() + .try_into() + .unwrap(); + SecpCoordinate(bytes) + } +} + +impl From for [Word; COORDINATE_WORDS] { + fn from(coord: SecpCoordinate) -> [Word; COORDINATE_WORDS] { + coord + .0 + .chunks_exact(4) + .map(|chunk| Word::from_le_bytes(chunk.try_into().unwrap())) + .collect_vec() + .try_into() + .unwrap() + } +} + +/// Trace the execution of a secp256k1_decompress call +pub fn secp256k1_decompress(vm: &VMState) -> SyscallEffects { + let ptr = vm.peek_register(Platform::reg_arg0()); + let y_is_odd = vm.peek_register(Platform::reg_arg1()); + + // Read the argument pointers + let reg_ops = vec![ + WriteOp::new_register_op( + Platform::reg_arg0(), + Change::new(ptr, ptr), + 0, // Cycle set later in finalize(). + ), + WriteOp::new_register_op( + Platform::reg_arg1(), + Change::new(y_is_odd, y_is_odd), + 0, // Cycle set later in finalize(). + ), + ]; + + // Memory segment of X coordinate + let input_view = MemoryView::::new(vm, ptr); + // Memory segment where Y coordinate will be written + let mut output_view = + MemoryView::::new(vm, ptr + (COORDINATE_WORDS * WORD_SIZE) as u32); + + let point = { + // Encode parity byte according to secp spec + let parity_byte = match y_is_odd { + 0 => 2, + 1 => 3, + _ => panic!("y_is_odd should be 0/1"), + }; + // Read bytes of the X coordinate + let coordinate_bytes = SecpCoordinate::from(input_view.words()).0; + // Prepend parity byte to complete compressed repr. + let bytes = iter::once(parity_byte) + .chain(coordinate_bytes.iter().cloned()) + .collect::>(); + + secp::Point::from_slice(&bytes).unwrap() + }; + + // Get uncompressed repr. of the point and extract the Y-coordinate bytes + // Y-coordinate is the second half after eliminating the "tag" byte + let y_bytes: [u8; 32] = point.serialize_uncompressed()[1..][32..] + .try_into() + .unwrap(); + + // Convert into words via the internal wrapper type + let output_words: [Word; COORDINATE_WORDS] = SecpCoordinate(y_bytes).into(); + + output_view.write(output_words); + + let y_mem_ops = output_view.mem_ops(); + let x_mem_ops = input_view.mem_ops(); + + let mem_ops = x_mem_ops.into_iter().chain(y_mem_ops).collect_vec(); + + assert_eq!(mem_ops.len(), 2 * COORDINATE_WORDS); + SyscallEffects { + witness: SyscallWitness::new(mem_ops, reg_ops), + next_pc: None, + } +} diff --git a/ceno_emul/src/syscalls/sha256.rs b/ceno_emul/src/syscalls/sha256.rs new file mode 100644 index 000000000..99cbff00c --- /dev/null +++ b/ceno_emul/src/syscalls/sha256.rs @@ -0,0 +1,78 @@ +use crate::{Change, EmuContext, Platform, VMState, Word, WriteOp, utils::MemoryView}; + +use super::{SyscallEffects, SyscallSpec, SyscallWitness}; + +pub const SHA_EXTEND_WORDS: usize = 64; // u64 cells + +pub struct Sha256ExtendSpec; + +impl SyscallSpec for Sha256ExtendSpec { + const NAME: &'static str = "SHA256_EXTEND"; + + const REG_OPS_COUNT: usize = 2; + const MEM_OPS_COUNT: usize = SHA_EXTEND_WORDS; + const CODE: u32 = ceno_rt::syscalls::SHA_EXTEND; +} + +/// Wrapper type for the sha_extend argument that implements conversions +/// from and to VM word-representations according to the syscall spec +pub struct ShaExtendWords(pub [Word; SHA_EXTEND_WORDS]); + +impl From<[Word; SHA_EXTEND_WORDS]> for ShaExtendWords { + fn from(value: [Word; SHA_EXTEND_WORDS]) -> Self { + ShaExtendWords(value) + } +} +impl From for [Word; SHA_EXTEND_WORDS] { + fn from(state: ShaExtendWords) -> [Word; SHA_EXTEND_WORDS] { + state.0 + } +} + +/// Based on: https://github.com/succinctlabs/sp1/blob/2aed8fea16a67a5b2983ffc471b2942c2f2512c8/crates/core/machine/src/syscall/precompiles/sha256/extend/mod.rs#L22 +pub fn sha_extend(w: &mut [u32]) { + for i in 16..64 { + let s0 = w[i - 15].rotate_right(7) ^ w[i - 15].rotate_right(18) ^ (w[i - 15] >> 3); + let s1 = w[i - 2].rotate_right(17) ^ w[i - 2].rotate_right(19) ^ (w[i - 2] >> 10); + w[i] = w[i - 16] + .wrapping_add(s0) + .wrapping_add(w[i - 7]) + .wrapping_add(s1); + // TODO: why doesn't sp1 use wrapping_add? + } +} + +pub fn extend(vm: &VMState) -> SyscallEffects { + let state_ptr = vm.peek_register(Platform::reg_arg0()); + + // for compatibility with sp1 spec + assert_eq!(vm.peek_register(Platform::reg_arg1()), 0); + + // Read the argument `state_ptr`. + let reg_ops = vec![ + WriteOp::new_register_op( + Platform::reg_arg0(), + Change::new(state_ptr, state_ptr), + 0, // Cycle set later in finalize(). + ), + WriteOp::new_register_op( + Platform::reg_arg1(), + Change::new(0, 0), + 0, // Cycle set later in finalize(). + ), + ]; + + let mut state_view = MemoryView::::new(vm, state_ptr); + let mut sha_extend_words = ShaExtendWords::from(state_view.words()); + sha_extend(&mut sha_extend_words.0); + let output_words: [Word; SHA_EXTEND_WORDS] = sha_extend_words.into(); + + state_view.write(output_words); + let mem_ops = state_view.mem_ops().to_vec(); + + assert_eq!(mem_ops.len(), SHA_EXTEND_WORDS); + SyscallEffects { + witness: SyscallWitness::new(mem_ops, reg_ops), + next_pc: None, + } +} diff --git a/ceno_emul/src/utils.rs b/ceno_emul/src/utils.rs new file mode 100644 index 000000000..0ca214176 --- /dev/null +++ b/ceno_emul/src/utils.rs @@ -0,0 +1,76 @@ +use itertools::{Itertools, izip}; + +use crate::{Change, EmuContext, VMState, WORD_SIZE, Word, WordAddr, WriteOp}; + +/// Utilities for reading/manipulating a memory segment of fixed length +pub struct MemoryView<'a, const LENGTH: usize> { + vm: &'a VMState, + start: WordAddr, + writes: Option<[Word; LENGTH]>, +} + +impl<'a, const LENGTH: usize> MemoryView<'a, LENGTH> { + /// Creates a new memory segment view + /// Asserts that `start` is a multiple of `WORD_SIZE` + pub fn new(vm: &'a VMState, start: u32) -> Self { + assert!(start % WORD_SIZE as u32 == 0); + // TODO: do we need stricter alignment requirements for keccak (u64 array) + MemoryView { + vm, + start: WordAddr::from(start), + writes: None, + } + } + + pub fn iter_addrs(&self) -> impl Iterator { + (self.start..).take(LENGTH) + } + + pub fn addrs(&self) -> [WordAddr; LENGTH] { + self.iter_addrs().collect_vec().try_into().unwrap() + } + + pub fn iter_words(&self) -> impl Iterator + '_ { + self.iter_addrs().map(|addr| self.vm.peek_memory(addr)) + } + + pub fn words(&self) -> [Word; LENGTH] { + self.iter_words().collect_vec().try_into().unwrap() + } + + pub fn iter_bytes(&self) -> impl Iterator + '_ { + self.iter_words().flat_map(|word| word.to_le_bytes()) + } + + pub fn bytes(&self) -> Vec { + self.iter_bytes().collect_vec() + } + + pub fn write(&mut self, writes: [Word; LENGTH]) { + assert!(self.writes.is_none(), "view can only be written once"); + self.writes = Some(writes); + } + + pub fn mem_ops(&self) -> [WriteOp; LENGTH] { + izip!( + self.addrs(), + self.words(), + self.writes.unwrap_or(self.words()) + ) + .map(|(addr, before, after)| WriteOp { + addr, + value: Change { before, after }, + previous_cycle: 0, // Cycle set later in finalize(). + }) + .collect_vec() + .try_into() + .unwrap() + } + + pub fn debug(&self) { + dbg!(self.start, LENGTH); + dbg!(self.addrs()); + dbg!(self.words()); + dbg!(self.bytes()); + } +} diff --git a/ceno_host/tests/test_elf.rs b/ceno_host/tests/test_elf.rs index 97fcd751b..4268a6f0e 100644 --- a/ceno_host/tests/test_elf.rs +++ b/ceno_host/tests/test_elf.rs @@ -2,8 +2,8 @@ use std::{collections::BTreeSet, iter::from_fn, sync::Arc}; use anyhow::Result; use ceno_emul::{ - CENO_PLATFORM, EmuContext, InsnKind, Platform, Program, StepRecord, VMState, WORD_SIZE, - host_utils::read_all_messages, + CENO_PLATFORM, COORDINATE_WORDS, EmuContext, InsnKind, Platform, Program, SECP256K1_ARG_WORDS, + SHA_EXTEND_WORDS, StepRecord, VMState, WORD_SIZE, WordAddr, host_utils::read_all_messages, }; use ceno_host::CenoStdin; use itertools::{Itertools, enumerate, izip}; @@ -248,8 +248,9 @@ fn test_ceno_rt_keccak() -> Result<()> { // Check the syscall effects. for (witness, expect) in izip!(syscalls, keccak_outs) { - assert_eq!(witness.reg_ops.len(), 1); + assert_eq!(witness.reg_ops.len(), 2); assert_eq!(witness.reg_ops[0].register_index(), Platform::reg_arg0()); + assert_eq!(witness.reg_ops[1].register_index(), Platform::reg_arg1()); assert_eq!(witness.mem_ops.len(), expect.len() * 2); let got = witness @@ -271,6 +272,212 @@ fn test_ceno_rt_keccak() -> Result<()> { Ok(()) } +fn bytes_to_words(bytes: [u8; 65]) -> [u32; 16] { + // ignore the tag byte (specific to the secp repr.) + let mut bytes: [u8; 64] = bytes[1..].try_into().unwrap(); + + // Reverse the order of bytes for each coordinate + bytes[0..32].reverse(); + bytes[32..].reverse(); + std::array::from_fn(|i| u32::from_le_bytes(bytes[4 * i..4 * (i + 1)].try_into().unwrap())) +} + +#[test] +fn test_secp256k1_add() -> Result<()> { + let program_elf = ceno_examples::secp256k1_add_syscall; + let mut state = VMState::new_from_elf(unsafe_platform(), program_elf)?; + let steps = run(&mut state)?; + + let syscalls = steps.iter().filter_map(|step| step.syscall()).collect_vec(); + assert_eq!(syscalls.len(), 1); + + let witness = syscalls[0]; + assert_eq!(witness.reg_ops.len(), 2); + assert_eq!(witness.reg_ops[0].register_index(), Platform::reg_arg0()); + assert_eq!(witness.reg_ops[1].register_index(), Platform::reg_arg1()); + + let p_address = witness.reg_ops[0].value.after; + assert_eq!(p_address, witness.reg_ops[0].value.before); + let p_address: WordAddr = p_address.into(); + + let q_address = witness.reg_ops[1].value.after; + assert_eq!(q_address, witness.reg_ops[1].value.before); + let q_address: WordAddr = q_address.into(); + + const P_PLUS_Q: [u8; 65] = [ + 4, 188, 11, 115, 232, 35, 63, 79, 186, 163, 11, 207, 165, 64, 247, 109, 81, 125, 56, 83, + 131, 221, 140, 154, 19, 186, 109, 173, 9, 127, 142, 169, 219, 108, 17, 216, 218, 125, 37, + 30, 87, 86, 194, 151, 20, 122, 64, 118, 123, 210, 29, 60, 209, 138, 131, 11, 247, 157, 212, + 209, 123, 162, 111, 197, 70, + ]; + let expect = bytes_to_words(P_PLUS_Q); + + assert_eq!(witness.mem_ops.len(), 2 * SECP256K1_ARG_WORDS); + // Expect first half to consist of read/writes on P + for (i, write_op) in witness.mem_ops.iter().take(SECP256K1_ARG_WORDS).enumerate() { + assert_eq!(write_op.addr, p_address + i); + assert_eq!(write_op.value.after, expect[i]); + } + + // Expect second half to consist of reads on Q + for (i, write_op) in witness + .mem_ops + .iter() + .skip(SECP256K1_ARG_WORDS) + .take(SECP256K1_ARG_WORDS) + .enumerate() + { + assert_eq!(write_op.addr, q_address + i); + assert_eq!(write_op.value.after, write_op.value.before); + } + + Ok(()) +} + +#[test] +fn test_secp256k1_double() -> Result<()> { + let program_elf = ceno_examples::secp256k1_double_syscall; + let mut state = VMState::new_from_elf(unsafe_platform(), program_elf)?; + + let steps = run(&mut state)?; + + let syscalls = steps.iter().filter_map(|step| step.syscall()).collect_vec(); + assert_eq!(syscalls.len(), 1); + + let witness = syscalls[0]; + assert_eq!(witness.reg_ops.len(), 2); + assert_eq!(witness.reg_ops[0].register_index(), Platform::reg_arg0()); + + let p_address = witness.reg_ops[0].value.after; + assert_eq!(p_address, witness.reg_ops[0].value.before); + let p_address: WordAddr = p_address.into(); + + const DOUBLE_P: [u8; 65] = [ + 4, 111, 137, 182, 244, 228, 50, 13, 91, 93, 34, 231, 93, 191, 248, 105, 28, 226, 251, 23, + 66, 192, 188, 66, 140, 44, 218, 130, 239, 101, 255, 164, 76, 202, 170, 134, 48, 127, 46, + 14, 9, 192, 64, 102, 67, 163, 33, 48, 157, 140, 217, 10, 97, 231, 183, 28, 129, 177, 185, + 253, 179, 135, 182, 253, 203, + ]; + let expect = bytes_to_words(DOUBLE_P); + + assert_eq!(witness.mem_ops.len(), SECP256K1_ARG_WORDS); + for (i, write_op) in witness.mem_ops.iter().enumerate() { + assert_eq!(write_op.addr, p_address + i); + assert_eq!(write_op.value.after, expect[i]); + } + + Ok(()) +} + +#[test] +fn test_secp256k1_decompress() -> Result<()> { + let program_elf = ceno_examples::secp256k1_decompress_syscall; + let mut state = VMState::new_from_elf(unsafe_platform(), program_elf)?; + + let steps = run(&mut state)?; + + let syscalls = steps.iter().filter_map(|step| step.syscall()).collect_vec(); + assert_eq!(syscalls.len(), 1); + + let witness = syscalls[0]; + assert_eq!(witness.reg_ops.len(), 2); + assert_eq!(witness.reg_ops[0].register_index(), Platform::reg_arg0()); + assert_eq!(witness.reg_ops[1].register_index(), Platform::reg_arg1()); + + let x_address = witness.reg_ops[0].value.after; + assert_eq!(x_address, witness.reg_ops[0].value.before); + let x_address: WordAddr = x_address.into(); + // Y coordinate should be written immediately after X coordinate + // X coordinate takes "half an argument" of words + let y_address = x_address + SECP256K1_ARG_WORDS / 2; + + // Complete decompressed point (X and Y) + let mut decompressed: [u8; 65] = [ + 4, 180, 53, 9, 32, 85, 226, 220, 154, 20, 116, 218, 199, 119, 48, 44, 23, 45, 222, 10, 64, + 50, 63, 8, 121, 191, 244, 141, 0, 37, 117, 182, 133, 190, 160, 239, 131, 180, 166, 242, + 145, 107, 249, 24, 168, 27, 69, 86, 58, 86, 159, 10, 210, 164, 20, 152, 148, 67, 37, 222, + 234, 108, 57, 84, 148, + ]; + + decompressed[33..].reverse(); + + // Writes should cover the Y coordinate, i.e latter half of the repr + let expect = bytes_to_words(decompressed)[8..].to_vec(); + + assert_eq!(witness.mem_ops.len(), 2 * COORDINATE_WORDS); + // Reads on X + for (i, write_op) in witness.mem_ops.iter().take(COORDINATE_WORDS).enumerate() { + assert_eq!(write_op.addr, x_address + i); + assert_eq!(write_op.value.after, write_op.value.before); + } + + // Reads/writes on Y + for (i, write_op) in witness + .mem_ops + .iter() + .skip(COORDINATE_WORDS) + .take(COORDINATE_WORDS) + .enumerate() + { + assert_eq!(write_op.addr, y_address + i); + assert_eq!(write_op.value.after, expect[i]); + } + + Ok(()) +} + +#[test] +fn test_sha256_extend() -> Result<()> { + let program_elf = ceno_examples::sha_extend_syscall; + let mut state = VMState::new_from_elf(unsafe_platform(), program_elf)?; + + let steps = run(&mut state)?; + let syscalls = steps.iter().filter_map(|step| step.syscall()).collect_vec(); + assert_eq!(syscalls.len(), 1); + + let witness = syscalls[0]; + assert_eq!(witness.reg_ops.len(), 2); + assert_eq!(witness.reg_ops[0].register_index(), Platform::reg_arg0()); + assert_eq!(witness.reg_ops[1].register_index(), Platform::reg_arg1()); + + let state_ptr = witness.reg_ops[0].value.after; + assert_eq!(state_ptr, witness.reg_ops[0].value.before); + let state_ptr: WordAddr = state_ptr.into(); + + let expected = [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 34013193, 67559435, 1711661200, + 3020350282, 1447362251, 3118632270, 4004188394, 690615167, 6070360, 1105370215, 2385558114, + 2348232513, 507799627, 2098764358, 5845374, 823657968, 2969863067, 3903496557, 4274682881, + 2059629362, 1849247231, 2656047431, 835162919, 2096647516, 2259195856, 1779072524, + 3152121987, 4210324067, 1557957044, 376930560, 982142628, 3926566666, 4164334963, + 789545383, 1028256580, 2867933222, 3843938318, 1135234440, 390334875, 2025924737, + 3318322046, 3436065867, 652746999, 4261492214, 2543173532, 3334668051, 3166416553, + 634956631, + ]; + + assert_eq!(witness.mem_ops.len(), SHA_EXTEND_WORDS); + + for (i, write_op) in witness.mem_ops.iter().enumerate() { + assert_eq!(write_op.addr, state_ptr + i); + assert_eq!(write_op.value.after, expected[i]); + if i < 16 { + // sanity check: first 16 entries remain unchanged + assert_eq!(write_op.value.before, write_op.value.after); + } + } + + Ok(()) +} + +#[test] +fn test_syscalls_compatibility() -> Result<()> { + let program_elf = ceno_examples::syscalls; + let mut state = VMState::new_from_elf(unsafe_platform(), program_elf)?; + + let _ = run(&mut state)?; + Ok(()) +} + fn unsafe_platform() -> Platform { let mut platform = CENO_PLATFORM; platform.unsafe_ecall_nop = true; diff --git a/ceno_rt/src/syscalls.rs b/ceno_rt/src/syscalls.rs index 1ebef70bc..32a87bfb2 100644 --- a/ceno_rt/src/syscalls.rs +++ b/ceno_rt/src/syscalls.rs @@ -1,17 +1,17 @@ -// Based on https://github.com/succinctlabs/sp1/blob/013c24ea2fa15a0e7ed94f7d11a7ada4baa39ab9/crates/zkvm/entrypoint/src/syscalls/keccak_permute.rs #[cfg(target_os = "zkvm")] use core::arch::asm; pub const KECCAK_PERMUTE: u32 = 0x00_01_01_09; +/// Based on https://github.com/succinctlabs/sp1/blob/013c24ea2fa15a0e7ed94f7d11a7ada4baa39ab9/crates/zkvm/entrypoint/src/syscalls/keccak_permute.rs /// Executes the Keccak256 permutation on the given state. /// -/// ### Safety +/// ### Spec /// -/// The caller must ensure that `state` is valid pointer to data that is aligned along a four -/// byte boundary. +/// - The caller must ensure that `state` is valid pointer to data that is aligned along a four +/// byte boundary. #[allow(unused_variables)] -pub fn keccak_permute(state: &mut [u64; 25]) { +pub fn syscall_keccak_permute(state: &mut [u64; 25]) { #[cfg(target_os = "zkvm")] unsafe { asm!( @@ -24,3 +24,113 @@ pub fn keccak_permute(state: &mut [u64; 25]) { #[cfg(not(target_os = "zkvm"))] unreachable!() } + +pub const SECP256K1_ADD: u32 = 0x00_01_01_0A; +/// Based on https://github.com/succinctlabs/sp1/blob/dbe622aa4a6a33c88d76298c2a29a1d7ef7e90df/crates/zkvm/entrypoint/src/syscalls/secp256k1.rs +/// Adds two Secp256k1 points. +/// +/// ### Spec +/// - The caller must ensure that `p` and `q` are valid pointers to data that is aligned along a four +/// byte boundary. +/// - Point representation: the first `8` words describe the X-coordinate, the last `8` describe the Y-coordinate. Each +/// coordinate is encoded as follows: its `32` bytes are ordered from lowest significance to highest and then stored into little endian words. +/// For example, the word `p[0]` contains the least significant `4` bytes of `X` and their significance is maintained w.r.t `p[0]` +/// - The caller must ensure that `p` and `q` are valid points on the `secp256k1` curve, and that `p` and `q` are not equal to each other. +/// - The result is stored in the first point. +#[allow(unused_variables)] +pub fn syscall_secp256k1_add(p: *mut [u32; 16], q: *mut [u32; 16]) { + #[cfg(target_os = "zkvm")] + unsafe { + asm!( + "ecall", + in("t0") SECP256K1_ADD, + in("a0") p, + in("a1") q + ); + } + + #[cfg(not(target_os = "zkvm"))] + unreachable!() +} + +pub const SECP256K1_DOUBLE: u32 = 0x00_00_01_0B; + +/// Based on: https://github.com/succinctlabs/sp1/blob/dbe622aa4a6a33c88d76298c2a29a1d7ef7e90df/crates/zkvm/entrypoint/src/syscalls/secp256k1.rs +/// Double a Secp256k1 point. +/// +/// ### Spec +/// - The caller must ensure that `p` is a valid pointer to data that is aligned along a four byte boundary. +/// - Point representation: the first `8` words describe the X-coordinate, the last `8` describe the Y-coordinate. Each +/// coordinate is encoded as follows: its `32` bytes are ordered from lowest significance to highest and then stored into little endian words. +/// For example, the word `p[0]` contains the least significant `4` bytes of `X` and their significance is maintained w.r.t `p[0]` +/// - The result is stored in p +#[allow(unused_variables)] +pub fn syscall_secp256k1_double(p: *mut [u32; 16]) { + #[cfg(target_os = "zkvm")] + unsafe { + asm!( + "ecall", + in("t0") SECP256K1_DOUBLE, + in("a0") p, + in("a1") 0 + ); + } + + #[cfg(not(target_os = "zkvm"))] + unreachable!() +} + +pub const SECP256K1_DECOMPRESS: u32 = 0x00_00_01_0C; + +/// Decompresses a compressed Secp256k1 point. +/// +/// ### Spec +/// - The input array should be 64 bytes long, with the first 32 bytes containing the X coordinate in +/// big-endian format. Note that this byte ordering is different than the one implied in the spec +/// of the `add` and `double` operations +/// - The second half of the input will be overwritten with the Y coordinate of the +/// decompressed point in big-endian format using the point's parity (is_odd). +/// - The caller must ensure that `point` is valid pointer to data that is aligned along a four byte +/// boundary. +#[allow(unused_variables)] +pub fn syscall_secp256k1_decompress(point: &mut [u8; 64], is_odd: bool) { + #[cfg(target_os = "zkvm")] + { + let p = point.as_mut_ptr(); + unsafe { + asm!( + "ecall", + in("t0") SECP256K1_DECOMPRESS, + in("a0") p, + in("a1") is_odd as u8 + ); + } + } + + #[cfg(not(target_os = "zkvm"))] + unreachable!() +} + +pub const SHA_EXTEND: u32 = 0x00_30_01_05; +/// Based on: https://github.com/succinctlabs/sp1/blob/2aed8fea16a67a5b2983ffc471b2942c2f2512c8/crates/zkvm/entrypoint/src/syscalls/sha_extend.rs#L12 +/// Executes the SHA256 extend operation on the given word array. +/// +/// ### Safety +/// +/// The caller must ensure that `w` is valid pointer to data that is aligned along a four byte +/// boundary. +#[allow(unused_variables)] +pub fn syscall_sha256_extend(w: *mut [u32; 64]) { + #[cfg(target_os = "zkvm")] + unsafe { + asm!( + "ecall", + in("t0") SHA_EXTEND, + in("a0") w, + in("a1") 0 + ); + } + + #[cfg(not(target_os = "zkvm"))] + unreachable!() +} diff --git a/ceno_zkvm/src/instructions/riscv/dummy/dummy_ecall.rs b/ceno_zkvm/src/instructions/riscv/dummy/dummy_ecall.rs index cd3156387..80057c991 100644 --- a/ceno_zkvm/src/instructions/riscv/dummy/dummy_ecall.rs +++ b/ceno_zkvm/src/instructions/riscv/dummy/dummy_ecall.rs @@ -1,6 +1,6 @@ use std::marker::PhantomData; -use ceno_emul::{Change, InsnKind, KECCAK_WORDS, StepRecord, WORD_SIZE}; +use ceno_emul::{Change, InsnKind, StepRecord, SyscallSpec, WORD_SIZE}; use ff_ext::ExtensionField; use itertools::Itertools; @@ -18,29 +18,13 @@ use crate::{ witness::LkMultiplicity, }; -trait EcallSpec { - const NAME: &'static str; - - const REG_OPS_COUNT: usize; - const MEM_OPS_COUNT: usize; -} - -pub struct KeccakSpec; - -impl EcallSpec for KeccakSpec { - const NAME: &'static str = "KECCAK"; - - const REG_OPS_COUNT: usize = 1; - const MEM_OPS_COUNT: usize = KECCAK_WORDS; -} - /// LargeEcallDummy can handle any instruction and produce its effects, /// including multiple memory operations. /// /// Unsafe: The content is not constrained. pub struct LargeEcallDummy(PhantomData<(E, S)>); -impl Instruction for LargeEcallDummy { +impl Instruction for LargeEcallDummy { type InstructionConfig = LargeEcallConfig; fn name() -> String { diff --git a/ceno_zkvm/src/instructions/riscv/dummy/test.rs b/ceno_zkvm/src/instructions/riscv/dummy/test.rs index 199d8a3a8..49c83245a 100644 --- a/ceno_zkvm/src/instructions/riscv/dummy/test.rs +++ b/ceno_zkvm/src/instructions/riscv/dummy/test.rs @@ -1,5 +1,4 @@ -use ceno_emul::{Change, InsnKind, StepRecord, encode_rv32}; -use dummy_ecall::KeccakSpec; +use ceno_emul::{Change, InsnKind, KeccakSpec, StepRecord, encode_rv32}; use goldilocks::GoldilocksExt2; use super::*; diff --git a/ceno_zkvm/src/instructions/riscv/rv32im.rs b/ceno_zkvm/src/instructions/riscv/rv32im.rs index fefb18ba5..da1df2caf 100644 --- a/ceno_zkvm/src/instructions/riscv/rv32im.rs +++ b/ceno_zkvm/src/instructions/riscv/rv32im.rs @@ -25,8 +25,10 @@ use crate::{ }; use ceno_emul::{ InsnKind::{self, *}, - Platform, StepRecord, + KeccakSpec, Platform, Secp256k1AddSpec, Secp256k1DecompressSpec, Secp256k1DoubleSpec, + Sha256ExtendSpec, StepRecord, SyscallSpec, }; +use dummy::LargeEcallDummy; use ecall::EcallDummy; use ff_ext::ExtensionField; use itertools::{Itertools, izip}; @@ -440,12 +442,37 @@ pub struct GroupedSteps(BTreeMap>); /// Fake version of what is missing in Rv32imConfig, for some tests. pub struct DummyExtraConfig { ecall_config: as Instruction>::InstructionConfig, + keccak_config: as Instruction>::InstructionConfig, + secp256k1_add_config: + as Instruction>::InstructionConfig, + secp256k1_double_config: + as Instruction>::InstructionConfig, + secp256k1_decompress_config: + as Instruction>::InstructionConfig, + sha256_extend_config: + as Instruction>::InstructionConfig, } impl DummyExtraConfig { pub fn construct_circuits(cs: &mut ZKVMConstraintSystem) -> Self { let ecall_config = cs.register_opcode_circuit::>(); - Self { ecall_config } + let keccak_config = cs.register_opcode_circuit::>(); + let secp256k1_add_config = + cs.register_opcode_circuit::>(); + let secp256k1_double_config = + cs.register_opcode_circuit::>(); + let secp256k1_decompress_config = + cs.register_opcode_circuit::>(); + let sha256_extend_config = + cs.register_opcode_circuit::>(); + Self { + ecall_config, + keccak_config, + secp256k1_add_config, + secp256k1_double_config, + secp256k1_decompress_config, + sha256_extend_config, + } } pub fn generate_fixed_traces( @@ -454,6 +481,11 @@ impl DummyExtraConfig { fixed: &mut ZKVMFixedTraces, ) { fixed.register_opcode_circuit::>(cs); + fixed.register_opcode_circuit::>(cs); + fixed.register_opcode_circuit::>(cs); + fixed.register_opcode_circuit::>(cs); + fixed.register_opcode_circuit::>(cs); + fixed.register_opcode_circuit::>(cs); } pub fn assign_opcode_circuit( @@ -464,17 +496,52 @@ impl DummyExtraConfig { ) -> Result<(), ZKVMError> { let mut steps = steps.0; - macro_rules! assign_opcode { - ($insn_kind:ident,$instruction:ty,$config:ident) => { - witness.assign_opcode_circuit::<$instruction>( - cs, - &self.$config, - steps.remove(&($insn_kind)).unwrap(), - )?; - }; + let mut keccak_steps = Vec::new(); + let mut secp256k1_add_steps = Vec::new(); + let mut secp256k1_double_steps = Vec::new(); + let mut secp256k1_decompress_steps = Vec::new(); + let mut sha256_extend_steps = Vec::new(); + let mut other_steps = Vec::new(); + + if let Some(ecall_steps) = steps.remove(&ECALL) { + for step in ecall_steps { + match step.rs1().unwrap().value { + KeccakSpec::CODE => keccak_steps.push(step), + Secp256k1AddSpec::CODE => secp256k1_add_steps.push(step), + Secp256k1DoubleSpec::CODE => secp256k1_double_steps.push(step), + Secp256k1DecompressSpec::CODE => secp256k1_decompress_steps.push(step), + Sha256ExtendSpec::CODE => sha256_extend_steps.push(step), + _ => other_steps.push(step), + } + } } - assign_opcode!(ECALL, EcallDummy, ecall_config); + witness.assign_opcode_circuit::>( + cs, + &self.keccak_config, + keccak_steps, + )?; + witness.assign_opcode_circuit::>( + cs, + &self.secp256k1_add_config, + secp256k1_add_steps, + )?; + witness.assign_opcode_circuit::>( + cs, + &self.secp256k1_double_config, + secp256k1_double_steps, + )?; + witness.assign_opcode_circuit::>( + cs, + &self.secp256k1_decompress_config, + secp256k1_decompress_steps, + )?; + witness.assign_opcode_circuit::>( + cs, + &self.sha256_extend_config, + sha256_extend_steps, + )?; + witness.assign_opcode_circuit::>(cs, &self.ecall_config, other_steps)?; let _ = steps.remove(&INVALID); let keys: Vec<&InsnKind> = steps.keys().collect::>(); diff --git a/examples/examples/ceno_rt_keccak.rs b/examples/examples/ceno_rt_keccak.rs index 57d9f76ed..fa76acf7d 100644 --- a/examples/examples/ceno_rt_keccak.rs +++ b/examples/examples/ceno_rt_keccak.rs @@ -3,7 +3,7 @@ //! Iterate multiple times and log the state after each iteration. extern crate ceno_rt; -use ceno_rt::{info_out, syscalls::keccak_permute}; +use ceno_rt::{info_out, syscalls::syscall_keccak_permute}; use core::slice; const ITERATIONS: usize = 3; @@ -12,7 +12,7 @@ fn main() { let mut state = [0_u64; 25]; for _ in 0..ITERATIONS { - keccak_permute(&mut state); + syscall_keccak_permute(&mut state); log_state(&state); } } diff --git a/examples/examples/secp256k1_add_syscall.rs b/examples/examples/secp256k1_add_syscall.rs new file mode 100644 index 000000000..07bb054a1 --- /dev/null +++ b/examples/examples/secp256k1_add_syscall.rs @@ -0,0 +1,49 @@ +// Test addition of two curve points. Assert result inside the guest +extern crate ceno_rt; +use ceno_rt::syscalls::syscall_secp256k1_add; + +// Byte repr. of points from https://docs.rs/secp/latest/secp/#arithmetic-1 +const P: [u8; 65] = [ + 4, 180, 53, 9, 32, 85, 226, 220, 154, 20, 116, 218, 199, 119, 48, 44, 23, 45, 222, 10, 64, 50, + 63, 8, 121, 191, 244, 141, 0, 37, 117, 182, 133, 190, 160, 239, 131, 180, 166, 242, 145, 107, + 249, 24, 168, 27, 69, 86, 58, 86, 159, 10, 210, 164, 20, 152, 148, 67, 37, 222, 234, 108, 57, + 84, 148, +]; +const Q: [u8; 65] = [ + 4, 117, 102, 61, 142, 169, 5, 99, 112, 146, 4, 241, 177, 255, 72, 34, 34, 12, 251, 37, 126, + 213, 96, 38, 9, 40, 35, 20, 186, 78, 125, 73, 44, 215, 29, 243, 127, 197, 147, 216, 206, 110, + 116, 63, 96, 72, 143, 182, 205, 11, 234, 96, 127, 206, 19, 1, 103, 103, 219, 255, 25, 229, 210, + 4, 141, +]; +const P_PLUS_Q: [u8; 65] = [ + 4, 188, 11, 115, 232, 35, 63, 79, 186, 163, 11, 207, 165, 64, 247, 109, 81, 125, 56, 83, 131, + 221, 140, 154, 19, 186, 109, 173, 9, 127, 142, 169, 219, 108, 17, 216, 218, 125, 37, 30, 87, + 86, 194, 151, 20, 122, 64, 118, 123, 210, 29, 60, 209, 138, 131, 11, 247, 157, 212, 209, 123, + 162, 111, 197, 70, +]; + +type DecompressedPoint = [u32; 16]; + +/// `bytes` is expected to contain the uncompressed representation of +/// a curve point, as described in https://docs.rs/secp/latest/secp/struct.Point.html +/// +/// The return value is an array of words compatible with the sp1 syscall for `add` and `double` +/// Notably, these words should encode the X and Y coordinates of the point +/// in "little endian" and not "big endian" as is the case of secp +fn bytes_to_words(bytes: [u8; 65]) -> [u32; 16] { + // ignore the tag byte (specific to the secp repr.) + let mut bytes: [u8; 64] = bytes[1..].try_into().unwrap(); + + // Reverse the order of bytes for each coordinate + bytes[0..32].reverse(); + bytes[32..].reverse(); + std::array::from_fn(|i| u32::from_le_bytes(bytes[4 * i..4 * (i + 1)].try_into().unwrap())) +} +fn main() { + let mut p: DecompressedPoint = bytes_to_words(P); + let mut q: DecompressedPoint = bytes_to_words(Q); + let p_plus_q: DecompressedPoint = bytes_to_words(P_PLUS_Q); + + syscall_secp256k1_add(&mut p, &mut q); + assert_eq!(p, p_plus_q); +} diff --git a/examples/examples/secp256k1_decompress_syscall.rs b/examples/examples/secp256k1_decompress_syscall.rs new file mode 100644 index 000000000..6666148ac --- /dev/null +++ b/examples/examples/secp256k1_decompress_syscall.rs @@ -0,0 +1,32 @@ +// Test decompression of curve point. Assert result inside the guest +extern crate ceno_rt; +use ceno_rt::syscalls::syscall_secp256k1_decompress; + +// Byte repr. of point P1 from https://docs.rs/secp/latest/secp/#arithmetic-1 +const COMPRESSED: [u8; 33] = [ + 2, 180, 53, 9, 32, 85, 226, 220, 154, 20, 116, 218, 199, 119, 48, 44, 23, 45, 222, 10, 64, 50, + 63, 8, 121, 191, 244, 141, 0, 37, 117, 182, 133, +]; +const DECOMPRESSED: [u8; 64] = [ + 180, 53, 9, 32, 85, 226, 220, 154, 20, 116, 218, 199, 119, 48, 44, 23, 45, 222, 10, 64, 50, 63, + 8, 121, 191, 244, 141, 0, 37, 117, 182, 133, 190, 160, 239, 131, 180, 166, 242, 145, 107, 249, + 24, 168, 27, 69, 86, 58, 86, 159, 10, 210, 164, 20, 152, 148, 67, 37, 222, 234, 108, 57, 84, + 148, +]; + +fn main() { + let is_odd = match COMPRESSED[0] { + 2 => false, + 3 => true, + _ => panic!("parity byte should be 2 or 3"), + }; + + // ignore parity byte, append 32 zero bytes for writing Y + let mut compressed_with_space: [u8; 64] = [COMPRESSED[1..].to_vec(), vec![0; 32]] + .concat() + .try_into() + .unwrap(); + + syscall_secp256k1_decompress(&mut compressed_with_space, is_odd); + assert_eq!(compressed_with_space, DECOMPRESSED); +} diff --git a/examples/examples/secp256k1_double_syscall.rs b/examples/examples/secp256k1_double_syscall.rs new file mode 100644 index 000000000..8585a8e19 --- /dev/null +++ b/examples/examples/secp256k1_double_syscall.rs @@ -0,0 +1,43 @@ +// Test addition of two curve points. Assert result inside the guest +extern crate ceno_rt; +use ceno_rt::syscalls::syscall_secp256k1_double; + +// Byte repr. of points from https://docs.rs/secp/latest/secp/#arithmetic-1 +const P: [u8; 65] = [ + 4, 180, 53, 9, 32, 85, 226, 220, 154, 20, 116, 218, 199, 119, 48, 44, 23, 45, 222, 10, 64, 50, + 63, 8, 121, 191, 244, 141, 0, 37, 117, 182, 133, 190, 160, 239, 131, 180, 166, 242, 145, 107, + 249, 24, 168, 27, 69, 86, 58, 86, 159, 10, 210, 164, 20, 152, 148, 67, 37, 222, 234, 108, 57, + 84, 148, +]; + +const DOUBLE_P: [u8; 65] = [ + 4, 111, 137, 182, 244, 228, 50, 13, 91, 93, 34, 231, 93, 191, 248, 105, 28, 226, 251, 23, 66, + 192, 188, 66, 140, 44, 218, 130, 239, 101, 255, 164, 76, 202, 170, 134, 48, 127, 46, 14, 9, + 192, 64, 102, 67, 163, 33, 48, 157, 140, 217, 10, 97, 231, 183, 28, 129, 177, 185, 253, 179, + 135, 182, 253, 203, +]; + +type DecompressedPoint = [u32; 16]; + +/// `bytes` is expected to contain the uncompressed representation of +/// a curve point, as described in https://docs.rs/secp/latest/secp/struct.Point.html +/// +/// The return value is an array of words compatible with the sp1 syscall for `add` and `double` +/// Notably, these words should encode the X and Y coordinates of the point +/// in "little endian" and not "big endian" as is the case of secp +fn bytes_to_words(bytes: [u8; 65]) -> [u32; 16] { + // ignore the tag byte (specific to the secp repr.) + let mut bytes: [u8; 64] = bytes[1..].try_into().unwrap(); + + // Reverse the order of bytes for each coordinate + bytes[0..32].reverse(); + bytes[32..].reverse(); + std::array::from_fn(|i| u32::from_le_bytes(bytes[4 * i..4 * (i + 1)].try_into().unwrap())) +} +fn main() { + let mut p: DecompressedPoint = bytes_to_words(P); + let double_p: DecompressedPoint = bytes_to_words(DOUBLE_P); + + syscall_secp256k1_double(&mut p); + assert_eq!(p, double_p); +} diff --git a/examples/examples/sha_extend_syscall.rs b/examples/examples/sha_extend_syscall.rs new file mode 100644 index 000000000..9f9021ce9 --- /dev/null +++ b/examples/examples/sha_extend_syscall.rs @@ -0,0 +1,23 @@ +// Test addition of two curve points. Assert result inside the guest +extern crate ceno_rt; +use std::array; + +use ceno_rt::syscalls::syscall_sha256_extend; + +fn main() { + let mut words: [u32; 64] = array::from_fn(|i| i as u32); + + let expected = [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 34013193, 67559435, 1711661200, + 3020350282, 1447362251, 3118632270, 4004188394, 690615167, 6070360, 1105370215, 2385558114, + 2348232513, 507799627, 2098764358, 5845374, 823657968, 2969863067, 3903496557, 4274682881, + 2059629362, 1849247231, 2656047431, 835162919, 2096647516, 2259195856, 1779072524, + 3152121987, 4210324067, 1557957044, 376930560, 982142628, 3926566666, 4164334963, + 789545383, 1028256580, 2867933222, 3843938318, 1135234440, 390334875, 2025924737, + 3318322046, 3436065867, 652746999, 4261492214, 2543173532, 3334668051, 3166416553, + 634956631, + ]; + + syscall_sha256_extend(&mut words); + assert_eq!(words, expected); +} diff --git a/examples/examples/syscalls.rs b/examples/examples/syscalls.rs new file mode 100644 index 000000000..d5e6a486c --- /dev/null +++ b/examples/examples/syscalls.rs @@ -0,0 +1,160 @@ +use std::array; + +use ceno_rt::syscalls::{ + syscall_keccak_permute, syscall_secp256k1_add, syscall_secp256k1_decompress, + syscall_secp256k1_double, syscall_sha256_extend, +}; + +/// One unit test for each implemented syscall +/// Meant to be used identically in a sp1 guest to confirm compatibility +pub fn test_syscalls() { + /// `bytes` is expected to contain the uncompressed representation of + /// a curve point, as described in https://docs.rs/secp/latest/secp/struct.Point.html + /// + /// The return value is an array of words compatible with the sp1 syscall for `add` and `double` + /// Notably, these words should encode the X and Y coordinates of the point + /// in "little endian" and not "big endian" as is the case of secp + fn bytes_to_words(bytes: [u8; 65]) -> [u32; 16] { + // ignore the tag byte (specific to the secp repr.) + let mut bytes: [u8; 64] = bytes[1..].try_into().unwrap(); + + // Reverse the order of bytes for each coordinate + bytes[0..32].reverse(); + bytes[32..].reverse(); + std::array::from_fn(|i| u32::from_le_bytes(bytes[4 * i..4 * (i + 1)].try_into().unwrap())) + } + { + const P: [u8; 65] = [ + 4, 180, 53, 9, 32, 85, 226, 220, 154, 20, 116, 218, 199, 119, 48, 44, 23, 45, 222, 10, + 64, 50, 63, 8, 121, 191, 244, 141, 0, 37, 117, 182, 133, 190, 160, 239, 131, 180, 166, + 242, 145, 107, 249, 24, 168, 27, 69, 86, 58, 86, 159, 10, 210, 164, 20, 152, 148, 67, + 37, 222, 234, 108, 57, 84, 148, + ]; + const Q: [u8; 65] = [ + 4, 117, 102, 61, 142, 169, 5, 99, 112, 146, 4, 241, 177, 255, 72, 34, 34, 12, 251, 37, + 126, 213, 96, 38, 9, 40, 35, 20, 186, 78, 125, 73, 44, 215, 29, 243, 127, 197, 147, + 216, 206, 110, 116, 63, 96, 72, 143, 182, 205, 11, 234, 96, 127, 206, 19, 1, 103, 103, + 219, 255, 25, 229, 210, 4, 141, + ]; + const P_PLUS_Q: [u8; 65] = [ + 4, 188, 11, 115, 232, 35, 63, 79, 186, 163, 11, 207, 165, 64, 247, 109, 81, 125, 56, + 83, 131, 221, 140, 154, 19, 186, 109, 173, 9, 127, 142, 169, 219, 108, 17, 216, 218, + 125, 37, 30, 87, 86, 194, 151, 20, 122, 64, 118, 123, 210, 29, 60, 209, 138, 131, 11, + 247, 157, 212, 209, 123, 162, 111, 197, 70, + ]; + + const DOUBLE_P: [u8; 65] = [ + 4, 111, 137, 182, 244, 228, 50, 13, 91, 93, 34, 231, 93, 191, 248, 105, 28, 226, 251, + 23, 66, 192, 188, 66, 140, 44, 218, 130, 239, 101, 255, 164, 76, 202, 170, 134, 48, + 127, 46, 14, 9, 192, 64, 102, 67, 163, 33, 48, 157, 140, 217, 10, 97, 231, 183, 28, + 129, 177, 185, 253, 179, 135, 182, 253, 203, + ]; + { + let mut p = bytes_to_words(P); + let mut q = bytes_to_words(Q); + let p_plus_q = bytes_to_words(P_PLUS_Q); + syscall_secp256k1_add(&mut p, &mut q); + + assert!(p == p_plus_q); + } + + { + let mut p = bytes_to_words(P); + let double_p = bytes_to_words(DOUBLE_P); + + syscall_secp256k1_double(&mut p); + assert!(p == double_p); + } + } + + { + const COMPRESSED: [u8; 33] = [ + 2, 180, 53, 9, 32, 85, 226, 220, 154, 20, 116, 218, 199, 119, 48, 44, 23, 45, 222, 10, + 64, 50, 63, 8, 121, 191, 244, 141, 0, 37, 117, 182, 133, + ]; + const DECOMPRESSED: [u8; 64] = [ + 180, 53, 9, 32, 85, 226, 220, 154, 20, 116, 218, 199, 119, 48, 44, 23, 45, 222, 10, 64, + 50, 63, 8, 121, 191, 244, 141, 0, 37, 117, 182, 133, 190, 160, 239, 131, 180, 166, 242, + 145, 107, 249, 24, 168, 27, 69, 86, 58, 86, 159, 10, 210, 164, 20, 152, 148, 67, 37, + 222, 234, 108, 57, 84, 148, + ]; + + let is_odd = match COMPRESSED[0] { + 2 => false, + 3 => true, + _ => panic!("parity byte should be 2 or 3"), + }; + + // ignore parity byte, append 32 zero bytes for writing Y + let mut compressed_with_space: [u8; 64] = [COMPRESSED[1..].to_vec(), vec![0; 32]] + .concat() + .try_into() + .unwrap(); + + // Note that in the case of the `decompress` syscall the X-coordinate which is part of + // the compressed representation has type [u8; 64] and expects the bytes + // to be "big-endian". + // + // Contrast with the format used for `add` and `double`, where arrays of words are used + // and "little-endian" ordering is expected. + syscall_secp256k1_decompress(&mut compressed_with_space, is_odd); + assert!(compressed_with_space == DECOMPRESSED); + } + + { + let mut state = [0u64; 25]; + syscall_keccak_permute(&mut state); + + const KECCAK_ON_ZEROS: [u64; 25] = [ + 17376452488221285863, + 9571781953733019530, + 15391093639620504046, + 13624874521033984333, + 10027350355371872343, + 18417369716475457492, + 10448040663659726788, + 10113917136857017974, + 12479658147685402012, + 3500241080921619556, + 16959053435453822517, + 12224711289652453635, + 9342009439668884831, + 4879704952849025062, + 140226327413610143, + 424854978622500449, + 7259519967065370866, + 7004910057750291985, + 13293599522548616907, + 10105770293752443592, + 10668034807192757780, + 1747952066141424100, + 1654286879329379778, + 8500057116360352059, + 16929593379567477321, + ]; + + assert!(state == KECCAK_ON_ZEROS); + } + + { + let mut words: [u32; 64] = array::from_fn(|i| i as u32); + + let expected = [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 34013193, 67559435, 1711661200, + 3020350282, 1447362251, 3118632270, 4004188394, 690615167, 6070360, 1105370215, + 2385558114, 2348232513, 507799627, 2098764358, 5845374, 823657968, 2969863067, + 3903496557, 4274682881, 2059629362, 1849247231, 2656047431, 835162919, 2096647516, + 2259195856, 1779072524, 3152121987, 4210324067, 1557957044, 376930560, 982142628, + 3926566666, 4164334963, 789545383, 1028256580, 2867933222, 3843938318, 1135234440, + 390334875, 2025924737, 3318322046, 3436065867, 652746999, 4261492214, 2543173532, + 3334668051, 3166416553, 634956631, + ]; + + syscall_sha256_extend(&mut words); + assert_eq!(words, expected); + } +} + +fn main() { + test_syscalls(); +}