From bbdd162dac5c5f775f0c1480875150228637462a Mon Sep 17 00:00:00 2001 From: Andrew Lilley Brinker Date: Thu, 8 Aug 2024 11:26:55 -0700 Subject: [PATCH] chore: Integrated policy expressions code This is still a work-in-progress, and nothing is actually used yet, but this at least moves over the draft implementation of policy expressions from the external experimental repo I'd made to enable integration with Hipcheck. The other thing I did was organize the Cargo.toml file for the hipcheck crate, particularly to put all dependencies in alphabetical order, as they had been at one point previously. Signed-off-by: Andrew Lilley Brinker --- Cargo.lock | 169 +++++- hipcheck/Cargo.toml | 127 +++-- hipcheck/src/main.rs | 15 +- hipcheck/src/policy_exprs/bridge.rs | 347 +++++++++++++ hipcheck/src/policy_exprs/env.rs | 774 ++++++++++++++++++++++++++++ hipcheck/src/policy_exprs/error.rs | 97 ++++ hipcheck/src/policy_exprs/expr.rs | 280 ++++++++++ hipcheck/src/policy_exprs/mod.rs | 141 +++++ hipcheck/src/policy_exprs/token.rs | 164 ++++++ 9 files changed, 2051 insertions(+), 63 deletions(-) create mode 100644 hipcheck/src/policy_exprs/bridge.rs create mode 100644 hipcheck/src/policy_exprs/env.rs create mode 100644 hipcheck/src/policy_exprs/error.rs create mode 100644 hipcheck/src/policy_exprs/expr.rs create mode 100644 hipcheck/src/policy_exprs/mod.rs create mode 100644 hipcheck/src/policy_exprs/token.rs diff --git a/Cargo.lock b/Cargo.lock index 0d30a08c..2a57c9cd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -241,6 +241,12 @@ version = "0.22.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" +[[package]] +name = "beef" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3a8241f3ebb85c056b509d4327ad0358fbbba6ffb340bf388f26350aeda225b1" + [[package]] name = "bitflags" version = "1.3.2" @@ -1115,8 +1121,10 @@ dependencies = [ "indexmap 2.2.6", "indextree", "indicatif", + "itertools", "kdl", "log", + "logos", "maplit", "nom", "num-traits", @@ -1145,6 +1153,8 @@ dependencies = [ "tar", "tempfile", "term_size", + "test-log", + "thiserror", "tokio", "toml", "tonic", @@ -1542,6 +1552,39 @@ version = "0.4.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" +[[package]] +name = "logos" +version = "0.14.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ff1ceb190eb9bdeecdd8f1ad6a71d6d632a50905948771718741b5461fb01e13" +dependencies = [ + "logos-derive", +] + +[[package]] +name = "logos-codegen" +version = "0.14.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "90be66cb7bd40cb5cc2e9cfaf2d1133b04a3d93b72344267715010a466e0915a" +dependencies = [ + "beef", + "fnv", + "lazy_static", + "proc-macro2", + "quote", + "regex-syntax 0.8.4", + "syn 2.0.75", +] + +[[package]] +name = "logos-derive" +version = "0.14.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "45154231e8e96586b39494029e58f12f8ffcb5ecf80333a603a13aa205ea8cbd" +dependencies = [ + "logos-codegen", +] + [[package]] name = "lzma-rs" version = "0.3.0" @@ -1569,6 +1612,15 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3e2e65a1a2e43cfcb47a895c4c8b10d1f4a61097f9f254f183aee60cad9c651d" +[[package]] +name = "matchers" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8263075bb86c5a1b1427b5ae862e8889656f126e9f77c484496e8b47cf5c5558" +dependencies = [ + "regex-automata 0.1.10", +] + [[package]] name = "matchit" version = "0.7.3" @@ -1627,9 +1679,9 @@ dependencies = [ [[package]] name = "mio" -version = "1.0.1" +version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4569e456d394deccd22ce1c1913e6ea0e54519f577285001215d33557431afe4" +checksum = "80e04d1dcff3aae0704555fe5fee3bcfaf3d1fdf8a7e521d5b9d2b42acb52cec" dependencies = [ "hermit-abi", "libc", @@ -1653,6 +1705,16 @@ dependencies = [ "minimal-lexical", ] +[[package]] +name = "nu-ansi-term" +version = "0.46.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77a8165726e8236064dbb45459242600304b42a5ea24ee2948e18e023bf7ba84" +dependencies = [ + "overload", + "winapi", +] + [[package]] name = "num-conv" version = "0.1.0" @@ -1771,6 +1833,12 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "overload" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" + [[package]] name = "packageurl" version = "0.4.0" @@ -2149,8 +2217,17 @@ checksum = "4219d74c6b67a3654a9fbebc4b419e22126d13d2f3c4a07ee0cb61ff79a79619" dependencies = [ "aho-corasick", "memchr", - "regex-automata", - "regex-syntax", + "regex-automata 0.4.7", + "regex-syntax 0.8.4", +] + +[[package]] +name = "regex-automata" +version = "0.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c230d73fb8d8c1b9c0b3135c5142a8acee3a0558fb8db5cf1cb65f8d7862132" +dependencies = [ + "regex-syntax 0.6.29", ] [[package]] @@ -2161,9 +2238,15 @@ checksum = "38caf58cc5ef2fed281f89292ef23f6365465ed9a41b7a7754eb4e26496c92df" dependencies = [ "aho-corasick", "memchr", - "regex-syntax", + "regex-syntax 0.8.4", ] +[[package]] +name = "regex-syntax" +version = "0.6.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1" + [[package]] name = "regex-syntax" version = "0.8.4" @@ -2448,6 +2531,15 @@ dependencies = [ "digest", ] +[[package]] +name = "sharded-slab" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f40ca3c46823713e0d4209592e8d6e826aa57e928f09752619fc696c499637f6" +dependencies = [ + "lazy_static", +] + [[package]] name = "shared_child" version = "1.0.0" @@ -2695,6 +2787,28 @@ dependencies = [ "winapi", ] +[[package]] +name = "test-log" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3dffced63c2b5c7be278154d76b479f9f9920ed34e7574201407f0b14e2bbb93" +dependencies = [ + "env_logger", + "test-log-macros", + "tracing-subscriber", +] + +[[package]] +name = "test-log-macros" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5999e24eaa32083191ba4e425deb75cdf25efefabe5aaccb7446dd0d4122a3f5" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.75", +] + [[package]] name = "thiserror" version = "1.0.63" @@ -2715,6 +2829,16 @@ dependencies = [ "syn 2.0.75", ] +[[package]] +name = "thread_local" +version = "1.1.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b9ef9bad013ada3808854ceac7b46812a6465ba368859a37e2100283d2d719c" +dependencies = [ + "cfg-if", + "once_cell", +] + [[package]] name = "time" version = "0.3.36" @@ -2961,6 +3085,35 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c06d3da6113f116aaee68e4d601191614c9053067f9ab7f6edbcb161237daa54" dependencies = [ "once_cell", + "valuable", +] + +[[package]] +name = "tracing-log" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee855f1f400bd0e5c02d150ae5de3840039a3f54b025156404e34c23c03f47c3" +dependencies = [ + "log", + "once_cell", + "tracing-core", +] + +[[package]] +name = "tracing-subscriber" +version = "0.3.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ad0f048c97dbd9faa9b7df56362b8ebcaa52adb06b498c050d2f4e32f90a7a8b" +dependencies = [ + "matchers", + "nu-ansi-term", + "once_cell", + "regex", + "sharded-slab", + "thread_local", + "tracing", + "tracing-core", + "tracing-log", ] [[package]] @@ -3066,6 +3219,12 @@ dependencies = [ "getrandom", ] +[[package]] +name = "valuable" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "830b7e5d4d90034032940e4ace0d9a9a057e7a45cd94e6c007832e39edb82f6d" + [[package]] name = "vcpkg" version = "0.2.15" diff --git a/hipcheck/Cargo.toml b/hipcheck/Cargo.toml index 26de5e5f..bca911ba 100644 --- a/hipcheck/Cargo.toml +++ b/hipcheck/Cargo.toml @@ -1,47 +1,100 @@ [package] name = "hipcheck" -description = "Automatically assess and score software repositories for supply chain risk" +description = """ +Automatically assess and score software packages for supply chain risk. +""" keywords = ["security", "sbom"] categories = ["command-line-utilities", "development-tools"] readme = "../README.md" version = "3.5.0" edition = "2021" license = "Apache-2.0" +homepage = "https://mitre.github.io/hipcheck" repository = "https://github.com/mitre/hipcheck" include = ["src/**/*", "../LICENSE", "../README.md"] -[features] -# Print timings feature is used to print timing information throughout hipchecks runtime. -print-timings = ["benchmarking"] -# Benchmarking enables the benchmarking module, containing special utilities for benchmarking. -benchmarking = [] - +# Rename the binary from the default "hipcheck" (based on the package name) +# to "hc". [[bin]] name = "hc" path = "src/main.rs" +[features] + +# Print timings feature is used to print timing information throughout +# Hipcheck's runtime. +print-timings = ["benchmarking"] + +# Benchmarking enables the benchmarking module, containing special utilities +# for benchmarking. +benchmarking = [] + [dependencies] +async-stream = "0.3.5" +base64 = "0.22.1" content_inspector = "0.2.4" cyclonedx-bom = "0.7.0" dotenv = "0.15.0" chrono = { version = "0.4.19", features = ["alloc", "serde"] } clap = { version = "4.5.13", features = ["derive"] } +console = { version = "0.15.8", features = ["windows-console-colors"] } +dashmap = { version = "6.0.1", features = ["rayon", "inline"] } +dialoguer = "0.11.0" dirs = "5.0.1" duct = "0.13.5" env_logger = { version = "0.11.5" } +finl_unicode = { version = "1.2.0", default-features = false, features = [ + "grapheme_clusters", +] } +fs_extra = "1.3.0" +futures = "0.3.30" +# Vendor libgit2 and openssl so that they will be statically included +# and not cause problems on certain systems that might not have one or +# the other. +git2 = { version = "0.19.0", features = [ + "vendored-libgit2", + "vendored-openssl", +] } graphql_client = "0.14.0" +# Include with both a `path` and `version` reference. +# Local builds will use the `path` dependency, which may be a newer +# version than the one published to Crates.io. +# People building from Crates.io will use the published version. +# +# See: https://doc.rust-lang.org/cargo/reference/specifying-dependencies.html#multiple-locations hipcheck-macros = { path = "../hipcheck-macros", version = "0.3.1" } +http = "1.1.0" +indexmap = "2.2.6" +indextree = "4.6.1" +indicatif = { version = "0.17.8", features = ["rayon"] } +itertools = "0.13.0" +kdl = "4.6.0" log = "0.4.22" +logos = "0.14.0" maplit = "1.0.2" nom = "7.1.3" +num-traits = "0.2.19" +num_enum = "0.7.3" once_cell = "1.10.0" ordered-float = { version = "4.2.2", features = ["serde"] } packageurl = "0.4.0" paste = "1.0.7" pathbuf = "1.0.0" petgraph = { version = "0.6.0", features = ["serde-1"] } -regex = "1.10.6" +prost = "0.13.1" +rand = "0.8.5" +rayon = "1.10.0" +regex = "1.10.5" +# Exactly matching the version of rustls used by ureq +# Get rid of default features since we don't use the AWS backed crypto +# provider (we use ring) and it breaks stuff on windows. +rustls = { version = "0.23.10", default-features = false, features = [ + "logging", + "std", + "tls12", + "ring", +] } rustls-native-certs = "0.7.1" salsa = "0.16.1" schemars = { version = "0.8.21", default-features = false, features = [ @@ -55,7 +108,18 @@ serde_derive = "1.0.137" serde_json = "1.0.122" smart-default = "0.7.1" spdx-rs = "0.5.0" +tabled = "0.15.0" +tar = "0.4.41" +term_size = "0.3.2" +tokio = { version = "1.39.3", features = [ + "rt", + "rt-multi-thread", + "sync", + "time", +] } toml = "0.8.19" +tonic = "0.12.1" +thiserror = "1.0.63" unicode-normalization = "0.1.19" ureq = { version = "2.10.0", default-features = false, features = [ "json", @@ -64,58 +128,21 @@ ureq = { version = "2.10.0", default-features = false, features = [ url = "2.5.1" walkdir = "2.5.0" which = { version = "6.0.1", default-features = false } -xml-rs = "0.8.21" -rayon = "1.10.0" -indexmap = "2.2.6" -dashmap = { version = "6.0.1", features = ["rayon", "inline"] } -# Vendor libgit2 and openssl so that they will be statically included and not cause problems on certain systems that might not have one or the other. -git2 = { version = "0.19.0", features = ["vendored-libgit2", "vendored-openssl"]} -indicatif = { version = "0.17.8", features = ["rayon"] } -finl_unicode = { version = "1.2.0", default-features = false, features = [ - "grapheme_clusters", -] } -tar = "0.4.41" -zip = "2.1.6" +xml-rs = "0.8.20" xz2 = "0.1.7" -indextree = "4.7.2" -num-traits = "0.2.19" -console = { version = "0.15.8", features = ["windows-console-colors"] } -term_size = "0.3.2" -base64 = "0.22.1" -http = "1.1.0" -dialoguer = "0.11.0" -tabled = "0.15.0" -fs_extra = "1.3.0" -tonic = "0.12.1" -prost = "0.13.1" -rand = "0.8.5" -kdl = "4.6.0" -tokio = { version = "1.39.2", features = ["rt", "rt-multi-thread", "sync", "time"] } -futures = "0.3.30" -async-stream = "0.3.5" -num_enum = "0.7.3" - -# Exactly matching the version of rustls used by ureq -# Get rid of default features since we don't use the AWS backed crypto provider (we use ring). -# and it breaks stuff on windows. -[dependencies.rustls] -version = "0.23.10" -default-features = false -features = [ - "logging", - "std", - "tls12", - "ring" -] +zip = "2.1.6" [build-dependencies] + anyhow = "1.0.86" tonic-build = "0.12.1" which = { version = "6.0.1", default-features = false } [dev-dependencies] + dirs = "5.0.1" tempfile = "3.12.0" +test-log = "0.2.16" [package.metadata.dist] diff --git a/hipcheck/src/main.rs b/hipcheck/src/main.rs index 5688d974..69ed58d8 100644 --- a/hipcheck/src/main.rs +++ b/hipcheck/src/main.rs @@ -17,6 +17,7 @@ mod log_bridge; mod metric; #[allow(unused)] mod plugin; +mod policy_exprs; mod report; mod session; mod setup; @@ -84,14 +85,6 @@ use target::{RemoteGitRepo, TargetSeed, TargetSeedKind, ToTargetSeed}; use util::fs::create_dir_all; use which::which; -fn init_logging() -> std::result::Result<(), log::SetLoggerError> { - let env = Env::new().filter("HC_LOG").write_style("HC_LOG_STYLE"); - - let logger = env_logger::Builder::from_env(env).build(); - - log_bridge::LogWrapper(logger).try_init() -} - /// Entry point for Hipcheck. fn main() -> ExitCode { // Initialize the global shell with normal verbosity by default. @@ -160,6 +153,12 @@ fn main() -> ExitCode { ExitCode::SUCCESS } +fn init_logging() -> std::result::Result<(), log::SetLoggerError> { + let env = Env::new().filter("HC_LOG").write_style("HC_LOG_STYLE"); + let logger = env_logger::Builder::from_env(env).build(); + log_bridge::LogWrapper(logger).try_init() +} + /// Run the `check` command. fn cmd_check(args: &CheckArgs, config: &CliConfig) -> ExitCode { let target = match args.to_target_seed() { diff --git a/hipcheck/src/policy_exprs/bridge.rs b/hipcheck/src/policy_exprs/bridge.rs new file mode 100644 index 00000000..24adef9f --- /dev/null +++ b/hipcheck/src/policy_exprs/bridge.rs @@ -0,0 +1,347 @@ +// The following code is copied from the `logos-nom-bridge` crate, which uses +// an outdated version of `logos` and thus can't be used directly here. +// +// The original code which we have copied and modified is MIT licensed, and +// used under the terms of that license here. + +//! # logos-nom-bridge +//! +//! A [`logos::Lexer`] wrapper than can be used as an input for +//! [nom](https://docs.rs/nom/7.0.0/nom/index.html). +//! + +use core::fmt; +use logos::{Lexer, Logos, Span, SpannedIter}; +use nom::{InputIter, InputLength, InputTake}; + +/// A [`logos::Lexer`] wrapper than can be used as an input for +/// [nom](https://docs.rs/nom/7.0.0/nom/index.html). +/// +/// You can find an example in the [module-level docs](..). +pub struct Tokens<'i, T> +where + T: Logos<'i>, +{ + lexer: Lexer<'i, T>, +} + +impl<'i, T> Clone for Tokens<'i, T> +where + T: Logos<'i> + Clone, + T::Extras: Clone, +{ + fn clone(&self) -> Self { + Self { + lexer: self.lexer.clone(), + } + } +} + +// Helper type returned by the logos parser. +type ParseResult<'i, T> = Result>::Error>; + +impl<'i, T> Tokens<'i, T> +where + T: Logos<'i, Source = str> + Clone, + T::Extras: Default + Clone, +{ + /// Create a new token parser. + pub fn new(input: &'i str) -> Self { + Tokens { + lexer: Lexer::new(input), + } + } + + /// Get the length of the remaining source to parse. + pub fn len(&self) -> usize { + self.lexer.source().len() - self.lexer.span().end + } + + /// See if the remaining length to parse is empty. + #[allow(unused)] + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Peek at the next token, possibly with a parsing error. + pub fn peek(&self) -> Option<(ParseResult<'i, T>, &'i str)> { + let mut iter = self.lexer.clone().spanned(); + iter.next().map(|(t, span)| (t, &self.lexer.source()[span])) + } + + /// Advance the parser one step. + pub fn advance(mut self) -> Self { + self.lexer.next(); + self + } + + /// Get the underlying lexer. + pub fn lexer(&self) -> &Lexer<'i, T> { + &self.lexer + } +} + +impl<'i, T> PartialEq for Tokens<'i, T> +where + T: PartialEq + Logos<'i> + Clone, + T::Extras: Clone, +{ + fn eq(&self, other: &Self) -> bool { + Iterator::eq(self.lexer.clone(), other.lexer.clone()) + } +} + +impl<'i, T> Eq for Tokens<'i, T> +where + T: Eq + Logos<'i> + Clone, + T::Extras: Clone, +{ +} + +impl<'i, T> fmt::Debug for Tokens<'i, T> +where + T: fmt::Debug + Logos<'i, Source = str>, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let source = self.lexer.source(); + let start = self.lexer.span().start; + f.debug_tuple("Tokens").field(&&source[start..]).finish() + } +} + +impl<'i, T> fmt::Display for Tokens<'i, T> +where + T: fmt::Debug + fmt::Display + Logos<'i, Source = str>, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + (self as &dyn fmt::Debug).fmt(f) + } +} + +impl<'i, T> Default for Tokens<'i, T> +where + T: Logos<'i, Source = str>, + T::Extras: Default, +{ + fn default() -> Self { + Tokens { + lexer: Lexer::new(""), + } + } +} + +/// An iterator, that (similarly to [`std::iter::Enumerate`]) produces byte offsets of the tokens. +pub struct IndexIterator<'i, T> +where + T: Logos<'i>, +{ + logos: Lexer<'i, T>, +} + +impl<'i, T> Iterator for IndexIterator<'i, T> +where + T: Logos<'i>, +{ + type Item = (usize, (ParseResult<'i, T>, Span)); + + fn next(&mut self) -> Option { + self.logos.next().map(|t| { + let span = self.logos.span(); + (span.start, (t, span)) + }) + } +} + +impl<'i, T> InputIter for Tokens<'i, T> +where + T: Logos<'i, Source = str> + Clone, + T::Extras: Default + Clone, +{ + type Item = (ParseResult<'i, T>, Span); + type Iter = IndexIterator<'i, T>; + type IterElem = SpannedIter<'i, T>; + + fn iter_indices(&self) -> Self::Iter { + IndexIterator { + logos: self.lexer.clone(), + } + } + + fn iter_elements(&self) -> Self::IterElem { + self.lexer.clone().spanned() + } + + fn position

