From 546b163607517219f15ddc6314fb85a0d6a68790 Mon Sep 17 00:00:00 2001 From: jlanson Date: Fri, 23 Aug 2024 08:51:06 -0400 Subject: [PATCH] feat: plugin comms interface can handle multiple active sessions --- Cargo.lock | 2 +- hipcheck/src/plugin/mod.rs | 81 +++++++++++++------ hipcheck/src/plugin/types.rs | 153 ++++++++++++++++++++++++++++++----- 3 files changed, 190 insertions(+), 46 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 9ce265d5..0d30a08c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -690,7 +690,7 @@ dependencies = [ [[package]] name = "dummy_rand_data" -version = "3.5.0" +version = "0.1.0" dependencies = [ "clap", "prost", diff --git a/hipcheck/src/plugin/mod.rs b/hipcheck/src/plugin/mod.rs index ff6a17cc..3306deca 100644 --- a/hipcheck/src/plugin/mod.rs +++ b/hipcheck/src/plugin/mod.rs @@ -8,7 +8,7 @@ use crate::Result; use futures::future::join_all; use serde_json::Value; use std::collections::HashMap; -use tokio::sync::mpsc; +use tokio::sync::{mpsc, Mutex}; pub fn dummy() { let plugin = Plugin { @@ -41,9 +41,59 @@ pub async fn initialize_plugins( Ok(out) } +struct ActivePlugin { + next_id: Mutex, + channel: PluginTransport, +} +impl ActivePlugin { + pub fn new(channel: PluginTransport) -> Self { + ActivePlugin { + next_id: Mutex::new(1), + channel, + } + } + async fn get_unique_id(&self) -> usize { + let mut id_lock = self.next_id.lock().await; + let res: usize = *id_lock; + // even IDs reserved for plugin-originated queries, so skip to next odd ID + *id_lock += 2; + drop(id_lock); + res + } + pub async fn query(&self, name: String, key: Value) -> Result { + let id = self.get_unique_id().await; + let query = Query { + id, + request: true, + publisher: "".to_owned(), + plugin: self.channel.name().to_owned(), + query: name, + key, + output: serde_json::json!(null), + }; + Ok(self.channel.query(query).await?.into()) + } + pub async fn resume_query( + &self, + state: AwaitingResult, + output: Value, + ) -> Result { + let query = Query { + id: state.id, + request: false, + publisher: state.publisher, + plugin: state.plugin, + query: state.query, + key: serde_json::json!(null), + output, + }; + Ok(self.channel.query(query).await?.into()) + } +} + pub struct HcPluginCore { executor: PluginExecutor, - plugins: HashMap, + plugins: HashMap, } impl HcPluginCore { // When this object is returned, the plugins are all connected but the @@ -69,36 +119,21 @@ impl HcPluginCore { }) .collect(); // Use configs to initialize corresponding plugin - let plugins = HashMap::::from_iter( + let plugins = HashMap::::from_iter( initialize_plugins(mapped_ctxs) .await? .into_iter() - .map(|p| (p.name().to_owned(), p)), + .map(|p| (p.name().to_owned(), ActivePlugin::new(p))), ); // Now we have a set of started and initialized plugins to interact with Ok(HcPluginCore { executor, plugins }) } // @Temporary pub async fn run(&mut self) -> Result<()> { - let channel = self.plugins.get_mut("rand_data").unwrap(); - match channel - .send(Query { - id: 1, - request: true, - publisher: "".to_owned(), - plugin: "".to_owned(), - query: "rand_data".to_owned(), - key: serde_json::json!(7), - output: serde_json::json!(null), - }) - .await - { - Ok(q) => q, - Err(e) => { - println!("Failed: {e}"); - } - }; - let resp = channel.recv().await?; + let handle = self.plugins.get("rand_data").unwrap(); + let resp = handle + .query("rand_data".to_owned(), serde_json::json!(7)) + .await?; println!("Plugin response: {resp:?}"); Ok(()) } diff --git a/hipcheck/src/plugin/types.rs b/hipcheck/src/plugin/types.rs index f9b1b339..3057bc3a 100644 --- a/hipcheck/src/plugin/types.rs +++ b/hipcheck/src/plugin/types.rs @@ -5,10 +5,11 @@ use crate::hipcheck::{ }; use crate::{hc_error, Error, Result, StdResult}; use serde_json::Value; -use std::collections::HashMap; +use std::collections::{HashMap, VecDeque}; use std::convert::TryFrom; use std::ops::Not; use std::process::Child; +use tokio::sync::{mpsc, Mutex}; use tonic::codec::Streaming; use tonic::transport::Channel; @@ -161,7 +162,7 @@ impl PluginContext { } pub async fn initiate_query_protocol( &mut self, - mut rx: tokio::sync::mpsc::Receiver, + mut rx: mpsc::Receiver, ) -> Result> { let stream = async_stream::stream! { while let Some(item) = rx.recv().await { @@ -185,8 +186,9 @@ impl PluginContext { ); self.set_configuration(&config).await?.as_result()?; let default_policy_expr = self.get_default_policy_expression().await?; - let (tx, mut out_rx) = tokio::sync::mpsc::channel::(10); + let (tx, mut out_rx) = mpsc::channel::(10); let rx = self.initiate_query_protocol(out_rx).await?; + let rx = Mutex::new(MultiplexedQueryReceiver::new(rx)); Ok(PluginTransport { schemas, default_policy_expr, @@ -271,46 +273,102 @@ impl TryFrom for PluginQuery { } } +pub struct MultiplexedQueryReceiver { + rx: Streaming, + backlog: HashMap>, +} +impl MultiplexedQueryReceiver { + pub fn new(rx: Streaming) -> Self { + Self { + rx, + backlog: HashMap::new(), + } + } + // @Invariant - this function will never return an empty VecDeque + pub async fn recv(&mut self, id: i32) -> Result>> { + // If we have 1+ messages on backlog for `id`, return them all, + // no need to waste time with successive calls + if let Some(msgs) = self.backlog.remove(&id) { + return Ok(Some(msgs)); + } + // No backlog message, need to operate the receiver + loop { + let Some(raw) = self.rx.message().await? else { + // gRPC channel was closed + return Ok(None); + }; + if raw.id == id { + return Ok(Some(VecDeque::from([raw]))); + } + match self.backlog.get_mut(&raw.id) { + Some(vec) => { + vec.push_back(raw); + } + None => { + self.backlog.insert(raw.id, VecDeque::from([raw])); + } + } + } + } +} + // Encapsulate an "initialized" state of a Plugin with interfaces that abstract // query chunking to produce whole messages for the Hipcheck engine pub struct PluginTransport { pub schemas: HashMap, pub default_policy_expr: String, // TODO - update with policy_expr type ctx: PluginContext, - tx: tokio::sync::mpsc::Sender, - rx: Streaming, + tx: mpsc::Sender, + rx: Mutex, } impl PluginTransport { pub fn name(&self) -> &str { &self.ctx.plugin.name } - pub async fn send(&mut self, query: Query) -> Result<()> { + pub async fn query(&self, query: Query) -> Result> { + use QueryState::*; + + // Send the query let query: PluginQuery = query.try_into()?; - eprintln!("Sending query: {query:?}"); + let id = query.id; self.tx .send(query) .await - .map_err(|e| hc_error!("sending query failed: {}", e)) - } - pub async fn recv(&mut self) -> Result> { - use QueryState::*; - let Some(mut raw) = self.rx.message().await? else { - // gRPC channel was closed + .map_err(|e| hc_error!("sending query failed: {}", e))?; + + // Get initial response batch + let mut rx_handle = self.rx.lock().await; + let Some(mut msg_chunks) = rx_handle.recv(id).await? else { return Ok(None); }; + drop(rx_handle); + + let mut raw = msg_chunks.pop_front().unwrap(); let mut state: QueryState = raw.state.try_into()?; - // As long as we expect successive chunks, keep receiving + + // If response is the first of a set of chunks, handle if matches!(state, QueryReplyInProgress) { while matches!(state, QueryReplyInProgress) { - let Some(next) = self.rx.message().await? else { - return Err(hc_error!( - "plugin gRPC channel closed while sending chunked message" - )); + // We expect another message. Pull it off the existing queue, + // or get a new one if we have run out + let next = match msg_chunks.pop_front() { + Some(msg) => msg, + None => { + // We ran out of messages, get a new batch + let mut rx_handle = self.rx.lock().await; + match rx_handle.recv(id).await? { + Some(x) => { + drop(rx_handle); + msg_chunks = x; + } + None => { + return Ok(None); + } + }; + msg_chunks.pop_front().unwrap() + } }; - // Assert that the ids are consistent - if next.id != raw.id { - return Err(hc_error!("msg ids from plugin do not match")); - } + // By now we have our "next" message state = next.state.try_into()?; match state { QueryUnspecified => return Err(hc_error!("unspecified error from plugin")), @@ -324,6 +382,13 @@ impl PluginTransport { } }; } + // Sanity check - after we've left this loop, there should be no left over message + if !msg_chunks.is_empty() { + return Err(hc_error!( + "received additional messages for id '{}' after QueryComplete status message", + id + )); + } } raw.try_into().map(Some) } @@ -341,3 +406,47 @@ impl From for (PluginContext, Value) { (value.0, value.1) } } + +#[derive(Clone, Debug)] +pub struct AwaitingResult { + pub id: usize, + pub publisher: String, + pub plugin: String, + pub query: String, + pub key: Value, +} +impl From for AwaitingResult { + fn from(value: Query) -> Self { + AwaitingResult { + id: value.id, + publisher: value.publisher, + plugin: value.plugin, + query: value.query, + key: value.key, + } + } +} + +#[derive(Clone, Debug)] +pub enum PluginResponse { + RemoteClosed, + AwaitingResult(AwaitingResult), + Completed(Value), +} +impl From> for PluginResponse { + fn from(value: Option) -> Self { + match value { + Some(q) => q.into(), + None => PluginResponse::RemoteClosed, + } + } +} +impl From for PluginResponse { + fn from(value: Query) -> Self { + if !value.request { + PluginResponse::Completed(value.output) + } else { + PluginResponse::AwaitingResult(value.into()) + } + } +}