Skip to content

Commit

Permalink
🔒🌯Provide LLVMTypeWrapper trait, add more wrappers
Browse files Browse the repository at this point in the history
Instead of keeping raw pointer fields (of types like `LLVMValueRef`,
`LLVMMetadataRef`) public and defining constructors with different
names, provide the `LLVMTypeWrapper` trait with `from_ptr()` and
`as_ptr()` methods.

This will allow to convert all safe wrappers from and to raw pointers
with one method, which is going to be helpful for building macros for
them.

On top of that, add more wrappers:

* `BasicBlock`
* `GlobalAlias`
* `GlobalVariable`

Use wrappers in iterators
  • Loading branch information
vadorovsky committed Dec 20, 2024
1 parent 1adc09e commit 06eb87f
Show file tree
Hide file tree
Showing 7 changed files with 786 additions and 378 deletions.
69 changes: 41 additions & 28 deletions src/linker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,15 @@ use std::{
use ar::Archive;
use llvm_sys::{
bit_writer::LLVMWriteBitcodeToFile,
core::{
LLVMContextCreate, LLVMContextDispose, LLVMContextSetDiagnosticHandler, LLVMDisposeModule,
LLVMGetTarget,
},
core::{LLVMContextCreate, LLVMContextDispose, LLVMContextSetDiagnosticHandler, LLVMGetTarget},
error_handling::{LLVMEnablePrettyStackTrace, LLVMInstallFatalErrorHandler},
prelude::{LLVMContextRef, LLVMModuleRef},
prelude::LLVMContextRef,
target_machine::{LLVMCodeGenFileType, LLVMDisposeTargetMachine, LLVMTargetMachineRef},
};
use thiserror::Error;
use tracing::{debug, error, info, warn};

use crate::llvm;
use crate::llvm::{self, LLVMTypeError, LLVMTypeWrapper, Module};

/// Linker error
#[derive(Debug, Error)]
Expand Down Expand Up @@ -77,6 +74,10 @@ pub enum LinkerError {
/// The input object file does not have embedded bitcode.
#[error("no bitcode section found in {0}")]
MissingBitcodeSection(PathBuf),

/// Instantiating of an LLVM type failed.
#[error(transparent)]
LLVMType(#[from] LLVMTypeError),
}

/// BPF Cpu type
Expand Down Expand Up @@ -224,21 +225,26 @@ pub struct LinkerOptions {
}

/// BPF Linker
pub struct Linker {
pub struct Linker<'ctx> {
options: LinkerOptions,
context: LLVMContextRef,
module: LLVMModuleRef,
module: Module<'ctx>,
target_machine: LLVMTargetMachineRef,
diagnostic_handler: DiagnosticHandler,
}

impl Linker {
impl Linker<'_> {
/// Create a new linker instance with the given options.
pub fn new(options: LinkerOptions) -> Self {
let context = unsafe { LLVMContextCreate() };
let module = Module::new(
options.output.file_stem().unwrap().to_str().unwrap(),
context,
);
Linker {
options,
context: ptr::null_mut(),
module: ptr::null_mut(),
context,
module,
target_machine: ptr::null_mut(),
diagnostic_handler: DiagnosticHandler::new(),
}
Expand Down Expand Up @@ -365,7 +371,7 @@ impl Linker {
Archive => panic!("nested archives not supported duh"),
};

if unsafe { !llvm::link_bitcode_buffer(self.context, self.module, &bitcode) } {
if unsafe { !llvm::link_bitcode_buffer(self.context, self.module.as_ptr(), &bitcode) } {
return Err(LinkerError::LinkModuleError(path.to_owned()));
}

Expand Down Expand Up @@ -407,11 +413,11 @@ impl Linker {
})
}
None => {
let c_triple = unsafe { LLVMGetTarget(*module) };
let c_triple = unsafe { LLVMGetTarget(module.as_ptr()) };
let triple = unsafe { CStr::from_ptr(c_triple) }.to_str().unwrap();
if triple.starts_with("bpf") {
// case 2
(triple, unsafe { llvm::target_from_module(*module) })
(triple, unsafe { llvm::target_from_module(module.as_ptr()) })
} else {
// case 3.
info!("detected non-bpf input target {} and no explicit output --target specified, selecting `bpf'", triple);
Expand Down Expand Up @@ -452,17 +458,18 @@ impl Linker {

if self.options.btf {
// if we want to emit BTF, we need to sanitize the debug information
llvm::DISanitizer::new(self.context, self.module).run(&self.options.export_symbols);
llvm::DISanitizer::new(self.context, &mut self.module)
.run(&mut self.module, &self.options.export_symbols)?;
} else {
// if we don't need BTF emission, we can strip DI
let ok = unsafe { llvm::strip_debug_info(self.module) };
let ok = unsafe { llvm::strip_debug_info(self.module.as_ptr()) };
debug!("Stripping DI, changed={}", ok);
}

unsafe {
llvm::optimize(
self.target_machine,
self.module,
&mut self.module,
self.options.optimize,
self.options.ignore_inline_never,
&self.options.export_symbols,
Expand All @@ -486,7 +493,7 @@ impl Linker {
fn write_bitcode(&mut self, output: &CStr) -> Result<(), LinkerError> {
info!("writing bitcode to {:?}", output);

if unsafe { LLVMWriteBitcodeToFile(self.module, output.as_ptr()) } == 1 {
if unsafe { LLVMWriteBitcodeToFile(self.module.as_ptr(), output.as_ptr()) } == 1 {
return Err(LinkerError::WriteBitcodeError);
}

Expand All @@ -496,14 +503,21 @@ impl Linker {
fn write_ir(&mut self, output: &CStr) -> Result<(), LinkerError> {
info!("writing IR to {:?}", output);

unsafe { llvm::write_ir(self.module, output) }.map_err(LinkerError::WriteIRError)
unsafe { llvm::write_ir(self.module.as_ptr(), output) }.map_err(LinkerError::WriteIRError)
}

fn emit(&mut self, output: &CStr, output_type: LLVMCodeGenFileType) -> Result<(), LinkerError> {
info!("emitting {:?} to {:?}", output_type, output);

unsafe { llvm::codegen(self.target_machine, self.module, output, output_type) }
.map_err(LinkerError::EmitCodeError)
unsafe {
llvm::codegen(
self.target_machine,
self.module.as_ptr(),
output,
output_type,
)
}
.map_err(LinkerError::EmitCodeError)
}

fn llvm_init(&mut self) {
Expand Down Expand Up @@ -542,24 +556,23 @@ impl Linker {
);
LLVMInstallFatalErrorHandler(Some(llvm::fatal_error));
LLVMEnablePrettyStackTrace();
self.module = llvm::create_module(
self.module = Module::new(
self.options.output.file_stem().unwrap().to_str().unwrap(),
self.context,
)
.unwrap();
);
}
}
}

impl Drop for Linker {
impl Drop for Linker<'_> {
fn drop(&mut self) {
unsafe {
if !self.target_machine.is_null() {
LLVMDisposeTargetMachine(self.target_machine);
}
if !self.module.is_null() {
LLVMDisposeModule(self.module);
}
// if !self.module.is_null() {
// LLVMDisposeModule(self.module);
// }
if !self.context.is_null() {
LLVMContextDispose(self.context);
}
Expand Down
94 changes: 51 additions & 43 deletions src/llvm/di.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,23 @@ use std::{
collections::{hash_map::DefaultHasher, HashMap, HashSet},
ffi::c_char,
hash::Hasher,
ptr,
ptr::{self, NonNull},
};

use gimli::{DW_TAG_pointer_type, DW_TAG_structure_type, DW_TAG_variant_part};
use llvm_sys::{core::*, debuginfo::*, prelude::*};
use tracing::{span, trace, warn, Level};

use super::types::{
di::DIType,
ir::{Function, MDNode, Metadata, Value},
use crate::llvm::{
iter::*,
types::{
di::{DISubprogram, DIType},
ir::{Function, Instruction, MDNode, Metadata, Module, Value},
LLVMTypeError, LLVMTypeWrapper,
},
};
use crate::llvm::{iter::*, types::di::DISubprogram};

use super::types::ir::{Argument, GlobalAlias, GlobalVariable};

// KSYM_NAME_LEN from linux kernel intentionally set
// to lower value found accross kernel versions to ensure
Expand All @@ -23,7 +28,6 @@ const MAX_KSYM_NAME_LEN: usize = 128;

pub struct DISanitizer {
context: LLVMContextRef,
module: LLVMModuleRef,
builder: LLVMDIBuilderRef,
visited_nodes: HashSet<u64>,
replace_operands: HashMap<u64, LLVMMetadataRef>,
Expand Down Expand Up @@ -59,11 +63,11 @@ fn sanitize_type_name<T: AsRef<str>>(name: T) -> String {
}

impl DISanitizer {
pub fn new(context: LLVMContextRef, module: LLVMModuleRef) -> DISanitizer {
pub fn new(context: LLVMContextRef, module: &Module<'_>) -> Self {
let builder = unsafe { LLVMCreateDIBuilder(module.as_ptr()) };
DISanitizer {
context,
module,
builder: unsafe { LLVMCreateDIBuilder(module) },
builder,
visited_nodes: HashSet::new(),
replace_operands: HashMap::new(),
skipped_types: Vec::new(),
Expand Down Expand Up @@ -227,7 +231,7 @@ impl DISanitizer {
// An operand with no value is valid and means that the operand is
// not set
(v, Item::Operand { .. }) if v.is_null() => return,
(v, _) if !v.is_null() => Value::new(v),
(v, _) if !v.is_null() => Value::from_ptr(NonNull::new(v).unwrap()).unwrap(),
// All other items should have values
(_, item) => panic!("{item:?} has no value"),
};
Expand Down Expand Up @@ -283,16 +287,18 @@ impl DISanitizer {
}
}

pub fn run(mut self, exported_symbols: &HashSet<Cow<'static, str>>) {
let module = self.module;

self.replace_operands = self.fix_subprogram_linkage(exported_symbols);
pub fn run(
mut self,
module: &mut Module<'_>,
exported_symbols: &HashSet<Cow<'static, str>>,
) -> Result<(), LLVMTypeError> {
self.replace_operands = self.fix_subprogram_linkage(module, exported_symbols)?;

for value in module.globals_iter() {
self.visit_item(Item::GlobalVariable(value));
for global in module.globals_iter() {
self.visit_item(Item::GlobalVariable(global));
}
for value in module.global_aliases_iter() {
self.visit_item(Item::GlobalAlias(value));
for alias in module.global_aliases_iter() {
self.visit_item(Item::GlobalAlias(alias));
}

for function in module.functions_iter() {
Expand All @@ -307,6 +313,8 @@ impl DISanitizer {
}

unsafe { LLVMDisposeDIBuilder(self.builder) };

Ok(())
}

// Make it so that only exported symbols (programs marked as #[no_mangle]) get BTF
Expand All @@ -324,16 +332,13 @@ impl DISanitizer {
// See tests/btf/assembly/exported-symbols.rs .
fn fix_subprogram_linkage(
&mut self,
module: &mut Module<'_>,
export_symbols: &HashSet<Cow<'static, str>>,
) -> HashMap<u64, LLVMMetadataRef> {
) -> Result<HashMap<u64, LLVMMetadataRef>, LLVMTypeError> {
let mut replace = HashMap::new();

for mut function in self
.module
.functions_iter()
.map(|value| unsafe { Function::from_value_ref(value) })
{
if export_symbols.contains(function.name()) {
for mut function in module.functions_iter() {
if export_symbols.contains(&function.name()) {
continue;
}

Expand Down Expand Up @@ -370,7 +375,10 @@ impl DISanitizer {
// replace retained nodes manually below.
LLVMDIBuilderFinalizeSubprogram(self.builder, new_program);

DISubprogram::from_value_ref(LLVMMetadataAsValue(self.context, new_program))
let new_program = LLVMMetadataAsValue(self.context, new_program);
let new_program =
NonNull::new(new_program).expect("new program should not be null");
DISubprogram::from_ptr(new_program)?
};

// Point the function to the new subprogram.
Expand All @@ -396,23 +404,23 @@ impl DISanitizer {
unsafe { LLVMMDNodeInContext2(self.context, core::ptr::null_mut(), 0) };
subprogram.set_retained_nodes(empty_node);

let ret = replace.insert(subprogram.value_ref as u64, unsafe {
LLVMValueAsMetadata(new_program.value_ref)
let ret = replace.insert(subprogram.as_ptr() as u64, unsafe {
LLVMValueAsMetadata(new_program.as_ptr())
});
assert!(ret.is_none());
}

replace
Ok(replace)
}
}

#[derive(Clone, Debug, Eq, PartialEq)]
enum Item {
GlobalVariable(LLVMValueRef),
GlobalAlias(LLVMValueRef),
Function(LLVMValueRef),
FunctionParam(LLVMValueRef),
Instruction(LLVMValueRef),
enum Item<'ctx> {
GlobalVariable(GlobalVariable<'ctx>),
GlobalAlias(GlobalAlias<'ctx>),
Function(Function<'ctx>),
FunctionParam(Argument<'ctx>),
Instruction(Instruction<'ctx>),
Operand(Operand),
MetadataEntry(LLVMValueRef, u32, usize),
}
Expand All @@ -437,16 +445,16 @@ impl Operand {
}
}

impl Item {
impl Item<'_> {
fn value_ref(&self) -> LLVMValueRef {
match self {
Item::GlobalVariable(value)
| Item::GlobalAlias(value)
| Item::Function(value)
| Item::FunctionParam(value)
| Item::Instruction(value)
| Item::Operand(Operand { value, .. })
| Item::MetadataEntry(value, _, _) => *value,
Item::GlobalVariable(global) => global.as_ptr(),
Item::GlobalAlias(global) => global.as_ptr(),
Item::Function(function) => function.as_ptr(),
Item::FunctionParam(function_param) => function_param.as_ptr(),
Item::Instruction(instruction) => instruction.as_ptr(),
Item::Operand(Operand { value, .. }) => *value,
Item::MetadataEntry(value, _, _) => *value,
}
}

Expand Down
Loading

0 comments on commit 06eb87f

Please sign in to comment.