(&self, predicate: P) -> Option + where + P: Fn(Self::Item) -> bool, + { + let mut iter = self.lexer.clone().spanned(); + iter.find(|t| predicate(t.clone())) + .map(|(_, span)| span.start) + } + + fn slice_index(&self, count: usize) -> Result { + let mut cnt = 0; + for (_, span) in self.lexer.clone().spanned() { + if cnt == count { + return Ok(span.start); + } + cnt += 1; + } + if cnt == count { + return Ok(self.len()); + } + Err(nom::Needed::Unknown) + } +} + +impl<'i, T> InputLength for Tokens<'i, T> +where + T: Logos<'i, Source = str> + Clone, + T::Extras: Default + Clone, +{ + fn input_len(&self) -> usize { + self.len() + } +} + +impl<'i, T> InputTake for Tokens<'i, T> +where + T: Logos<'i, Source = str>, + T::Extras: Default, +{ + fn take(&self, count: usize) -> Self { + Tokens { + lexer: Lexer::new(&self.lexer.source()[..count]), + } + } + + fn take_split(&self, count: usize) -> (Self, Self) { + let (a, b) = self.lexer.source().split_at(count); + ( + Tokens { + lexer: Lexer::new(a), + }, + Tokens { + lexer: Lexer::new(b), + }, + ) + } +} + +#[macro_export] +#[doc(hidden)] +macro_rules! token_parser { + ( + token: $token_ty:ty $(,)? + ) => { + $crate::token_parser!( + token: $token_ty, + error<'source>(input, token): ::nom::error::Error<$crate::policy_exprs::Tokens<'source, $token_ty>> = + nom::error::Error::new(input, nom::error::ErrorKind::IsA), + ); + }; + + ( + token: $token_ty:ty, + error: $error_ty:ty = $error:expr $(,)? + ) => { + $crate::token_parser!( + token: $token_ty, + error<'source>(input, token): $error_ty = $error, + ); + }; + + ( + token: $token_ty:ty, + error<$lt:lifetime>($input:ident, $token:ident): $error_ty:ty = $error:expr $(,)? + ) => { + #[allow(unused)] + impl<$lt> ::nom::Parser< + $crate::policy_exprs::Tokens<$lt, $token_ty>, + &$lt str, + $error_ty, + > for $token_ty { + fn parse( + &mut self, + $input: $crate::policy_exprs::Tokens<$lt, $token_ty>, + ) -> ::nom::IResult< + $crate::policy_exprs::Tokens<$lt, $token_ty>, + &$lt str, + $error_ty, + > { + match $input.peek() { + ::std::option::Option::Some((::std::result::Result::Ok(__token), __s)) if __token == *self => { + ::std::result::Result::Ok(($input.advance(), __s)) + } + ::std::option::Option::Some((::std::result::Result::Err(__err), __s)) => { + // Technically this could just be the subsequent case as well, but I am + // deciding to distinguish it here. + ::std::result::Result::Err(::nom::Err::Error($error)) + } + _ => { + // This was in the original code. It appears to be unused, but I am leaving it here + // as a sort of Chesterton's Fence situation. + let $token = self; + ::std::result::Result::Err(::nom::Err::Error($error)) + }, + } + } + } + }; +} + +/// Generates a nom parser function to parse an enum variant that contains data. +#[macro_export] +#[doc(hidden)] +macro_rules! data_variant_parser { + ( + fn $fn_name:ident($input:ident) -> Result<$ok_ty:ty>; + + pattern = $type:ident :: $variant:ident $data:tt => $res:expr; + ) => { + $crate::data_variant_parser! { + fn $fn_name<'src>($input) -> Result< + $ok_ty, + ::nom::error::Error<$crate::policy_exprs::Tokens<'src, $type>>, + >; + + pattern = $type :: $variant $data => $res; + error = ::nom::error::Error::new($input, ::nom::error::ErrorKind::IsA); + } + }; + + ( + fn $fn_name:ident($input:ident) -> Result<$ok_ty:ty, $error_ty:ty $(,)?>; + + pattern = $type:ident :: $variant:ident $data:tt => $res:expr; + error = $error:expr; + ) => { + $crate::data_variant_parser! { + fn $fn_name<'src>($input) -> Result<$ok_ty, $error_ty>; + + pattern = $type :: $variant $data => $res; + error = $error; + } + }; + + ( + fn $fn_name:ident<$lt:lifetime>($input:ident) -> Result<$ok_ty:ty, $error_ty:ty $(,)?>; + + pattern = $type:ident :: $variant:ident $data:tt => $res:expr; + error = $error:expr; + ) => { + fn $fn_name<$lt>($input: $crate::policy_exprs::Tokens<$lt, $type>) -> ::nom::IResult< + $crate::policy_exprs::Tokens<$lt, $type>, + $ok_ty, + $error_ty, + > { + match $input.peek() { + ::std::option::Option::Some((::std::result::Result::Ok($type::$variant $data), _)) => { + Ok(($input.advance(), $res)) + } + _ => ::std::result::Result::Err(::nom::Err::Error($error)), + } + } + }; +} diff --git a/hipcheck/src/policy_exprs/env.rs b/hipcheck/src/policy_exprs/env.rs new file mode 100644 index 00000000..bb22184c --- /dev/null +++ b/hipcheck/src/policy_exprs/env.rs @@ -0,0 +1,774 @@ +use crate::policy_exprs::eval; +use crate::policy_exprs::Error; +use crate::policy_exprs::Expr; +use crate::policy_exprs::Ident; +use crate::policy_exprs::Primitive; +use crate::policy_exprs::Result; +use crate::policy_exprs::F64; +use itertools::Itertools as _; +use std::cmp::Ordering; +use std::collections::HashMap; +use std::ops::Not as _; +use Expr::*; +use Primitive::*; + +/// Environment, containing bindings of names to functions and variables. +pub struct Env<'parent> { + /// Map of bindings,. + bindings: HashMap, + + /// Possible pointer to parent, for lexical scope. + parent: Option<&'parent Env<'parent>>, +} + +/// A binding in the environment. +#[derive(Clone)] +pub enum Binding { + /// A function. + Fn(Op), + + /// A primitive value. + Var(Primitive), +} + +/// Helper type for operation function pointer. +type Op = fn(&Env, &[Expr]) -> Result; + +impl<'parent> Env<'parent> { + /// Create an empty environment. + fn empty() -> Self { + Env { + bindings: HashMap::new(), + parent: None, + } + } + + /// Create the standard environment. + pub fn std() -> Self { + let mut env = Env::empty(); + + // Comparison functions. + env.add_fn("gt", gt); + env.add_fn("lt", lt); + env.add_fn("gte", gte); + env.add_fn("lte", lte); + env.add_fn("eq", eq); + env.add_fn("neq", neq); + + // Math functions. + env.add_fn("add", add); + env.add_fn("sub", sub); + + // Logical functions. + env.add_fn("and", and); + env.add_fn("or", or); + env.add_fn("not", not); + + // Array math functions. + env.add_fn("max", max); + env.add_fn("min", min); + env.add_fn("avg", avg); + env.add_fn("median", median); + env.add_fn("count", count); + + // Array logic functions. + env.add_fn("all", all); + env.add_fn("nall", nall); + env.add_fn("some", some); + env.add_fn("none", none); + + // Array higher-order functions. + env.add_fn("filter", filter); + env.add_fn("foreach", foreach); + + // Debugging functions. + env.add_fn("dbg", dbg); + + env + } + + /// Create a child environment. + pub fn child(&self) -> Env<'_> { + Env { + bindings: HashMap::new(), + parent: Some(self), + } + } + + /// Add a variable to the environment. + pub fn add_var(&mut self, name: &str, value: Primitive) -> Option { + self.bindings.insert(name.to_owned(), Binding::Var(value)) + } + + /// Add a function to the environment. + pub fn add_fn(&mut self, name: &str, op: Op) -> Option { + self.bindings.insert(name.to_owned(), Binding::Fn(op)) + } + + /// Get a binding from the environment, walking up the scopes. + pub fn get(&self, name: &str) -> Option { + self.bindings + .get(name) + .cloned() + .or_else(|| self.parent.and_then(|parent| parent.get(name))) + } +} + +/// Check the number of args provided to the function. +fn check_num_args(name: &str, args: &[Expr], expected: usize) -> Result<()> { + let given = args.len(); + + match expected.cmp(&given) { + Ordering::Equal => Ok(()), + Ordering::Less => Err(Error::TooManyArgs { + name: name.to_string(), + expected, + given, + }), + Ordering::Greater => Err(Error::NotEnoughArgs { + name: name.to_string(), + expected, + given, + }), + } +} + +/// Partially evaluate a binary operation on primitives. +fn partially_evaluate(fn_name: &'static str, arg: Expr) -> Result { + let var_name = "x"; + let var = Ident(String::from(var_name)); + let func = Ident(String::from(fn_name)); + let op = Function(func, vec![Primitive(Identifier(var.clone())), arg]); + let lambda = Lambda(var, Box::new(op)); + Ok(lambda) +} + +/// Define binary operations on primitives. +fn binary_primitive_op(name: &'static str, env: &Env, args: &[Expr], op: F) -> Result +where + F: FnOnce(Primitive, Primitive) -> Result, +{ + if args.len() == 1 { + return partially_evaluate(name, args[0].clone()); + } + + check_num_args(name, args, 2)?; + + let arg_1 = match eval(env, &args[0])? { + Primitive(p) => p, + _ => return Err(Error::BadType(name)), + }; + + let arg_2 = match eval(env, &args[1])? { + Primitive(p) => p, + _ => return Err(Error::BadType(name)), + }; + + let primitive = match (&arg_1, &arg_2) { + (Int(_), Int(_)) | (Float(_), Float(_)) | (Bool(_), Bool(_)) => op(arg_1, arg_2)?, + _ => return Err(Error::BadType(name)), + }; + + Ok(Primitive(primitive)) +} + +/// Define unary operations on primitives. +fn unary_primitive_op(name: &'static str, env: &Env, args: &[Expr], op: F) -> Result +where + F: FnOnce(Primitive) -> Result, +{ + check_num_args(name, args, 1)?; + + let primitive = match eval(env, &args[0])? { + Primitive(arg) => arg, + _ => return Err(Error::BadType(name)), + }; + + Ok(Expr::Primitive(op(primitive)?)) +} + +/// Define unary operations on arrays. +fn unary_array_op(name: &'static str, env: &Env, args: &[Expr], op: F) -> Result +where + F: FnOnce(ArrayType) -> Result, +{ + check_num_args(name, args, 1)?; + + let arr = match eval(env, &args[0])? { + Array(arg) => array_type(&arg[..])?, + _ => return Err(Error::BadType(name)), + }; + + op(arr) +} + +/// Define a higher-order operation over arrays. +fn higher_order_array_op(name: &'static str, env: &Env, args: &[Expr], op: F) -> Result +where + F: FnOnce(ArrayType, Ident, Box) -> Result, +{ + check_num_args(name, args, 2)?; + + let (ident, body) = match eval(env, &args[0])? { + Lambda(ident, body) => (ident, body), + _ => return Err(Error::BadType(name)), + }; + + let arr = match eval(env, &args[1])? { + Array(arr) => array_type(&arr[..])?, + _ => return Err(Error::BadType(name)), + }; + + op(arr, ident, body) +} + +/// A fully-typed array. +enum ArrayType { + /// An array of ints. + Int(Vec), + + /// An array of floats. + Float(Vec), + + /// An array of bools. + Bool(Vec), + + /// An empty array (no type hints). + Empty, +} + +/// Process an array into a singular type, or error out. +fn array_type(arr: &[Primitive]) -> Result { + if arr.is_empty() { + return Ok(ArrayType::Empty); + } + + match &arr[0] { + Int(_) => { + let mut result: Vec = Vec::with_capacity(arr.len()); + for elem in arr { + if let Int(val) = elem { + result.push(*val); + } else { + return Err(Error::InconsistentArrayTypes); + } + } + Ok(ArrayType::Int(result)) + } + Float(_) => { + let mut result: Vec = Vec::with_capacity(arr.len()); + for elem in arr { + if let Float(val) = elem { + result.push(*val); + } else { + return Err(Error::InconsistentArrayTypes); + } + } + Ok(ArrayType::Float(result)) + } + Bool(_) => { + let mut result: Vec = Vec::with_capacity(arr.len()); + for elem in arr { + if let Bool(val) = elem { + result.push(*val); + } else { + return Err(Error::InconsistentArrayTypes); + } + } + Ok(ArrayType::Bool(result)) + } + Identifier(_) => unimplemented!("we don't currently support idents in arrays"), + } +} + +/// Evaluate the lambda, injecting into the environment. +fn eval_lambda(env: &Env, ident: &Ident, val: Primitive, body: Expr) -> Result { + let mut child = env.child(); + + if child.add_var(&ident.0, val).is_some() { + return Err(Error::AlreadyBound); + } + + eval(&child, &body) +} + +#[allow(clippy::bool_comparison)] +fn gt(env: &Env, args: &[Expr]) -> Result { + let name = "gt"; + + let op = |arg_1, arg_2| match (arg_1, arg_2) { + (Int(arg_1), Int(arg_2)) => Ok(Bool(arg_1 > arg_2)), + (Float(arg_1), Float(arg_2)) => Ok(Bool(arg_1 > arg_2)), + (Bool(arg_1), Bool(arg_2)) => Ok(Bool(arg_1 > arg_2)), + _ => unreachable!(), + }; + + binary_primitive_op(name, env, args, op) +} + +#[allow(clippy::bool_comparison)] +fn lt(env: &Env, args: &[Expr]) -> Result { + let name = "lt"; + + let op = |arg_1, arg_2| match (arg_1, arg_2) { + (Int(arg_1), Int(arg_2)) => Ok(Bool(arg_1 < arg_2)), + (Float(arg_1), Float(arg_2)) => Ok(Bool(arg_1 < arg_2)), + (Bool(arg_1), Bool(arg_2)) => Ok(Bool(arg_1 < arg_2)), + _ => unreachable!(), + }; + + binary_primitive_op(name, env, args, op) +} + +#[allow(clippy::bool_comparison)] +fn gte(env: &Env, args: &[Expr]) -> Result { + let name = "gte"; + + let op = |arg_1, arg_2| match (arg_1, arg_2) { + (Int(arg_1), Int(arg_2)) => Ok(Bool(arg_1 >= arg_2)), + (Float(arg_1), Float(arg_2)) => Ok(Bool(arg_1 >= arg_2)), + (Bool(arg_1), Bool(arg_2)) => Ok(Bool(arg_1 >= arg_2)), + _ => unreachable!(), + }; + + binary_primitive_op(name, env, args, op) +} + +#[allow(clippy::bool_comparison)] +fn lte(env: &Env, args: &[Expr]) -> Result { + let name = "lte"; + + let op = |arg_1, arg_2| match (arg_1, arg_2) { + (Int(arg_1), Int(arg_2)) => Ok(Bool(arg_1 <= arg_2)), + (Float(arg_1), Float(arg_2)) => Ok(Bool(arg_1 <= arg_2)), + (Bool(arg_1), Bool(arg_2)) => Ok(Bool(arg_1 <= arg_2)), + _ => unreachable!(), + }; + + binary_primitive_op(name, env, args, op) +} + +#[allow(clippy::bool_comparison)] +fn eq(env: &Env, args: &[Expr]) -> Result { + let name = "eq"; + + let op = |arg_1, arg_2| match (arg_1, arg_2) { + (Int(arg_1), Int(arg_2)) => Ok(Bool(arg_1 == arg_2)), + (Float(arg_1), Float(arg_2)) => Ok(Bool(arg_1 == arg_2)), + (Bool(arg_1), Bool(arg_2)) => Ok(Bool(arg_1 == arg_2)), + _ => unreachable!(), + }; + + binary_primitive_op(name, env, args, op) +} + +#[allow(clippy::bool_comparison)] +fn neq(env: &Env, args: &[Expr]) -> Result { + let name = "neq"; + + let op = |arg_1, arg_2| match (arg_1, arg_2) { + (Int(arg_1), Int(arg_2)) => Ok(Bool(arg_1 != arg_2)), + (Float(arg_1), Float(arg_2)) => Ok(Bool(arg_1 != arg_2)), + (Bool(arg_1), Bool(arg_2)) => Ok(Bool(arg_1 != arg_2)), + _ => unreachable!(), + }; + + binary_primitive_op(name, env, args, op) +} + +fn add(env: &Env, args: &[Expr]) -> Result { + let name = "add"; + + let op = |arg_1, arg_2| match (arg_1, arg_2) { + (Int(arg_1), Int(arg_2)) => Ok(Int(arg_1 + arg_2)), + (Float(arg_1), Float(arg_2)) => Ok(Float(arg_1 + arg_2)), + (Bool(_), Bool(_)) => Err(Error::BadType(name)), + _ => unreachable!(), + }; + + binary_primitive_op(name, env, args, op) +} + +fn sub(env: &Env, args: &[Expr]) -> Result { + let name = "sub"; + + let op = |arg_1, arg_2| match (arg_1, arg_2) { + (Int(arg_1), Int(arg_2)) => Ok(Int(arg_1 - arg_2)), + (Float(arg_1), Float(arg_2)) => Ok(Float(arg_1 - arg_2)), + (Bool(_), Bool(_)) => Err(Error::BadType(name)), + _ => unreachable!(), + }; + + binary_primitive_op(name, env, args, op) +} + +fn and(env: &Env, args: &[Expr]) -> Result { + let name = "and"; + + let op = |arg_1, arg_2| match (arg_1, arg_2) { + (Int(_), Int(_)) => Err(Error::BadType(name)), + (Float(_), Float(_)) => Err(Error::BadType(name)), + (Bool(arg_1), Bool(arg_2)) => Ok(Bool(arg_1 && arg_2)), + _ => unreachable!(), + }; + + binary_primitive_op(name, env, args, op) +} + +fn or(env: &Env, args: &[Expr]) -> Result { + let name = "or"; + + let op = |arg_1, arg_2| match (arg_1, arg_2) { + (Int(_), Int(_)) => Err(Error::BadType(name)), + (Float(_), Float(_)) => Err(Error::BadType(name)), + (Bool(arg_1), Bool(arg_2)) => Ok(Bool(arg_1 || arg_2)), + _ => unreachable!(), + }; + + binary_primitive_op(name, env, args, op) +} + +fn not(env: &Env, args: &[Expr]) -> Result { + let name = "not"; + + let op = |arg| match arg { + Int(_) => Err(Error::BadType(name)), + Float(_) => Err(Error::BadType(name)), + Bool(arg) => Ok(Primitive::Bool(arg.not())), + Identifier(_) => unreachable!("no idents should be here"), + }; + + unary_primitive_op(name, env, args, op) +} + +fn max(env: &Env, args: &[Expr]) -> Result { + let name = "max"; + + let op = |arg| match arg { + ArrayType::Int(ints) => ints + .iter() + .copied() + .max() + .ok_or(Error::NoMax) + .map(|m| Primitive(Int(m))), + + ArrayType::Float(floats) => floats + .iter() + .copied() + .max() + .ok_or(Error::NoMax) + .map(|m| Primitive(Float(m))), + + ArrayType::Bool(_) => Err(Error::BadType(name)), + ArrayType::Empty => Err(Error::NoMax), + }; + + unary_array_op(name, env, args, op) +} + +fn min(env: &Env, args: &[Expr]) -> Result { + let name = "min"; + + let op = |arg| match arg { + ArrayType::Int(ints) => ints + .iter() + .copied() + .min() + .ok_or(Error::NoMin) + .map(|m| Primitive(Int(m))), + + ArrayType::Float(floats) => floats + .iter() + .copied() + .min() + .ok_or(Error::NoMin) + .map(|m| Primitive(Float(m))), + + ArrayType::Bool(_) => Err(Error::BadType(name)), + ArrayType::Empty => Err(Error::NoMin), + }; + + unary_array_op(name, env, args, op) +} + +fn avg(env: &Env, args: &[Expr]) -> Result { + let name = "avg"; + + let op = |arg| match arg { + ArrayType::Int(ints) => { + let count = ints.len() as i64; + let sum = ints.iter().copied().sum::(); + Ok(Primitive(Float(F64::new(sum as f64 / count as f64)?))) + } + + ArrayType::Float(floats) => { + let count = floats.len() as i64; + let sum = floats.iter().copied().sum::(); + Ok(Primitive(Float(F64::new(sum.into_inner() / count as f64)?))) + } + + ArrayType::Bool(_) => Err(Error::BadType(name)), + ArrayType::Empty => Err(Error::NoAvg), + }; + + unary_array_op(name, env, args, op) +} + +fn median(env: &Env, args: &[Expr]) -> Result { + let name = "median"; + + let op = |arg| match arg { + ArrayType::Int(mut ints) => { + ints.sort(); + let mid = ints.len() / 2; + Ok(Primitive(Int(ints[mid]))) + } + ArrayType::Float(mut floats) => { + floats.sort(); + let mid = floats.len() / 2; + Ok(Primitive(Float(floats[mid]))) + } + ArrayType::Bool(_) => Err(Error::BadType(name)), + ArrayType::Empty => Err(Error::NoMedian), + }; + + unary_array_op(name, env, args, op) +} + +fn count(env: &Env, args: &[Expr]) -> Result { + let name = "count"; + + let op = |arg| match arg { + ArrayType::Int(ints) => Ok(Primitive(Int(ints.len() as i64))), + ArrayType::Float(floats) => Ok(Primitive(Int(floats.len() as i64))), + ArrayType::Bool(bools) => Ok(Primitive(Int(bools.len() as i64))), + ArrayType::Empty => Ok(Primitive(Int(0))), + }; + + unary_array_op(name, env, args, op) +} + +fn all(env: &Env, args: &[Expr]) -> Result { + let name = "all"; + + let op = |arr, ident: Ident, body: Box| { + let result = match arr { + ArrayType::Int(ints) => ints + .iter() + .map(|val| eval_lambda(env, &ident, Int(*val), (*body).clone())) + .process_results(|mut iter| { + iter.all(|expr| matches!(expr, Primitive(Bool(true)))) + })?, + ArrayType::Float(floats) => floats + .iter() + .map(|val| eval_lambda(env, &ident, Float(*val), (*body).clone())) + .process_results(|mut iter| { + iter.all(|expr| matches!(expr, Primitive(Bool(true)))) + })?, + ArrayType::Bool(bools) => bools + .iter() + .map(|val| eval_lambda(env, &ident, Bool(*val), (*body).clone())) + .process_results(|mut iter| { + iter.all(|expr| matches!(expr, Primitive(Bool(true)))) + })?, + ArrayType::Empty => true, + }; + + Ok(Primitive(Bool(result))) + }; + + higher_order_array_op(name, env, args, op) +} + +fn nall(env: &Env, args: &[Expr]) -> Result { + let name = "nall"; + + let op = |arr, ident: Ident, body: Box| { + let result = match arr { + ArrayType::Int(ints) => ints + .iter() + .map(|val| eval_lambda(env, &ident, Int(*val), (*body).clone())) + .process_results(|mut iter| { + iter.all(|expr| matches!(expr, Primitive(Bool(true)))).not() + })?, + ArrayType::Float(floats) => floats + .iter() + .map(|val| eval_lambda(env, &ident, Float(*val), (*body).clone())) + .process_results(|mut iter| { + iter.all(|expr| matches!(expr, Primitive(Bool(true)))).not() + })?, + ArrayType::Bool(bools) => bools + .iter() + .map(|val| eval_lambda(env, &ident, Bool(*val), (*body).clone())) + .process_results(|mut iter| { + iter.all(|expr| matches!(expr, Primitive(Bool(true)))).not() + })?, + ArrayType::Empty => false, + }; + + Ok(Primitive(Bool(result))) + }; + + higher_order_array_op(name, env, args, op) +} + +fn some(env: &Env, args: &[Expr]) -> Result { + let name = "some"; + + let op = |arr, ident: Ident, body: Box| { + let result = match arr { + ArrayType::Int(ints) => ints + .iter() + .map(|val| eval_lambda(env, &ident, Int(*val), (*body).clone())) + .process_results(|mut iter| { + iter.any(|expr| matches!(expr, Primitive(Bool(true)))) + })?, + ArrayType::Float(floats) => floats + .iter() + .map(|val| eval_lambda(env, &ident, Float(*val), (*body).clone())) + .process_results(|mut iter| { + iter.any(|expr| matches!(expr, Primitive(Bool(true)))) + })?, + ArrayType::Bool(bools) => bools + .iter() + .map(|val| eval_lambda(env, &ident, Bool(*val), (*body).clone())) + .process_results(|mut iter| { + iter.any(|expr| matches!(expr, Primitive(Bool(true)))) + })?, + ArrayType::Empty => false, + }; + + Ok(Primitive(Bool(result))) + }; + + higher_order_array_op(name, env, args, op) +} + +fn none(env: &Env, args: &[Expr]) -> Result { + let name = "none"; + + let op = |arr, ident: Ident, body: Box| { + let result = match arr { + ArrayType::Int(ints) => ints + .iter() + .map(|val| eval_lambda(env, &ident, Int(*val), (*body).clone())) + .process_results(|mut iter| { + iter.any(|expr| matches!(expr, Primitive(Bool(true)))).not() + })?, + ArrayType::Float(floats) => floats + .iter() + .map(|val| eval_lambda(env, &ident, Float(*val), (*body).clone())) + .process_results(|mut iter| { + iter.any(|expr| matches!(expr, Primitive(Bool(true)))).not() + })?, + ArrayType::Bool(bools) => bools + .iter() + .map(|val| eval_lambda(env, &ident, Bool(*val), (*body).clone())) + .process_results(|mut iter| { + iter.any(|expr| matches!(expr, Primitive(Bool(true)))).not() + })?, + ArrayType::Empty => true, + }; + + Ok(Primitive(Bool(result))) + }; + + higher_order_array_op(name, env, args, op) +} + +fn filter(env: &Env, args: &[Expr]) -> Result { + let name = "filter"; + + let op = |arr, ident: Ident, body: Box| { + let arr = match arr { + ArrayType::Int(ints) => ints + .iter() + .map(|val| Ok((val, eval_lambda(env, &ident, Int(*val), (*body).clone())))) + .filter_map_ok(|(val, expr)| { + if let Ok(Primitive(Bool(true))) = expr { + Some(Primitive::Int(*val)) + } else { + None + } + }) + .collect::>>()?, + ArrayType::Float(floats) => floats + .iter() + .map(|val| Ok((val, eval_lambda(env, &ident, Float(*val), (*body).clone())))) + .filter_map_ok(|(val, expr)| { + if let Ok(Primitive(Bool(true))) = expr { + Some(Primitive::Float(*val)) + } else { + None + } + }) + .collect::>>()?, + ArrayType::Bool(bools) => bools + .iter() + .map(|val| Ok((val, eval_lambda(env, &ident, Bool(*val), (*body).clone())))) + .filter_map_ok(|(val, expr)| { + if let Ok(Primitive(Bool(true))) = expr { + Some(Primitive::Bool(*val)) + } else { + None + } + }) + .collect::>>()?, + ArrayType::Empty => Vec::new(), + }; + + Ok(Array(arr)) + }; + + higher_order_array_op(name, env, args, op) +} + +fn foreach(env: &Env, args: &[Expr]) -> Result { + let name = "foreach"; + + let op = |arr, ident: Ident, body: Box| { + let arr = match arr { + ArrayType::Int(ints) => ints + .iter() + .map(|val| eval_lambda(env, &ident, Int(*val), (*body).clone())) + .map(|expr| match expr { + Ok(Primitive(inner)) => Ok(inner), + Ok(_) => Err(Error::BadType(name)), + Err(err) => Err(err), + }) + .collect::>>()?, + ArrayType::Float(floats) => floats + .iter() + .map(|val| eval_lambda(env, &ident, Float(*val), (*body).clone())) + .map(|expr| match expr { + Ok(Primitive(inner)) => Ok(inner), + Ok(_) => Err(Error::BadType(name)), + Err(err) => Err(err), + }) + .collect::>>()?, + ArrayType::Bool(bools) => bools + .iter() + .map(|val| eval_lambda(env, &ident, Bool(*val), (*body).clone())) + .map(|expr| match expr { + Ok(Primitive(inner)) => Ok(inner), + Ok(_) => Err(Error::BadType(name)), + Err(err) => Err(err), + }) + .collect::>>()?, + ArrayType::Empty => Vec::new(), + }; + + Ok(Array(arr)) + }; + + higher_order_array_op(name, env, args, op) +} + +fn dbg(env: &Env, args: &[Expr]) -> Result { + let name = "dbg"; + check_num_args(name, args, 1)?; + let arg = &args[0]; + let result = eval(env, arg)?; + log::debug!("{arg} = {result}"); + Ok(result) +} diff --git a/hipcheck/src/policy_exprs/error.rs b/hipcheck/src/policy_exprs/error.rs new file mode 100644 index 00000000..586fb3b7 --- /dev/null +++ b/hipcheck/src/policy_exprs/error.rs @@ -0,0 +1,97 @@ +use crate::policy_exprs::{Expr, Ident, LexingError}; +use nom::{error::ErrorKind, Needed}; +use ordered_float::FloatIsNan; + +/// `Result` which uses [`Error`]. +pub type Result = std::result::Result; + +/// An error arising during program execution. +#[derive(Debug, thiserror::Error)] +pub enum Error { + #[error("missing close paren")] + MissingOpenParen, + + #[error("missing open paren")] + MissingCloseParen, + + #[error("missing ident")] + MissingIdent, + + #[error("wrong type in ident spot")] + WrongTypeInIdentSpot, + + #[error("missing args")] + MissingArgs, + + #[error(transparent)] + Lex(#[from] LexingError), + + #[error("expression returned '{0:?}', not a boolean")] + DidNotReturnBool(Expr), + + #[error("tried to call unknown function '{0}'")] + UnknownFunction(String), + + #[error("ident '{0}' resolved to a variable, not a function")] + FoundVarExpectedFunc(String), + + #[error("parsing did not consume the entire input {}", needed_str(.0))] + IncompleteParse(Needed), + + #[error("parse failed with kind '{kind:?}', with '{remaining}' remaining")] + Parse { remaining: String, kind: ErrorKind }, + + #[error(transparent)] + FloatIsNan(#[from] FloatIsNan), + + #[error("too many args to '{name}'; expected {expected}, got {given}")] + TooManyArgs { + name: String, + expected: usize, + given: usize, + }, + + #[error("not enough args to '{name}'; expected {expected}, got {given}")] + NotEnoughArgs { + name: String, + expected: usize, + given: usize, + }, + + #[error("called '{0}' with mismatched types")] + BadType(&'static str), + + #[error("no max value found in array")] + NoMax, + + #[error("no min value found in array")] + NoMin, + + #[error("no avg value found for array")] + NoAvg, + + #[error("no median value found for array")] + NoMedian, + + #[error("array mixing multiple primitive types")] + InconsistentArrayTypes, + + #[error("variable '{0}' is not bound")] + UnboundVar(Ident), + + #[error("variable '{0}' conflicts with function")] + VarConflictsWithFunc(Ident), + + #[error("variable '{checked}' resolves to another variable '{found}'")] + VarResolvesToVar { checked: Ident, found: Ident }, + + #[error("variable is already bound")] + AlreadyBound, +} + +fn needed_str(needed: &Needed) -> String { + match needed { + Needed::Unknown => String::from(""), + Needed::Size(bytes) => format!(", needed {} more bytes", bytes), + } +} diff --git a/hipcheck/src/policy_exprs/expr.rs b/hipcheck/src/policy_exprs/expr.rs new file mode 100644 index 00000000..a04c3d98 --- /dev/null +++ b/hipcheck/src/policy_exprs/expr.rs @@ -0,0 +1,280 @@ +use crate::policy_exprs::env::Binding; +use crate::policy_exprs::env::Env; +use crate::policy_exprs::token::Token; +use crate::policy_exprs::Error; +use crate::policy_exprs::Result; +use crate::policy_exprs::Tokens; +use itertools::Itertools; +use nom::branch::alt; +use nom::combinator::all_consuming; +use nom::combinator::map; +use nom::multi::many0; +use nom::sequence::tuple; +use nom::Finish as _; +use nom::IResult; +use ordered_float::NotNan; +use std::fmt::Display; +use std::ops::Deref; + +/// A `deke` expression to evaluate. +#[derive(Debug, PartialEq, Eq, Clone)] +pub enum Expr { + /// Primitive data (ints, floats, bool). + Primitive(Primitive), + + /// An array of primitive data. + Array(Vec), + + /// Stores the name of the function, followed by the args. + Function(Ident, Vec), + + /// Stores the name of the input variable, followed by the lambda body. + Lambda(Ident, Box), +} + +/// Primitive data. +#[derive(Debug, PartialEq, Eq, Clone)] +pub enum Primitive { + /// Identifier in a lambda, to be substituted. + Identifier(Ident), + + /// Signed 64-bit integer. + Int(i64), + + /// 64-bit float, not allowed to be NaN. + Float(F64), + + /// Boolean. + Bool(bool), +} + +/// A variable or function identifier. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Ident(pub String); + +/// A non-NaN 64-bit floating point number. +pub type F64 = NotNan; + +impl Display for Expr { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Expr::Primitive(primitive) => write!(f, "{}", primitive), + Expr::Array(array) => { + write!(f, "[{}]", array.iter().map(ToString::to_string).join(" ")) + } + Expr::Function(ident, args) => { + let args = args.iter().map(ToString::to_string).join(" "); + write!(f, "({} {})", ident, args) + } + Expr::Lambda(arg, body) => write!(f, "(lambda ({}) {}", arg, body), + } + } +} + +impl Display for Primitive { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Primitive::Identifier(ident) => write!(f, "{}", ident), + Primitive::Int(i) => write!(f, "{}", i), + Primitive::Float(fl) => write!(f, "{}", fl), + Primitive::Bool(b) => write!(f, "{}", if *b { "#t" } else { "#f" }), + } + } +} + +impl Primitive { + pub fn resolve(&self, env: &Env) -> Result { + match self { + Primitive::Identifier(ident) => match env.get(ident) { + Some(Binding::Var(Primitive::Identifier(found))) => Err(Error::VarResolvesToVar { + checked: ident.clone(), + found, + }), + Some(Binding::Var(var)) => Ok(var), + Some(Binding::Fn(_)) => Err(Error::VarConflictsWithFunc(ident.clone())), + None => Err(Error::UnboundVar(ident.clone())), + }, + _ => Ok(self.clone()), + } + } +} + +impl Display for Ident { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + +impl Deref for Ident { + type Target = str; + + fn deref(&self) -> &Self::Target { + self.0.deref() + } +} + +crate::token_parser!(token: Token); + +crate::data_variant_parser! { + fn parse_integer(input) -> Result; + pattern = Token::Integer(n) => Primitive::Int(n); +} + +crate::data_variant_parser! { + fn parse_float(input) -> Result; + pattern = Token::Float(f) => Primitive::Float(f); +} + +crate::data_variant_parser! { + fn parse_bool(input) -> Result; + pattern = Token::Bool(b) => Primitive::Bool(b); +} + +crate::data_variant_parser! { + fn parse_ident(input) -> Result; + pattern = Token::Ident(s) => s.to_owned(); +} + +// Helper type for token parsing. +pub type Input<'source> = Tokens<'source, Token>; + +/// Parse a single piece of primitive data. +fn parse_primitive(input: Input<'_>) -> IResult, Primitive> { + alt((parse_integer, parse_float, parse_bool))(input) +} + +/// Parse an array. +fn parse_array(input: Input<'_>) -> IResult, Expr> { + let parser = tuple((Token::OpenBrace, many0(parse_primitive), Token::CloseBrace)); + let mut parser = map(parser, |(_, inner, _)| Expr::Array(inner)); + parser(input) +} + +/// Parse an expression. +fn parse_expr(input: Input<'_>) -> IResult, Expr> { + let primitive = map(parse_primitive, Expr::Primitive); + alt((primitive, parse_array, parse_function))(input) +} + +/// Parse a function call. +fn parse_function(input: Input<'_>) -> IResult, Expr> { + let parser = tuple(( + Token::OpenParen, + parse_ident, + many0(parse_expr), + Token::CloseParen, + )); + let mut parser = map(parser, |(_, ident, args, _)| { + Expr::Function(Ident(ident), args) + }); + parser(input) +} + +pub fn parse(input: &str) -> Result { + let tokens = Tokens::new(input); + let mut parser = all_consuming(parse_function); + + match parser(tokens).finish() { + Ok((rest, expr)) if rest.is_empty() => Ok(expr), + Ok(_) => Err(Error::IncompleteParse(nom::Needed::Unknown)), + Err(err) => { + let remaining = err.input.lexer().slice().to_string(); + let kind = err.code; + Err(Error::Parse { remaining, kind }) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use test_log::test; + + trait IntoExpr { + fn into_expr(self) -> Expr; + } + + impl IntoExpr for Expr { + fn into_expr(self) -> Expr { + self + } + } + + impl IntoExpr for Primitive { + fn into_expr(self) -> Expr { + Expr::Primitive(self) + } + } + + fn func(name: &str, args: Vec) -> Expr { + let args = args.into_iter().map(|arg| arg.into_expr()).collect(); + Expr::Function(Ident(String::from(name)), args) + } + + fn int(val: i64) -> Primitive { + Primitive::Int(val) + } + + fn float(val: f64) -> Primitive { + Primitive::Float(F64::new(val).unwrap()) + } + + #[allow(unused)] + fn boolean(val: bool) -> Primitive { + Primitive::Bool(val) + } + + fn array(vals: Vec) -> Expr { + Expr::Array(vals) + } + + #[test] + fn parse_function() { + let input = "(add 2 3)"; + let expected = func("add", vec![int(2), int(3)]); + let result = parse(input).unwrap(); + assert_eq!(result, expected); + } + + #[test] + fn parse_nested_function() { + let input = "(add (add 1 2) 3)"; + let expected = func( + "add", + vec![func("add", vec![int(1), int(2)]), int(3).into_expr()], + ); + let result = parse(input).unwrap(); + assert_eq!(result, expected); + } + + #[test] + fn parse_array() { + let input = "(eq 0 (count (filter (gt 8.0) [1.0 2.0 10.0 20.0 30.0])))"; + + let expected = func( + "eq", + vec![ + int(0).into_expr(), + func( + "count", + vec![func( + "filter", + vec![ + func("gt", vec![float(8.0)]), + array(vec![ + float(1.0), + float(2.0), + float(10.0), + float(20.0), + float(30.0), + ]), + ], + )], + ), + ], + ); + + let result = parse(input).unwrap(); + assert_eq!(result, expected); + } +} diff --git a/hipcheck/src/policy_exprs/mod.rs b/hipcheck/src/policy_exprs/mod.rs new file mode 100644 index 00000000..57c6b91a --- /dev/null +++ b/hipcheck/src/policy_exprs/mod.rs @@ -0,0 +1,141 @@ +#![allow(unused)] + +mod bridge; +mod env; +mod error; +mod expr; +mod token; + +pub(crate) use crate::policy_exprs::bridge::Tokens; +use crate::policy_exprs::env::Env; +pub use crate::policy_exprs::error::Error; +pub use crate::policy_exprs::error::Result; +pub use crate::policy_exprs::expr::Expr; +pub use crate::policy_exprs::expr::Ident; +pub(crate) use crate::policy_exprs::expr::F64; +pub use crate::policy_exprs::token::LexingError; +use env::Binding; +use expr::parse; +pub use expr::Primitive; +use std::ops::Deref; + +/// Evaluates `deke` expressions. +pub struct Executor { + env: Env<'static>, +} + +impl Executor { + /// Create an `Executor` with the standard set of functions defined. + pub fn std() -> Self { + Executor { env: Env::std() } + } + + /// Run a `deke` program. + pub fn run(&self, raw_program: &str) -> Result { + match self.parse_and_eval(raw_program)? { + Expr::Primitive(Primitive::Bool(b)) => Ok(b), + result => Err(Error::DidNotReturnBool(result)), + } + } + + /// Run a `deke` program, but don't try to convert the result to a `bool`. + pub fn parse_and_eval(&self, raw_program: &str) -> Result { + let program = parse(raw_program)?; + let expr = eval(&self.env, &program)?; + Ok(expr) + } +} + +/// Evaluate the `Expr`, returning a boolean. +pub(crate) fn eval(env: &Env, program: &Expr) -> Result { + let output = match program { + Expr::Primitive(primitive) => Ok(Expr::Primitive(primitive.resolve(env)?)), + Expr::Array(_) => Ok(program.clone()), + Expr::Function(name, args) => { + let binding = env + .get(name) + .ok_or_else(|| Error::UnknownFunction(name.deref().to_owned()))?; + + if let Binding::Fn(op) = binding { + op(env, args) + } else { + Err(Error::FoundVarExpectedFunc(name.deref().to_owned())) + } + } + Expr::Lambda(_, body) => Ok((**body).clone()), + }; + + log::debug!("input: {program:?}, output: {output:?}"); + + output +} + +#[cfg(test)] +mod tests { + use super::*; + use test_log::test; + + #[test] + fn run_basic() { + let program = "(eq (add 1 2) 3)"; + let is_true = Executor::std().run(program).unwrap(); + assert!(is_true); + } + + #[test] + fn eval_basic() { + let program = "(add 1 2)"; + let result = Executor::std().parse_and_eval(program).unwrap(); + assert_eq!(result, Expr::Primitive(Primitive::Int(3))); + } + + #[test] + fn eval_bools() { + let program = "(neq 1 2)"; + let result = Executor::std().parse_and_eval(program).unwrap(); + assert_eq!(result, Expr::Primitive(Primitive::Bool(true))); + } + + #[test] + fn eval_array() { + let program = "(max [1 4 6 10 2 3 0])"; + let result = Executor::std().parse_and_eval(program).unwrap(); + assert_eq!(result, Expr::Primitive(Primitive::Int(10))); + } + + #[test] + fn run_array() { + let program = "(eq 7 (count [1 4 6 10 2 3 0]))"; + let is_true = Executor::std().run(program).unwrap(); + assert!(is_true); + } + + #[test] + fn eval_higher_order_func() { + let program = "(eq 3 (count (filter (gt 8.0) [1.0 2.0 10.0 20.0 30.0])))"; + let result = Executor::std().parse_and_eval(program).unwrap(); + assert_eq!(result, Expr::Primitive(Primitive::Bool(true))); + } + + #[test] + fn eval_foreach() { + let program = + "(eq 3 (count (filter (gt 8.0) (foreach (sub 1.0) [1.0 2.0 10.0 20.0 30.0]))))"; + let result = Executor::std().parse_and_eval(program).unwrap(); + assert_eq!(result, Expr::Primitive(Primitive::Bool(true))); + } + + #[test] + fn eval_basic_filter() { + let program = "(filter (eq 0) [1 0 1 0 0 1 2])"; + let result = Executor::std().parse_and_eval(program).unwrap(); + assert_eq!( + result, + Expr::Array(vec![ + Primitive::Int(0), + Primitive::Int(0), + Primitive::Int(0) + ]) + ); + } +} diff --git a/hipcheck/src/policy_exprs/token.rs b/hipcheck/src/policy_exprs/token.rs new file mode 100644 index 00000000..0d8b60ca --- /dev/null +++ b/hipcheck/src/policy_exprs/token.rs @@ -0,0 +1,164 @@ +use crate::policy_exprs::F64; +use logos::Lexer; +use logos::Logos; +use ordered_float::FloatIsNan; +use std::fmt::Display; +use std::num::ParseFloatError; +use std::num::ParseIntError; + +type Result = std::result::Result; + +#[derive(Logos, Clone, Debug, PartialEq)] +#[logos(skip r"[ \t\n\f]+", error = LexingError)] +pub enum Token { + #[token("(")] + OpenParen, + + #[token(")")] + CloseParen, + + #[token("[")] + OpenBrace, + + #[token("]")] + CloseBrace, + + #[regex(r"\#[tf]", lex_bool)] + Bool(bool), + + #[regex(r"-?(?:0|[1-9]\d*)(?:\.\d+)?(?:[eE][+-]?\d+)?", lex_float)] + Float(F64), + + #[regex(r"([1-9]?[0-9]*)", lex_integer, priority = 20)] + Integer(i64), + + #[regex("([a-zA-Z]+)", lex_ident)] + Ident(String), +} + +/// Lex a single boolean. +fn lex_bool(input: &mut Lexer<'_, Token>) -> Result { + match input.slice() { + "#t" => Ok(true), + "#f" => Ok(false), + value => Err(LexingError::InvalidBool(String::from(value))), + } +} + +/// Lex a single integer. +fn lex_integer(input: &mut Lexer<'_, Token>) -> Result { + let s = input.slice(); + let i = s + .parse::() + .map_err(|err| LexingError::InvalidInteger(s.to_string(), err))?; + Ok(i) +} + +/// Lex a single float. +fn lex_float(input: &mut Lexer<'_, Token>) -> Result { + let s = input.slice(); + let f = s + .parse::() + .map_err(|err| LexingError::InvalidFloat(s.to_string(), err))?; + Ok(F64::new(f)?) +} + +/// Lex a single identifier. +fn lex_ident(input: &mut Lexer<'_, Token>) -> Result { + Ok(input.slice().to_owned()) +} + +impl Display for Token { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Token::OpenParen => write!(f, "("), + Token::CloseParen => write!(f, ")"), + Token::OpenBrace => write!(f, "["), + Token::CloseBrace => write!(f, "]"), + Token::Bool(true) => write!(f, "#t"), + Token::Bool(false) => write!(f, "#f"), + Token::Integer(i) => write!(f, "{i}"), + Token::Float(fl) => write!(f, "{fl}"), + Token::Ident(i) => write!(f, "{i}"), + } + } +} + +/// Error arising during lexing. +#[derive(Default, Debug, Clone, PartialEq, thiserror::Error)] +pub enum LexingError { + #[error("an unknown lexing error occured")] + #[default] + UnknownError, + + #[error("failed to parse integer")] + InvalidInteger(String, ParseIntError), + + #[error("failed to parse float")] + InvalidFloat(String, ParseFloatError), + + #[error("float is not a number")] + FloatIsNan(#[from] FloatIsNan), + + #[error("invalid boolean, found '{0}'")] + InvalidBool(String), +} + +#[cfg(test)] +mod tests { + use crate::policy_exprs::token::Token; + use crate::policy_exprs::Result; + use crate::policy_exprs::F64; + use logos::Logos as _; + use test_log::test; + + // Helper function for running the lexer to get all tokens. + fn lex(input: &str) -> Result> { + let tokens = Token::lexer(input) + .map(|res| res.map_err(Into::into)) + .collect::>>()?; + Ok(tokens) + } + + #[test] + fn basic_lexing() { + let raw_program = "(add 1 2)"; + let expected = vec![ + Token::OpenParen, + Token::Ident(String::from("add")), + Token::Integer(1), + Token::Integer(2), + Token::CloseParen, + ]; + let tokens = lex(raw_program).unwrap(); + assert_eq!(tokens, expected); + } + + #[test] + fn basic_lexing_with_floats() { + let raw_program = "(add 1.0 2.0)"; + let expected = vec![ + Token::OpenParen, + Token::Ident(String::from("add")), + Token::Float(F64::new(1.0).unwrap()), + Token::Float(F64::new(2.0).unwrap()), + Token::CloseParen, + ]; + let tokens = lex(raw_program).unwrap(); + assert_eq!(tokens, expected); + } + + #[test] + fn basic_lexing_with_bools() { + let raw_program = "(eq #t #f)"; + let expected = vec![ + Token::OpenParen, + Token::Ident(String::from("eq")), + Token::Bool(true), + Token::Bool(false), + Token::CloseParen, + ]; + let tokens = lex(raw_program).unwrap(); + assert_eq!(tokens, expected); + } +}