Skip to content

Commit

Permalink
feat: add --arch flag to hc check subcommand to allow architecture …
Browse files Browse the repository at this point in the history
…detection override

We had previously encountered problems with cross-compilation to
niche-but-valid architectures, where the compile-time target detection
would fail, causing `CURRENT_ARCH` to be `()` instead of a valid arch
enum variant. We changed this field to be type `Option<SupportedArch>`
and also support a user-provided `--arch` flag to override the detected
arch at runtime if the user has good reason to do so.
  • Loading branch information
j-lanson authored and alilleybrinker committed Sep 20, 2024
1 parent 36a18a6 commit 5817dd4
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 32 deletions.
4 changes: 4 additions & 0 deletions hipcheck/src/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use crate::{
error::Context,
error::Result,
hc_error,
plugin::SupportedArch,
session::pm,
shell::{color_choice::ColorChoice, verbosity::Verbosity},
source,
Expand Down Expand Up @@ -416,6 +417,9 @@ pub struct CheckArgs {
#[clap(subcommand)]
command: Option<CheckCommand>,

#[arg(long = "arch")]
pub arch: Option<SupportedArch>,

#[arg(short = 't', long = "target")]
pub target_type: Option<TargetType>,
#[arg(
Expand Down
10 changes: 6 additions & 4 deletions hipcheck/src/engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ use crate::{
cache::plugin::HcPluginCache,
hc_error,
plugin::{
get_plugin_key, retrieve_plugins, Plugin, PluginManifest, PluginResponse, QueryResult,
CURRENT_ARCH,
get_plugin_key, retrieve_plugins, try_get_current_arch, Plugin, PluginManifest,
PluginResponse, QueryResult,
},
policy::PolicyFile,
util::fs::{find_file_by_name, read_string},
Expand Down Expand Up @@ -225,6 +225,8 @@ pub fn start_plugins(
/* jitter_percent */ 10,
)?;

let current_arch = try_get_current_arch()?;

// retrieve, verify and extract all required plugins
let required_plugin_names = retrieve_plugins(&policy_file.plugins.0, plugin_cache)?;

Expand All @@ -241,11 +243,11 @@ pub fn start_plugins(
let contents = read_string(&plugin_kdl)?;
let plugin_manifest = PluginManifest::from_str(contents.as_str())?;
let entrypoint = plugin_manifest
.get_entrypoint(CURRENT_ARCH)
.get_entrypoint(current_arch)
.ok_or_else(|| {
hc_error!(
"Could not find {} entrypoint for {}/{} {}",
CURRENT_ARCH,
current_arch,
plugin_id.publisher.0,
plugin_id.name.0,
plugin_id.version.0
Expand Down
9 changes: 8 additions & 1 deletion hipcheck/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ use crate::{
cli::Format,
config::WeightTreeProvider,
error::{Context as _, Error, Result},
plugin::{Plugin, PluginExecutor, PluginWithConfig},
plugin::{try_set_arch, Plugin, PluginExecutor, PluginWithConfig},
report::report_builder::{build_report, Report},
session::Session,
setup::{resolve_and_transform_source, SourceType},
Expand Down Expand Up @@ -118,6 +118,13 @@ fn main() -> ExitCode {

/// Run the `check` command.
fn cmd_check(args: &CheckArgs, config: &CliConfig) -> ExitCode {
// Before we do any analysis, set the user-provided arch
if let Some(arch) = args.arch {
if let Err(e) = try_set_arch(arch) {
Shell::print_error(&e, Format::Human);
return ExitCode::FAILURE;
}
}
let target = match args.to_target_seed() {
Ok(target) => target,
Err(e) => {
Expand Down
6 changes: 4 additions & 2 deletions hipcheck/src/plugin/download_manifest.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::{
plugin::{
retrieval::{download_plugin, extract_plugin},
supported_arch::SupportedArch,
CURRENT_ARCH,
try_get_current_arch,
},
util::kdl::{extract_data, ParseKdlNode},
util::{
Expand Down Expand Up @@ -288,6 +288,8 @@ impl DownloadManifestEntry {
version: &PluginVersion,
downloaded_plugins: &'a mut HashSet<PluginId>,
) -> Result<&'a HashSet<PluginId>, Error> {
let current_arch = try_get_current_arch()?;

let plugin_id = PluginId::new(publisher.clone(), name.clone(), version.clone());

if downloaded_plugins.contains(&plugin_id) {
Expand Down Expand Up @@ -323,7 +325,7 @@ impl DownloadManifestEntry {
publisher.0,
name.0,
version.0,
CURRENT_ARCH
current_arch,
)
})?;

Expand Down
2 changes: 1 addition & 1 deletion hipcheck/src/plugin/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ pub use plugin_manifest::{PluginManifest, PluginName, PluginPublisher, PluginVer
pub use retrieval::retrieve_plugins;
use serde_json::Value;
use std::collections::HashMap;
pub use supported_arch::CURRENT_ARCH;
pub use supported_arch::{try_get_current_arch, try_set_arch, SupportedArch};
use tokio::sync::Mutex;

pub async fn initialize_plugins(
Expand Down
84 changes: 60 additions & 24 deletions hipcheck/src/plugin/supported_arch.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
// SPDX-License-Identifier: Apache-2.0

use crate::error::Result;
use crate::hc_error;
use std::{fmt::Display, str::FromStr};
use clap::ValueEnum;
use std::{fmt::Display, result::Result as StdResult, str::FromStr, sync::OnceLock};

#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, ValueEnum)]
/// Officially supported target triples, as of RFD #0004
///
/// NOTE: these architectures correspond to the offically supported Rust platforms
Expand All @@ -18,36 +20,70 @@ pub enum SupportedArch {
X86_64UnknownLinuxGnu,
}

/// Architecture `hc` was built for
pub const CURRENT_ARCH: SupportedArch = {
#[cfg(target_arch = "x86_64")]
{
#[cfg(target_os = "macos")]
{
SupportedArch::X86_64AppleDarwin
pub const DETECTED_ARCH: Option<SupportedArch> = {
if cfg!(target_arch = "x86_64") {
if cfg!(target_os = "macos") {
Some(SupportedArch::X86_64AppleDarwin)
} else if cfg!(target_os = "linux") {
Some(SupportedArch::X86_64UnknownLinuxGnu)
} else if cfg!(target_os = "windows") {
Some(SupportedArch::X86_64PcWindowsMsvc)
} else {
None
}
#[cfg(target_os = "linux")]
{
SupportedArch::X86_64UnknownLinuxGnu
}
#[cfg(target_os = "windows")]
{
SupportedArch::X86_64PcWindowsMsvc
}
}
#[cfg(target_arch = "aarch64")]
{
#[cfg(target_os = "macos")]
{
SupportedArch::Aarch64AppleDarwin
} else if cfg!(target_arch = "aarch64") {
if cfg!(target_os = "macos") {
Some(SupportedArch::Aarch64AppleDarwin)
} else {
None
}
} else {
None
}
};

pub static USER_PROVIDED_ARCH: OnceLock<SupportedArch> = OnceLock::new();

/// Get the target architecture for plugins. If the user provided a target,
/// return that. Otherwise, if the `hc` binary was compiled for a supported
/// architecture, return that. Otherwise return None.
pub fn get_current_arch() -> Option<SupportedArch> {
if let Some(arch) = USER_PROVIDED_ARCH.get() {
Some(*arch)
} else if DETECTED_ARCH.is_some() {
DETECTED_ARCH
} else {
None
}
}

/// Like `get_current_arch()`, but returns an error message suggesting the
/// user specifies a target on the CLI
pub fn try_get_current_arch() -> Result<SupportedArch> {
if let Some(arch) = get_current_arch() {
Ok(arch)
} else {
Err(hc_error!("Could not resolve the current machine to one of the Hipcheck supported architectures. Please specify --arch on the commandline."))
}
}

pub fn try_set_arch(arch: SupportedArch) -> Result<()> {
let set_arch = USER_PROVIDED_ARCH.get_or_init(|| arch);
if *set_arch == arch {
Ok(())
} else {
Err(hc_error!(
"Architecture could not be set to {}, has already been set to {}",
arch,
set_arch
))
}
}

impl FromStr for SupportedArch {
type Err = crate::Error;

fn from_str(s: &str) -> Result<Self, Self::Err> {
fn from_str(s: &str) -> StdResult<Self, Self::Err> {
match s {
"aarch64-apple-darwin" => Ok(Self::Aarch64AppleDarwin),
"x86_64-apple-darwin" => Ok(Self::X86_64AppleDarwin),
Expand Down

0 comments on commit 5817dd4

Please sign in to comment.