Skip to content

Commit

Permalink
feat: implemented submit chunking for RFD #10
Browse files Browse the repository at this point in the history
- Added support for chunking Query submissions
- Added test for submission chunking
- Bumped hipcheck-common to 0.2.0
- Bumped rust-sdk to 0.3.0
- Updated plugins to rely on rust-sdk 0.3.0

Signed-off-by: Patrick Casey <[email protected]>
  • Loading branch information
patrickjcasey committed Jan 6, 2025
1 parent 7399a3d commit 29eb6ad
Show file tree
Hide file tree
Showing 24 changed files with 171 additions and 88 deletions.
4 changes: 2 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion hipcheck-common/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
name = "hipcheck-common"
description = "Common functionality for the Hipcheck gRPC protocol"
repository = "https://github.com/mitre/hipcheck"
version = "0.1.0"
version = "0.2.0"
license = "Apache-2.0"
edition = "2021"

Expand Down
7 changes: 5 additions & 2 deletions hipcheck-common/proto/hipcheck/v1/hipcheck.proto
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,11 @@ enum QueryState {
// Something has gone wrong.
QUERY_STATE_UNSPECIFIED = 0;

// We are submitting a new query.
QUERY_STATE_SUBMIT = 1;
// We are sending a query to a plugin and we are expecting to need to send more chunks
QUERY_STATE_SUBMIT_IN_PROGRESS = 4;

// We are completed submitting a new query.
QUERY_STATE_SUBMIT_COMPLETE = 1;

// We are replying to a query and expect more chunks.
QUERY_STATE_REPLY_IN_PROGRESS = 2;
Expand Down
158 changes: 110 additions & 48 deletions hipcheck-common/src/chunk.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,20 @@ fn estimate_size(msg: &PluginQuery) -> usize {
}

pub fn chunk_with_size(msg: PluginQuery, max_est_size: usize) -> Result<Vec<PluginQuery>> {
// Chunking only does something on response objects, mostly because
// we don't have a state to represent "SubmitInProgress"
if msg.state == QueryState::Submit as i32 {
return Ok(vec![msg]);
}
// in_progress_state - the state the PluginQuery is in for all queries in the resulting Vec,
// EXCEPT the last one
//
// completion_state - the state the PluginQuery is in if it is the last chunked message
let (in_progress_state, completion_state) = match msg.state() {
// if the message gets chunked, then it must either be a reply or submission that is in process
QueryState::Unspecified => return Err(anyhow!("msg in Unspecified query state")),
QueryState::SubmitInProgress | QueryState::SubmitComplete => {
(QueryState::SubmitInProgress, QueryState::SubmitComplete)
}
QueryState::ReplyInProgress | QueryState::ReplyComplete => {
(QueryState::ReplyInProgress, QueryState::ReplyComplete)
}
};

let mut out: Vec<PluginQuery> = vec![];
let mut base: PluginQuery = msg;
Expand All @@ -58,9 +67,9 @@ pub fn chunk_with_size(msg: PluginQuery, max_est_size: usize) -> Result<Vec<Plug
// For this loop, we want to take at most MAX_SIZE bytes because that's
// all that can fit in a PluginQuery
let mut remaining = max_est_size;
let mut query = PluginQuery {
let mut chunked_query = PluginQuery {
id: base.id,
state: QueryState::ReplyInProgress as i32,
state: in_progress_state as i32,
publisher_name: base.publisher_name.clone(),
plugin_name: base.plugin_name.clone(),
query_name: base.query_name.clone(),
Expand All @@ -71,44 +80,47 @@ pub fn chunk_with_size(msg: PluginQuery, max_est_size: usize) -> Result<Vec<Plug

if remaining > 0 && base.key.bytes().len() > 0 {
// steal from key
query.key = drain_at_most_n_bytes(&mut base.key, remaining)?;
remaining -= query.key.bytes().len();
chunked_query.key = drain_at_most_n_bytes(&mut base.key, remaining)?;
remaining -= chunked_query.key.bytes().len();
made_progress = true;
}

if remaining > 0 && base.output.bytes().len() > 0 {
// steal from output
query.output = drain_at_most_n_bytes(&mut base.output, remaining)?;
remaining -= query.output.bytes().len();
chunked_query.output = drain_at_most_n_bytes(&mut base.output, remaining)?;
remaining -= chunked_query.output.bytes().len();
made_progress = true;
}

let mut l = base.concern.len();
// While we still want to steal more bytes and we have more elements of
// `concern` to possibly steal
while remaining > 0 && l > 0 {
let i = l - 1;

let mut i = 0;
while remaining > 0 && i < base.concern.len() {
let c_bytes = base.concern.get(i).unwrap().bytes().len();

if c_bytes > max_est_size {
return Err(anyhow!("Query cannot be chunked, there is a concern that is larger than max chunk size"));
} else if c_bytes <= remaining {
// steal this concern
let concern = base.concern.swap_remove(i);
query.concern.push(concern);
chunked_query.concern.push(concern);
remaining -= c_bytes;
made_progress = true;
}
// since we use `swap_remove`, whether or not we stole a concern we know the element
// currently at `i` is too big for `remainder` (since if we removed, the element at `i`
// now is one we already passed on)
l -= 1;
i += 1;
}

out.push(query);
out.push(chunked_query);
}
out.push(base);

// ensure the last message in the chunked messages is set to the appropriate Complete state
if let Some(last) = out.last_mut() {
last.state = completion_state as i32;
}
Ok(out)
}

Expand All @@ -120,6 +132,14 @@ pub fn prepare(msg: Query) -> Result<Vec<PluginQuery>> {
chunk(msg.try_into()?)
}

/// Determine whether or not the given `QueryState` represents an intermediate InProgress state
fn in_progress_state(state: &QueryState) -> bool {
matches!(
state,
QueryState::ReplyInProgress | QueryState::SubmitInProgress
)
}

#[derive(Default)]
pub struct QuerySynthesizer {
raw: Option<PluginQuery>,
Expand All @@ -138,14 +158,16 @@ impl QuerySynthesizer {
};
}
let raw = self.raw.as_mut().unwrap(); // We know its `Some`, was set above
let mut state = raw
let initial_state: QueryState = raw
.state
.try_into()
.map_err(|_| Error::UnspecifiedQueryState)?;
// holds state of current chunk
let mut current_state: QueryState = initial_state;

// If response is the first of a set of chunks, handle
if matches!(state, QueryState::ReplyInProgress) {
while matches!(state, QueryState::ReplyInProgress) {
if in_progress_state(&current_state) {
while in_progress_state(&current_state) {
// We expect another message. Pull it off the existing queue,
// or get a new one if we have run out
let next = match chunks.next() {
Expand All @@ -156,20 +178,40 @@ impl QuerySynthesizer {
};

// By now we have our "next" message
state = next
current_state = next
.state
.try_into()
.map_err(|_| Error::UnspecifiedQueryState)?;
match state {
QueryState::Unspecified => return Err(Error::UnspecifiedQueryState),
QueryState::Submit => return Err(Error::ReceivedSubmitWhenExpectingReplyChunk),
QueryState::ReplyInProgress | QueryState::ReplyComplete => {
if state == QueryState::ReplyComplete {
raw.state = QueryState::ReplyComplete.into();
match (initial_state, current_state) {
// error out if any states are unspecified
(QueryState::Unspecified, _) | (_, QueryState::Unspecified) => {
return Err(Error::UnspecifiedQueryState)
}
// error out if expecting a Submit messages and a Reply is received
(QueryState::SubmitInProgress, QueryState::ReplyInProgress)
| (QueryState::SubmitInProgress, QueryState::ReplyComplete)
| (QueryState::SubmitComplete, QueryState::ReplyInProgress)
| (QueryState::SubmitComplete, QueryState::ReplyComplete) => {
return Err(Error::ReceivedReplyWhenExpectingSubmitChunk)
}
// error out if expecting a Reply message and Submit is received
(QueryState::ReplyInProgress, QueryState::SubmitInProgress)
| (QueryState::ReplyInProgress, QueryState::SubmitComplete)
| (QueryState::ReplyComplete, QueryState::SubmitInProgress)
| (QueryState::ReplyComplete, QueryState::SubmitComplete) => {
return Err(Error::ReceivedSubmitWhenExpectingReplyChunk)
}
// otherwise we got an expected message type
(_, _) => {
if current_state == QueryState::ReplyComplete {
raw.set_state(QueryState::ReplyComplete);
}
if current_state == QueryState::SubmitComplete {
raw.set_state(QueryState::SubmitComplete);
}
raw.key.push_str(next.key.as_str());
raw.output.push_str(next.output.as_str());
raw.concern.extend_from_slice(next.concern.as_slice());
raw.concern.extend(next.concern);
}
};
}
Expand All @@ -181,7 +223,6 @@ impl QuerySynthesizer {
});
}
}

self.raw.take().unwrap().try_into().map(Some)
}
}
Expand Down Expand Up @@ -209,23 +250,44 @@ mod test {

#[test]
fn test_chunking() {
let query = PluginQuery {
id: 0,
state: QueryState::ReplyComplete as i32,
publisher_name: "".to_owned(),
plugin_name: "".to_owned(),
query_name: "".to_owned(),
// This key will cause the chunk not to occur on a char boundary
key: "aこれは実験です".to_owned(),
output: "".to_owned(),
concern: vec!["< 10".to_owned(), "0123456789".to_owned()],
};
let res = match chunk_with_size(query, 10) {
Ok(r) => r,
Err(e) => {
panic!("{e}");
}
};
assert_eq!(res.len(), 4);
// test both reply and submission chunking
let states = [
(QueryState::SubmitInProgress, QueryState::SubmitComplete),
(QueryState::ReplyInProgress, QueryState::ReplyComplete),
];

for (intermediate_state, final_state) in states.into_iter() {
let orig_query = PluginQuery {
id: 0,
state: final_state as i32,
publisher_name: "".to_owned(),
plugin_name: "".to_owned(),
query_name: "".to_owned(),
// This key will cause the chunk not to occur on a char boundary
key: serde_json::to_string("aこれは実験です").unwrap(),
output: serde_json::to_string("").unwrap(),
concern: vec!["< 10".to_owned(), "0123456789".to_owned()],
};
let res = match chunk_with_size(orig_query.clone(), 10) {
Ok(r) => r,
Err(e) => {
panic!("{e}");
}
};
// ensure first 4 are ...InProgress
assert_eq!(
res.iter()
.filter(|x| x.state() == intermediate_state)
.count(),
4
);
// ensure last one is ...Complete
assert_eq!(res.last().unwrap().state(), final_state);
assert_eq!(res.len(), 5);
// attempt to reassemble message
let mut synth = QuerySynthesizer::default();
let synthesized_query = synth.add(res.into_iter()).unwrap();
assert_eq!(orig_query, synthesized_query.unwrap().try_into().unwrap());
}
}
}
8 changes: 8 additions & 0 deletions hipcheck-common/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,18 @@ pub enum Error {
#[error("unexpected ReplyInProgress state for query")]
UnexpectedReplyInProgress,

/// The `PluginEngine` received a message with the unexpected status `RequestInProgress`
#[error("unexpected RequestInProgress state for query")]
UnexpectedRequestInProgress,

/// The `PluginEngine` received a message with a request-type status when it expected a reply
#[error("remote sent QuerySubmit when reply chunk expected")]
ReceivedSubmitWhenExpectingReplyChunk,

/// The `PluginEngine` received a message with a reply-type status when it expected a submit
#[error("remote sent QueryReply when submit chunk expected")]
ReceivedReplyWhenExpectingSubmitChunk,

/// The `PluginEngine` received additional messages when it did not expect any
#[error("received additional message for ID '{id}' after query completion")]
MoreAfterQueryComplete { id: usize },
Expand Down
16 changes: 10 additions & 6 deletions hipcheck-common/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ impl TryFrom<QueryState> for QueryDirection {
fn try_from(value: QueryState) -> Result<Self, Self::Error> {
match value {
QueryState::Unspecified => Err(Error::UnspecifiedQueryState),
QueryState::Submit => Ok(QueryDirection::Request),
QueryState::SubmitInProgress => Err(Error::UnexpectedRequestInProgress),
QueryState::SubmitComplete => Ok(QueryDirection::Request),
QueryState::ReplyInProgress => Err(Error::UnexpectedReplyInProgress),
QueryState::ReplyComplete => Ok(QueryDirection::Response),
}
Expand All @@ -39,7 +40,7 @@ impl TryFrom<QueryState> for QueryDirection {
impl From<QueryDirection> for QueryState {
fn from(value: QueryDirection) -> Self {
match value {
QueryDirection::Request => QueryState::Submit,
QueryDirection::Request => QueryState::SubmitComplete,
QueryDirection::Response => QueryState::ReplyComplete,
}
}
Expand All @@ -49,15 +50,18 @@ impl TryFrom<PluginQuery> for Query {
type Error = Error;

fn try_from(value: PluginQuery) -> Result<Query, Self::Error> {
let direction = QueryDirection::try_from(value.state())?;
let key = serde_json::from_str(&value.key).map_err(Error::InvalidJsonInQueryKey)?;
let output =
serde_json::from_str(&value.output).map_err(Error::InvalidJsonInQueryOutput)?;
Ok(Query {
id: value.id as usize,
direction: QueryDirection::try_from(value.state())?,
direction,
publisher: value.publisher_name,
plugin: value.plugin_name,
query: value.query_name,
key: serde_json::from_str(value.key.as_str()).map_err(Error::InvalidJsonInQueryKey)?,
output: serde_json::from_str(value.output.as_str())
.map_err(Error::InvalidJsonInQueryOutput)?,
key,
output,
concerns: value.concern,
})
}
Expand Down
2 changes: 1 addition & 1 deletion hipcheck/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ xz2 = "0.1.7"
zip = "2.2.2"
zip-extensions = "0.8.1"
zstd = "0.13.2"
hipcheck-common = { version = "0.1.0", path = "../hipcheck-common" }
hipcheck-common = { version = "0.2.0", path = "../hipcheck-common" }
serde_with = "3.11.0"

[build-dependencies]
Expand Down
4 changes: 2 additions & 2 deletions plugins/activity/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ publish = false

[dependencies]
clap = { version = "4.5.23", features = ["derive"] }
hipcheck-sdk = { version = "0.2.0", path = "../../sdk/rust", features = [
hipcheck-sdk = { version = "0.3.0", path = "../../sdk/rust", features = [
"macros",
] }
jiff = { version = "0.1.14", features = ["serde"] }
Expand All @@ -19,6 +19,6 @@ serde_json = "1.0.134"
tokio = { version = "1.42.0", features = ["rt"] }

[dev-dependencies]
hipcheck-sdk = { version = "0.2.0", path = "../../sdk/rust", features = [
hipcheck-sdk = { version = "0.3.0", path = "../../sdk/rust", features = [
"mock_engine",
] }
4 changes: 2 additions & 2 deletions plugins/affiliation/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ publish = false
[dependencies]
anyhow = "1.0.95"
clap = { version = "4.5.23", features = ["derive"] }
hipcheck-sdk = { version = "0.2.0", path = "../../sdk/rust", features = [
hipcheck-sdk = { version = "0.3.0", path = "../../sdk/rust", features = [
"macros",
] }
kdl = "4.7.1"
Expand All @@ -22,7 +22,7 @@ strum = { version = "0.26.3", features = ["derive"] }
tokio = { version = "1.42.0", features = ["rt"] }

[dev-dependencies]
hipcheck-sdk = { version = "0.2.0", path = "../../sdk/rust", features = [
hipcheck-sdk = { version = "0.3.0", path = "../../sdk/rust", features = [
"macros",
"mock_engine",
] }
Loading

0 comments on commit 29eb6ad

Please sign in to comment.