From 0e7f7ec79e48ff99387b2d06576122bae8469810 Mon Sep 17 00:00:00 2001 From: Andrew Lilley Brinker Date: Tue, 27 Aug 2024 17:40:06 -0700 Subject: [PATCH] feat: Update proto def to pass 'buf lint' This commit modifies the protobuf definition for Hipcheck plugins to meet the expectations of 'buf lint', which checks protobuf definitions against a set of best practices. The main goal here is to make the service definition more defensive against forward-compatibility issues. This also includes updates to the code that interacts with the generated code to make sure it still compiles. Signed-off-by: Andrew Lilley Brinker --- .buf.yaml | 5 + Cargo.lock | 5 + hipcheck/Cargo.toml | 1 + hipcheck/build.rs | 2 +- hipcheck/src/main.rs | 2 +- hipcheck/src/plugin/manager.rs | 4 +- hipcheck/src/plugin/types.rs | 192 +++-- plugins/dummy_rand_data/Cargo.toml | 5 + plugins/dummy_rand_data/build.rs | 6 + .../query_schema_get_rand.json | 0 plugins/dummy_rand_data/src/hipcheck.rs | 679 ------------------ plugins/dummy_rand_data/src/main.rs | 100 +-- .../{hipcheck_transport.rs => transport.rs} | 232 +++--- plugins/dummy_sha256/Cargo.toml | 5 + plugins/dummy_sha256/build.rs | 6 + .../{src => schema}/query_schema_sha256.json | 0 plugins/dummy_sha256/src/hipcheck.rs | 679 ------------------ .../dummy_sha256/src/hipcheck_transport.rs | 226 ------ plugins/dummy_sha256/src/main.rs | 205 +++--- plugins/dummy_sha256/src/transport.rs | 365 ++++++++++ .../hipcheck/v1}/hipcheck.proto | 132 ++-- 21 files changed, 947 insertions(+), 1904 deletions(-) create mode 100644 .buf.yaml create mode 100644 plugins/dummy_rand_data/build.rs rename plugins/dummy_rand_data/{src => schema}/query_schema_get_rand.json (100%) delete mode 100644 plugins/dummy_rand_data/src/hipcheck.rs rename plugins/dummy_rand_data/src/{hipcheck_transport.rs => transport.rs} (51%) create mode 100644 plugins/dummy_sha256/build.rs rename plugins/dummy_sha256/{src => schema}/query_schema_sha256.json (100%) delete mode 100644 plugins/dummy_sha256/src/hipcheck.rs delete mode 100644 plugins/dummy_sha256/src/hipcheck_transport.rs create mode 100644 plugins/dummy_sha256/src/transport.rs rename {hipcheck/proto => proto/hipcheck/v1}/hipcheck.proto (70%) diff --git a/.buf.yaml b/.buf.yaml new file mode 100644 index 00000000..bb18c2c8 --- /dev/null +++ b/.buf.yaml @@ -0,0 +1,5 @@ +version: v2 +lint: + use: + - DEFAULT + rpc_allow_same_request_response: true diff --git a/Cargo.lock b/Cargo.lock index 415eaf32..536a47f5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -700,6 +700,7 @@ version = "0.1.0" dependencies = [ "anyhow", "clap", + "futures", "indexmap 2.4.0", "prost", "rand", @@ -707,6 +708,7 @@ dependencies = [ "tokio", "tokio-stream", "tonic", + "tonic-build", ] [[package]] @@ -715,6 +717,7 @@ version = "0.1.0" dependencies = [ "anyhow", "clap", + "futures", "indexmap 2.4.0", "prost", "rand", @@ -723,6 +726,7 @@ dependencies = [ "tokio", "tokio-stream", "tonic", + "tonic-build", ] [[package]] @@ -1174,6 +1178,7 @@ dependencies = [ "test-log", "thiserror", "tokio", + "tokio-stream", "toml", "tonic", "tonic-build", diff --git a/hipcheck/Cargo.toml b/hipcheck/Cargo.toml index 53c7d5ad..677c948c 100644 --- a/hipcheck/Cargo.toml +++ b/hipcheck/Cargo.toml @@ -131,6 +131,7 @@ which = { version = "6.0.3", default-features = false } xml-rs = "0.8.20" xz2 = "0.1.7" zip = "2.1.6" +tokio-stream = "0.1.15" [build-dependencies] diff --git a/hipcheck/build.rs b/hipcheck/build.rs index 39eceaa8..be11f191 100644 --- a/hipcheck/build.rs +++ b/hipcheck/build.rs @@ -1,6 +1,6 @@ // SPDX-License-Identifier: Apache-2.0 fn main() -> anyhow::Result<()> { - tonic_build::compile_protos("proto/hipcheck.proto")?; + tonic_build::compile_protos("../proto/hipcheck/v1/hipcheck.proto")?; Ok(()) } diff --git a/hipcheck/src/main.rs b/hipcheck/src/main.rs index deccb6bb..6dc6651f 100644 --- a/hipcheck/src/main.rs +++ b/hipcheck/src/main.rs @@ -27,7 +27,7 @@ mod util; mod version; pub mod hipcheck { - include!(concat!(env!("OUT_DIR"), "/hipcheck.rs")); + include!(concat!(env!("OUT_DIR"), "/hipcheck.v1.rs")); } use crate::analysis::report_builder::build_report; diff --git a/hipcheck/src/plugin/manager.rs b/hipcheck/src/plugin/manager.rs index 5e1af99b..a978c5d5 100644 --- a/hipcheck/src/plugin/manager.rs +++ b/hipcheck/src/plugin/manager.rs @@ -1,4 +1,4 @@ -use crate::hipcheck::plugin_client::PluginClient; +use crate::hipcheck::plugin_service_client::PluginServiceClient; use crate::plugin::{HcPluginClient, Plugin, PluginContext}; use crate::{hc_error, Result, F64}; use futures::future::join_all; @@ -106,7 +106,7 @@ impl PluginExecutor { .mul_f64(jitter_percent); sleep_until(Instant::now() + sleep_duration).await; if let Ok(grpc) = - PluginClient::connect(format!("http://127.0.0.1:{port_str}")).await + PluginServiceClient::connect(format!("http://127.0.0.1:{port_str}")).await { opt_grpc = Some(grpc); break; diff --git a/hipcheck/src/plugin/types.rs b/hipcheck/src/plugin/types.rs index 26ac2ac3..0decf2b7 100644 --- a/hipcheck/src/plugin/types.rs +++ b/hipcheck/src/plugin/types.rs @@ -1,19 +1,30 @@ -use crate::hipcheck::plugin_client::PluginClient; -use crate::hipcheck::{ - Configuration, ConfigurationResult as PluginConfigResult, ConfigurationStatus, Empty, - Query as PluginQuery, QueryState, Schema as PluginSchema, +use crate::{ + hc_error, + hipcheck::{ + plugin_service_client::PluginServiceClient, ConfigurationStatus, Empty, + GetDefaultPolicyExpressionRequest, GetQuerySchemasRequest, + GetQuerySchemasResponse as PluginSchema, InitiateQueryProtocolRequest, + InitiateQueryProtocolResponse, Query as PluginQuery, QueryState, SetConfigurationRequest, + SetConfigurationResponse as PluginConfigResult, + }, + Error, Result, }; -use crate::{hc_error, Error, Result, StdResult}; +use futures::{Stream, StreamExt}; use serde_json::Value; -use std::collections::{HashMap, VecDeque}; -use std::convert::TryFrom; -use std::ops::Not; -use std::process::Child; +use std::{ + collections::{HashMap, VecDeque}, + convert::TryFrom, + future::{self, poll_fn}, + ops::Not as _, + pin::Pin, + process::Child, + result::Result as StdResult, +}; use tokio::sync::{mpsc, Mutex}; -use tonic::codec::Streaming; -use tonic::transport::Channel; +use tokio_stream::wrappers::ReceiverStream; +use tonic::{codec::Streaming, transport::Channel, Code, Status}; -pub type HcPluginClient = PluginClient; +pub type HcPluginClient = PluginServiceClient; #[derive(Clone, Debug)] pub struct Plugin { @@ -28,6 +39,7 @@ pub struct Schema { pub key_schema: Value, pub output_schema: Value, } + impl TryFrom for Schema { type Error = crate::error::Error; fn try_from(value: PluginSchema) -> Result { @@ -46,6 +58,7 @@ pub struct ConfigurationResult { pub status: ConfigurationStatus, pub message: Option, } + impl TryFrom for ConfigurationResult { type Error = crate::error::Error; fn try_from(value: PluginConfigResult) -> Result { @@ -54,6 +67,7 @@ impl TryFrom for ConfigurationResult { Ok(ConfigurationResult { status, message }) } } + // hipcheck::ConfigurationStatus has an enum that captures both error and success // scenarios. The below code allows interpreting the struct as a Rust Result. If // the success variant was the status, Ok(()) is returned, otherwise the code @@ -70,37 +84,42 @@ impl ConfigurationResult { )) } } + pub enum ConfigErrorType { Unknown = 0, MissingRequiredConfig = 2, UnrecognizedConfig = 3, InvalidConfigValue = 4, } + impl TryFrom for ConfigErrorType { type Error = crate::error::Error; fn try_from(value: ConfigurationStatus) -> Result { use ConfigErrorType::*; use ConfigurationStatus::*; Ok(match value as i32 { - x if x == ErrorUnknown as i32 => Unknown, - x if x == ErrorMissingRequiredConfiguration as i32 => MissingRequiredConfig, - x if x == ErrorUnrecognizedConfiguration as i32 => UnrecognizedConfig, - x if x == ErrorInvalidConfigurationValue as i32 => InvalidConfigValue, + x if x == Unspecified as i32 => Unknown, + x if x == MissingRequiredConfiguration as i32 => MissingRequiredConfig, + x if x == UnrecognizedConfiguration as i32 => UnrecognizedConfig, + x if x == InvalidConfigurationValue as i32 => InvalidConfigValue, x => { return Err(hc_error!("status value '{}' is not an error", x)); } }) } } + pub struct ConfigError { error: ConfigErrorType, message: Option, } + impl ConfigError { pub fn new(error: ConfigErrorType, message: Option) -> Self { ConfigError { error, message } } } + impl std::fmt::Display for ConfigError { fn fmt(&self, f: &mut std::fmt::Formatter) -> StdResult<(), std::fmt::Error> { use ConfigErrorType::*; @@ -126,11 +145,17 @@ pub struct PluginContext { pub grpc: HcPluginClient, pub proc: Child, } + // Redefinition of `grpc` field's functions with more useful types, additional // error & sanity checking impl PluginContext { pub async fn get_query_schemas(&mut self) -> Result> { - let mut res = self.grpc.get_query_schemas(Empty {}).await?; + let mut res = self + .grpc + .get_query_schemas(GetQuerySchemasRequest { + empty: Some(Empty {}), + }) + .await?; let stream = res.get_mut(); let mut schema_builder: HashMap = HashMap::new(); while let Some(msg) = stream.message().await? { @@ -149,36 +174,58 @@ impl PluginContext { .map(TryInto::try_into) .collect() } + pub async fn set_configuration(&mut self, conf: &Value) -> Result { - let conf_query = Configuration { + let req = SetConfigurationRequest { configuration: serde_json::to_string(&conf)?, }; - let res = self.grpc.set_configuration(conf_query).await?; + let res = self.grpc.set_configuration(req).await?; res.into_inner().try_into() } + // TODO - the String in the result should be replaced with a structured // type once the policy expression code is integrated pub async fn get_default_policy_expression(&mut self) -> Result { - let mut res = self.grpc.get_default_policy_expression(Empty {}).await?; - Ok(res.get_ref().policy_expression.to_owned()) + let req = GetDefaultPolicyExpressionRequest { + empty: Some(Empty {}), + }; + + let res = self.grpc.get_default_policy_expression(req).await?; + let policy_expression = res.get_ref().policy_expression.to_owned(); + Ok(policy_expression) } + pub async fn initiate_query_protocol( &mut self, mut rx: mpsc::Receiver, - ) -> Result> { - let stream = async_stream::stream! { - while let Some(item) = rx.recv().await { - yield item; - } - }; - match self.grpc.initiate_query_protocol(stream).await { - Ok(resp) => Ok(resp.into_inner()), - Err(e) => Err(hc_error!( - "query protocol initiation failed with tonic status code {}", - e - )), - } + ) -> Result { + // Convert the receiver into a stream. + let stream = ReceiverStream::new(rx) + .map(|query| InitiateQueryProtocolRequest { query: Some(query) }); + + // Make the gRPC request. + let resp = self + .grpc + .initiate_query_protocol(stream) + .await + .map_err(|err| { + hc_error!( + "query protocol initiation failed with tonic status code {}", + err + ) + })?; + + // Pull out the inner query from the response. + let stream = resp.into_inner().map(|response| { + response.and_then(|res| { + res.query + .ok_or_else(|| Status::new(Code::Unknown, "no query present in response")) + }) + }); + + Ok(Box::new(stream)) } + pub async fn initialize(mut self, config: Value) -> Result { let schemas = HashMap::::from_iter( self.get_query_schemas() @@ -227,22 +274,27 @@ pub struct Query { pub key: Value, pub output: Value, } + impl TryFrom for Query { type Error = Error; + fn try_from(value: PluginQuery) -> Result { use QueryState::*; + let request = match TryInto::::try_into(value.state)? { - QueryUnspecified => return Err(hc_error!("unspecified error from plugin")), - QueryReplyInProgress => { + Unspecified => return Err(hc_error!("unspecified error from plugin")), + ReplyInProgress => { return Err(hc_error!( "invalid state QueryReplyInProgress for conversion to Query" )) } - QueryReplyComplete => false, - QuerySubmit => true, + ReplyComplete => false, + Submit => true, }; + let key: Value = serde_json::from_str(value.key.as_str())?; let output: Value = serde_json::from_str(value.output.as_str())?; + Ok(Query { id: value.id as usize, request, @@ -254,15 +306,19 @@ impl TryFrom for Query { }) } } + impl TryFrom for PluginQuery { type Error = crate::error::Error; + fn try_from(value: Query) -> Result { let state_enum = match value.request { - true => QueryState::QuerySubmit, - false => QueryState::QueryReplyComplete, + true => QueryState::Submit, + false => QueryState::ReplyComplete, }; + let key = serde_json::to_string(&value.key)?; let output = serde_json::to_string(&value.output)?; + Ok(PluginQuery { id: value.id as i32, state: state_enum as i32, @@ -275,18 +331,47 @@ impl TryFrom for PluginQuery { } } -#[derive(Debug)] pub struct MultiplexedQueryReceiver { - rx: Streaming, + rx: QueryStream, backlog: HashMap>, } + +impl std::fmt::Debug for MultiplexedQueryReceiver { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("MultiplexedQueryReceiver") + .field("rx", &"") + .field("backlog", &self.backlog) + .finish() + } +} + +/// Helper type for a stream of query messages. +/// +/// Note that the inner item is a `Result` because the inner +/// `query` field on the message can technically be missing. +/// +/// This case is handled in the `message` method to flatten +/// that kind of error so code consuming the stream doesn't +/// have to worry about it. +type QueryStream = Box> + Send + Unpin + 'static>; + impl MultiplexedQueryReceiver { - pub fn new(rx: Streaming) -> Self { + pub fn new(rx: QueryStream) -> Self { Self { rx, backlog: HashMap::new(), } } + + /// Poll the underlying stream future to get the next query, if present. + async fn message(&mut self) -> StdResult, Status> { + match poll_fn(|cx| Pin::new(self.rx.as_mut()).poll_next(cx)).await { + Some(Ok(m)) => Ok(Some(m)), + Some(Err(e)) => Err(e), + None => Ok(None), + } + } + // @Invariant - this function will never return an empty VecDeque pub async fn recv(&mut self, id: i32) -> Result>> { // If we have 1+ messages on backlog for `id`, return them all, @@ -294,15 +379,18 @@ impl MultiplexedQueryReceiver { if let Some(msgs) = self.backlog.remove(&id) { return Ok(Some(msgs)); } + // No backlog message, need to operate the receiver loop { - let Some(raw) = self.rx.message().await? else { + let Some(raw) = self.message().await? else { // gRPC channel was closed return Ok(None); }; + if raw.id == id { return Ok(Some(VecDeque::from([raw]))); } + match self.backlog.get_mut(&raw.id) { Some(vec) => { vec.push_back(raw); @@ -325,10 +413,12 @@ pub struct PluginTransport { tx: mpsc::Sender, rx: Mutex, } + impl PluginTransport { pub fn name(&self) -> &str { &self.ctx.plugin.name } + pub async fn query(&self, query: Query) -> Result> { use QueryState::*; @@ -351,8 +441,8 @@ impl PluginTransport { let mut state: QueryState = raw.state.try_into()?; // If response is the first of a set of chunks, handle - if matches!(state, QueryReplyInProgress) { - while matches!(state, QueryReplyInProgress) { + if matches!(state, ReplyInProgress) { + while matches!(state, ReplyInProgress) { // We expect another message. Pull it off the existing queue, // or get a new one if we have run out let next = match msg_chunks.pop_front() { @@ -375,13 +465,13 @@ impl PluginTransport { // By now we have our "next" message state = next.state.try_into()?; match state { - QueryUnspecified => return Err(hc_error!("unspecified error from plugin")), - QuerySubmit => { + Unspecified => return Err(hc_error!("unspecified error from plugin")), + Submit => { return Err(hc_error!( "plugin sent QuerySubmit state when reply chunk expected" )) } - QueryReplyInProgress | QueryReplyComplete => { + ReplyInProgress | ReplyComplete => { raw.output.push_str(next.output.as_str()); } }; @@ -404,6 +494,7 @@ impl From for (Plugin, Value) { (value.0, value.1) } } + pub struct PluginContextWithConfig(pub PluginContext, pub Value); impl From for (PluginContext, Value) { fn from(value: PluginContextWithConfig) -> Self { @@ -419,6 +510,7 @@ pub struct AwaitingResult { pub query: String, pub key: Value, } + impl From for AwaitingResult { fn from(value: Query) -> Self { AwaitingResult { @@ -437,6 +529,7 @@ pub enum PluginResponse { AwaitingResult(AwaitingResult), Completed(Value), } + impl From> for PluginResponse { fn from(value: Option) -> Self { match value { @@ -445,6 +538,7 @@ impl From> for PluginResponse { } } } + impl From for PluginResponse { fn from(value: Query) -> Self { if !value.request { diff --git a/plugins/dummy_rand_data/Cargo.toml b/plugins/dummy_rand_data/Cargo.toml index 0d5b659c..af51ea32 100644 --- a/plugins/dummy_rand_data/Cargo.toml +++ b/plugins/dummy_rand_data/Cargo.toml @@ -7,6 +7,7 @@ publish = false [dependencies] anyhow = "1.0.86" clap = { version = "4.5.16", features = ["derive"] } +futures = "0.3.30" indexmap = "2.4.0" prost = "0.13.1" rand = "0.8.5" @@ -14,3 +15,7 @@ serde_json = "1.0.125" tokio = { version = "1.39.2", features = ["rt"] } tokio-stream = "0.1.15" tonic = "0.12.1" + +[build-dependencies] +anyhow = "1.0.86" +tonic-build = "0.12.1" diff --git a/plugins/dummy_rand_data/build.rs b/plugins/dummy_rand_data/build.rs new file mode 100644 index 00000000..759819ea --- /dev/null +++ b/plugins/dummy_rand_data/build.rs @@ -0,0 +1,6 @@ +// SPDX-License-Identifier: Apache-2.0 + +fn main() -> anyhow::Result<()> { + tonic_build::compile_protos("../../proto/hipcheck/v1/hipcheck.proto")?; + Ok(()) +} diff --git a/plugins/dummy_rand_data/src/query_schema_get_rand.json b/plugins/dummy_rand_data/schema/query_schema_get_rand.json similarity index 100% rename from plugins/dummy_rand_data/src/query_schema_get_rand.json rename to plugins/dummy_rand_data/schema/query_schema_get_rand.json diff --git a/plugins/dummy_rand_data/src/hipcheck.rs b/plugins/dummy_rand_data/src/hipcheck.rs deleted file mode 100644 index 50ce6bf2..00000000 --- a/plugins/dummy_rand_data/src/hipcheck.rs +++ /dev/null @@ -1,679 +0,0 @@ -#![allow(clippy::enum_variant_names)] - -// This file is @generated by prost-build. -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct Configuration { - /// JSON string containing configuration data expected by the plugin, - /// pulled from the user's policy file. - #[prost(string, tag = "1")] - pub configuration: ::prost::alloc::string::String, -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct ConfigurationResult { - /// The status of the configuration call. - #[prost(enumeration = "ConfigurationStatus", tag = "1")] - pub status: i32, - /// An optional error message, if there was an error. - #[prost(string, tag = "2")] - pub message: ::prost::alloc::string::String, -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct PolicyExpression { - /// A policy expression, if the plugin has a default policy. - /// This MUST be filled in with any default values pulled from the plugin's - /// configuration. Hipcheck will only request the default policy _after_ - /// configuring the plugin. - #[prost(string, tag = "1")] - pub policy_expression: ::prost::alloc::string::String, -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct Schema { - /// The name of the query being described by the schemas provided. - /// - /// If either the key and/or output schemas result in a message which is - /// too big, they may be chunked across multiple replies in the stream. - /// Replies with matching query names should have their fields concatenated - /// in the order received to reconstruct the chunks. - #[prost(string, tag = "1")] - pub query_name: ::prost::alloc::string::String, - /// The key schema, in JSON Schema format. - #[prost(string, tag = "2")] - pub key_schema: ::prost::alloc::string::String, - /// The output schema, in JSON Schema format. - #[prost(string, tag = "3")] - pub output_schema: ::prost::alloc::string::String, -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct Query { - /// The ID of the request, used to associate requests and replies. - /// Odd numbers = initiated by `hc`. - /// Even numbers = initiated by a plugin. - #[prost(int32, tag = "1")] - pub id: i32, - /// The state of the query, indicating if this is a request or a reply, - /// and if it's a reply whether it's the end of the reply. - #[prost(enumeration = "QueryState", tag = "2")] - pub state: i32, - /// Publisher name and plugin name, when sent from Hipcheck to a plugin - /// to initiate a fresh query, are used by the receiving plugin to validate - /// that the query was intended for them. - /// - /// When a plugin is making a query to another plugin through Hipcheck, it's - /// used to indicate the destination plugin, and to indicate the plugin that - /// is replying when Hipcheck sends back the reply. - #[prost(string, tag = "3")] - pub publisher_name: ::prost::alloc::string::String, - #[prost(string, tag = "4")] - pub plugin_name: ::prost::alloc::string::String, - /// The name of the query being made, so the responding plugin knows what - /// to do with the provided data. - #[prost(string, tag = "5")] - pub query_name: ::prost::alloc::string::String, - /// The key for the query, as a JSON object. This is the data that Hipcheck's - /// incremental computation system will use to cache the response. - #[prost(string, tag = "6")] - pub key: ::prost::alloc::string::String, - /// The response for the query, as a JSON object. This will be cached by - /// Hipcheck for future queries matching the publisher name, plugin name, - /// query name, and key. - #[prost(string, tag = "7")] - pub output: ::prost::alloc::string::String, -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, Copy, PartialEq, ::prost::Message)] -pub struct Empty {} -#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] -#[repr(i32)] -pub enum ConfigurationStatus { - /// An unknown error occured. - ErrorUnknown = 0, - /// No error; the operation was successful. - ErrorNone = 1, - /// The user failed to provide a required configuration item. - ErrorMissingRequiredConfiguration = 2, - /// The user provided a configuration item whose name was not recognized. - ErrorUnrecognizedConfiguration = 3, - /// The user provided a configuration item whose value is invalid. - ErrorInvalidConfigurationValue = 4, -} -impl ConfigurationStatus { - /// String value of the enum field names used in the ProtoBuf definition. - /// - /// The values are not transformed in any way and thus are considered stable - /// (if the ProtoBuf definition does not change) and safe for programmatic use. - pub fn as_str_name(&self) -> &'static str { - match self { - ConfigurationStatus::ErrorUnknown => "ERROR_UNKNOWN", - ConfigurationStatus::ErrorNone => "ERROR_NONE", - ConfigurationStatus::ErrorMissingRequiredConfiguration => { - "ERROR_MISSING_REQUIRED_CONFIGURATION" - } - ConfigurationStatus::ErrorUnrecognizedConfiguration => { - "ERROR_UNRECOGNIZED_CONFIGURATION" - } - ConfigurationStatus::ErrorInvalidConfigurationValue => { - "ERROR_INVALID_CONFIGURATION_VALUE" - } - } - } - /// Creates an enum from field names used in the ProtoBuf definition. - pub fn from_str_name(value: &str) -> ::core::option::Option { - match value { - "ERROR_UNKNOWN" => Some(Self::ErrorUnknown), - "ERROR_NONE" => Some(Self::ErrorNone), - "ERROR_MISSING_REQUIRED_CONFIGURATION" => Some(Self::ErrorMissingRequiredConfiguration), - "ERROR_UNRECOGNIZED_CONFIGURATION" => Some(Self::ErrorUnrecognizedConfiguration), - "ERROR_INVALID_CONFIGURATION_VALUE" => Some(Self::ErrorInvalidConfigurationValue), - _ => None, - } - } -} -#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] -#[repr(i32)] -pub enum QueryState { - /// Something has gone wrong. - QueryUnspecified = 0, - /// We are submitting a new query. - QuerySubmit = 1, - /// We are replying to a query and expect more chunks. - QueryReplyInProgress = 2, - /// We are closing a reply to a query. If a query response is in one chunk, - /// just send this. If a query is in more than one chunk, send this with - /// the last message in the reply. This tells the receiver that all chunks - /// have been received. - QueryReplyComplete = 3, -} -impl QueryState { - /// String value of the enum field names used in the ProtoBuf definition. - /// - /// The values are not transformed in any way and thus are considered stable - /// (if the ProtoBuf definition does not change) and safe for programmatic use. - pub fn as_str_name(&self) -> &'static str { - match self { - QueryState::QueryUnspecified => "QUERY_UNSPECIFIED", - QueryState::QuerySubmit => "QUERY_SUBMIT", - QueryState::QueryReplyInProgress => "QUERY_REPLY_IN_PROGRESS", - QueryState::QueryReplyComplete => "QUERY_REPLY_COMPLETE", - } - } - /// Creates an enum from field names used in the ProtoBuf definition. - pub fn from_str_name(value: &str) -> ::core::option::Option { - match value { - "QUERY_UNSPECIFIED" => Some(Self::QueryUnspecified), - "QUERY_SUBMIT" => Some(Self::QuerySubmit), - "QUERY_REPLY_IN_PROGRESS" => Some(Self::QueryReplyInProgress), - "QUERY_REPLY_COMPLETE" => Some(Self::QueryReplyComplete), - _ => None, - } - } -} -/// Generated client implementations. -pub mod plugin_client { - #![allow(unused_variables, dead_code, missing_docs, clippy::let_unit_value)] - use tonic::codegen::http::Uri; - use tonic::codegen::*; - #[derive(Debug, Clone)] - pub struct PluginClient { - inner: tonic::client::Grpc, - } - impl PluginClient { - /// Attempt to create a new client by connecting to a given endpoint. - pub async fn connect(dst: D) -> Result - where - D: TryInto, - D::Error: Into, - { - let conn = tonic::transport::Endpoint::new(dst)?.connect().await?; - Ok(Self::new(conn)) - } - } - impl PluginClient - where - T: tonic::client::GrpcService, - T::Error: Into, - T::ResponseBody: Body + Send + 'static, - ::Error: Into + Send, - { - pub fn new(inner: T) -> Self { - let inner = tonic::client::Grpc::new(inner); - Self { inner } - } - pub fn with_origin(inner: T, origin: Uri) -> Self { - let inner = tonic::client::Grpc::with_origin(inner, origin); - Self { inner } - } - pub fn with_interceptor( - inner: T, - interceptor: F, - ) -> PluginClient> - where - F: tonic::service::Interceptor, - T::ResponseBody: Default, - T: tonic::codegen::Service< - http::Request, - Response = http::Response< - >::ResponseBody, - >, - >, - >>::Error: - Into + Send + Sync, - { - PluginClient::new(InterceptedService::new(inner, interceptor)) - } - /// Compress requests with the given encoding. - /// - /// This requires the server to support it otherwise it might respond with an - /// error. - #[must_use] - pub fn send_compressed(mut self, encoding: CompressionEncoding) -> Self { - self.inner = self.inner.send_compressed(encoding); - self - } - /// Enable decompressing responses. - #[must_use] - pub fn accept_compressed(mut self, encoding: CompressionEncoding) -> Self { - self.inner = self.inner.accept_compressed(encoding); - self - } - /// Limits the maximum size of a decoded message. - /// - /// Default: `4MB` - #[must_use] - pub fn max_decoding_message_size(mut self, limit: usize) -> Self { - self.inner = self.inner.max_decoding_message_size(limit); - self - } - /// Limits the maximum size of an encoded message. - /// - /// Default: `usize::MAX` - #[must_use] - pub fn max_encoding_message_size(mut self, limit: usize) -> Self { - self.inner = self.inner.max_encoding_message_size(limit); - self - } - /// * - /// Get schemas for all supported queries by the plugin. - /// - /// This is used by Hipcheck to validate that: - /// - /// - The plugin supports a default query taking a `target` type if used - /// as a top-level plugin in the user's policy file. - /// - That requests sent to the plugin and data returned by the plugin - /// match the schema during execution. - pub async fn get_query_schemas( - &mut self, - request: impl tonic::IntoRequest, - ) -> std::result::Result< - tonic::Response>, - tonic::Status, - > { - self.inner.ready().await.map_err(|e| { - tonic::Status::new( - tonic::Code::Unknown, - format!("Service was not ready: {}", e.into()), - ) - })?; - let codec = tonic::codec::ProstCodec::default(); - let path = http::uri::PathAndQuery::from_static("/hipcheck.Plugin/GetQuerySchemas"); - let mut req = request.into_request(); - req.extensions_mut() - .insert(GrpcMethod::new("hipcheck.Plugin", "GetQuerySchemas")); - self.inner.server_streaming(req, path, codec).await - } - /// * - /// Hipcheck sends all child nodes for the plugin from the user's policy - /// file to configure the plugin. - pub async fn set_configuration( - &mut self, - request: impl tonic::IntoRequest, - ) -> std::result::Result, tonic::Status> { - self.inner.ready().await.map_err(|e| { - tonic::Status::new( - tonic::Code::Unknown, - format!("Service was not ready: {}", e.into()), - ) - })?; - let codec = tonic::codec::ProstCodec::default(); - let path = http::uri::PathAndQuery::from_static("/hipcheck.Plugin/SetConfiguration"); - let mut req = request.into_request(); - req.extensions_mut() - .insert(GrpcMethod::new("hipcheck.Plugin", "SetConfiguration")); - self.inner.unary(req, path, codec).await - } - /// * - /// Get the default policy for a plugin, which may additionally depend on - /// the plugin's configuration. - pub async fn get_default_policy_expression( - &mut self, - request: impl tonic::IntoRequest, - ) -> std::result::Result, tonic::Status> { - self.inner.ready().await.map_err(|e| { - tonic::Status::new( - tonic::Code::Unknown, - format!("Service was not ready: {}", e.into()), - ) - })?; - let codec = tonic::codec::ProstCodec::default(); - let path = - http::uri::PathAndQuery::from_static("/hipcheck.Plugin/GetDefaultPolicyExpression"); - let mut req = request.into_request(); - req.extensions_mut().insert(GrpcMethod::new( - "hipcheck.Plugin", - "GetDefaultPolicyExpression", - )); - self.inner.unary(req, path, codec).await - } - /// * - /// Open a bidirectional streaming RPC to enable a request/response - /// protocol between Hipcheck and a plugin, where Hipcheck can issue - /// queries to the plugin, and the plugin may issue queries to _other_ - /// plugins through Hipcheck. - /// - /// Queries are cached by the publisher name, plugin name, query name, - /// and key, and if a match is found for those four values, then - /// Hipcheck will respond with the cached result of that prior matching - /// query rather than running the query again. - pub async fn initiate_query_protocol( - &mut self, - request: impl tonic::IntoStreamingRequest, - ) -> std::result::Result< - tonic::Response>, - tonic::Status, - > { - self.inner.ready().await.map_err(|e| { - tonic::Status::new( - tonic::Code::Unknown, - format!("Service was not ready: {}", e.into()), - ) - })?; - let codec = tonic::codec::ProstCodec::default(); - let path = - http::uri::PathAndQuery::from_static("/hipcheck.Plugin/InitiateQueryProtocol"); - let mut req = request.into_streaming_request(); - req.extensions_mut() - .insert(GrpcMethod::new("hipcheck.Plugin", "InitiateQueryProtocol")); - self.inner.streaming(req, path, codec).await - } - } -} -/// Generated server implementations. -pub mod plugin_server { - #![allow(unused_variables, dead_code, missing_docs, clippy::let_unit_value)] - use tonic::codegen::*; - /// Generated trait containing gRPC methods that should be implemented for use with PluginServer. - #[async_trait] - pub trait Plugin: Send + Sync + 'static { - /// Server streaming response type for the GetQuerySchemas method. - type GetQuerySchemasStream: tonic::codegen::tokio_stream::Stream< - Item = std::result::Result, - > + Send - + 'static; - /// * - /// Get schemas for all supported queries by the plugin. - /// - /// This is used by Hipcheck to validate that: - /// - /// - The plugin supports a default query taking a `target` type if used - /// as a top-level plugin in the user's policy file. - /// - That requests sent to the plugin and data returned by the plugin - /// match the schema during execution. - async fn get_query_schemas( - &self, - request: tonic::Request, - ) -> std::result::Result, tonic::Status>; - /// * - /// Hipcheck sends all child nodes for the plugin from the user's policy - /// file to configure the plugin. - async fn set_configuration( - &self, - request: tonic::Request, - ) -> std::result::Result, tonic::Status>; - /// * - /// Get the default policy for a plugin, which may additionally depend on - /// the plugin's configuration. - async fn get_default_policy_expression( - &self, - request: tonic::Request, - ) -> std::result::Result, tonic::Status>; - /// Server streaming response type for the InitiateQueryProtocol method. - type InitiateQueryProtocolStream: tonic::codegen::tokio_stream::Stream< - Item = std::result::Result, - > + Send - + 'static; - /// * - /// Open a bidirectional streaming RPC to enable a request/response - /// protocol between Hipcheck and a plugin, where Hipcheck can issue - /// queries to the plugin, and the plugin may issue queries to _other_ - /// plugins through Hipcheck. - /// - /// Queries are cached by the publisher name, plugin name, query name, - /// and key, and if a match is found for those four values, then - /// Hipcheck will respond with the cached result of that prior matching - /// query rather than running the query again. - async fn initiate_query_protocol( - &self, - request: tonic::Request>, - ) -> std::result::Result, tonic::Status>; - } - #[derive(Debug)] - pub struct PluginServer { - inner: Arc, - accept_compression_encodings: EnabledCompressionEncodings, - send_compression_encodings: EnabledCompressionEncodings, - max_decoding_message_size: Option, - max_encoding_message_size: Option, - } - impl PluginServer { - pub fn new(inner: T) -> Self { - Self::from_arc(Arc::new(inner)) - } - pub fn from_arc(inner: Arc) -> Self { - Self { - inner, - accept_compression_encodings: Default::default(), - send_compression_encodings: Default::default(), - max_decoding_message_size: None, - max_encoding_message_size: None, - } - } - pub fn with_interceptor(inner: T, interceptor: F) -> InterceptedService - where - F: tonic::service::Interceptor, - { - InterceptedService::new(Self::new(inner), interceptor) - } - /// Enable decompressing requests with the given encoding. - #[must_use] - pub fn accept_compressed(mut self, encoding: CompressionEncoding) -> Self { - self.accept_compression_encodings.enable(encoding); - self - } - /// Compress responses with the given encoding, if the client supports it. - #[must_use] - pub fn send_compressed(mut self, encoding: CompressionEncoding) -> Self { - self.send_compression_encodings.enable(encoding); - self - } - /// Limits the maximum size of a decoded message. - /// - /// Default: `4MB` - #[must_use] - pub fn max_decoding_message_size(mut self, limit: usize) -> Self { - self.max_decoding_message_size = Some(limit); - self - } - /// Limits the maximum size of an encoded message. - /// - /// Default: `usize::MAX` - #[must_use] - pub fn max_encoding_message_size(mut self, limit: usize) -> Self { - self.max_encoding_message_size = Some(limit); - self - } - } - impl tonic::codegen::Service> for PluginServer - where - T: Plugin, - B: Body + Send + 'static, - B::Error: Into + Send + 'static, - { - type Response = http::Response; - type Error = std::convert::Infallible; - type Future = BoxFuture; - fn poll_ready( - &mut self, - _cx: &mut Context<'_>, - ) -> Poll> { - Poll::Ready(Ok(())) - } - fn call(&mut self, req: http::Request) -> Self::Future { - match req.uri().path() { - "/hipcheck.Plugin/GetQuerySchemas" => { - #[allow(non_camel_case_types)] - struct GetQuerySchemasSvc(pub Arc); - impl tonic::server::ServerStreamingService for GetQuerySchemasSvc { - type Response = super::Schema; - type ResponseStream = T::GetQuerySchemasStream; - type Future = - BoxFuture, tonic::Status>; - fn call(&mut self, request: tonic::Request) -> Self::Future { - let inner = Arc::clone(&self.0); - let fut = async move { - ::get_query_schemas(&inner, request).await - }; - Box::pin(fut) - } - } - let accept_compression_encodings = self.accept_compression_encodings; - let send_compression_encodings = self.send_compression_encodings; - let max_decoding_message_size = self.max_decoding_message_size; - let max_encoding_message_size = self.max_encoding_message_size; - let inner = self.inner.clone(); - let fut = async move { - let method = GetQuerySchemasSvc(inner); - let codec = tonic::codec::ProstCodec::default(); - let mut grpc = tonic::server::Grpc::new(codec) - .apply_compression_config( - accept_compression_encodings, - send_compression_encodings, - ) - .apply_max_message_size_config( - max_decoding_message_size, - max_encoding_message_size, - ); - let res = grpc.server_streaming(method, req).await; - Ok(res) - }; - Box::pin(fut) - } - "/hipcheck.Plugin/SetConfiguration" => { - #[allow(non_camel_case_types)] - struct SetConfigurationSvc(pub Arc); - impl tonic::server::UnaryService for SetConfigurationSvc { - type Response = super::ConfigurationResult; - type Future = BoxFuture, tonic::Status>; - fn call( - &mut self, - request: tonic::Request, - ) -> Self::Future { - let inner = Arc::clone(&self.0); - let fut = async move { - ::set_configuration(&inner, request).await - }; - Box::pin(fut) - } - } - let accept_compression_encodings = self.accept_compression_encodings; - let send_compression_encodings = self.send_compression_encodings; - let max_decoding_message_size = self.max_decoding_message_size; - let max_encoding_message_size = self.max_encoding_message_size; - let inner = self.inner.clone(); - let fut = async move { - let method = SetConfigurationSvc(inner); - let codec = tonic::codec::ProstCodec::default(); - let mut grpc = tonic::server::Grpc::new(codec) - .apply_compression_config( - accept_compression_encodings, - send_compression_encodings, - ) - .apply_max_message_size_config( - max_decoding_message_size, - max_encoding_message_size, - ); - let res = grpc.unary(method, req).await; - Ok(res) - }; - Box::pin(fut) - } - "/hipcheck.Plugin/GetDefaultPolicyExpression" => { - #[allow(non_camel_case_types)] - struct GetDefaultPolicyExpressionSvc(pub Arc); - impl tonic::server::UnaryService for GetDefaultPolicyExpressionSvc { - type Response = super::PolicyExpression; - type Future = BoxFuture, tonic::Status>; - fn call(&mut self, request: tonic::Request) -> Self::Future { - let inner = Arc::clone(&self.0); - let fut = async move { - ::get_default_policy_expression(&inner, request).await - }; - Box::pin(fut) - } - } - let accept_compression_encodings = self.accept_compression_encodings; - let send_compression_encodings = self.send_compression_encodings; - let max_decoding_message_size = self.max_decoding_message_size; - let max_encoding_message_size = self.max_encoding_message_size; - let inner = self.inner.clone(); - let fut = async move { - let method = GetDefaultPolicyExpressionSvc(inner); - let codec = tonic::codec::ProstCodec::default(); - let mut grpc = tonic::server::Grpc::new(codec) - .apply_compression_config( - accept_compression_encodings, - send_compression_encodings, - ) - .apply_max_message_size_config( - max_decoding_message_size, - max_encoding_message_size, - ); - let res = grpc.unary(method, req).await; - Ok(res) - }; - Box::pin(fut) - } - "/hipcheck.Plugin/InitiateQueryProtocol" => { - #[allow(non_camel_case_types)] - struct InitiateQueryProtocolSvc(pub Arc); - impl tonic::server::StreamingService for InitiateQueryProtocolSvc { - type Response = super::Query; - type ResponseStream = T::InitiateQueryProtocolStream; - type Future = - BoxFuture, tonic::Status>; - fn call( - &mut self, - request: tonic::Request>, - ) -> Self::Future { - let inner = Arc::clone(&self.0); - let fut = async move { - ::initiate_query_protocol(&inner, request).await - }; - Box::pin(fut) - } - } - let accept_compression_encodings = self.accept_compression_encodings; - let send_compression_encodings = self.send_compression_encodings; - let max_decoding_message_size = self.max_decoding_message_size; - let max_encoding_message_size = self.max_encoding_message_size; - let inner = self.inner.clone(); - let fut = async move { - let method = InitiateQueryProtocolSvc(inner); - let codec = tonic::codec::ProstCodec::default(); - let mut grpc = tonic::server::Grpc::new(codec) - .apply_compression_config( - accept_compression_encodings, - send_compression_encodings, - ) - .apply_max_message_size_config( - max_decoding_message_size, - max_encoding_message_size, - ); - let res = grpc.streaming(method, req).await; - Ok(res) - }; - Box::pin(fut) - } - _ => Box::pin(async move { - Ok(http::Response::builder() - .status(200) - .header("grpc-status", tonic::Code::Unimplemented as i32) - .header( - http::header::CONTENT_TYPE, - tonic::metadata::GRPC_CONTENT_TYPE, - ) - .body(empty_body()) - .unwrap()) - }), - } - } - } - impl Clone for PluginServer { - fn clone(&self) -> Self { - let inner = self.inner.clone(); - Self { - inner, - accept_compression_encodings: self.accept_compression_encodings, - send_compression_encodings: self.send_compression_encodings, - max_decoding_message_size: self.max_decoding_message_size, - max_encoding_message_size: self.max_encoding_message_size, - } - } - } - impl tonic::server::NamedService for PluginServer { - const NAME: &'static str = "hipcheck.Plugin"; - } -} diff --git a/plugins/dummy_rand_data/src/main.rs b/plugins/dummy_rand_data/src/main.rs index 494b61fc..35d0fece 100644 --- a/plugins/dummy_rand_data/src/main.rs +++ b/plugins/dummy_rand_data/src/main.rs @@ -1,24 +1,27 @@ -#![allow(unused_variables)] - -mod hipcheck; -mod hipcheck_transport; +mod transport; +mod proto { + include!(concat!(env!("OUT_DIR"), "/hipcheck.v1.rs")); +} -use crate::hipcheck_transport::*; +use crate::{ + proto::{ + plugin_service_server::{PluginService, PluginServiceServer}, + ConfigurationStatus, GetDefaultPolicyExpressionRequest, GetDefaultPolicyExpressionResponse, + GetQuerySchemasRequest, GetQuerySchemasResponse, InitiateQueryProtocolRequest, + InitiateQueryProtocolResponse, SetConfigurationRequest, SetConfigurationResponse, + }, + transport::*, +}; use anyhow::{anyhow, Result}; use clap::Parser; -use hipcheck::plugin_server::{Plugin, PluginServer}; -use hipcheck::{ - Configuration, ConfigurationResult, ConfigurationStatus, Empty, PolicyExpression, - Query as PluginQuery, Schema, -}; use serde_json::{json, Value}; use std::pin::Pin; use tokio::sync::mpsc; use tokio_stream::{wrappers::ReceiverStream, Stream}; use tonic::{transport::Server, Request, Response, Status, Streaming}; -static GET_RAND_KEY_SCHEMA: &str = include_str!("query_schema_get_rand.json"); -static GET_RAND_OUTPUT_SCHEMA: &str = include_str!("query_schema_get_rand.json"); +static GET_RAND_KEY_SCHEMA: &str = include_str!("../schema/query_schema_get_rand.json"); +static GET_RAND_OUTPUT_SCHEMA: &str = include_str!("../schema/query_schema_get_rand.json"); fn reduce(input: u64) -> u64 { input % 7 @@ -30,7 +33,7 @@ pub async fn handle_rand_data(mut session: QuerySession, key: u64) -> Result<()> eprintln!("RAND-{id}: key: {key}, reduced: {sha_input}"); let sha_req = Query { - request: true, + direction: QueryDirection::Request, publisher: "MITRE".to_owned(), plugin: "sha256".to_owned(), query: "sha256".to_owned(), @@ -43,23 +46,21 @@ pub async fn handle_rand_data(mut session: QuerySession, key: u64) -> Result<()> return Err(anyhow!("channel closed prematurely by remote")); }; - if res.request { + if res.direction == QueryDirection::Request { return Err(anyhow!("expected response from remote")); } let mut sha_vec: Vec = serde_json::from_value(res.output)?; eprintln!("RAND-{id}: hash: {sha_vec:02x?}"); - let key_vec = key.to_le_bytes().to_vec(); - - for (i, val) in key_vec.into_iter().enumerate() { - *sha_vec.get_mut(i).unwrap() += val; + for (sha_val, key_val) in Iterator::zip(sha_vec.iter_mut(), key.to_le_bytes()) { + *sha_val += key_val; } eprintln!("RAND-{id}: output: {sha_vec:02x?}"); let output = serde_json::to_value(sha_vec)?; let resp = Query { - request: false, + direction: QueryDirection::Response, publisher: "".to_owned(), plugin: "".to_owned(), query: "".to_owned(), @@ -78,28 +79,28 @@ async fn handle_session(mut session: QuerySession) -> Result<()> { return Ok(()); }; - if !query.request { + if query.direction == QueryDirection::Response { return Err(anyhow!("Expected request from remote")); } let name = query.query; let key = query.key; - if name == "rand_data" { - let Value::Number(num_size) = &key else { - return Err(anyhow!("get_rand argument must be a number")); - }; + if name != "rand_data" { + return Err(anyhow!("unrecognized query '{}'", name)); + } - let Some(size) = num_size.as_u64() else { - return Err(anyhow!("get_rand argument must be an unsigned integer")); - }; + let Value::Number(num_size) = &key else { + return Err(anyhow!("get_rand argument must be a number")); + }; - handle_rand_data(session, size).await?; + let Some(size) = num_size.as_u64() else { + return Err(anyhow!("get_rand argument must be an unsigned integer")); + }; - Ok(()) - } else { - Err(anyhow!("unrecognized query '{}'", name)) - } + handle_rand_data(session, size).await?; + + Ok(()) } struct RandDataRunner { @@ -133,29 +134,32 @@ impl RandDataRunner { #[derive(Debug)] struct RandDataPlugin { - pub schema: Schema, + pub schema: GetQuerySchemasResponse, } impl RandDataPlugin { pub fn new() -> Self { - let schema = Schema { + let schema = GetQuerySchemasResponse { query_name: "rand_data".to_owned(), key_schema: GET_RAND_KEY_SCHEMA.to_owned(), output_schema: GET_RAND_OUTPUT_SCHEMA.to_owned(), }; + RandDataPlugin { schema } } } #[tonic::async_trait] -impl Plugin for RandDataPlugin { +impl PluginService for RandDataPlugin { type GetQuerySchemasStream = - Pin> + Send + 'static>>; - type InitiateQueryProtocolStream = ReceiverStream>; + Pin> + Send + 'static>>; + + type InitiateQueryProtocolStream = + ReceiverStream>; async fn get_query_schemas( &self, - _request: Request, + _request: Request, ) -> Result, Status> { Ok(Response::new(Box::pin(tokio_stream::iter(vec![Ok(self .schema @@ -164,29 +168,29 @@ impl Plugin for RandDataPlugin { async fn set_configuration( &self, - request: Request, - ) -> Result, Status> { - Ok(Response::new(ConfigurationResult { - status: ConfigurationStatus::ErrorNone as i32, + _request: Request, + ) -> Result, Status> { + Ok(Response::new(SetConfigurationResponse { + status: ConfigurationStatus::None as i32, message: "".to_owned(), })) } async fn get_default_policy_expression( &self, - request: Request, - ) -> Result, Status> { - Ok(Response::new(PolicyExpression { + _request: Request, + ) -> Result, Status> { + Ok(Response::new(GetDefaultPolicyExpressionResponse { policy_expression: "".to_owned(), })) } async fn initiate_query_protocol( &self, - request: Request>, + request: Request>, ) -> Result, Status> { let rx = request.into_inner(); - let (tx, out_rx) = mpsc::channel::>(4); + let (tx, out_rx) = mpsc::channel::>(4); tokio::spawn(async move { let channel = HcSessionSocket::new(tx, rx); @@ -210,7 +214,7 @@ async fn main() -> Result<(), Box> { let args = Args::try_parse().map_err(Box::new)?; let addr = format!("127.0.0.1:{}", args.port); let plugin = RandDataPlugin::new(); - let svc = PluginServer::new(plugin); + let svc = PluginServiceServer::new(plugin); Server::builder() .add_service(svc) diff --git a/plugins/dummy_rand_data/src/hipcheck_transport.rs b/plugins/dummy_rand_data/src/transport.rs similarity index 51% rename from plugins/dummy_rand_data/src/hipcheck_transport.rs rename to plugins/dummy_rand_data/src/transport.rs index e206860b..7a26145f 100644 --- a/plugins/dummy_rand_data/src/hipcheck_transport.rs +++ b/plugins/dummy_rand_data/src/transport.rs @@ -1,14 +1,21 @@ -use crate::hipcheck::{Query as PluginQuery, QueryState}; +use crate::proto::{ + InitiateQueryProtocolRequest, InitiateQueryProtocolResponse, Query as PluginQuery, QueryState, +}; use anyhow::{anyhow, Result}; +use futures::Stream; use serde_json::Value; -use std::collections::{HashMap, VecDeque}; -use tokio::sync::mpsc; -use tonic::{codec::Streaming, Status}; +use std::{ + collections::{HashMap, VecDeque}, + future::poll_fn, + ops::Not as _, + pin::Pin, +}; +use tokio::sync::mpsc::{self, error::TrySendError}; +use tonic::{Status, Streaming}; #[derive(Debug)] pub struct Query { - // if false, response - pub request: bool, + pub direction: QueryDirection, pub publisher: String, pub plugin: String, pub query: String, @@ -16,33 +23,47 @@ pub struct Query { pub output: Value, } -impl TryFrom for Query { - type Error = anyhow::Error; +#[derive(Debug, PartialEq, Eq)] +pub enum QueryDirection { + Request, + Response, +} - fn try_from(value: PluginQuery) -> Result { - use QueryState::*; +impl TryFrom for QueryDirection { + type Error = anyhow::Error; - let request = match TryInto::::try_into(value.state)? { - QueryUnspecified => return Err(anyhow!("unspecified error from plugin")), - QueryReplyInProgress => { - return Err(anyhow!( - "invalid state QueryReplyInProgress for conversion to Query" - )) + fn try_from(value: QueryState) -> std::result::Result { + match value { + QueryState::Unspecified => { + Err(anyhow!("unspecified error; query is in an invalid state")) } - QueryReplyComplete => false, - QuerySubmit => true, - }; + QueryState::Submit => Ok(QueryDirection::Request), + QueryState::ReplyInProgress => Err(anyhow!("invalid state QueryReplyInProgress")), + QueryState::ReplyComplete => Ok(QueryDirection::Response), + } + } +} + +impl From for QueryState { + fn from(value: QueryDirection) -> Self { + match value { + QueryDirection::Request => QueryState::Submit, + QueryDirection::Response => QueryState::ReplyComplete, + } + } +} - let key: Value = serde_json::from_str(value.key.as_str())?; - let output: Value = serde_json::from_str(value.output.as_str())?; +impl TryFrom for Query { + type Error = anyhow::Error; + fn try_from(value: PluginQuery) -> Result { Ok(Query { - request, + direction: QueryDirection::try_from(value.state())?, publisher: value.publisher_name, plugin: value.plugin_name, query: value.query_name, - key, - output, + key: serde_json::from_str(value.key.as_str())?, + output: serde_json::from_str(value.output.as_str())?, }) } } @@ -51,7 +72,7 @@ type SessionTracker = HashMap>>; pub struct QuerySession { id: usize, - tx: mpsc::Sender>, + tx: mpsc::Sender>, rx: mpsc::Receiver>, // So that we can remove ourselves when we get dropped drop_tx: mpsc::Sender, @@ -65,17 +86,13 @@ impl QuerySession { // Roughly equivalent to TryFrom, but the `id` field value // comes from the QuerySession fn convert(&self, value: Query) -> Result { - let state_enum = match value.request { - true => QueryState::QuerySubmit, - false => QueryState::QueryReplyComplete, - }; - + let state: QueryState = value.direction.into(); let key = serde_json::to_string(&value.key)?; let output = serde_json::to_string(&value.output)?; Ok(PluginQuery { id: self.id() as i32, - state: state_enum as i32, + state: state as i32, publisher_name: value.publisher, plugin_name: value.plugin, query_name: value.query, @@ -127,8 +144,13 @@ impl QuerySession { pub async fn send(&self, query: Query) -> Result<()> { eprintln!("RAND-session: sending query"); - let query: PluginQuery = self.convert(query)?; + + let query = InitiateQueryProtocolResponse { + query: Some(self.convert(query)?), + }; + self.tx.send(Ok(query)).await?; + Ok(()) } @@ -139,14 +161,15 @@ impl QuerySession { let Some(mut msg_chunks) = self.recv_raw().await? else { return Ok(None); }; + let mut raw = msg_chunks.pop_front().unwrap(); eprintln!("RAND-session: recv got raw {raw:?}"); let mut state: QueryState = raw.state.try_into()?; // If response is the first of a set of chunks, handle - if matches!(state, QueryReplyInProgress) { - while matches!(state, QueryReplyInProgress) { + if matches!(state, ReplyInProgress) { + while matches!(state, ReplyInProgress) { // We expect another message. Pull it off the existing queue, // or get a new one if we have run out let next = match msg_chunks.pop_front() { @@ -168,20 +191,20 @@ impl QuerySession { // By now we have our "next" message state = next.state.try_into()?; match state { - QueryUnspecified => return Err(anyhow!("unspecified error from plugin")), - QuerySubmit => { + Unspecified => return Err(anyhow!("unspecified error from plugin")), + Submit => { return Err(anyhow!( "plugin sent QuerySubmit state when reply chunk expected" )) } - QueryReplyInProgress | QueryReplyComplete => { + ReplyInProgress | ReplyComplete => { raw.output.push_str(next.output.as_str()); } }; } // Sanity check - after we've left this loop, there should be no left over message - if !msg_chunks.is_empty() { + if msg_chunks.is_empty().not() { return Err(anyhow!( "received additional messages for id '{}' after QueryComplete status message", self.id @@ -196,9 +219,6 @@ impl QuerySession { impl Drop for QuerySession { // Notify to have self removed from session tracker fn drop(&mut self) { - use mpsc::error::TrySendError; - let raw_id = self.id as i32; - while let Err(e) = self.drop_tx.try_send(self.id as i32) { match e { TrySendError::Closed(_) => { @@ -210,20 +230,37 @@ impl Drop for QuerySession { } } -#[derive(Debug)] pub struct HcSessionSocket { - tx: mpsc::Sender>, - rx: Streaming, + tx: mpsc::Sender>, + rx: Streaming, drop_tx: mpsc::Sender, drop_rx: mpsc::Receiver, sessions: SessionTracker, } +// This is implemented manually since the stream trait object +// can't impl `Debug`. +impl std::fmt::Debug for HcSessionSocket { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("HcSessionSocket") + .field("tx", &self.tx) + .field("rx", &"") + .field("drop_tx", &self.drop_tx) + .field("drop_rx", &self.drop_rx) + .field("sessions", &self.sessions) + .finish() + } +} + impl HcSessionSocket { - pub fn new(tx: mpsc::Sender>, rx: Streaming) -> Self { + pub fn new( + tx: mpsc::Sender>, + rx: Streaming, + ) -> Self { // channel for QuerySession objects to notify us they dropped // @Todo - make this configurable let (drop_tx, drop_rx) = mpsc::channel(10); + Self { tx, rx, @@ -233,69 +270,96 @@ impl HcSessionSocket { } } + /// Clean up completed sessions by going through all drop messages. fn cleanup_sessions(&mut self) { - // Pull off all existing drop notifications while let Ok(id) = self.drop_rx.try_recv() { - if self.sessions.remove(&id).is_none() { - eprintln!( + match self.sessions.remove(&id) { + Some(_) => eprintln!("Cleaned up session {id}"), + None => eprintln!( "WARNING: HcSessionSocket got request to drop a session that does not exist" - ); - } else { - eprintln!("Cleaned up session {id}"); + ), } } } + async fn message(&mut self) -> Result, Status> { + let fut = poll_fn(|cx| Pin::new(&mut self.rx).poll_next(cx)); + + match fut.await { + Some(Ok(m)) => Ok(m.query), + Some(Err(e)) => Err(e), + None => Ok(None), + } + } + pub async fn listen(&mut self) -> Result> { loop { eprintln!("RAND: listening"); - let Some(raw) = self.rx.message().await? else { + let Some(raw) = self.message().await? else { return Ok(None); }; + let id = raw.id; // While we were waiting for a message, some session objects may have // dropped, handle them before we look at the ID of this message. - // The downside of this strategy is that once we receive our last message, - // we won't clean up any sessions that close after + // The downside of this strategy is that once we receive our last message, + // we won't clean up any sessions that close after self.cleanup_sessions(); - let id = raw.id; + match self.decide_action(&raw) { + Ok(HandleAction::ForwardMsgToExistingSession(tx)) => { + eprintln!("RAND-listen: forwarding message to session {id}"); - // If there is already a session with this ID, forward msg - if let Some(tx) = self.sessions.get_mut(&id) { - eprintln!("RAND-listen: forwarding message to session {id}"); + if let Err(_e) = tx.send(Some(raw)).await { + eprintln!("Error forwarding msg to session {id}"); + self.sessions.remove(&id); + }; + } + Ok(HandleAction::CreateSession) => { + eprintln!("RAND-listen: creating new session {id}"); - if let Err(e) = tx.send(Some(raw)).await { - eprintln!("Error forwarding msg to session {id}"); - self.sessions.remove(&id); - }; - // If got a new query ID, create session - } else if raw.state() == QueryState::QuerySubmit { - eprintln!("RAND-listen: creating new session {id}"); - - let (in_tx, rx) = mpsc::channel::>(10); - let tx = self.tx.clone(); - - let session = QuerySession { - id: id as usize, - tx, - rx, - drop_tx: self.drop_tx.clone(), - }; + let (in_tx, rx) = mpsc::channel::>(10); + let tx = self.tx.clone(); + + let session = QuerySession { + id: id as usize, + tx, + rx, + drop_tx: self.drop_tx.clone(), + }; - in_tx - .send(Some(raw)) - .await - .expect("Failed sending message to newly created Session, should never happen"); + in_tx.send(Some(raw)).await.expect( + "Failed sending message to newly created Session, should never happen", + ); - eprintln!("RAND-listen: adding new session {id} to tracker"); - self.sessions.insert(id, in_tx); + eprintln!("RAND-listen: adding new session {id} to tracker"); + self.sessions.insert(id, in_tx); - return Ok(Some(session)); - } else { - eprintln!("Got query with id {}, does not match existing session and is not new QuerySubmit", id); + return Ok(Some(session)); + } + Err(e) => eprintln!("error: {}", e), } } } + + fn decide_action(&mut self, query: &PluginQuery) -> Result> { + if let Some(tx) = self.sessions.get_mut(&query.id) { + return Ok(HandleAction::ForwardMsgToExistingSession(tx)); + } + + if query.state() == QueryState::Submit { + return Ok(HandleAction::CreateSession); + } + + Err(anyhow!( + "Got query with id {}, does not match existing session and is not new QuerySubmit", + query.id + )) + } +} + +enum HandleAction<'s> { + ForwardMsgToExistingSession(&'s mut mpsc::Sender>), + CreateSession, } diff --git a/plugins/dummy_sha256/Cargo.toml b/plugins/dummy_sha256/Cargo.toml index 409cb015..56aef2ec 100644 --- a/plugins/dummy_sha256/Cargo.toml +++ b/plugins/dummy_sha256/Cargo.toml @@ -8,6 +8,7 @@ publish = false anyhow = "1.0.86" clap = { version = "4.5.16", features = ["derive"] } indexmap = "2.4.0" +futures = "0.3.30" prost = "0.13.1" rand = "0.8.5" serde_json = "1.0.125" @@ -15,3 +16,7 @@ sha2 = "0.10.8" tokio = { version = "1.39.2", features = ["rt"] } tokio-stream = "0.1.15" tonic = "0.12.1" + +[build-dependencies] +anyhow = "1.0.86" +tonic-build = "0.12.1" diff --git a/plugins/dummy_sha256/build.rs b/plugins/dummy_sha256/build.rs new file mode 100644 index 00000000..759819ea --- /dev/null +++ b/plugins/dummy_sha256/build.rs @@ -0,0 +1,6 @@ +// SPDX-License-Identifier: Apache-2.0 + +fn main() -> anyhow::Result<()> { + tonic_build::compile_protos("../../proto/hipcheck/v1/hipcheck.proto")?; + Ok(()) +} diff --git a/plugins/dummy_sha256/src/query_schema_sha256.json b/plugins/dummy_sha256/schema/query_schema_sha256.json similarity index 100% rename from plugins/dummy_sha256/src/query_schema_sha256.json rename to plugins/dummy_sha256/schema/query_schema_sha256.json diff --git a/plugins/dummy_sha256/src/hipcheck.rs b/plugins/dummy_sha256/src/hipcheck.rs deleted file mode 100644 index 50ce6bf2..00000000 --- a/plugins/dummy_sha256/src/hipcheck.rs +++ /dev/null @@ -1,679 +0,0 @@ -#![allow(clippy::enum_variant_names)] - -// This file is @generated by prost-build. -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct Configuration { - /// JSON string containing configuration data expected by the plugin, - /// pulled from the user's policy file. - #[prost(string, tag = "1")] - pub configuration: ::prost::alloc::string::String, -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct ConfigurationResult { - /// The status of the configuration call. - #[prost(enumeration = "ConfigurationStatus", tag = "1")] - pub status: i32, - /// An optional error message, if there was an error. - #[prost(string, tag = "2")] - pub message: ::prost::alloc::string::String, -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct PolicyExpression { - /// A policy expression, if the plugin has a default policy. - /// This MUST be filled in with any default values pulled from the plugin's - /// configuration. Hipcheck will only request the default policy _after_ - /// configuring the plugin. - #[prost(string, tag = "1")] - pub policy_expression: ::prost::alloc::string::String, -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct Schema { - /// The name of the query being described by the schemas provided. - /// - /// If either the key and/or output schemas result in a message which is - /// too big, they may be chunked across multiple replies in the stream. - /// Replies with matching query names should have their fields concatenated - /// in the order received to reconstruct the chunks. - #[prost(string, tag = "1")] - pub query_name: ::prost::alloc::string::String, - /// The key schema, in JSON Schema format. - #[prost(string, tag = "2")] - pub key_schema: ::prost::alloc::string::String, - /// The output schema, in JSON Schema format. - #[prost(string, tag = "3")] - pub output_schema: ::prost::alloc::string::String, -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct Query { - /// The ID of the request, used to associate requests and replies. - /// Odd numbers = initiated by `hc`. - /// Even numbers = initiated by a plugin. - #[prost(int32, tag = "1")] - pub id: i32, - /// The state of the query, indicating if this is a request or a reply, - /// and if it's a reply whether it's the end of the reply. - #[prost(enumeration = "QueryState", tag = "2")] - pub state: i32, - /// Publisher name and plugin name, when sent from Hipcheck to a plugin - /// to initiate a fresh query, are used by the receiving plugin to validate - /// that the query was intended for them. - /// - /// When a plugin is making a query to another plugin through Hipcheck, it's - /// used to indicate the destination plugin, and to indicate the plugin that - /// is replying when Hipcheck sends back the reply. - #[prost(string, tag = "3")] - pub publisher_name: ::prost::alloc::string::String, - #[prost(string, tag = "4")] - pub plugin_name: ::prost::alloc::string::String, - /// The name of the query being made, so the responding plugin knows what - /// to do with the provided data. - #[prost(string, tag = "5")] - pub query_name: ::prost::alloc::string::String, - /// The key for the query, as a JSON object. This is the data that Hipcheck's - /// incremental computation system will use to cache the response. - #[prost(string, tag = "6")] - pub key: ::prost::alloc::string::String, - /// The response for the query, as a JSON object. This will be cached by - /// Hipcheck for future queries matching the publisher name, plugin name, - /// query name, and key. - #[prost(string, tag = "7")] - pub output: ::prost::alloc::string::String, -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, Copy, PartialEq, ::prost::Message)] -pub struct Empty {} -#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] -#[repr(i32)] -pub enum ConfigurationStatus { - /// An unknown error occured. - ErrorUnknown = 0, - /// No error; the operation was successful. - ErrorNone = 1, - /// The user failed to provide a required configuration item. - ErrorMissingRequiredConfiguration = 2, - /// The user provided a configuration item whose name was not recognized. - ErrorUnrecognizedConfiguration = 3, - /// The user provided a configuration item whose value is invalid. - ErrorInvalidConfigurationValue = 4, -} -impl ConfigurationStatus { - /// String value of the enum field names used in the ProtoBuf definition. - /// - /// The values are not transformed in any way and thus are considered stable - /// (if the ProtoBuf definition does not change) and safe for programmatic use. - pub fn as_str_name(&self) -> &'static str { - match self { - ConfigurationStatus::ErrorUnknown => "ERROR_UNKNOWN", - ConfigurationStatus::ErrorNone => "ERROR_NONE", - ConfigurationStatus::ErrorMissingRequiredConfiguration => { - "ERROR_MISSING_REQUIRED_CONFIGURATION" - } - ConfigurationStatus::ErrorUnrecognizedConfiguration => { - "ERROR_UNRECOGNIZED_CONFIGURATION" - } - ConfigurationStatus::ErrorInvalidConfigurationValue => { - "ERROR_INVALID_CONFIGURATION_VALUE" - } - } - } - /// Creates an enum from field names used in the ProtoBuf definition. - pub fn from_str_name(value: &str) -> ::core::option::Option { - match value { - "ERROR_UNKNOWN" => Some(Self::ErrorUnknown), - "ERROR_NONE" => Some(Self::ErrorNone), - "ERROR_MISSING_REQUIRED_CONFIGURATION" => Some(Self::ErrorMissingRequiredConfiguration), - "ERROR_UNRECOGNIZED_CONFIGURATION" => Some(Self::ErrorUnrecognizedConfiguration), - "ERROR_INVALID_CONFIGURATION_VALUE" => Some(Self::ErrorInvalidConfigurationValue), - _ => None, - } - } -} -#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] -#[repr(i32)] -pub enum QueryState { - /// Something has gone wrong. - QueryUnspecified = 0, - /// We are submitting a new query. - QuerySubmit = 1, - /// We are replying to a query and expect more chunks. - QueryReplyInProgress = 2, - /// We are closing a reply to a query. If a query response is in one chunk, - /// just send this. If a query is in more than one chunk, send this with - /// the last message in the reply. This tells the receiver that all chunks - /// have been received. - QueryReplyComplete = 3, -} -impl QueryState { - /// String value of the enum field names used in the ProtoBuf definition. - /// - /// The values are not transformed in any way and thus are considered stable - /// (if the ProtoBuf definition does not change) and safe for programmatic use. - pub fn as_str_name(&self) -> &'static str { - match self { - QueryState::QueryUnspecified => "QUERY_UNSPECIFIED", - QueryState::QuerySubmit => "QUERY_SUBMIT", - QueryState::QueryReplyInProgress => "QUERY_REPLY_IN_PROGRESS", - QueryState::QueryReplyComplete => "QUERY_REPLY_COMPLETE", - } - } - /// Creates an enum from field names used in the ProtoBuf definition. - pub fn from_str_name(value: &str) -> ::core::option::Option { - match value { - "QUERY_UNSPECIFIED" => Some(Self::QueryUnspecified), - "QUERY_SUBMIT" => Some(Self::QuerySubmit), - "QUERY_REPLY_IN_PROGRESS" => Some(Self::QueryReplyInProgress), - "QUERY_REPLY_COMPLETE" => Some(Self::QueryReplyComplete), - _ => None, - } - } -} -/// Generated client implementations. -pub mod plugin_client { - #![allow(unused_variables, dead_code, missing_docs, clippy::let_unit_value)] - use tonic::codegen::http::Uri; - use tonic::codegen::*; - #[derive(Debug, Clone)] - pub struct PluginClient { - inner: tonic::client::Grpc, - } - impl PluginClient { - /// Attempt to create a new client by connecting to a given endpoint. - pub async fn connect(dst: D) -> Result - where - D: TryInto, - D::Error: Into, - { - let conn = tonic::transport::Endpoint::new(dst)?.connect().await?; - Ok(Self::new(conn)) - } - } - impl PluginClient - where - T: tonic::client::GrpcService, - T::Error: Into, - T::ResponseBody: Body + Send + 'static, - ::Error: Into + Send, - { - pub fn new(inner: T) -> Self { - let inner = tonic::client::Grpc::new(inner); - Self { inner } - } - pub fn with_origin(inner: T, origin: Uri) -> Self { - let inner = tonic::client::Grpc::with_origin(inner, origin); - Self { inner } - } - pub fn with_interceptor( - inner: T, - interceptor: F, - ) -> PluginClient> - where - F: tonic::service::Interceptor, - T::ResponseBody: Default, - T: tonic::codegen::Service< - http::Request, - Response = http::Response< - >::ResponseBody, - >, - >, - >>::Error: - Into + Send + Sync, - { - PluginClient::new(InterceptedService::new(inner, interceptor)) - } - /// Compress requests with the given encoding. - /// - /// This requires the server to support it otherwise it might respond with an - /// error. - #[must_use] - pub fn send_compressed(mut self, encoding: CompressionEncoding) -> Self { - self.inner = self.inner.send_compressed(encoding); - self - } - /// Enable decompressing responses. - #[must_use] - pub fn accept_compressed(mut self, encoding: CompressionEncoding) -> Self { - self.inner = self.inner.accept_compressed(encoding); - self - } - /// Limits the maximum size of a decoded message. - /// - /// Default: `4MB` - #[must_use] - pub fn max_decoding_message_size(mut self, limit: usize) -> Self { - self.inner = self.inner.max_decoding_message_size(limit); - self - } - /// Limits the maximum size of an encoded message. - /// - /// Default: `usize::MAX` - #[must_use] - pub fn max_encoding_message_size(mut self, limit: usize) -> Self { - self.inner = self.inner.max_encoding_message_size(limit); - self - } - /// * - /// Get schemas for all supported queries by the plugin. - /// - /// This is used by Hipcheck to validate that: - /// - /// - The plugin supports a default query taking a `target` type if used - /// as a top-level plugin in the user's policy file. - /// - That requests sent to the plugin and data returned by the plugin - /// match the schema during execution. - pub async fn get_query_schemas( - &mut self, - request: impl tonic::IntoRequest, - ) -> std::result::Result< - tonic::Response>, - tonic::Status, - > { - self.inner.ready().await.map_err(|e| { - tonic::Status::new( - tonic::Code::Unknown, - format!("Service was not ready: {}", e.into()), - ) - })?; - let codec = tonic::codec::ProstCodec::default(); - let path = http::uri::PathAndQuery::from_static("/hipcheck.Plugin/GetQuerySchemas"); - let mut req = request.into_request(); - req.extensions_mut() - .insert(GrpcMethod::new("hipcheck.Plugin", "GetQuerySchemas")); - self.inner.server_streaming(req, path, codec).await - } - /// * - /// Hipcheck sends all child nodes for the plugin from the user's policy - /// file to configure the plugin. - pub async fn set_configuration( - &mut self, - request: impl tonic::IntoRequest, - ) -> std::result::Result, tonic::Status> { - self.inner.ready().await.map_err(|e| { - tonic::Status::new( - tonic::Code::Unknown, - format!("Service was not ready: {}", e.into()), - ) - })?; - let codec = tonic::codec::ProstCodec::default(); - let path = http::uri::PathAndQuery::from_static("/hipcheck.Plugin/SetConfiguration"); - let mut req = request.into_request(); - req.extensions_mut() - .insert(GrpcMethod::new("hipcheck.Plugin", "SetConfiguration")); - self.inner.unary(req, path, codec).await - } - /// * - /// Get the default policy for a plugin, which may additionally depend on - /// the plugin's configuration. - pub async fn get_default_policy_expression( - &mut self, - request: impl tonic::IntoRequest, - ) -> std::result::Result, tonic::Status> { - self.inner.ready().await.map_err(|e| { - tonic::Status::new( - tonic::Code::Unknown, - format!("Service was not ready: {}", e.into()), - ) - })?; - let codec = tonic::codec::ProstCodec::default(); - let path = - http::uri::PathAndQuery::from_static("/hipcheck.Plugin/GetDefaultPolicyExpression"); - let mut req = request.into_request(); - req.extensions_mut().insert(GrpcMethod::new( - "hipcheck.Plugin", - "GetDefaultPolicyExpression", - )); - self.inner.unary(req, path, codec).await - } - /// * - /// Open a bidirectional streaming RPC to enable a request/response - /// protocol between Hipcheck and a plugin, where Hipcheck can issue - /// queries to the plugin, and the plugin may issue queries to _other_ - /// plugins through Hipcheck. - /// - /// Queries are cached by the publisher name, plugin name, query name, - /// and key, and if a match is found for those four values, then - /// Hipcheck will respond with the cached result of that prior matching - /// query rather than running the query again. - pub async fn initiate_query_protocol( - &mut self, - request: impl tonic::IntoStreamingRequest, - ) -> std::result::Result< - tonic::Response>, - tonic::Status, - > { - self.inner.ready().await.map_err(|e| { - tonic::Status::new( - tonic::Code::Unknown, - format!("Service was not ready: {}", e.into()), - ) - })?; - let codec = tonic::codec::ProstCodec::default(); - let path = - http::uri::PathAndQuery::from_static("/hipcheck.Plugin/InitiateQueryProtocol"); - let mut req = request.into_streaming_request(); - req.extensions_mut() - .insert(GrpcMethod::new("hipcheck.Plugin", "InitiateQueryProtocol")); - self.inner.streaming(req, path, codec).await - } - } -} -/// Generated server implementations. -pub mod plugin_server { - #![allow(unused_variables, dead_code, missing_docs, clippy::let_unit_value)] - use tonic::codegen::*; - /// Generated trait containing gRPC methods that should be implemented for use with PluginServer. - #[async_trait] - pub trait Plugin: Send + Sync + 'static { - /// Server streaming response type for the GetQuerySchemas method. - type GetQuerySchemasStream: tonic::codegen::tokio_stream::Stream< - Item = std::result::Result, - > + Send - + 'static; - /// * - /// Get schemas for all supported queries by the plugin. - /// - /// This is used by Hipcheck to validate that: - /// - /// - The plugin supports a default query taking a `target` type if used - /// as a top-level plugin in the user's policy file. - /// - That requests sent to the plugin and data returned by the plugin - /// match the schema during execution. - async fn get_query_schemas( - &self, - request: tonic::Request, - ) -> std::result::Result, tonic::Status>; - /// * - /// Hipcheck sends all child nodes for the plugin from the user's policy - /// file to configure the plugin. - async fn set_configuration( - &self, - request: tonic::Request, - ) -> std::result::Result, tonic::Status>; - /// * - /// Get the default policy for a plugin, which may additionally depend on - /// the plugin's configuration. - async fn get_default_policy_expression( - &self, - request: tonic::Request, - ) -> std::result::Result, tonic::Status>; - /// Server streaming response type for the InitiateQueryProtocol method. - type InitiateQueryProtocolStream: tonic::codegen::tokio_stream::Stream< - Item = std::result::Result, - > + Send - + 'static; - /// * - /// Open a bidirectional streaming RPC to enable a request/response - /// protocol between Hipcheck and a plugin, where Hipcheck can issue - /// queries to the plugin, and the plugin may issue queries to _other_ - /// plugins through Hipcheck. - /// - /// Queries are cached by the publisher name, plugin name, query name, - /// and key, and if a match is found for those four values, then - /// Hipcheck will respond with the cached result of that prior matching - /// query rather than running the query again. - async fn initiate_query_protocol( - &self, - request: tonic::Request>, - ) -> std::result::Result, tonic::Status>; - } - #[derive(Debug)] - pub struct PluginServer { - inner: Arc, - accept_compression_encodings: EnabledCompressionEncodings, - send_compression_encodings: EnabledCompressionEncodings, - max_decoding_message_size: Option, - max_encoding_message_size: Option, - } - impl PluginServer { - pub fn new(inner: T) -> Self { - Self::from_arc(Arc::new(inner)) - } - pub fn from_arc(inner: Arc) -> Self { - Self { - inner, - accept_compression_encodings: Default::default(), - send_compression_encodings: Default::default(), - max_decoding_message_size: None, - max_encoding_message_size: None, - } - } - pub fn with_interceptor(inner: T, interceptor: F) -> InterceptedService - where - F: tonic::service::Interceptor, - { - InterceptedService::new(Self::new(inner), interceptor) - } - /// Enable decompressing requests with the given encoding. - #[must_use] - pub fn accept_compressed(mut self, encoding: CompressionEncoding) -> Self { - self.accept_compression_encodings.enable(encoding); - self - } - /// Compress responses with the given encoding, if the client supports it. - #[must_use] - pub fn send_compressed(mut self, encoding: CompressionEncoding) -> Self { - self.send_compression_encodings.enable(encoding); - self - } - /// Limits the maximum size of a decoded message. - /// - /// Default: `4MB` - #[must_use] - pub fn max_decoding_message_size(mut self, limit: usize) -> Self { - self.max_decoding_message_size = Some(limit); - self - } - /// Limits the maximum size of an encoded message. - /// - /// Default: `usize::MAX` - #[must_use] - pub fn max_encoding_message_size(mut self, limit: usize) -> Self { - self.max_encoding_message_size = Some(limit); - self - } - } - impl tonic::codegen::Service> for PluginServer - where - T: Plugin, - B: Body + Send + 'static, - B::Error: Into + Send + 'static, - { - type Response = http::Response; - type Error = std::convert::Infallible; - type Future = BoxFuture; - fn poll_ready( - &mut self, - _cx: &mut Context<'_>, - ) -> Poll> { - Poll::Ready(Ok(())) - } - fn call(&mut self, req: http::Request) -> Self::Future { - match req.uri().path() { - "/hipcheck.Plugin/GetQuerySchemas" => { - #[allow(non_camel_case_types)] - struct GetQuerySchemasSvc(pub Arc); - impl tonic::server::ServerStreamingService for GetQuerySchemasSvc { - type Response = super::Schema; - type ResponseStream = T::GetQuerySchemasStream; - type Future = - BoxFuture, tonic::Status>; - fn call(&mut self, request: tonic::Request) -> Self::Future { - let inner = Arc::clone(&self.0); - let fut = async move { - ::get_query_schemas(&inner, request).await - }; - Box::pin(fut) - } - } - let accept_compression_encodings = self.accept_compression_encodings; - let send_compression_encodings = self.send_compression_encodings; - let max_decoding_message_size = self.max_decoding_message_size; - let max_encoding_message_size = self.max_encoding_message_size; - let inner = self.inner.clone(); - let fut = async move { - let method = GetQuerySchemasSvc(inner); - let codec = tonic::codec::ProstCodec::default(); - let mut grpc = tonic::server::Grpc::new(codec) - .apply_compression_config( - accept_compression_encodings, - send_compression_encodings, - ) - .apply_max_message_size_config( - max_decoding_message_size, - max_encoding_message_size, - ); - let res = grpc.server_streaming(method, req).await; - Ok(res) - }; - Box::pin(fut) - } - "/hipcheck.Plugin/SetConfiguration" => { - #[allow(non_camel_case_types)] - struct SetConfigurationSvc(pub Arc); - impl tonic::server::UnaryService for SetConfigurationSvc { - type Response = super::ConfigurationResult; - type Future = BoxFuture, tonic::Status>; - fn call( - &mut self, - request: tonic::Request, - ) -> Self::Future { - let inner = Arc::clone(&self.0); - let fut = async move { - ::set_configuration(&inner, request).await - }; - Box::pin(fut) - } - } - let accept_compression_encodings = self.accept_compression_encodings; - let send_compression_encodings = self.send_compression_encodings; - let max_decoding_message_size = self.max_decoding_message_size; - let max_encoding_message_size = self.max_encoding_message_size; - let inner = self.inner.clone(); - let fut = async move { - let method = SetConfigurationSvc(inner); - let codec = tonic::codec::ProstCodec::default(); - let mut grpc = tonic::server::Grpc::new(codec) - .apply_compression_config( - accept_compression_encodings, - send_compression_encodings, - ) - .apply_max_message_size_config( - max_decoding_message_size, - max_encoding_message_size, - ); - let res = grpc.unary(method, req).await; - Ok(res) - }; - Box::pin(fut) - } - "/hipcheck.Plugin/GetDefaultPolicyExpression" => { - #[allow(non_camel_case_types)] - struct GetDefaultPolicyExpressionSvc(pub Arc); - impl tonic::server::UnaryService for GetDefaultPolicyExpressionSvc { - type Response = super::PolicyExpression; - type Future = BoxFuture, tonic::Status>; - fn call(&mut self, request: tonic::Request) -> Self::Future { - let inner = Arc::clone(&self.0); - let fut = async move { - ::get_default_policy_expression(&inner, request).await - }; - Box::pin(fut) - } - } - let accept_compression_encodings = self.accept_compression_encodings; - let send_compression_encodings = self.send_compression_encodings; - let max_decoding_message_size = self.max_decoding_message_size; - let max_encoding_message_size = self.max_encoding_message_size; - let inner = self.inner.clone(); - let fut = async move { - let method = GetDefaultPolicyExpressionSvc(inner); - let codec = tonic::codec::ProstCodec::default(); - let mut grpc = tonic::server::Grpc::new(codec) - .apply_compression_config( - accept_compression_encodings, - send_compression_encodings, - ) - .apply_max_message_size_config( - max_decoding_message_size, - max_encoding_message_size, - ); - let res = grpc.unary(method, req).await; - Ok(res) - }; - Box::pin(fut) - } - "/hipcheck.Plugin/InitiateQueryProtocol" => { - #[allow(non_camel_case_types)] - struct InitiateQueryProtocolSvc(pub Arc); - impl tonic::server::StreamingService for InitiateQueryProtocolSvc { - type Response = super::Query; - type ResponseStream = T::InitiateQueryProtocolStream; - type Future = - BoxFuture, tonic::Status>; - fn call( - &mut self, - request: tonic::Request>, - ) -> Self::Future { - let inner = Arc::clone(&self.0); - let fut = async move { - ::initiate_query_protocol(&inner, request).await - }; - Box::pin(fut) - } - } - let accept_compression_encodings = self.accept_compression_encodings; - let send_compression_encodings = self.send_compression_encodings; - let max_decoding_message_size = self.max_decoding_message_size; - let max_encoding_message_size = self.max_encoding_message_size; - let inner = self.inner.clone(); - let fut = async move { - let method = InitiateQueryProtocolSvc(inner); - let codec = tonic::codec::ProstCodec::default(); - let mut grpc = tonic::server::Grpc::new(codec) - .apply_compression_config( - accept_compression_encodings, - send_compression_encodings, - ) - .apply_max_message_size_config( - max_decoding_message_size, - max_encoding_message_size, - ); - let res = grpc.streaming(method, req).await; - Ok(res) - }; - Box::pin(fut) - } - _ => Box::pin(async move { - Ok(http::Response::builder() - .status(200) - .header("grpc-status", tonic::Code::Unimplemented as i32) - .header( - http::header::CONTENT_TYPE, - tonic::metadata::GRPC_CONTENT_TYPE, - ) - .body(empty_body()) - .unwrap()) - }), - } - } - } - impl Clone for PluginServer { - fn clone(&self) -> Self { - let inner = self.inner.clone(); - Self { - inner, - accept_compression_encodings: self.accept_compression_encodings, - send_compression_encodings: self.send_compression_encodings, - max_decoding_message_size: self.max_decoding_message_size, - max_encoding_message_size: self.max_encoding_message_size, - } - } - } - impl tonic::server::NamedService for PluginServer { - const NAME: &'static str = "hipcheck.Plugin"; - } -} diff --git a/plugins/dummy_sha256/src/hipcheck_transport.rs b/plugins/dummy_sha256/src/hipcheck_transport.rs deleted file mode 100644 index 5d50c0f4..00000000 --- a/plugins/dummy_sha256/src/hipcheck_transport.rs +++ /dev/null @@ -1,226 +0,0 @@ -use crate::hipcheck::{Query as PluginQuery, QueryState}; -use anyhow::{anyhow, Result}; -use indexmap::map::IndexMap; -use serde_json::Value; -use std::collections::VecDeque; -use std::sync::Arc; -use tokio::sync::{mpsc, Mutex}; -use tonic::{codec::Streaming, Status}; - -#[derive(Debug)] -pub struct Query { - pub id: usize, - // if false, response - pub request: bool, - pub publisher: String, - pub plugin: String, - pub query: String, - pub key: Value, - pub output: Value, -} -impl TryFrom for Query { - type Error = anyhow::Error; - fn try_from(value: PluginQuery) -> Result { - use QueryState::*; - let request = match TryInto::::try_into(value.state)? { - QueryUnspecified => return Err(anyhow!("unspecified error from plugin")), - QueryReplyInProgress => { - return Err(anyhow!( - "invalid state QueryReplyInProgress for conversion to Query" - )) - } - QueryReplyComplete => false, - QuerySubmit => true, - }; - let key: Value = serde_json::from_str(value.key.as_str())?; - let output: Value = serde_json::from_str(value.output.as_str())?; - Ok(Query { - id: value.id as usize, - request, - publisher: value.publisher_name, - plugin: value.plugin_name, - query: value.query_name, - key, - output, - }) - } -} -impl TryFrom for PluginQuery { - type Error = anyhow::Error; - fn try_from(value: Query) -> Result { - let state_enum = match value.request { - true => QueryState::QuerySubmit, - false => QueryState::QueryReplyComplete, - }; - let key = serde_json::to_string(&value.key)?; - let output = serde_json::to_string(&value.output)?; - Ok(PluginQuery { - id: value.id as i32, - state: state_enum as i32, - publisher_name: value.publisher, - plugin_name: value.plugin, - query_name: value.query, - key, - output, - }) - } -} - -#[derive(Clone, Debug)] -pub struct HcTransport { - tx: mpsc::Sender>, - rx: Arc>, -} -impl HcTransport { - pub fn new(rx: Streaming, tx: mpsc::Sender>) -> Self { - HcTransport { - rx: Arc::new(Mutex::new(MultiplexedQueryReceiver::new(rx))), - tx, - } - } - pub async fn send(&self, query: Query) -> Result<()> { - let query: PluginQuery = query.try_into()?; - self.tx.send(Ok(query)).await?; - Ok(()) - } - pub async fn recv_new(&self) -> Result> { - let mut rx_handle = self.rx.lock().await; - match rx_handle.recv_new().await? { - Some(msg) => msg.try_into().map(Some), - None => Ok(None), - } - } - pub async fn recv(&self, id: usize) -> Result> { - use QueryState::*; - let id = id as i32; - let mut rx_handle = self.rx.lock().await; - let Some(mut msg_chunks) = rx_handle.recv(id).await? else { - return Ok(None); - }; - drop(rx_handle); - let mut raw = msg_chunks.pop_front().unwrap(); - let mut state: QueryState = raw.state.try_into()?; - - // If response is the first of a set of chunks, handle - if matches!(state, QueryReplyInProgress) { - while matches!(state, QueryReplyInProgress) { - // We expect another message. Pull it off the existing queue, - // or get a new one if we have run out - let next = match msg_chunks.pop_front() { - Some(msg) => msg, - None => { - // We ran out of messages, get a new batch - let mut rx_handle = self.rx.lock().await; - match rx_handle.recv(id).await? { - Some(x) => { - drop(rx_handle); - msg_chunks = x; - } - None => { - return Ok(None); - } - }; - msg_chunks.pop_front().unwrap() - } - }; - // By now we have our "next" message - state = next.state.try_into()?; - match state { - QueryUnspecified => return Err(anyhow!("unspecified error from plugin")), - QuerySubmit => { - return Err(anyhow!( - "plugin sent QuerySubmit state when reply chunk expected" - )) - } - QueryReplyInProgress | QueryReplyComplete => { - raw.output.push_str(next.output.as_str()); - } - }; - } - // Sanity check - after we've left this loop, there should be no left over message - if !msg_chunks.is_empty() { - return Err(anyhow!( - "received additional messages for id '{}' after QueryComplete status message", - id - )); - } - } - raw.try_into().map(Some) - } -} - -#[derive(Debug)] -pub struct MultiplexedQueryReceiver { - rx: Streaming, - // Unlike in HipCheck, backlog is an IndexMap to ensure the earliest received - // requests are handled first - backlog: IndexMap>, -} -impl MultiplexedQueryReceiver { - pub fn new(rx: Streaming) -> Self { - Self { - rx, - backlog: IndexMap::new(), - } - } - pub async fn recv_new(&mut self) -> Result> { - let opt_unhandled = self.backlog.iter().find(|(k, v)| { - if let Some(req) = v.front() { - return req.state() == QueryState::QuerySubmit; - } - false - }); - if let Some((k, v)) = opt_unhandled { - let id: i32 = *k; - let mut vec = self.backlog.shift_remove(&id).unwrap(); - // @Note - for now QuerySubmit doesn't chunk so we shouldn't expect - // multiple messages in the backlog for a new request - assert!(vec.len() == 1); - return Ok(vec.pop_front()); - } - // No backlog message, need to operate the receiver - loop { - let Some(raw) = self.rx.message().await? else { - // gRPC channel was closed - return Ok(None); - }; - if raw.state() == QueryState::QuerySubmit { - return Ok(Some(raw)); - } - match self.backlog.get_mut(&raw.id) { - Some(vec) => { - vec.push_back(raw); - } - None => { - self.backlog.insert(raw.id, VecDeque::from([raw])); - } - } - } - } - // @Invariant - this function will never return an empty VecDeque - pub async fn recv(&mut self, id: i32) -> Result>> { - // If we have 1+ messages on backlog for `id`, return them all, - // no need to waste time with successive calls - if let Some(msgs) = self.backlog.shift_remove(&id) { - return Ok(Some(msgs)); - } - // No backlog message, need to operate the receiver - loop { - let Some(raw) = self.rx.message().await? else { - // gRPC channel was closed - return Ok(None); - }; - if raw.id == id { - return Ok(Some(VecDeque::from([raw]))); - } - match self.backlog.get_mut(&raw.id) { - Some(vec) => { - vec.push_back(raw); - } - None => { - self.backlog.insert(raw.id, VecDeque::from([raw])); - } - } - } - } -} diff --git a/plugins/dummy_sha256/src/main.rs b/plugins/dummy_sha256/src/main.rs index 88d87d48..0ea65fda 100644 --- a/plugins/dummy_sha256/src/main.rs +++ b/plugins/dummy_sha256/src/main.rs @@ -1,15 +1,21 @@ -#![allow(unused_variables)] - -mod hipcheck; -mod hipcheck_transport; +mod transport; +mod proto { + include!(concat!(env!("OUT_DIR"), "/hipcheck.v1.rs")); +} -use crate::hipcheck_transport::*; +use crate::{ + proto::{ + plugin_service_server::{PluginService, PluginServiceServer}, + ConfigurationStatus, GetDefaultPolicyExpressionRequest, GetQuerySchemasRequest, + SetConfigurationRequest, SetConfigurationResponse, + }, + transport::*, +}; use anyhow::{anyhow, Result}; use clap::Parser; -use hipcheck::plugin_server::{Plugin, PluginServer}; -use hipcheck::{ - Configuration, ConfigurationResult, ConfigurationStatus, Empty, PolicyExpression, - Query as PluginQuery, Schema, +use proto::{ + GetDefaultPolicyExpressionResponse, GetQuerySchemasResponse, InitiateQueryProtocolRequest, + InitiateQueryProtocolResponse, }; use serde_json::{json, Value}; use sha2::{Digest, Sha256}; @@ -18,145 +24,165 @@ use tokio::sync::mpsc; use tokio_stream::{wrappers::ReceiverStream, Stream}; use tonic::{transport::Server, Request, Response, Status, Streaming}; -static SHA256_KEY_SCHEMA: &str = include_str!("query_schema_sha256.json"); -static SHA256_OUTPUT_SCHEMA: &str = include_str!("query_schema_sha256.json"); +static SHA256_KEY_SCHEMA: &str = include_str!("../schema/query_schema_sha256.json"); +static SHA256_OUTPUT_SCHEMA: &str = include_str!("../schema/query_schema_sha256.json"); -fn sha256(content: Vec) -> Vec { +fn sha256(content: &[u8]) -> Vec { let mut hasher = Sha256::new(); hasher.update(content); hasher.finalize().to_vec() } -pub async fn handle_sha256(channel: HcTransport, id: usize, key: Vec) -> Result<()> { - println!("SHA256-{id}: Key: {key:02x?}"); +async fn handle_sha256(session: QuerySession, key: &[u8]) -> Result<()> { + println!("Key: {key:02x?}"); let res = sha256(key); - println!("SHA256-{id}: Hash: {res:02x?}"); + + println!("Hash: {res:02x?}"); let output = serde_json::to_value(res)?; + let resp = Query { - id, - request: false, + direction: QueryDirection::Response, publisher: "".to_owned(), plugin: "".to_owned(), query: "".to_owned(), key: json!(null), output, }; - channel.send(resp).await?; + + session.send(resp).await?; + + Ok(()) +} + +async fn handle_session(mut session: QuerySession) -> Result<()> { + let Some(query) = session.recv().await? else { + eprintln!("session closed by remote"); + return Ok(()); + }; + + if query.direction == QueryDirection::Response { + return Err(anyhow!("Expected request from remote")); + } + + let name = query.query; + let key = query.key; + + if name != "sha256" { + return Err(anyhow!("unrecognized query '{}'", name)); + } + + let Value::Array(data) = &key else { + return Err(anyhow!("get_sha256 argument must be an array")); + }; + + let data = data + .iter() + .map(|elem| elem.as_u64().map(|num| num as u8)) + .collect::>>() + .ok_or_else(|| anyhow!("non-numeric data in get_sha256 array argument"))?; + + handle_sha256(session, &data[..]).await?; + Ok(()) } + struct Sha256Runner { - channel: HcTransport, + channel: HcSessionSocket, } + impl Sha256Runner { - pub fn new(channel: HcTransport) -> Self { + pub fn new(channel: HcSessionSocket) -> Self { Sha256Runner { channel } } - async fn handle_query(channel: HcTransport, id: usize, name: String, key: Value) -> Result<()> { - if name == "sha256" { - let Value::Array(val_vec) = &key else { - return Err(anyhow!("get_rand argument must be a number")); - }; - let byte_vec = val_vec - .iter() - .map(|x| { - let Value::Number(val_byte) = x else { - return Err(anyhow!("expected all integers")); - }; - let Some(byte) = val_byte.as_u64() else { - return Err(anyhow!( - "sha256 input array must contain only unsigned integers" - )); - }; - Ok(byte as u8) - }) - .collect::>>()?; - handle_sha256(channel, id, byte_vec).await?; - Ok(()) - } else { - Err(anyhow!("unrecognized query '{}'", name)) - } - } - pub async fn run(self) -> Result<()> { + + pub async fn run(mut self) -> Result<()> { loop { eprintln!("SHA256: Looping"); - let Some(msg) = self.channel.recv_new().await? else { + + let Some(session) = self.channel.listen().await? else { eprintln!("Channel closed by remote"); break; }; - if msg.request { - let child_channel = self.channel.clone(); - tokio::spawn(async move { - if let Err(e) = - Sha256Runner::handle_query(child_channel, msg.id, msg.query, msg.key).await - { - eprintln!("handle_query failed: {e}"); - }; - }); - } else { - return Err(anyhow!("Did not expect a response-type message here")); - } + + tokio::spawn(async move { + if let Err(e) = handle_session(session).await { + eprintln!("handle_session failed: {e}"); + }; + }); } + Ok(()) } } #[derive(Debug)] -struct RandDataPlugin { - pub schema: Schema, +struct Sha256Plugin { + schema: GetQuerySchemasResponse, } -impl RandDataPlugin { - pub fn new() -> Self { - let schema = Schema { - query_name: "sha256".to_owned(), - key_schema: SHA256_KEY_SCHEMA.to_owned(), - output_schema: SHA256_OUTPUT_SCHEMA.to_owned(), - }; - RandDataPlugin { schema } + +impl Sha256Plugin { + fn new() -> Self { + Sha256Plugin { + schema: GetQuerySchemasResponse { + query_name: "sha256".to_owned(), + key_schema: SHA256_KEY_SCHEMA.to_owned(), + output_schema: SHA256_OUTPUT_SCHEMA.to_owned(), + }, + } } } #[tonic::async_trait] -impl Plugin for RandDataPlugin { +impl PluginService for Sha256Plugin { type GetQuerySchemasStream = - Pin> + Send + 'static>>; - type InitiateQueryProtocolStream = ReceiverStream>; + Pin> + Send + 'static>>; + + type InitiateQueryProtocolStream = + ReceiverStream>; + async fn get_query_schemas( &self, - _request: Request, + _request: Request, ) -> Result, Status> { Ok(Response::new(Box::pin(tokio_stream::iter(vec![Ok(self .schema .clone())])))) } + async fn set_configuration( &self, - request: Request, - ) -> Result, Status> { - Ok(Response::new(ConfigurationResult { - status: ConfigurationStatus::ErrorNone as i32, + _request: Request, + ) -> Result, Status> { + Ok(Response::new(SetConfigurationResponse { + status: ConfigurationStatus::None as i32, message: "".to_owned(), })) } + async fn get_default_policy_expression( &self, - request: Request, - ) -> Result, Status> { - Ok(Response::new(PolicyExpression { + _request: Request, + ) -> Result, Status> { + Ok(Response::new(GetDefaultPolicyExpressionResponse { policy_expression: "".to_owned(), })) } + async fn initiate_query_protocol( &self, - request: Request>, + request: Request>, ) -> Result, Status> { let rx = request.into_inner(); - let (tx, out_rx) = mpsc::channel::>(4); + let (tx, out_rx) = mpsc::channel::>(4); + tokio::spawn(async move { - let channel = HcTransport::new(rx, tx); + let channel = HcSessionSocket::new(tx, rx); + if let Err(e) = Sha256Runner::new(channel).run().await { eprintln!("sha256 plugin ended in error: {e}"); } }); + Ok(Response::new(ReceiverStream::new(out_rx))) } } @@ -170,12 +196,11 @@ struct Args { #[tokio::main(flavor = "current_thread")] async fn main() -> Result<(), Box> { let args = Args::try_parse().map_err(Box::new)?; - let addr = format!("127.0.0.1:{}", args.port); - let plugin = RandDataPlugin::new(); - let svc = PluginServer::new(plugin); - Server::builder() - .add_service(svc) - .serve(addr.parse().unwrap()) - .await?; + + let service = PluginServiceServer::new(Sha256Plugin::new()); + let host = format!("127.0.0.1:{}", args.port).parse().unwrap(); + + Server::builder().add_service(service).serve(host).await?; + Ok(()) } diff --git a/plugins/dummy_sha256/src/transport.rs b/plugins/dummy_sha256/src/transport.rs new file mode 100644 index 00000000..7a26145f --- /dev/null +++ b/plugins/dummy_sha256/src/transport.rs @@ -0,0 +1,365 @@ +use crate::proto::{ + InitiateQueryProtocolRequest, InitiateQueryProtocolResponse, Query as PluginQuery, QueryState, +}; +use anyhow::{anyhow, Result}; +use futures::Stream; +use serde_json::Value; +use std::{ + collections::{HashMap, VecDeque}, + future::poll_fn, + ops::Not as _, + pin::Pin, +}; +use tokio::sync::mpsc::{self, error::TrySendError}; +use tonic::{Status, Streaming}; + +#[derive(Debug)] +pub struct Query { + pub direction: QueryDirection, + pub publisher: String, + pub plugin: String, + pub query: String, + pub key: Value, + pub output: Value, +} + +#[derive(Debug, PartialEq, Eq)] +pub enum QueryDirection { + Request, + Response, +} + +impl TryFrom for QueryDirection { + type Error = anyhow::Error; + + fn try_from(value: QueryState) -> std::result::Result { + match value { + QueryState::Unspecified => { + Err(anyhow!("unspecified error; query is in an invalid state")) + } + QueryState::Submit => Ok(QueryDirection::Request), + QueryState::ReplyInProgress => Err(anyhow!("invalid state QueryReplyInProgress")), + QueryState::ReplyComplete => Ok(QueryDirection::Response), + } + } +} + +impl From for QueryState { + fn from(value: QueryDirection) -> Self { + match value { + QueryDirection::Request => QueryState::Submit, + QueryDirection::Response => QueryState::ReplyComplete, + } + } +} + +impl TryFrom for Query { + type Error = anyhow::Error; + + fn try_from(value: PluginQuery) -> Result { + Ok(Query { + direction: QueryDirection::try_from(value.state())?, + publisher: value.publisher_name, + plugin: value.plugin_name, + query: value.query_name, + key: serde_json::from_str(value.key.as_str())?, + output: serde_json::from_str(value.output.as_str())?, + }) + } +} + +type SessionTracker = HashMap>>; + +pub struct QuerySession { + id: usize, + tx: mpsc::Sender>, + rx: mpsc::Receiver>, + // So that we can remove ourselves when we get dropped + drop_tx: mpsc::Sender, +} + +impl QuerySession { + pub fn id(&self) -> usize { + self.id + } + + // Roughly equivalent to TryFrom, but the `id` field value + // comes from the QuerySession + fn convert(&self, value: Query) -> Result { + let state: QueryState = value.direction.into(); + let key = serde_json::to_string(&value.key)?; + let output = serde_json::to_string(&value.output)?; + + Ok(PluginQuery { + id: self.id() as i32, + state: state as i32, + publisher_name: value.publisher, + plugin_name: value.plugin, + query_name: value.query, + key, + output, + }) + } + + async fn recv_raw(&mut self) -> Result>> { + let mut out = VecDeque::new(); + + eprintln!("RAND-session: awaiting raw rx recv"); + + let opt_first = self + .rx + .recv() + .await + .ok_or(anyhow!("session channel closed unexpectedly"))?; + + let Some(first) = opt_first else { + // Underlying gRPC channel closed + return Ok(None); + }; + eprintln!("RAND-session: got first msg"); + out.push_back(first); + + // If more messages in the queue, opportunistically read more + loop { + eprintln!("RAND-session: trying to get additional msg"); + + match self.rx.try_recv() { + Ok(Some(msg)) => { + out.push_back(msg); + } + Ok(None) => { + eprintln!("warning: None received, gRPC channel closed. we may not close properly if None is not returned again"); + break; + } + // Whether empty or disconnected, we return what we have + Err(_) => { + break; + } + } + } + + eprintln!("RAND-session: got {} msgs", out.len()); + Ok(Some(out)) + } + + pub async fn send(&self, query: Query) -> Result<()> { + eprintln!("RAND-session: sending query"); + + let query = InitiateQueryProtocolResponse { + query: Some(self.convert(query)?), + }; + + self.tx.send(Ok(query)).await?; + + Ok(()) + } + + pub async fn recv(&mut self) -> Result> { + use QueryState::*; + + eprintln!("RAND-session: calling recv_raw"); + let Some(mut msg_chunks) = self.recv_raw().await? else { + return Ok(None); + }; + + let mut raw = msg_chunks.pop_front().unwrap(); + eprintln!("RAND-session: recv got raw {raw:?}"); + + let mut state: QueryState = raw.state.try_into()?; + + // If response is the first of a set of chunks, handle + if matches!(state, ReplyInProgress) { + while matches!(state, ReplyInProgress) { + // We expect another message. Pull it off the existing queue, + // or get a new one if we have run out + let next = match msg_chunks.pop_front() { + Some(msg) => msg, + None => { + // We ran out of messages, get a new batch + match self.recv_raw().await? { + Some(x) => { + msg_chunks = x; + } + None => { + return Ok(None); + } + }; + msg_chunks.pop_front().unwrap() + } + }; + + // By now we have our "next" message + state = next.state.try_into()?; + match state { + Unspecified => return Err(anyhow!("unspecified error from plugin")), + Submit => { + return Err(anyhow!( + "plugin sent QuerySubmit state when reply chunk expected" + )) + } + ReplyInProgress | ReplyComplete => { + raw.output.push_str(next.output.as_str()); + } + }; + } + + // Sanity check - after we've left this loop, there should be no left over message + if msg_chunks.is_empty().not() { + return Err(anyhow!( + "received additional messages for id '{}' after QueryComplete status message", + self.id + )); + } + } + + raw.try_into().map(Some) + } +} + +impl Drop for QuerySession { + // Notify to have self removed from session tracker + fn drop(&mut self) { + while let Err(e) = self.drop_tx.try_send(self.id as i32) { + match e { + TrySendError::Closed(_) => { + break; + } + TrySendError::Full(_) => (), + } + } + } +} + +pub struct HcSessionSocket { + tx: mpsc::Sender>, + rx: Streaming, + drop_tx: mpsc::Sender, + drop_rx: mpsc::Receiver, + sessions: SessionTracker, +} + +// This is implemented manually since the stream trait object +// can't impl `Debug`. +impl std::fmt::Debug for HcSessionSocket { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("HcSessionSocket") + .field("tx", &self.tx) + .field("rx", &"") + .field("drop_tx", &self.drop_tx) + .field("drop_rx", &self.drop_rx) + .field("sessions", &self.sessions) + .finish() + } +} + +impl HcSessionSocket { + pub fn new( + tx: mpsc::Sender>, + rx: Streaming, + ) -> Self { + // channel for QuerySession objects to notify us they dropped + // @Todo - make this configurable + let (drop_tx, drop_rx) = mpsc::channel(10); + + Self { + tx, + rx, + drop_tx, + drop_rx, + sessions: HashMap::new(), + } + } + + /// Clean up completed sessions by going through all drop messages. + fn cleanup_sessions(&mut self) { + while let Ok(id) = self.drop_rx.try_recv() { + match self.sessions.remove(&id) { + Some(_) => eprintln!("Cleaned up session {id}"), + None => eprintln!( + "WARNING: HcSessionSocket got request to drop a session that does not exist" + ), + } + } + } + + async fn message(&mut self) -> Result, Status> { + let fut = poll_fn(|cx| Pin::new(&mut self.rx).poll_next(cx)); + + match fut.await { + Some(Ok(m)) => Ok(m.query), + Some(Err(e)) => Err(e), + None => Ok(None), + } + } + + pub async fn listen(&mut self) -> Result> { + loop { + eprintln!("RAND: listening"); + + let Some(raw) = self.message().await? else { + return Ok(None); + }; + let id = raw.id; + + // While we were waiting for a message, some session objects may have + // dropped, handle them before we look at the ID of this message. + // The downside of this strategy is that once we receive our last message, + // we won't clean up any sessions that close after + self.cleanup_sessions(); + + match self.decide_action(&raw) { + Ok(HandleAction::ForwardMsgToExistingSession(tx)) => { + eprintln!("RAND-listen: forwarding message to session {id}"); + + if let Err(_e) = tx.send(Some(raw)).await { + eprintln!("Error forwarding msg to session {id}"); + self.sessions.remove(&id); + }; + } + Ok(HandleAction::CreateSession) => { + eprintln!("RAND-listen: creating new session {id}"); + + let (in_tx, rx) = mpsc::channel::>(10); + let tx = self.tx.clone(); + + let session = QuerySession { + id: id as usize, + tx, + rx, + drop_tx: self.drop_tx.clone(), + }; + + in_tx.send(Some(raw)).await.expect( + "Failed sending message to newly created Session, should never happen", + ); + + eprintln!("RAND-listen: adding new session {id} to tracker"); + self.sessions.insert(id, in_tx); + + return Ok(Some(session)); + } + Err(e) => eprintln!("error: {}", e), + } + } + } + + fn decide_action(&mut self, query: &PluginQuery) -> Result> { + if let Some(tx) = self.sessions.get_mut(&query.id) { + return Ok(HandleAction::ForwardMsgToExistingSession(tx)); + } + + if query.state() == QueryState::Submit { + return Ok(HandleAction::CreateSession); + } + + Err(anyhow!( + "Got query with id {}, does not match existing session and is not new QuerySubmit", + query.id + )) + } +} + +enum HandleAction<'s> { + ForwardMsgToExistingSession(&'s mut mpsc::Sender>), + CreateSession, +} diff --git a/hipcheck/proto/hipcheck.proto b/proto/hipcheck/v1/hipcheck.proto similarity index 70% rename from hipcheck/proto/hipcheck.proto rename to proto/hipcheck/v1/hipcheck.proto index 724514e2..ba93a508 100644 --- a/hipcheck/proto/hipcheck.proto +++ b/proto/hipcheck/v1/hipcheck.proto @@ -1,8 +1,8 @@ syntax = "proto3"; -package hipcheck; +package hipcheck.v1; -service Plugin { +service PluginService { /** * Get schemas for all supported queries by the plugin. * @@ -13,19 +13,22 @@ service Plugin { * - That requests sent to the plugin and data returned by the plugin * match the schema during execution. */ - rpc GetQuerySchemas (Empty) returns (stream Schema); + rpc GetQuerySchemas (GetQuerySchemasRequest) + returns (stream GetQuerySchemasResponse); /** * Hipcheck sends all child nodes for the plugin from the user's policy * file to configure the plugin. */ - rpc SetConfiguration (Configuration) returns (ConfigurationResult); + rpc SetConfiguration (SetConfigurationRequest) + returns (SetConfigurationResponse); /** * Get the default policy for a plugin, which may additionally depend on * the plugin's configuration. */ - rpc GetDefaultPolicyExpression (Empty) returns (PolicyExpression); + rpc GetDefaultPolicyExpression (GetDefaultPolicyExpressionRequest) + returns (GetDefaultPolicyExpressionResponse); /** * Open a bidirectional streaming RPC to enable a request/response @@ -38,36 +41,73 @@ service Plugin { * Hipcheck will respond with the cached result of that prior matching * query rather than running the query again. */ - rpc InitiateQueryProtocol (stream Query) returns (stream Query); + rpc InitiateQueryProtocol (stream InitiateQueryProtocolRequest) + returns (stream InitiateQueryProtocolResponse); } -message Configuration { +/*=========================================================================== + * GetQuerySchemas RPC Types + */ + +message GetQuerySchemasRequest { + Empty empty = 1; +} + +message GetQuerySchemasResponse { + // The name of the query being described by the schemas provided. + // + // If either the key and/or output schemas result in a message which is + // too big, they may be chunked across multiple replies in the stream. + // Replies with matching query names should have their fields concatenated + // in the order received to reconstruct the chunks. + string query_name = 1; + + // The key schema, in JSON Schema format. + string key_schema = 2; + + // The output schema, in JSON Schema format. + string output_schema = 3; +} + +/*=========================================================================== + * SetConfiguration RPC Types + */ + +message SetConfigurationRequest { // JSON string containing configuration data expected by the plugin, // pulled from the user's policy file. string configuration = 1; } +message SetConfigurationResponse { + // The status of the configuration call. + ConfigurationStatus status = 1; + // An optional error message, if there was an error. + string message = 2; +} + enum ConfigurationStatus { // An unknown error occured. - ERROR_UNKNOWN = 0; + CONFIGURATION_STATUS_UNSPECIFIED = 0; // No error; the operation was successful. - ERROR_NONE = 1; + CONFIGURATION_STATUS_NONE = 1; // The user failed to provide a required configuration item. - ERROR_MISSING_REQUIRED_CONFIGURATION = 2; + CONFIGURATION_STATUS_MISSING_REQUIRED_CONFIGURATION = 2; // The user provided a configuration item whose name was not recognized. - ERROR_UNRECOGNIZED_CONFIGURATION = 3; + CONFIGURATION_STATUS_UNRECOGNIZED_CONFIGURATION = 3; // The user provided a configuration item whose value is invalid. - ERROR_INVALID_CONFIGURATION_VALUE = 4; + CONFIGURATION_STATUS_INVALID_CONFIGURATION_VALUE = 4; } -message ConfigurationResult { - // The status of the configuration call. - ConfigurationStatus status = 1; - // An optional error message, if there was an error. - string message = 2; +/*=========================================================================== + * GetDefaultPolicyExpression RPC Types + */ + +message GetDefaultPolicyExpressionRequest { + Empty empty = 1; } -message PolicyExpression { +message GetDefaultPolicyExpressionResponse { // A policy expression, if the plugin has a default policy. // This MUST be filled in with any default values pulled from the plugin's // configuration. Hipcheck will only request the default policy _after_ @@ -75,37 +115,17 @@ message PolicyExpression { string policy_expression = 1; } -message Schema { - // The name of the query being described by the schemas provided. - // - // If either the key and/or output schemas result in a message which is - // too big, they may be chunked across multiple replies in the stream. - // Replies with matching query names should have their fields concatenated - // in the order received to reconstruct the chunks. - string query_name = 1; - // The key schema, in JSON Schema format. - string key_schema = 2; +/*=========================================================================== + * Query Protocol RPC Types + */ - // The output schema, in JSON Schema format. - string output_schema = 3; +message InitiateQueryProtocolRequest { + Query query = 1; } -enum QueryState { - // Something has gone wrong. - QUERY_UNSPECIFIED = 0; - - // We are submitting a new query. - QUERY_SUBMIT = 1; - - // We are replying to a query and expect more chunks. - QUERY_REPLY_IN_PROGRESS = 2; - - // We are closing a reply to a query. If a query response is in one chunk, - // just send this. If a query is in more than one chunk, send this with - // the last message in the reply. This tells the receiver that all chunks - // have been received. - QUERY_REPLY_COMPLETE = 3; +message InitiateQueryProtocolResponse { + Query query = 1; } message Query { @@ -142,4 +162,26 @@ message Query { string output = 7; } +enum QueryState { + // Something has gone wrong. + QUERY_STATE_UNSPECIFIED = 0; + + // We are submitting a new query. + QUERY_STATE_SUBMIT = 1; + + // We are replying to a query and expect more chunks. + QUERY_STATE_REPLY_IN_PROGRESS = 2; + + // We are closing a reply to a query. If a query response is in one chunk, + // just send this. If a query is in more than one chunk, send this with + // the last message in the reply. This tells the receiver that all chunks + // have been received. + QUERY_STATE_REPLY_COMPLETE = 3; +} + + +/*=========================================================================== + * Helper Types + */ + message Empty {}