From 5804f23298b832d956070c8f4f656f8c039e788e Mon Sep 17 00:00:00 2001 From: smtmfft <99081233+smtmfft@users.noreply.github.com> Date: Wed, 25 Dec 2024 01:20:45 -0800 Subject: [PATCH] fix(raiko): avoid duplicate image uploads (#439) * fix(raiko): avoid duplicate image uploads Signed-off-by: smtmfft * revert risc0 bonsai change as it checks the image id Signed-off-by: smtmfft --------- Signed-off-by: smtmfft --- Cargo.lock | 1 + Cargo.toml | 1 + provers/sp1/driver/Cargo.toml | 2 +- provers/sp1/driver/src/lib.rs | 85 ++++++++++++++++++++++++----------- 4 files changed, 63 insertions(+), 26 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 338c5ec8..8573722a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8772,6 +8772,7 @@ dependencies = [ "cargo_metadata 0.18.1", "cfg-if", "chrono", + "dashmap", "dotenv", "once_cell", "raiko-lib", diff --git a/Cargo.toml b/Cargo.toml index e956cc0a..0aad4880 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -140,6 +140,7 @@ tokio-util = { version = "0.7.11" } reqwest = { version = "0.11.22", features = ["json"] } url = "2.5.0" async-trait = "0.1.80" +dashmap = "5.5.3" # crypto kzg = { package = "rust-kzg-zkcrypto", git = "https://github.com/ceciliaz030/rust-kzg.git", branch = "brecht/sp1-patch", default-features = false } diff --git a/provers/sp1/driver/Cargo.toml b/provers/sp1/driver/Cargo.toml index 18ca5794..f9fef1ac 100644 --- a/provers/sp1/driver/Cargo.toml +++ b/provers/sp1/driver/Cargo.toml @@ -35,7 +35,7 @@ bincode = { workspace = true } reth-primitives = { workspace = true } tokio = { workspace = true, optional = true } tracing = { workspace = true, optional = true } - +dashmap = { workspace = true } [build-dependencies] sp1-helper = { workspace = true, optional = true } diff --git a/provers/sp1/driver/src/lib.rs b/provers/sp1/driver/src/lib.rs index 35775a54..ddc691e7 100644 --- a/provers/sp1/driver/src/lib.rs +++ b/provers/sp1/driver/src/lib.rs @@ -1,6 +1,8 @@ #![cfg(feature = "enable")] #![feature(iter_advance_by)] +use dashmap::DashMap; +use once_cell::sync::Lazy; use raiko_lib::{ input::{ AggregationGuestInput, AggregationGuestOutput, GuestInput, GuestOutput, @@ -15,12 +17,15 @@ use serde::{Deserialize, Serialize}; use serde_with::serde_as; use sp1_sdk::{ action, - network::client::NetworkClient, - network::proto::network::{ProofMode, UnclaimReason}, - NetworkProverV1 as NetworkProver, SP1Proof, SP1ProofWithPublicValues, SP1VerifyingKey, + network::{ + client::NetworkClient, + proto::network::{ProofMode, UnclaimReason}, + }, + NetworkProverV1 as NetworkProver, SP1Proof, SP1ProofWithPublicValues, SP1ProvingKey, + SP1VerifyingKey, }; use sp1_sdk::{HashableKey, ProverClient, SP1Stdin}; -use std::{borrow::BorrowMut, env, time::Duration}; +use std::{borrow::BorrowMut, env, sync::Arc, time::Duration}; use tracing::{debug, error, info}; mod proof_verify; @@ -63,7 +68,7 @@ impl From for ProofMode { } } -#[derive(Clone, Debug, Serialize, Deserialize)] +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq, Hash)] #[serde(rename_all = "lowercase")] pub enum ProverMode { Mock, @@ -99,6 +104,17 @@ pub struct Sp1Response { pub struct Sp1Prover; +#[derive(Clone)] +struct Sp1ProverClient { + pub(crate) client: Arc, + pub(crate) pk: SP1ProvingKey, + pub(crate) vk: SP1VerifyingKey, +} + +//TODO: use prover object to save such local storage members. +static BLOCK_PROOF_CLIENT: Lazy> = Lazy::new(DashMap::new); +static AGGREGATION_CLIENT: Lazy> = Lazy::new(DashMap::new); + impl Prover for Sp1Prover { async fn run( input: GuestInput, @@ -110,21 +126,28 @@ impl Prover for Sp1Prover { let mode = param.prover.clone().unwrap_or_else(get_env_mock); println!("param: {param:?}"); - let mut stdin = SP1Stdin::new(); stdin.write(&input); - // Generate the proof for the given program. - let client = param - .prover - .map(|mode| match mode { - ProverMode::Mock => ProverClient::mock(), - ProverMode::Local => ProverClient::local(), - ProverMode::Network => ProverClient::network(), + let Sp1ProverClient { client, pk, vk } = BLOCK_PROOF_CLIENT + .entry(mode.clone()) + .or_insert_with(|| { + let base_client = match mode { + ProverMode::Mock => ProverClient::mock(), + ProverMode::Local => ProverClient::local(), + ProverMode::Network => ProverClient::network(), + }; + + let client = Arc::new(base_client); + let (pk, vk) = client.setup(ELF); + info!( + "new client and setup() for block {:?}.", + output.header.number + ); + Sp1ProverClient { client, pk, vk } }) - .unwrap_or_else(ProverClient::new); + .clone(); - let (pk, vk) = client.setup(ELF); info!( "Sp1 Prover: block {:?} with vk {:?}", output.header.number, @@ -214,7 +237,7 @@ impl Prover for Sp1Prover { Sp1Response { proof: proof_string, sp1_proof: Some(prove_result), - vkey: Some(vk), + vkey: Some(vk.clone()), } .into(), ) @@ -287,16 +310,28 @@ impl Prover for Sp1Prover { } // Generate the proof for the given program. - let client = param - .prover - .map(|mode| match mode { - ProverMode::Mock => ProverClient::mock(), - ProverMode::Local => ProverClient::local(), - ProverMode::Network => ProverClient::network(), + let Sp1ProverClient { client, pk, vk } = AGGREGATION_CLIENT + .entry(param.prover.clone().unwrap_or_else(get_env_mock)) + .or_insert_with(|| { + let base_client = param + .prover + .map(|mode| match mode { + ProverMode::Mock => ProverClient::mock(), + ProverMode::Local => ProverClient::local(), + ProverMode::Network => ProverClient::network(), + }) + .unwrap_or_else(ProverClient::new); + + let client = Arc::new(base_client); + let (pk, vk) = client.setup(AGGREGATION_ELF); + info!( + "new client and setup() for aggregation based on {:?} proofs with vk {:?}", + input.proofs.len(), + vk.bytes32() + ); + Sp1ProverClient { client, pk, vk } }) - .unwrap_or_else(ProverClient::new); - - let (pk, vk) = client.setup(AGGREGATION_ELF); + .clone(); info!( "sp1 aggregate: {:?} based {:?} blocks with vk {:?}", reth_primitives::hex::encode_prefixed(stark_vk.hash_bytes()),