Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(raiko): avoid duplicate image uploads #439

Merged
merged 3 commits into from
Dec 25, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ tokio-util = { version = "0.7.11" }
reqwest = { version = "0.11.22", features = ["json"] }
url = "2.5.0"
async-trait = "0.1.80"
dashmap = "5.5.3"

# crypto
kzg = { package = "rust-kzg-zkcrypto", git = "https://github.com/ceciliaz030/rust-kzg.git", branch = "brecht/sp1-patch", default-features = false }
Expand Down
31 changes: 29 additions & 2 deletions provers/risc0/driver/src/bonsai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use crate::{
use alloy_primitives::B256;
use bonsai_sdk::blocking::{Client, SessionId};
use log::{debug, error, info, warn};
use once_cell::sync::OnceCell;
use raiko_lib::{
primitives::keccak::keccak,
prover::{IdWrite, ProofKey, ProverError, ProverResult},
Expand All @@ -19,6 +20,7 @@ use std::{
fmt::Debug,
fs,
path::{Path, PathBuf},
sync::Arc,
};
use tokio::time::{sleep as tokio_async_sleep, Duration};

Expand Down Expand Up @@ -263,6 +265,30 @@ pub async fn cancel_proof(uuid: String) -> anyhow::Result<()> {
Ok(())
}

struct BonsaiClient {
pub(crate) client: Arc<Client>,
}

impl BonsaiClient {
pub fn instance(
encoded_image_id: &String,
elf: Vec<u8>,
) -> Result<&'static BonsaiClient, BonsaiExecutionError> {
static INSTANCE: OnceCell<BonsaiClient> = OnceCell::new();

Ok(INSTANCE
.get_or_try_init(|| {
let client = Client::from_env(risc0_zkvm::VERSION)?;
client.upload_img(&encoded_image_id, elf)?;

Ok(BonsaiClient {
client: Arc::new(client),
})
})
.map_err(|e| BonsaiExecutionError::SdkFailure(e))?)
}
smtmfft marked this conversation as resolved.
Show resolved Hide resolved
}

pub async fn prove_bonsai<O: Eq + Debug + DeserializeOwned>(
encoded_input: Vec<u32>,
elf: &[u8],
Expand All @@ -279,8 +305,9 @@ pub async fn prove_bonsai<O: Eq + Debug + DeserializeOwned>(
// Prepare input data
let input_data = bytemuck::cast_slice(&encoded_input).to_vec();

let client = Client::from_env(risc0_zkvm::VERSION)?;
client.upload_img(&encoded_image_id, elf.to_vec())?;
let client = BonsaiClient::instance(&encoded_image_id, elf.to_vec())?
.client
.clone();
// upload input
let input_id = client.upload_input(input_data.clone())?;

Expand Down
2 changes: 1 addition & 1 deletion provers/sp1/driver/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ bincode = { workspace = true }
reth-primitives = { workspace = true }
tokio = { workspace = true, optional = true }
tracing = { workspace = true, optional = true }

dashmap = { workspace = true }

[build-dependencies]
sp1-helper = { workspace = true, optional = true }
Expand Down
85 changes: 60 additions & 25 deletions provers/sp1/driver/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#![cfg(feature = "enable")]
#![feature(iter_advance_by)]

use dashmap::DashMap;
use once_cell::sync::Lazy;
use raiko_lib::{
input::{
AggregationGuestInput, AggregationGuestOutput, GuestInput, GuestOutput,
Expand All @@ -15,12 +17,15 @@ use serde::{Deserialize, Serialize};
use serde_with::serde_as;
use sp1_sdk::{
action,
network::client::NetworkClient,
network::proto::network::{ProofMode, UnclaimReason},
NetworkProverV1 as NetworkProver, SP1Proof, SP1ProofWithPublicValues, SP1VerifyingKey,
network::{
client::NetworkClient,
proto::network::{ProofMode, UnclaimReason},
},
NetworkProverV1 as NetworkProver, SP1Proof, SP1ProofWithPublicValues, SP1ProvingKey,
SP1VerifyingKey,
};
use sp1_sdk::{HashableKey, ProverClient, SP1Stdin};
use std::{borrow::BorrowMut, env, time::Duration};
use std::{borrow::BorrowMut, env, sync::Arc, time::Duration};
use tracing::{debug, error, info};

mod proof_verify;
Expand Down Expand Up @@ -63,7 +68,7 @@ impl From<RecursionMode> for ProofMode {
}
}

#[derive(Clone, Debug, Serialize, Deserialize)]
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq, Hash)]
#[serde(rename_all = "lowercase")]
pub enum ProverMode {
Mock,
Expand Down Expand Up @@ -99,6 +104,17 @@ pub struct Sp1Response {

pub struct Sp1Prover;

#[derive(Clone)]
struct Sp1ProverClient {
pub(crate) client: Arc<ProverClient>,
pub(crate) pk: SP1ProvingKey,
pub(crate) vk: SP1VerifyingKey,
}

//TODO: use prover object to save such local storage members.
static BLOCK_PROOF_CLIENT: Lazy<DashMap<ProverMode, Sp1ProverClient>> = Lazy::new(DashMap::new);
static AGGREGATION_CLIENT: Lazy<DashMap<ProverMode, Sp1ProverClient>> = Lazy::new(DashMap::new);

impl Prover for Sp1Prover {
async fn run(
input: GuestInput,
Expand All @@ -110,21 +126,28 @@ impl Prover for Sp1Prover {
let mode = param.prover.clone().unwrap_or_else(get_env_mock);

println!("param: {param:?}");

let mut stdin = SP1Stdin::new();
stdin.write(&input);

// Generate the proof for the given program.
let client = param
.prover
.map(|mode| match mode {
ProverMode::Mock => ProverClient::mock(),
ProverMode::Local => ProverClient::local(),
ProverMode::Network => ProverClient::network(),
let Sp1ProverClient { client, pk, vk } = BLOCK_PROOF_CLIENT
.entry(mode.clone())
.or_insert_with(|| {
let base_client = match mode {
ProverMode::Mock => ProverClient::mock(),
ProverMode::Local => ProverClient::local(),
ProverMode::Network => ProverClient::network(),
};

let client = Arc::new(base_client);
let (pk, vk) = client.setup(ELF);
info!(
"new client and setup() for block {:?}.",
output.header.number
);
Sp1ProverClient { client, pk, vk }
})
.unwrap_or_else(ProverClient::new);
.clone();

let (pk, vk) = client.setup(ELF);
info!(
"Sp1 Prover: block {:?} with vk {:?}",
output.header.number,
Expand Down Expand Up @@ -214,7 +237,7 @@ impl Prover for Sp1Prover {
Sp1Response {
proof: proof_string,
sp1_proof: Some(prove_result),
vkey: Some(vk),
vkey: Some(vk.clone()),
}
.into(),
)
Expand Down Expand Up @@ -287,16 +310,28 @@ impl Prover for Sp1Prover {
}

// Generate the proof for the given program.
let client = param
.prover
.map(|mode| match mode {
ProverMode::Mock => ProverClient::mock(),
ProverMode::Local => ProverClient::local(),
ProverMode::Network => ProverClient::network(),
let Sp1ProverClient { client, pk, vk } = AGGREGATION_CLIENT
.entry(param.prover.clone().unwrap_or_else(get_env_mock))
.or_insert_with(|| {
let base_client = param
.prover
.map(|mode| match mode {
ProverMode::Mock => ProverClient::mock(),
ProverMode::Local => ProverClient::local(),
ProverMode::Network => ProverClient::network(),
})
.unwrap_or_else(ProverClient::new);

let client = Arc::new(base_client);
let (pk, vk) = client.setup(AGGREGATION_ELF);
info!(
"new client and setup() for aggregation based on {:?} proofs with vk {:?}",
input.proofs.len(),
vk.bytes32()
);
Sp1ProverClient { client, pk, vk }
})
.unwrap_or_else(ProverClient::new);

let (pk, vk) = client.setup(AGGREGATION_ELF);
.clone();
info!(
"sp1 aggregate: {:?} based {:?} blocks with vk {:?}",
reth_primitives::hex::encode_prefixed(stark_vk.hash_bytes()),
Expand Down