Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🔒🌯 Provide LLVMTypeWrapper trait, add more safe wrappers #223

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 41 additions & 28 deletions src/linker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,15 @@ use std::{
use ar::Archive;
use llvm_sys::{
bit_writer::LLVMWriteBitcodeToFile,
core::{
LLVMContextCreate, LLVMContextDispose, LLVMContextSetDiagnosticHandler, LLVMDisposeModule,
LLVMGetTarget,
},
core::{LLVMContextCreate, LLVMContextDispose, LLVMContextSetDiagnosticHandler, LLVMGetTarget},
error_handling::{LLVMEnablePrettyStackTrace, LLVMInstallFatalErrorHandler},
prelude::{LLVMContextRef, LLVMModuleRef},
prelude::LLVMContextRef,
target_machine::{LLVMCodeGenFileType, LLVMDisposeTargetMachine, LLVMTargetMachineRef},
};
use thiserror::Error;
use tracing::{debug, error, info, warn};

use crate::llvm;
use crate::llvm::{self, LLVMTypeError, LLVMTypeWrapper, Module};

/// Linker error
#[derive(Debug, Error)]
Expand Down Expand Up @@ -77,6 +74,10 @@ pub enum LinkerError {
/// The input object file does not have embedded bitcode.
#[error("no bitcode section found in {0}")]
MissingBitcodeSection(PathBuf),

/// Instantiating of an LLVM type failed.
#[error(transparent)]
LLVMType(#[from] LLVMTypeError),
}

/// BPF Cpu type
Expand Down Expand Up @@ -224,21 +225,26 @@ pub struct LinkerOptions {
}

/// BPF Linker
pub struct Linker {
pub struct Linker<'ctx> {
options: LinkerOptions,
context: LLVMContextRef,
module: LLVMModuleRef,
module: Module<'ctx>,
target_machine: LLVMTargetMachineRef,
diagnostic_handler: DiagnosticHandler,
}

impl Linker {
impl Linker<'_> {
/// Create a new linker instance with the given options.
pub fn new(options: LinkerOptions) -> Self {
let context = unsafe { LLVMContextCreate() };
let module = Module::new(
options.output.file_stem().unwrap().to_str().unwrap(),
context,
);
Linker {
options,
context: ptr::null_mut(),
module: ptr::null_mut(),
context,
module,
target_machine: ptr::null_mut(),
diagnostic_handler: DiagnosticHandler::new(),
}
Expand Down Expand Up @@ -365,7 +371,7 @@ impl Linker {
Archive => panic!("nested archives not supported duh"),
};

if unsafe { !llvm::link_bitcode_buffer(self.context, self.module, &bitcode) } {
if unsafe { !llvm::link_bitcode_buffer(self.context, self.module.as_ptr(), &bitcode) } {
return Err(LinkerError::LinkModuleError(path.to_owned()));
}

Expand Down Expand Up @@ -407,11 +413,11 @@ impl Linker {
})
}
None => {
let c_triple = unsafe { LLVMGetTarget(*module) };
let c_triple = unsafe { LLVMGetTarget(module.as_ptr()) };
let triple = unsafe { CStr::from_ptr(c_triple) }.to_str().unwrap();
if triple.starts_with("bpf") {
// case 2
(triple, unsafe { llvm::target_from_module(*module) })
(triple, unsafe { llvm::target_from_module(module.as_ptr()) })
} else {
// case 3.
info!("detected non-bpf input target {} and no explicit output --target specified, selecting `bpf'", triple);
Expand Down Expand Up @@ -452,17 +458,18 @@ impl Linker {

if self.options.btf {
// if we want to emit BTF, we need to sanitize the debug information
llvm::DISanitizer::new(self.context, self.module).run(&self.options.export_symbols);
llvm::DISanitizer::new(self.context, &self.module)
.run(&mut self.module, &self.options.export_symbols)?;
} else {
// if we don't need BTF emission, we can strip DI
let ok = unsafe { llvm::strip_debug_info(self.module) };
let ok = unsafe { llvm::strip_debug_info(self.module.as_ptr()) };
debug!("Stripping DI, changed={}", ok);
}

unsafe {
llvm::optimize(
self.target_machine,
self.module,
&mut self.module,
self.options.optimize,
self.options.ignore_inline_never,
&self.options.export_symbols,
Expand All @@ -486,7 +493,7 @@ impl Linker {
fn write_bitcode(&mut self, output: &CStr) -> Result<(), LinkerError> {
info!("writing bitcode to {:?}", output);

if unsafe { LLVMWriteBitcodeToFile(self.module, output.as_ptr()) } == 1 {
if unsafe { LLVMWriteBitcodeToFile(self.module.as_ptr(), output.as_ptr()) } == 1 {
return Err(LinkerError::WriteBitcodeError);
}

Expand All @@ -496,14 +503,21 @@ impl Linker {
fn write_ir(&mut self, output: &CStr) -> Result<(), LinkerError> {
info!("writing IR to {:?}", output);

unsafe { llvm::write_ir(self.module, output) }.map_err(LinkerError::WriteIRError)
unsafe { llvm::write_ir(self.module.as_ptr(), output) }.map_err(LinkerError::WriteIRError)
}

fn emit(&mut self, output: &CStr, output_type: LLVMCodeGenFileType) -> Result<(), LinkerError> {
info!("emitting {:?} to {:?}", output_type, output);

unsafe { llvm::codegen(self.target_machine, self.module, output, output_type) }
.map_err(LinkerError::EmitCodeError)
unsafe {
llvm::codegen(
self.target_machine,
self.module.as_ptr(),
output,
output_type,
)
}
.map_err(LinkerError::EmitCodeError)
}

fn llvm_init(&mut self) {
Expand Down Expand Up @@ -542,24 +556,23 @@ impl Linker {
);
LLVMInstallFatalErrorHandler(Some(llvm::fatal_error));
LLVMEnablePrettyStackTrace();
self.module = llvm::create_module(
self.module = Module::new(
self.options.output.file_stem().unwrap().to_str().unwrap(),
self.context,
)
.unwrap();
);
}
}
}

impl Drop for Linker {
impl Drop for Linker<'_> {
fn drop(&mut self) {
unsafe {
if !self.target_machine.is_null() {
LLVMDisposeTargetMachine(self.target_machine);
}
if !self.module.is_null() {
LLVMDisposeModule(self.module);
}
// if !self.module.is_null() {
// LLVMDisposeModule(self.module);
// }
if !self.context.is_null() {
LLVMContextDispose(self.context);
}
Expand Down
95 changes: 52 additions & 43 deletions src/llvm/di.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,24 @@ use std::{
collections::{hash_map::DefaultHasher, HashMap, HashSet},
ffi::c_char,
hash::Hasher,
ptr,
ptr::{self, NonNull},
};

use gimli::{DW_TAG_pointer_type, DW_TAG_structure_type, DW_TAG_variant_part};
use llvm_sys::{core::*, debuginfo::*, prelude::*};
use tracing::{span, trace, warn, Level};

use super::types::{
di::DIType,
ir::{Function, MDNode, Metadata, Value},
use crate::llvm::{
iter::*,
types::{
di::{DISubprogram, DIType},
ir::{
Argument, Function, GlobalAlias, GlobalVariable, Instruction, MDNode, Metadata, Module,
Value,
},
LLVMTypeError, LLVMTypeWrapper,
},
};
use crate::llvm::{iter::*, types::di::DISubprogram};

// KSYM_NAME_LEN from linux kernel intentionally set
// to lower value found accross kernel versions to ensure
Expand All @@ -23,7 +29,6 @@ const MAX_KSYM_NAME_LEN: usize = 128;

pub struct DISanitizer {
context: LLVMContextRef,
module: LLVMModuleRef,
builder: LLVMDIBuilderRef,
visited_nodes: HashSet<u64>,
replace_operands: HashMap<u64, LLVMMetadataRef>,
Expand Down Expand Up @@ -59,11 +64,11 @@ fn sanitize_type_name<T: AsRef<str>>(name: T) -> String {
}

impl DISanitizer {
pub fn new(context: LLVMContextRef, module: LLVMModuleRef) -> DISanitizer {
pub fn new(context: LLVMContextRef, module: &Module<'_>) -> Self {
let builder = unsafe { LLVMCreateDIBuilder(module.as_ptr()) };
DISanitizer {
context,
module,
builder: unsafe { LLVMCreateDIBuilder(module) },
builder,
visited_nodes: HashSet::new(),
replace_operands: HashMap::new(),
skipped_types: Vec::new(),
Expand Down Expand Up @@ -227,7 +232,7 @@ impl DISanitizer {
// An operand with no value is valid and means that the operand is
// not set
(v, Item::Operand { .. }) if v.is_null() => return,
(v, _) if !v.is_null() => Value::new(v),
(v, _) if !v.is_null() => Value::from_ptr(NonNull::new(v).unwrap()).unwrap(),
// All other items should have values
(_, item) => panic!("{item:?} has no value"),
};
Expand Down Expand Up @@ -283,16 +288,18 @@ impl DISanitizer {
}
}

pub fn run(mut self, exported_symbols: &HashSet<Cow<'static, str>>) {
let module = self.module;
pub fn run(
mut self,
module: &mut Module<'_>,
exported_symbols: &HashSet<Cow<'static, str>>,
) -> Result<(), LLVMTypeError> {
self.replace_operands = self.fix_subprogram_linkage(module, exported_symbols)?;

self.replace_operands = self.fix_subprogram_linkage(exported_symbols);

for value in module.globals_iter() {
self.visit_item(Item::GlobalVariable(value));
for global in module.globals_iter() {
self.visit_item(Item::GlobalVariable(global));
}
for value in module.global_aliases_iter() {
self.visit_item(Item::GlobalAlias(value));
for alias in module.global_aliases_iter() {
self.visit_item(Item::GlobalAlias(alias));
}

for function in module.functions_iter() {
Expand All @@ -307,6 +314,8 @@ impl DISanitizer {
}

unsafe { LLVMDisposeDIBuilder(self.builder) };

Ok(())
}

// Make it so that only exported symbols (programs marked as #[no_mangle]) get BTF
Expand All @@ -324,16 +333,13 @@ impl DISanitizer {
// See tests/btf/assembly/exported-symbols.rs .
fn fix_subprogram_linkage(
&mut self,
module: &mut Module<'_>,
export_symbols: &HashSet<Cow<'static, str>>,
) -> HashMap<u64, LLVMMetadataRef> {
) -> Result<HashMap<u64, LLVMMetadataRef>, LLVMTypeError> {
let mut replace = HashMap::new();

for mut function in self
.module
.functions_iter()
.map(|value| unsafe { Function::from_value_ref(value) })
{
if export_symbols.contains(function.name()) {
for mut function in module.functions_iter() {
if export_symbols.contains(&function.name()) {
continue;
}

Expand Down Expand Up @@ -370,7 +376,10 @@ impl DISanitizer {
// replace retained nodes manually below.
LLVMDIBuilderFinalizeSubprogram(self.builder, new_program);

DISubprogram::from_value_ref(LLVMMetadataAsValue(self.context, new_program))
let new_program = LLVMMetadataAsValue(self.context, new_program);
let new_program =
NonNull::new(new_program).expect("new program should not be null");
DISubprogram::from_ptr(new_program)?
};

// Point the function to the new subprogram.
Expand All @@ -396,23 +405,23 @@ impl DISanitizer {
unsafe { LLVMMDNodeInContext2(self.context, core::ptr::null_mut(), 0) };
subprogram.set_retained_nodes(empty_node);

let ret = replace.insert(subprogram.value_ref as u64, unsafe {
LLVMValueAsMetadata(new_program.value_ref)
let ret = replace.insert(subprogram.as_ptr() as u64, unsafe {
LLVMValueAsMetadata(new_program.as_ptr())
});
assert!(ret.is_none());
}

replace
Ok(replace)
}
}

#[derive(Clone, Debug, Eq, PartialEq)]
enum Item {
GlobalVariable(LLVMValueRef),
GlobalAlias(LLVMValueRef),
Function(LLVMValueRef),
FunctionParam(LLVMValueRef),
Instruction(LLVMValueRef),
enum Item<'ctx> {
GlobalVariable(GlobalVariable<'ctx>),
GlobalAlias(GlobalAlias<'ctx>),
Function(Function<'ctx>),
FunctionParam(Argument<'ctx>),
Instruction(Instruction<'ctx>),
Operand(Operand),
MetadataEntry(LLVMValueRef, u32, usize),
}
Expand All @@ -437,16 +446,16 @@ impl Operand {
}
}

impl Item {
impl Item<'_> {
fn value_ref(&self) -> LLVMValueRef {
match self {
Item::GlobalVariable(value)
| Item::GlobalAlias(value)
| Item::Function(value)
| Item::FunctionParam(value)
| Item::Instruction(value)
| Item::Operand(Operand { value, .. })
| Item::MetadataEntry(value, _, _) => *value,
Item::GlobalVariable(global) => global.as_ptr(),
Item::GlobalAlias(global) => global.as_ptr(),
Item::Function(function) => function.as_ptr(),
Item::FunctionParam(function_param) => function_param.as_ptr(),
Item::Instruction(instruction) => instruction.as_ptr(),
Item::Operand(Operand { value, .. }) => *value,
Item::MetadataEntry(value, _, _) => *value,
}
}

Expand Down
Loading
Loading