Skip to content

Commit

Permalink
feat(sdk): get_proof_status, request, cycle_limit (#1883)
Browse files Browse the repository at this point in the history
  • Loading branch information
ratankaliani authored Dec 20, 2024
1 parent 445e5e6 commit 653e8c5
Show file tree
Hide file tree
Showing 3 changed files with 129 additions and 11 deletions.
6 changes: 5 additions & 1 deletion crates/sdk/src/cpu/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ impl CpuProverBuilder {
/// ```
#[must_use]
pub fn build(self) -> CpuProver {
if self.mock { CpuProver::mock() } else { CpuProver::new() }
if self.mock {
CpuProver::mock()
} else {
CpuProver::new()
}
}
}
65 changes: 61 additions & 4 deletions crates/sdk/src/network/prove.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ pub struct NetworkProveBuilder<'a> {
pub(crate) timeout: Option<Duration>,
pub(crate) strategy: FulfillmentStrategy,
pub(crate) skip_simulation: bool,
pub(crate) cycle_limit: Option<u64>,
}

impl<'a> NetworkProveBuilder<'a> {
Expand Down Expand Up @@ -231,6 +232,62 @@ impl<'a> NetworkProveBuilder<'a> {
self
}

/// Sets the cycle limit for the proof request.
///
/// # Details
/// The cycle limit determines the maximum number of cycles that the program should take to
/// execute. By default, the cycle limit is determined by simulating the program locally.
/// However, you can manually set it if you know the exact cycle count needed and want to skip
/// the simulation step locally.
///
/// The cycle limit ensures that a prover on the network will stop generating a proof once the
/// cycle limit is reached, which prevents DoS attacks.
///
/// # Example
/// ```rust,no_run
/// use sp1_sdk::{ProverClient, SP1Stdin, Prover};
///
/// let elf = &[1, 2, 3];
/// let stdin = SP1Stdin::new();
///
/// let client = ProverClient::builder().network().build();
/// let (pk, vk) = client.setup(elf);
/// let proof = client.prove(&pk, &stdin)
/// .cycle_limit(1_000_000) // Set 1M cycle limit.
/// .skip_simulation(true) // Skip simulation since the limit is set manually.
/// .run()
/// .unwrap();
/// ```
#[must_use]
pub fn cycle_limit(mut self, cycle_limit: u64) -> Self {
self.cycle_limit = Some(cycle_limit);
self
}

/// Request a proof from the prover network.
///
/// # Details
/// This method will request a proof from the prover network. If the prover fails to request
/// a proof, the method will return an error. It will not wait for the proof to be generated.
///
/// # Example
/// ```rust,no_run
/// use sp1_sdk::{ProverClient, SP1Stdin, Prover};
///
/// let elf = &[1, 2, 3];
/// let stdin = SP1Stdin::new();
///
/// let client = ProverClient::builder().network().build();
/// let (pk, vk) = client.setup(elf);
/// let request_id = client.prove(&pk, &stdin)
/// .request()
/// .unwrap();
/// ```
pub async fn request(self) -> Result<Vec<u8>> {
let Self { prover, mode, pk, stdin, timeout, strategy, skip_simulation, cycle_limit } = self;
prover.request_proof_impl(pk, &stdin, mode, strategy, timeout, skip_simulation, cycle_limit).await
}

/// Run the prover with the built arguments.
///
/// # Details
Expand All @@ -251,7 +308,7 @@ impl<'a> NetworkProveBuilder<'a> {
/// .unwrap();
/// ```
pub fn run(self) -> Result<SP1ProofWithPublicValues> {
let Self { prover, mode, pk, stdin, timeout, strategy, mut skip_simulation } = self;
let Self { prover, mode, pk, stdin, timeout, strategy, mut skip_simulation, cycle_limit } = self;

// Check for deprecated environment variable
if let Ok(val) = std::env::var("SKIP_SIMULATION") {
Expand All @@ -263,7 +320,7 @@ impl<'a> NetworkProveBuilder<'a> {

sp1_dump(&pk.elf, &stdin);

block_on(prover.prove_impl(pk, &stdin, mode, strategy, timeout, skip_simulation))
block_on(prover.prove_impl(pk, &stdin, mode, strategy, timeout, skip_simulation, cycle_limit))
}

/// Run the prover with the built arguments asynchronously.
Expand All @@ -284,7 +341,7 @@ impl<'a> NetworkProveBuilder<'a> {
/// .run_async();
/// ```
pub async fn run_async(self) -> Result<SP1ProofWithPublicValues> {
let Self { prover, mode, pk, stdin, timeout, strategy, mut skip_simulation } = self;
let Self { prover, mode, pk, stdin, timeout, strategy, mut skip_simulation, cycle_limit } = self;

// Check for deprecated environment variable
if let Ok(val) = std::env::var("SKIP_SIMULATION") {
Expand All @@ -296,6 +353,6 @@ impl<'a> NetworkProveBuilder<'a> {

sp1_dump(&pk.elf, &stdin);

prover.prove_impl(pk, &stdin, mode, strategy, timeout, skip_simulation).await
prover.prove_impl(pk, &stdin, mode, strategy, timeout, skip_simulation, cycle_limit).await
}
}
69 changes: 63 additions & 6 deletions crates/sdk/src/network/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
use std::time::{Duration, Instant};

use super::proto::network::GetProofRequestStatusResponse;
use super::prove::NetworkProveBuilder;
use super::DEFAULT_CYCLE_LIMIT;
use crate::cpu::execute::CpuExecuteBuilder;
Expand Down Expand Up @@ -108,6 +109,7 @@ impl NetworkProver {
timeout: None,
strategy: FulfillmentStrategy::Hosted,
skip_simulation: false,
cycle_limit: None,
}
}

Expand All @@ -134,6 +136,26 @@ impl NetworkProver {
self.client.register_program(vk, elf).await
}

/// Gets the status of a proof request.
///
/// # Details
/// * `request_id`: The request ID to get the status of.
///
/// # Example
/// ```rust,no_run
/// use sp1_sdk::{ProverClient}
///
/// let request_id = vec![1u8; 32];
/// let client = ProverClient::builder().network().build();
/// let (status, maybe_proof) = client.get_proof_status(&request_id).await?;
/// ```
pub async fn get_proof_status(
&self,
request_id: &[u8],
) -> Result<(GetProofRequestStatusResponse, Option<SP1ProofWithPublicValues>)> {
self.client.get_proof_request_status(request_id).await
}

/// Requests a proof from the prover network, returning the request ID.
pub(crate) async fn request_proof(
&self,
Expand Down Expand Up @@ -210,7 +232,11 @@ impl NetworkProver {
}
let remaining_timeout = timeout.map(|t| {
let elapsed = start_time.elapsed();
if elapsed < t { t - elapsed } else { Duration::from_secs(0) }
if elapsed < t {
t - elapsed
} else {
Duration::from_secs(0)
}
});

// Get status with retries.
Expand Down Expand Up @@ -247,6 +273,22 @@ impl NetworkProver {
}
}

/// Requests a proof from the prover network.
pub(crate) async fn request_proof_impl(
&self,
pk: &SP1ProvingKey,
stdin: &SP1Stdin,
mode: SP1ProofMode,
strategy: FulfillmentStrategy,
timeout: Option<Duration>,
skip_simulation: bool,
cycle_limit: Option<u64>,
) -> Result<Vec<u8>> {
let vk_hash = self.register_program(&pk.vk, &pk.elf).await?;
let cycle_limit = self.get_cycle_limit(cycle_limit, &pk.elf, stdin, skip_simulation)?;
self.request_proof(&vk_hash, stdin, mode.into(), strategy, cycle_limit, timeout).await
}

/// Requests a proof from the prover network and waits for it to be generated.
pub(crate) async fn prove_impl(
&self,
Expand All @@ -256,16 +298,31 @@ impl NetworkProver {
strategy: FulfillmentStrategy,
timeout: Option<Duration>,
skip_simulation: bool,
cycle_limit: Option<u64>,
) -> Result<SP1ProofWithPublicValues> {
let vk_hash = self.register_program(&pk.vk, &pk.elf).await?;
let cycle_limit = self.get_cycle_limit(&pk.elf, stdin, skip_simulation)?;
let request_id = self
.request_proof(&vk_hash, stdin, mode.into(), strategy, cycle_limit, timeout)
.request_proof_impl(pk, stdin, mode, strategy, timeout, skip_simulation, cycle_limit)
.await?;
self.wait_proof(&request_id, timeout).await
}

fn get_cycle_limit(&self, elf: &[u8], stdin: &SP1Stdin, skip_simulation: bool) -> Result<u64> {
/// The cycle limit is determined according to the following priority:
///
/// 1. If a cycle limit was explicitly set by the requester, use the specified value.
/// 2. If simulation is enabled, calculate the limit by simulating the
/// execution of the program. This is the default behavior.
/// 3. Otherwise, use the default cycle limit ([`DEFAULT_CYCLE_LIMIT`]).
fn get_cycle_limit(
&self,
cycle_limit: Option<u64>,
elf: &[u8],
stdin: &SP1Stdin,
skip_simulation: bool,
) -> Result<u64> {
if let Some(cycle_limit) = cycle_limit {
return Ok(cycle_limit);
}

if skip_simulation {
Ok(DEFAULT_CYCLE_LIMIT)
} else {
Expand All @@ -291,7 +348,7 @@ impl Prover<CpuProverComponents> for NetworkProver {
stdin: &SP1Stdin,
mode: SP1ProofMode,
) -> Result<SP1ProofWithPublicValues> {
block_on(self.prove_impl(pk, stdin, mode, FulfillmentStrategy::Hosted, None, false))
block_on(self.prove_impl(pk, stdin, mode, FulfillmentStrategy::Hosted, None, false, None))
}
}

Expand Down

0 comments on commit 653e8c5

Please sign in to comment.