diff --git a/Cargo.lock b/Cargo.lock index 4c7c0661..2ebdb168 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -698,7 +698,9 @@ dependencies = [ name = "dummy_rand_data" version = "0.1.0" dependencies = [ + "anyhow", "clap", + "indexmap 2.4.0", "prost", "rand", "serde_json", diff --git a/hipcheck/src/engine.rs b/hipcheck/src/engine.rs index 597b31cc..8559f840 100644 --- a/hipcheck/src/engine.rs +++ b/hipcheck/src/engine.rs @@ -1,10 +1,12 @@ #![allow(unused)] -use crate::plugin::{ActivePlugin, HcPluginCore, PluginExecutor, PluginResponse, PluginWithConfig}; +use crate::plugin::{ActivePlugin, PluginResponse}; +pub use crate::plugin::{HcPluginCore, PluginExecutor, PluginWithConfig}; use crate::{hc_error, Result}; +use futures::future::{BoxFuture, FutureExt}; use serde_json::Value; use std::sync::{Arc, LazyLock}; -use tokio::runtime::Runtime; +use tokio::runtime::{Handle, Runtime}; // Salsa doesn't natively support async functions, so our recursive `query()` function that // interacts with plugins (which use async) has to get a handle to the underlying runtime, @@ -62,6 +64,54 @@ fn query( } } +// Demonstration of how the above `query()` function would be implemented as async +pub fn async_query( + core: Arc, + publisher: String, + plugin: String, + query: String, + key: Value, +) -> BoxFuture<'static, Result> { + async move { + // Find the plugin + let Some(p_handle) = core.plugins.get(&plugin) else { + return Err(hc_error!("No such plugin {}::{}", publisher, plugin)); + }; + // Initiate the query. If remote closed or we got our response immediately, + // return + let mut ar = match p_handle.query(query, key).await? { + PluginResponse::RemoteClosed => { + return Err(hc_error!("Plugin channel closed unexpected")); + } + PluginResponse::Completed(v) => { + return Ok(v); + } + PluginResponse::AwaitingResult(a) => a, + }; + // Otherwise, the plugin needs more data to continue. Recursively query + // (with salsa memo-ization) to get the needed data, and resume our + // current query by providing the plugin the answer. + loop { + let answer = async_query( + Arc::clone(&core), + ar.publisher.clone(), + ar.plugin.clone(), + ar.query.clone(), + ar.key.clone(), + ) + .await?; + ar = match p_handle.resume_query(ar, answer).await? { + PluginResponse::RemoteClosed => { + return Err(hc_error!("Plugin channel closed unexpected")); + } + PluginResponse::Completed(v) => return Ok(v), + PluginResponse::AwaitingResult(a) => a, + }; + } + } + .boxed() +} + #[salsa::database(HcEngineStorage)] pub struct HcEngineImpl { // Query storage @@ -77,6 +127,7 @@ impl HcEngineImpl { // independent of Salsa. pub fn new(executor: PluginExecutor, plugins: Vec<(PluginWithConfig)>) -> Result { let runtime = RUNTIME.handle(); + println!("Starting HcPluginCore"); let core = runtime.block_on(HcPluginCore::new(executor, plugins))?; let mut engine = HcEngineImpl { storage: Default::default(), @@ -84,6 +135,9 @@ impl HcEngineImpl { engine.set_core(Arc::new(core)); Ok(engine) } + pub fn runtime() -> &'static Handle { + RUNTIME.handle() + } // TODO - "run" function that takes analysis heirarchy and target, and queries each // analysis plugin to kick off the execution } diff --git a/hipcheck/src/main.rs b/hipcheck/src/main.rs index 80a14ac4..0c79b96f 100644 --- a/hipcheck/src/main.rs +++ b/hipcheck/src/main.rs @@ -39,7 +39,6 @@ use crate::analysis::report_builder::Report; use crate::analysis::score::score_results; use crate::cache::HcCache; use crate::context::Context as _; -use crate::engine::{HcEngine, HcEngineImpl}; use crate::error::Error; use crate::error::Result; use crate::plugin::{Plugin, PluginExecutor, PluginWithConfig}; @@ -641,6 +640,10 @@ fn check_github_token() -> StdResult<(), EnvVarCheckError> { } fn cmd_plugin() { + use crate::engine::{async_query, HcEngine, HcEngineImpl}; + use std::sync::Arc; + use tokio::task::JoinSet; + let tgt_dir = "./target/debug"; let entrypoint = pathbuf![tgt_dir, "dummy_rand_data"]; let plugin = Plugin { @@ -665,19 +668,49 @@ fn cmd_plugin() { return; } }; - let res = match engine.query( - "MITRE".to_owned(), - "rand_data".to_owned(), - "rand_data".to_owned(), - serde_json::json!(7), - ) { - Ok(r) => r, - Err(e) => { - println!("Query failed: {e}"); - return; + let core = engine.core(); + let handle = HcEngineImpl::runtime(); + // @Note - how to initiate multiple queries with async calls + handle.block_on(async move { + let mut futs = JoinSet::new(); + for i in 1..10 { + let arc_core = Arc::clone(&core); + println!("Spawning"); + futs.spawn(async_query( + arc_core, + "MITRE".to_owned(), + "rand_data".to_owned(), + "rand_data".to_owned(), + serde_json::json!(i), + )); } - }; - println!("Result: {res}"); + while let Some(res) = futs.join_next().await { + println!("res: {res:?}"); + } + }); + // @Note - how to initiate multiple queries with sync calls + // let conc: Vec> = vec![]; + // for i in 0..10 { + // let fut = thread::spawn(|| { + // let res = match engine.query( + // "MITRE".to_owned(), + // "rand_data".to_owned(), + // "rand_data".to_owned(), + // serde_json::json!(i), + // ) { + // Ok(r) => r, + // Err(e) => { + // println!("{i}: Query failed: {e}"); + // return; + // } + // }; + // println!("{i}: Result: {res}"); + // }); + // conc.push(fut); + // } + // while let Some(x) = conc.pop() { + // x.join().unwrap(); + // } } fn cmd_ready(config: &CliConfig) { diff --git a/plugins/dummy_rand_data/Cargo.toml b/plugins/dummy_rand_data/Cargo.toml index fe23ae7a..0d5b659c 100644 --- a/plugins/dummy_rand_data/Cargo.toml +++ b/plugins/dummy_rand_data/Cargo.toml @@ -5,7 +5,9 @@ edition = "2021" publish = false [dependencies] +anyhow = "1.0.86" clap = { version = "4.5.16", features = ["derive"] } +indexmap = "2.4.0" prost = "0.13.1" rand = "0.8.5" serde_json = "1.0.125" diff --git a/plugins/dummy_rand_data/src/hipcheck_transport.rs b/plugins/dummy_rand_data/src/hipcheck_transport.rs index fdc38cae..5d50c0f4 100644 --- a/plugins/dummy_rand_data/src/hipcheck_transport.rs +++ b/plugins/dummy_rand_data/src/hipcheck_transport.rs @@ -1,6 +1,10 @@ use crate::hipcheck::{Query as PluginQuery, QueryState}; +use anyhow::{anyhow, Result}; +use indexmap::map::IndexMap; use serde_json::Value; -use tokio::sync::mpsc; +use std::collections::VecDeque; +use std::sync::Arc; +use tokio::sync::{mpsc, Mutex}; use tonic::{codec::Streaming, Status}; #[derive(Debug)] @@ -15,21 +19,21 @@ pub struct Query { pub output: Value, } impl TryFrom for Query { - type Error = String; - fn try_from(value: PluginQuery) -> Result { + type Error = anyhow::Error; + fn try_from(value: PluginQuery) -> Result { use QueryState::*; - let request = - match TryInto::::try_into(value.state).map_err(|e| e.to_string())? { - QueryUnspecified => return Err("unspecified error from plugin".into()), - QueryReplyInProgress => { - return Err("invalid state QueryReplyInProgress for conversion to Query".into()) - } - QueryReplyComplete => false, - QuerySubmit => true, - }; - let key: Value = serde_json::from_str(value.key.as_str()).map_err(|e| e.to_string())?; - let output: Value = - serde_json::from_str(value.output.as_str()).map_err(|e| e.to_string())?; + let request = match TryInto::::try_into(value.state)? { + QueryUnspecified => return Err(anyhow!("unspecified error from plugin")), + QueryReplyInProgress => { + return Err(anyhow!( + "invalid state QueryReplyInProgress for conversion to Query" + )) + } + QueryReplyComplete => false, + QuerySubmit => true, + }; + let key: Value = serde_json::from_str(value.key.as_str())?; + let output: Value = serde_json::from_str(value.output.as_str())?; Ok(Query { id: value.id as usize, request, @@ -42,14 +46,14 @@ impl TryFrom for Query { } } impl TryFrom for PluginQuery { - type Error = String; - fn try_from(value: Query) -> Result { + type Error = anyhow::Error; + fn try_from(value: Query) -> Result { let state_enum = match value.request { true => QueryState::QuerySubmit, false => QueryState::QueryReplyComplete, }; - let key = serde_json::to_string(&value.key).map_err(|e| e.to_string())?; - let output = serde_json::to_string(&value.output).map_err(|e| e.to_string())?; + let key = serde_json::to_string(&value.key)?; + let output = serde_json::to_string(&value.output)?; Ok(PluginQuery { id: value.id as i32, state: state_enum as i32, @@ -62,54 +66,161 @@ impl TryFrom for PluginQuery { } } +#[derive(Clone, Debug)] pub struct HcTransport { - rx: Streaming, tx: mpsc::Sender>, + rx: Arc>, } impl HcTransport { pub fn new(rx: Streaming, tx: mpsc::Sender>) -> Self { - HcTransport { rx, tx } + HcTransport { + rx: Arc::new(Mutex::new(MultiplexedQueryReceiver::new(rx))), + tx, + } } - pub async fn send(&mut self, query: Query) -> Result<(), String> { + pub async fn send(&self, query: Query) -> Result<()> { let query: PluginQuery = query.try_into()?; - self.tx - .send(Ok(query)) - .await - .map_err(|e| format!("sending query failed: {}", e)) + self.tx.send(Ok(query)).await?; + Ok(()) + } + pub async fn recv_new(&self) -> Result> { + let mut rx_handle = self.rx.lock().await; + match rx_handle.recv_new().await? { + Some(msg) => msg.try_into().map(Some), + None => Ok(None), + } } - pub async fn recv(&mut self) -> Result, String> { + pub async fn recv(&self, id: usize) -> Result> { use QueryState::*; - let Some(mut raw) = self.rx.message().await.map_err(|e| e.to_string())? else { - // gRPC channel was closed + let id = id as i32; + let mut rx_handle = self.rx.lock().await; + let Some(mut msg_chunks) = rx_handle.recv(id).await? else { return Ok(None); }; - let mut state: QueryState = - TryInto::::try_into(raw.state).map_err(|e| e.to_string())?; - // As long as we expect successive chunks, keep receiving + drop(rx_handle); + let mut raw = msg_chunks.pop_front().unwrap(); + let mut state: QueryState = raw.state.try_into()?; + + // If response is the first of a set of chunks, handle if matches!(state, QueryReplyInProgress) { while matches!(state, QueryReplyInProgress) { - println!("Retrieving next response"); - let Some(next) = self.rx.message().await.map_err(|e| e.to_string())? else { - return Err("plugin gRPC channel closed while sending chunked message".into()); + // We expect another message. Pull it off the existing queue, + // or get a new one if we have run out + let next = match msg_chunks.pop_front() { + Some(msg) => msg, + None => { + // We ran out of messages, get a new batch + let mut rx_handle = self.rx.lock().await; + match rx_handle.recv(id).await? { + Some(x) => { + drop(rx_handle); + msg_chunks = x; + } + None => { + return Ok(None); + } + }; + msg_chunks.pop_front().unwrap() + } }; - // Assert that the ids are consistent - if next.id != raw.id { - return Err("msg ids from plugin do not match".into()); - } - state = TryInto::::try_into(next.state).map_err(|e| e.to_string())?; + // By now we have our "next" message + state = next.state.try_into()?; match state { - QueryUnspecified => return Err("unspecified error from plugin".to_owned()), + QueryUnspecified => return Err(anyhow!("unspecified error from plugin")), QuerySubmit => { - return Err( - "plugin sent QuerySubmit state when reply chunk expected".to_owned() - ) + return Err(anyhow!( + "plugin sent QuerySubmit state when reply chunk expected" + )) } QueryReplyInProgress | QueryReplyComplete => { raw.output.push_str(next.output.as_str()); } }; } + // Sanity check - after we've left this loop, there should be no left over message + if !msg_chunks.is_empty() { + return Err(anyhow!( + "received additional messages for id '{}' after QueryComplete status message", + id + )); + } } raw.try_into().map(Some) } } + +#[derive(Debug)] +pub struct MultiplexedQueryReceiver { + rx: Streaming, + // Unlike in HipCheck, backlog is an IndexMap to ensure the earliest received + // requests are handled first + backlog: IndexMap>, +} +impl MultiplexedQueryReceiver { + pub fn new(rx: Streaming) -> Self { + Self { + rx, + backlog: IndexMap::new(), + } + } + pub async fn recv_new(&mut self) -> Result> { + let opt_unhandled = self.backlog.iter().find(|(k, v)| { + if let Some(req) = v.front() { + return req.state() == QueryState::QuerySubmit; + } + false + }); + if let Some((k, v)) = opt_unhandled { + let id: i32 = *k; + let mut vec = self.backlog.shift_remove(&id).unwrap(); + // @Note - for now QuerySubmit doesn't chunk so we shouldn't expect + // multiple messages in the backlog for a new request + assert!(vec.len() == 1); + return Ok(vec.pop_front()); + } + // No backlog message, need to operate the receiver + loop { + let Some(raw) = self.rx.message().await? else { + // gRPC channel was closed + return Ok(None); + }; + if raw.state() == QueryState::QuerySubmit { + return Ok(Some(raw)); + } + match self.backlog.get_mut(&raw.id) { + Some(vec) => { + vec.push_back(raw); + } + None => { + self.backlog.insert(raw.id, VecDeque::from([raw])); + } + } + } + } + // @Invariant - this function will never return an empty VecDeque + pub async fn recv(&mut self, id: i32) -> Result>> { + // If we have 1+ messages on backlog for `id`, return them all, + // no need to waste time with successive calls + if let Some(msgs) = self.backlog.shift_remove(&id) { + return Ok(Some(msgs)); + } + // No backlog message, need to operate the receiver + loop { + let Some(raw) = self.rx.message().await? else { + // gRPC channel was closed + return Ok(None); + }; + if raw.id == id { + return Ok(Some(VecDeque::from([raw]))); + } + match self.backlog.get_mut(&raw.id) { + Some(vec) => { + vec.push_back(raw); + } + None => { + self.backlog.insert(raw.id, VecDeque::from([raw])); + } + } + } + } +} diff --git a/plugins/dummy_rand_data/src/main.rs b/plugins/dummy_rand_data/src/main.rs index beadadd9..ee8d9ee2 100644 --- a/plugins/dummy_rand_data/src/main.rs +++ b/plugins/dummy_rand_data/src/main.rs @@ -4,6 +4,7 @@ mod hipcheck; mod hipcheck_transport; use crate::hipcheck_transport::*; +use anyhow::{anyhow, Result}; use clap::Parser; use hipcheck::plugin_server::{Plugin, PluginServer}; use hipcheck::{ @@ -27,6 +28,21 @@ fn get_rand(num_bytes: usize) -> Vec { vec } +pub async fn handle_rand_data(channel: HcTransport, id: usize, key: u64) -> Result<()> { + let res = get_rand(key as usize); + let output = serde_json::to_value(res)?; + let resp = Query { + id, + request: false, + publisher: "".to_owned(), + plugin: "".to_owned(), + query: "".to_owned(), + key: json!(null), + output, + }; + channel.send(resp).await?; + Ok(()) +} struct RandDataRunner { channel: HcTransport, } @@ -34,42 +50,39 @@ impl RandDataRunner { pub fn new(channel: HcTransport) -> Self { RandDataRunner { channel } } - pub fn handle_query(id: usize, name: String, key: Value) -> Result { + async fn handle_query(channel: HcTransport, id: usize, name: String, key: Value) -> Result<()> { if name == "rand_data" { let Value::Number(num_size) = &key else { - return Err("get_rand argument must be a number".to_owned()); + return Err(anyhow!("get_rand argument must be a number")); }; let Some(size) = num_size.as_u64() else { - return Err("get_rand argument must be an unsigned integer".to_owned()); + return Err(anyhow!("get_rand argument must be an unsigned integer")); }; - let res = get_rand(size as usize); - let output = serde_json::to_value(res).map_err(|e| e.to_string())?; - Ok(Query { - id, - request: false, - publisher: "".to_owned(), - plugin: "".to_owned(), - query: "".to_owned(), - key: json!(null), - output, - }) + handle_rand_data(channel, id, size).await?; + Ok(()) } else { - Err(format!("unrecognized query '{name}'")) + Err(anyhow!("unrecognized query '{}'", name)) } } - pub async fn run(mut self) -> Result<(), String> { + pub async fn run(self) -> Result<()> { loop { eprintln!("Looping"); - let Some(msg) = self.channel.recv().await? else { + let Some(msg) = self.channel.recv_new().await? else { eprintln!("Channel closed by remote"); break; }; if msg.request { - let rsp = RandDataRunner::handle_query(msg.id, msg.query, msg.key)?; - eprintln!("Sending response: {rsp:?}"); - self.channel.send(rsp).await?; + let child_channel = self.channel.clone(); + tokio::spawn(async move { + if let Err(e) = + RandDataRunner::handle_query(child_channel, msg.id, msg.query, msg.key) + .await + { + eprintln!("handle_query failed: {e}"); + }; + }); } else { - return Err("Did not expect a response-type message here".to_owned()); + return Err(anyhow!("Did not expect a response-type message here")); } } Ok(())