Skip to content

Commit

Permalink
good
Browse files Browse the repository at this point in the history
  • Loading branch information
mattstam committed Dec 20, 2024
1 parent 9bdf767 commit c018bef
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 60 deletions.
1 change: 0 additions & 1 deletion crates/sdk/src/network/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
pub mod client;
pub mod prover;
mod sign_message;
#[rustfmt::skip]
#[allow(missing_docs)]
#[allow(clippy::default_trait_access)]
Expand Down
23 changes: 18 additions & 5 deletions crates/sdk/src/network/prove.rs
Original file line number Diff line number Diff line change
Expand Up @@ -284,8 +284,11 @@ impl<'a> NetworkProveBuilder<'a> {
/// .unwrap();
/// ```
pub async fn request(self) -> Result<Vec<u8>> {
let Self { prover, mode, pk, stdin, timeout, strategy, skip_simulation, cycle_limit } = self;
prover.request_proof_impl(pk, &stdin, mode, strategy, timeout, skip_simulation, cycle_limit).await
let Self { prover, mode, pk, stdin, timeout, strategy, skip_simulation, cycle_limit } =
self;
prover
.request_proof_impl(pk, &stdin, mode, strategy, timeout, skip_simulation, cycle_limit)
.await
}

/// Run the prover with the built arguments.
Expand All @@ -308,7 +311,8 @@ impl<'a> NetworkProveBuilder<'a> {
/// .unwrap();
/// ```
pub fn run(self) -> Result<SP1ProofWithPublicValues> {
let Self { prover, mode, pk, stdin, timeout, strategy, mut skip_simulation, cycle_limit } = self;
let Self { prover, mode, pk, stdin, timeout, strategy, mut skip_simulation, cycle_limit } =
self;

// Check for deprecated environment variable
if let Ok(val) = std::env::var("SKIP_SIMULATION") {
Expand All @@ -320,7 +324,15 @@ impl<'a> NetworkProveBuilder<'a> {

sp1_dump(&pk.elf, &stdin);

block_on(prover.prove_impl(pk, &stdin, mode, strategy, timeout, skip_simulation, cycle_limit))
block_on(prover.prove_impl(
pk,
&stdin,
mode,
strategy,
timeout,
skip_simulation,
cycle_limit,
))
}

/// Run the prover with the built arguments asynchronously.
Expand All @@ -341,7 +353,8 @@ impl<'a> NetworkProveBuilder<'a> {
/// .run_async();
/// ```
pub async fn run_async(self) -> Result<SP1ProofWithPublicValues> {
let Self { prover, mode, pk, stdin, timeout, strategy, mut skip_simulation, cycle_limit } = self;
let Self { prover, mode, pk, stdin, timeout, strategy, mut skip_simulation, cycle_limit } =
self;

// Check for deprecated environment variable
if let Ok(val) = std::env::var("SKIP_SIMULATION") {
Expand Down
1 change: 1 addition & 0 deletions crates/sdk/src/network/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use super::prove::NetworkProveBuilder;
use super::DEFAULT_CYCLE_LIMIT;
use crate::cpu::execute::CpuExecuteBuilder;
use crate::cpu::CpuProver;
use crate::network::proto::network::GetProofRequestStatusResponse;
use crate::network::{Error, DEFAULT_PROVER_NETWORK_RPC, DEFAULT_TIMEOUT_SECS};
use crate::{
network::client::NetworkClient,
Expand Down
56 changes: 2 additions & 54 deletions crates/sdk/src/network/utils.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
#![allow(deprecated)]

//! # Network Utils
//!
//! This module provides utility functions for the network module.
use alloy_signer::{Signature, SignerSync};
use prost::Message;
use serde::Serialize;
use thiserror::Error;

pub(crate) trait Signable: Message {
fn sign<S: SignerSync>(&self, signer: &S) -> Signature;
Expand All @@ -16,55 +16,3 @@ impl<T: Message> Signable for T {
signer.sign_message_sync(&self.encode_to_vec()).unwrap()
}
}

#[derive(Error, Debug)]
pub(crate) enum JsonFormatError {
#[error("Serialization error: {0}")]
SerializationError(String),
}

pub(crate) fn format_json_message<T>(body: &T) -> Result<Vec<u8>, JsonFormatError>
where
T: Message + Serialize,
{
match serde_json::to_string(body) {
Ok(json_str) => {
if json_str.starts_with('"') && json_str.ends_with('"') {
let inner = &json_str[1..json_str.len() - 1];
let unescaped = inner.replace("\\\"", "\"");
Ok(unescaped.into_bytes())
} else {
Ok(json_str.into_bytes())
}
}
Err(e) => Err(JsonFormatError::SerializationError(e.to_string())),
}
}

#[cfg(test)]
mod tests {
use super::*;
use prost::Message as ProstMessage;
use serde::{Deserialize, Serialize};

// Test message for JSON formatting.
#[derive(Clone, ProstMessage, Serialize, Deserialize)]
struct TestMessage {
#[prost(string, tag = 1)]
value: String,
}

#[test]
fn test_format_json_message_simple() {
let msg = TestMessage { value: "hello".to_string() };
let result = format_json_message(&msg).unwrap();
assert_eq!(result, b"{\"value\":\"hello\"}");
}

#[test]
fn test_format_json_message_with_quotes() {
let msg = TestMessage { value: "hello \"world\"".to_string() };
let result = format_json_message(&msg).unwrap();
assert_eq!(result, b"{\"value\":\"hello \\\"world\\\"\"}");
}
}

0 comments on commit c018bef

Please sign in to comment.