diff --git a/sdk/rust/src/engine.rs b/sdk/rust/src/engine.rs index 0021de42..d7338ef2 100644 --- a/sdk/rust/src/engine.rs +++ b/sdk/rust/src/engine.rs @@ -14,12 +14,12 @@ use hipcheck_common::{ types::{Query, QueryDirection}, }; use serde::Serialize; -use std::sync::Arc; use std::{ collections::{HashMap, VecDeque}, future::poll_fn, pin::Pin, result::Result as StdResult, + sync::Arc, }; use tokio::sync::mpsc::{self, error::TrySendError}; use tonic::Status; @@ -33,6 +33,42 @@ impl From for Error { type SessionTracker = HashMap>>; +/// Used for building a up a `Vec` of keys to send to specific hipcheck plugin +pub struct QueryBuilder<'engine> { + keys: Vec, + target: QueryTarget, + plugin_engine: &'engine mut PluginEngine, +} + +impl<'engine> QueryBuilder<'engine> { + /// Create a new `QueryBuilder` to dynamically add keys to send to `target` plugin + fn new(plugin_engine: &'engine mut PluginEngine, target: T) -> Result> + where + T: TryInto>, + { + let target: QueryTarget = target.try_into().map_err(|e| e.into())?; + Ok(Self { + plugin_engine, + target, + keys: vec![], + }) + } + + /// Add a key to the internal list of keys to be sent to `target` + /// + /// Returns the index `key` was inserted was inserted to + pub fn query(&mut self, key: JsonValue) -> usize { + let len = self.keys.len(); + self.keys.push(key); + len + } + + /// Send all of the provided keys to `target` plugin endpont and wait for query results + pub async fn send(self) -> Result> { + self.plugin_engine.batch_query(self.target, self.keys).await + } +} + /// Manages a particular query session. /// /// This struct invokes a `Query` trait object, passing a handle to itself to `Query::run()`. This @@ -57,6 +93,16 @@ impl PluginEngine { mock_responses.into() } + /// Convenience function to expose a `QueryBuilder` to make it convenient to dynamically build + /// up queries to plugins and send them off to the `target` plugin, in as few GRPC calls as + /// possible + pub fn batch(&mut self, target: T) -> Result + where + T: TryInto>, + { + QueryBuilder::new(self, target) + } + async fn query_inner( &mut self, target: QueryTarget, @@ -547,3 +593,35 @@ impl MockResponses { Ok(()) } } + +#[cfg(test)] +mod test { + use super::*; + + #[cfg(feature = "mock_engine")] + #[tokio::test] + async fn test_query_builder() { + let mut mock_responses = MockResponses::new(); + mock_responses + .insert("mitre/foo", "abcd", Ok(1234)) + .unwrap(); + mock_responses + .insert("mitre/foo", "efgh", Ok(5678)) + .unwrap(); + let mut engine = PluginEngine::mock(mock_responses); + let mut builder = engine.batch("mitre/foo").unwrap(); + let idx = builder.query("abcd".into()); + assert_eq!(idx, 0); + let idx = builder.query("efgh".into()); + assert_eq!(idx, 1); + let response = builder.send().await.unwrap(); + assert_eq!( + response.first().unwrap(), + &>::into(1234) + ); + assert_eq!( + response.get(1).unwrap(), + &>::into(5678) + ); + } +}