Skip to content

Commit

Permalink
separate debug functinality from web.rs
Browse files Browse the repository at this point in the history
  • Loading branch information
kevindeforth committed Dec 17, 2024
1 parent c1964d7 commit 437caa1
Show file tree
Hide file tree
Showing 4 changed files with 205 additions and 209 deletions.
17 changes: 9 additions & 8 deletions node/src/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@ use crate::sign::PresignatureStorage;
use crate::sign_request::SignRequestStorage;
use crate::tracking;
use crate::triple::TripleStorage;
#[cfg(not(test))]
use crate::web::start_web_server;
#[cfg(test)]
use crate::web_test::start_web_server_testing;
use clap::ArgAction;
use clap::Parser;
use near_crypto::SecretKey;
Expand Down Expand Up @@ -196,15 +199,13 @@ impl Cli {
)?;

let (root_task, _) = tracking::start_root_task(async move {
#[cfg(test)]
let root_task_handle = tracking::current_task();

let _root_task_handle = tracking::current_task();
let mpc_client_cell = Arc::new(OnceCell::new());
#[cfg(test)]
let _web_server_handle = tracking::spawn(
"web server",
start_web_server(
root_task_handle,
start_web_server_testing(
_root_task_handle,
config.web_ui.clone(),
Some(mpc_client_cell.clone()),
)
Expand Down Expand Up @@ -292,12 +293,12 @@ impl Cli {
let config = config.into_full_config(mpc_config, secrets);

let (root_task, _) = tracking::start_root_task(async move {
#[cfg(test)]
let root_task_handle = tracking::current_task();
let _root_task_handle = tracking::current_task();
#[cfg(test)]
let _web_server_handle = tracking::spawn_checked(
"web server",
start_web_server(root_task_handle, config.web_ui.clone(), None).await?,
start_web_server_testing(_root_task_handle, config.web_ui.clone(), None)
.await?,
);

#[cfg(not(test))]
Expand Down
2 changes: 2 additions & 0 deletions node/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ mod tracing;
mod tracking;
mod triple;
mod web;
#[cfg(test)]
mod web_test;

#[tokio::main]
async fn main() -> anyhow::Result<()> {
Expand Down
207 changes: 6 additions & 201 deletions node/src/web.rs
Original file line number Diff line number Diff line change
@@ -1,47 +1,18 @@
#[cfg(not(test))]
use crate::config::WebUIConfig;
#[cfg(test)]
use crate::mpc_client::MpcClient;
#[cfg(test)]
use crate::sign_request::SignatureRequest;
#[cfg(test)]
use crate::tracking::TaskHandle;
#[cfg(test)]
use crate::tracking::{self};
#[cfg(test)]
use anyhow::Context;
use axum::body::Body;
#[cfg(test)]
use axum::extract::{Query, State};
use axum::http::{Response, StatusCode};
use axum::response::IntoResponse;
#[cfg(not(test))]
use axum::{routing::get, Router};
#[cfg(not(test))]
use futures::future::BoxFuture;
#[cfg(not(test))]
use futures::FutureExt;
#[cfg(test)]
use futures::{stream, StreamExt, TryStreamExt};
#[cfg(test)]
use k256::elliptic_curve::scalar::FromUintUnchecked;
#[cfg(test)]
use k256::sha2::{Digest, Sha256};
use k256::Scalar;
#[cfg(test)]
use k256::U256;
use prometheus::{default_registry, Encoder, TextEncoder};
#[cfg(test)]
use rand::Rng;
use serde::{Deserialize, Serialize};
#[cfg(test)]
use std::sync::Arc;
#[cfg(test)]
use std::time::Duration;
#[cfg(test)]
use tokio::sync::OnceCell;
#[cfg(test)]
use tokio::time;

/// Wrapper to make Axum understand how to convert anyhow::Error into a 500
/// response.
struct AnyhowErrorWrapper(anyhow::Error);
pub(crate) struct AnyhowErrorWrapper(anyhow::Error);

impl From<anyhow::Error> for AnyhowErrorWrapper {
fn from(e: anyhow::Error) -> Self {
Expand All @@ -58,180 +29,14 @@ impl IntoResponse for AnyhowErrorWrapper {
}
}

async fn metrics() -> String {
pub(crate) async fn metrics() -> String {
let metric_families = default_registry().gather();
let mut buffer = vec![];
let encoder = TextEncoder::new();
encoder.encode(&metric_families, &mut buffer).unwrap();
String::from_utf8(buffer).unwrap()
}

#[cfg(test)]
async fn debug_tasks(State(state): State<DebugWebServerState>) -> String {
format!("{:?}", state.root_task_handle.report())
}

#[cfg(test)]
fn generate_ids(repeat: usize, seed: u64) -> Vec<[u8; 32]> {
let mut rng: rand_xorshift::XorShiftRng = rand::SeedableRng::seed_from_u64(seed);
(0..repeat).map(|_| rng.gen::<[u8; 32]>()).collect()
}

#[cfg(test)]
async fn debug_index(
State(state): State<DebugWebServerState>,
Query(query): Query<DebugIndexRequest>,
) -> Result<(), AnyhowErrorWrapper> {
let Some(mpc_client) = state.mpc_client.unwrap().get().cloned() else {
return Err(anyhow::anyhow!("MPC client not ready").into());
};
let repeat = query.repeat.unwrap_or(1);
for id in generate_ids(repeat, query.seed) {
mpc_client.clone().add_sign_request(&SignatureRequest {
id,
msg_hash: sha256hash(query.msg.as_bytes()),
tweak: query.tweak,
entropy: query.entropy,
timestamp_nanosec: 0,
});
}
Ok(())
}

#[cfg(test)]
async fn debug_sign(
State(state): State<DebugWebServerState>,
Query(query): Query<DebugSignatureRequest>,
) -> Result<axum::Json<Vec<DebugSignatureOutput>>, AnyhowErrorWrapper> {
let Some(mpc_client) = state.mpc_client.unwrap().get().cloned() else {
return Err(anyhow::anyhow!("MPC client not ready").into());
};
let client = Arc::new(mpc_client);
let result = state
.task_handle
.scope("debug_sign", async move {
let repeat = query.repeat.unwrap_or(1);
let ids = generate_ids(repeat, query.seed);
let timeout = Duration::from_secs(query.timeout.unwrap_or(60));
let signatures = time::timeout(
timeout,
stream::iter(ids.clone())
.map(|id| {
tracking::spawn(
&format!("debug sign #{:?}", id),
client.clone().make_signature(id),
)
.map(|result| anyhow::Ok(result??))
})
.buffered(query.parallelism.unwrap_or(repeat))
.try_collect::<Vec<_>>(),
)
.await
.context("timeout")?
.context("signature failed")?;

anyhow::Ok(axum::Json(
signatures
.into_iter()
.map(|s| DebugSignatureOutput {
big_r: format!("{:?}", s.big_r),
s: format!("{:?}", s.s),
})
.collect(),
))
})
.await?;
Ok(result)
}

#[cfg(test)]
fn sha256hash(data: &[u8]) -> k256::Scalar {
let mut hasher = Sha256::new();
hasher.update(data);
let result = hasher.finalize();
let mut bytes = [0u8; 32];
bytes.copy_from_slice(&result);
Scalar::from_uint_unchecked(U256::from_be_slice(&bytes))
}

#[derive(Debug, Clone, Serialize, Deserialize)]
struct DebugIndexRequest {
#[serde(default)]
repeat: Option<usize>,
#[serde(default)]
seed: u64,
msg: String,
#[serde(default)]
tweak: Scalar,
#[serde(default)]
entropy: [u8; 32],
}

#[derive(Debug, Clone, Serialize, Deserialize)]
struct DebugSignatureRequest {
#[serde(default)]
repeat: Option<usize>,
#[serde(default)]
seed: u64,
#[serde(default)]
parallelism: Option<usize>,
#[serde(default)]
timeout: Option<u64>,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
struct DebugSignatureOutput {
big_r: String,
s: String,
}

#[cfg(test)]
#[derive(Clone)]
struct DebugWebServerState {
/// Task handle for the task that runs the web server. Used by debug_sign.
task_handle: Arc<TaskHandle>,
/// Root task handle for the whole program. Used by debug_tasks.
root_task_handle: Arc<TaskHandle>,
/// MPC client, for signing. We take a OnceCell, so that we can start the
/// web server (for debugging) before the MPC client is ready.
/// Used by debug_index and debug_sign.
mpc_client: Option<Arc<OnceCell<MpcClient>>>,
}

#[cfg(test)]
pub async fn start_web_server(
root_task_handle: Arc<TaskHandle>,
config: WebUIConfig,
mpc_client: Option<Arc<OnceCell<MpcClient>>>,
) -> anyhow::Result<BoxFuture<'static, anyhow::Result<()>>> {
let router = Router::new().route("/metrics", get(metrics));

let router = router.route("/debug/tasks", get(debug_tasks));

let router = if mpc_client.is_some() {
router
.route("/debug/index", get(debug_index))
.route("/debug/sign", get(debug_sign))
} else {
router
};

let web_server_state = DebugWebServerState {
task_handle: tracking::current_task(),
root_task_handle: root_task_handle.clone(),
mpc_client: mpc_client.clone(),
};

let router = router.with_state(web_server_state);
let tcp_listener =
tokio::net::TcpListener::bind(&format!("{}:{}", config.host, config.port)).await?;
Ok(async move {
axum::serve(tcp_listener, router).await?;
anyhow::Ok(())
}
.boxed())
}

/// Starts the web server. This is an async function that returns a future.
/// The function itself will return error if the server cannot be started.
///
Expand Down
Loading

0 comments on commit 437caa1

Please sign in to comment.