diff --git a/src/linker.rs b/src/linker.rs index 8c183f68..a93459e9 100644 --- a/src/linker.rs +++ b/src/linker.rs @@ -14,18 +14,15 @@ use std::{ use ar::Archive; use llvm_sys::{ bit_writer::LLVMWriteBitcodeToFile, - core::{ - LLVMContextCreate, LLVMContextDispose, LLVMContextSetDiagnosticHandler, LLVMDisposeModule, - LLVMGetTarget, - }, + core::{LLVMContextCreate, LLVMContextDispose, LLVMContextSetDiagnosticHandler, LLVMGetTarget}, error_handling::{LLVMEnablePrettyStackTrace, LLVMInstallFatalErrorHandler}, - prelude::{LLVMContextRef, LLVMModuleRef}, + prelude::LLVMContextRef, target_machine::{LLVMCodeGenFileType, LLVMDisposeTargetMachine, LLVMTargetMachineRef}, }; use thiserror::Error; use tracing::{debug, error, info, warn}; -use crate::llvm; +use crate::llvm::{self, LLVMTypeError, LLVMTypeWrapper, Module}; /// Linker error #[derive(Debug, Error)] @@ -77,6 +74,10 @@ pub enum LinkerError { /// The input object file does not have embedded bitcode. #[error("no bitcode section found in {0}")] MissingBitcodeSection(PathBuf), + + /// Instantiating of an LLVM type failed. + #[error(transparent)] + LLVMType(#[from] LLVMTypeError), } /// BPF Cpu type @@ -224,21 +225,26 @@ pub struct LinkerOptions { } /// BPF Linker -pub struct Linker { +pub struct Linker<'ctx> { options: LinkerOptions, context: LLVMContextRef, - module: LLVMModuleRef, + module: Module<'ctx>, target_machine: LLVMTargetMachineRef, diagnostic_handler: DiagnosticHandler, } -impl Linker { +impl Linker<'_> { /// Create a new linker instance with the given options. pub fn new(options: LinkerOptions) -> Self { + let context = unsafe { LLVMContextCreate() }; + let module = Module::new( + options.output.file_stem().unwrap().to_str().unwrap(), + context, + ); Linker { options, - context: ptr::null_mut(), - module: ptr::null_mut(), + context, + module, target_machine: ptr::null_mut(), diagnostic_handler: DiagnosticHandler::new(), } @@ -365,7 +371,7 @@ impl Linker { Archive => panic!("nested archives not supported duh"), }; - if unsafe { !llvm::link_bitcode_buffer(self.context, self.module, &bitcode) } { + if unsafe { !llvm::link_bitcode_buffer(self.context, self.module.as_ptr(), &bitcode) } { return Err(LinkerError::LinkModuleError(path.to_owned())); } @@ -407,11 +413,11 @@ impl Linker { }) } None => { - let c_triple = unsafe { LLVMGetTarget(*module) }; + let c_triple = unsafe { LLVMGetTarget(module.as_ptr()) }; let triple = unsafe { CStr::from_ptr(c_triple) }.to_str().unwrap(); if triple.starts_with("bpf") { // case 2 - (triple, unsafe { llvm::target_from_module(*module) }) + (triple, unsafe { llvm::target_from_module(module.as_ptr()) }) } else { // case 3. info!("detected non-bpf input target {} and no explicit output --target specified, selecting `bpf'", triple); @@ -452,17 +458,18 @@ impl Linker { if self.options.btf { // if we want to emit BTF, we need to sanitize the debug information - llvm::DISanitizer::new(self.context, self.module).run(&self.options.export_symbols); + llvm::DISanitizer::new(self.context, &self.module) + .run(&mut self.module, &self.options.export_symbols)?; } else { // if we don't need BTF emission, we can strip DI - let ok = unsafe { llvm::strip_debug_info(self.module) }; + let ok = unsafe { llvm::strip_debug_info(self.module.as_ptr()) }; debug!("Stripping DI, changed={}", ok); } unsafe { llvm::optimize( self.target_machine, - self.module, + &mut self.module, self.options.optimize, self.options.ignore_inline_never, &self.options.export_symbols, @@ -486,7 +493,7 @@ impl Linker { fn write_bitcode(&mut self, output: &CStr) -> Result<(), LinkerError> { info!("writing bitcode to {:?}", output); - if unsafe { LLVMWriteBitcodeToFile(self.module, output.as_ptr()) } == 1 { + if unsafe { LLVMWriteBitcodeToFile(self.module.as_ptr(), output.as_ptr()) } == 1 { return Err(LinkerError::WriteBitcodeError); } @@ -496,14 +503,21 @@ impl Linker { fn write_ir(&mut self, output: &CStr) -> Result<(), LinkerError> { info!("writing IR to {:?}", output); - unsafe { llvm::write_ir(self.module, output) }.map_err(LinkerError::WriteIRError) + unsafe { llvm::write_ir(self.module.as_ptr(), output) }.map_err(LinkerError::WriteIRError) } fn emit(&mut self, output: &CStr, output_type: LLVMCodeGenFileType) -> Result<(), LinkerError> { info!("emitting {:?} to {:?}", output_type, output); - unsafe { llvm::codegen(self.target_machine, self.module, output, output_type) } - .map_err(LinkerError::EmitCodeError) + unsafe { + llvm::codegen( + self.target_machine, + self.module.as_ptr(), + output, + output_type, + ) + } + .map_err(LinkerError::EmitCodeError) } fn llvm_init(&mut self) { @@ -542,24 +556,23 @@ impl Linker { ); LLVMInstallFatalErrorHandler(Some(llvm::fatal_error)); LLVMEnablePrettyStackTrace(); - self.module = llvm::create_module( + self.module = Module::new( self.options.output.file_stem().unwrap().to_str().unwrap(), self.context, - ) - .unwrap(); + ); } } } -impl Drop for Linker { +impl Drop for Linker<'_> { fn drop(&mut self) { unsafe { if !self.target_machine.is_null() { LLVMDisposeTargetMachine(self.target_machine); } - if !self.module.is_null() { - LLVMDisposeModule(self.module); - } + // if !self.module.is_null() { + // LLVMDisposeModule(self.module); + // } if !self.context.is_null() { LLVMContextDispose(self.context); } diff --git a/src/llvm/di.rs b/src/llvm/di.rs index c4a939d4..6387ba2e 100644 --- a/src/llvm/di.rs +++ b/src/llvm/di.rs @@ -3,18 +3,24 @@ use std::{ collections::{hash_map::DefaultHasher, HashMap, HashSet}, ffi::c_char, hash::Hasher, - ptr, + ptr::{self, NonNull}, }; use gimli::{DW_TAG_pointer_type, DW_TAG_structure_type, DW_TAG_variant_part}; use llvm_sys::{core::*, debuginfo::*, prelude::*}; use tracing::{span, trace, warn, Level}; -use super::types::{ - di::DIType, - ir::{Function, MDNode, Metadata, Value}, +use crate::llvm::{ + iter::*, + types::{ + di::{DISubprogram, DIType}, + ir::{ + Argument, Function, GlobalAlias, GlobalVariable, Instruction, MDNode, Metadata, Module, + Value, + }, + LLVMTypeError, LLVMTypeWrapper, + }, }; -use crate::llvm::{iter::*, types::di::DISubprogram}; // KSYM_NAME_LEN from linux kernel intentionally set // to lower value found accross kernel versions to ensure @@ -23,7 +29,6 @@ const MAX_KSYM_NAME_LEN: usize = 128; pub struct DISanitizer { context: LLVMContextRef, - module: LLVMModuleRef, builder: LLVMDIBuilderRef, visited_nodes: HashSet<u64>, replace_operands: HashMap<u64, LLVMMetadataRef>, @@ -59,11 +64,11 @@ fn sanitize_type_name<T: AsRef<str>>(name: T) -> String { } impl DISanitizer { - pub fn new(context: LLVMContextRef, module: LLVMModuleRef) -> DISanitizer { + pub fn new(context: LLVMContextRef, module: &Module<'_>) -> Self { + let builder = unsafe { LLVMCreateDIBuilder(module.as_ptr()) }; DISanitizer { context, - module, - builder: unsafe { LLVMCreateDIBuilder(module) }, + builder, visited_nodes: HashSet::new(), replace_operands: HashMap::new(), skipped_types: Vec::new(), @@ -227,7 +232,7 @@ impl DISanitizer { // An operand with no value is valid and means that the operand is // not set (v, Item::Operand { .. }) if v.is_null() => return, - (v, _) if !v.is_null() => Value::new(v), + (v, _) if !v.is_null() => Value::from_ptr(NonNull::new(v).unwrap()).unwrap(), // All other items should have values (_, item) => panic!("{item:?} has no value"), }; @@ -283,16 +288,18 @@ impl DISanitizer { } } - pub fn run(mut self, exported_symbols: &HashSet<Cow<'static, str>>) { - let module = self.module; + pub fn run( + mut self, + module: &mut Module<'_>, + exported_symbols: &HashSet<Cow<'static, str>>, + ) -> Result<(), LLVMTypeError> { + self.replace_operands = self.fix_subprogram_linkage(module, exported_symbols)?; - self.replace_operands = self.fix_subprogram_linkage(exported_symbols); - - for value in module.globals_iter() { - self.visit_item(Item::GlobalVariable(value)); + for global in module.globals_iter() { + self.visit_item(Item::GlobalVariable(global)); } - for value in module.global_aliases_iter() { - self.visit_item(Item::GlobalAlias(value)); + for alias in module.global_aliases_iter() { + self.visit_item(Item::GlobalAlias(alias)); } for function in module.functions_iter() { @@ -307,6 +314,8 @@ impl DISanitizer { } unsafe { LLVMDisposeDIBuilder(self.builder) }; + + Ok(()) } // Make it so that only exported symbols (programs marked as #[no_mangle]) get BTF @@ -324,16 +333,13 @@ impl DISanitizer { // See tests/btf/assembly/exported-symbols.rs . fn fix_subprogram_linkage( &mut self, + module: &mut Module<'_>, export_symbols: &HashSet<Cow<'static, str>>, - ) -> HashMap<u64, LLVMMetadataRef> { + ) -> Result<HashMap<u64, LLVMMetadataRef>, LLVMTypeError> { let mut replace = HashMap::new(); - for mut function in self - .module - .functions_iter() - .map(|value| unsafe { Function::from_value_ref(value) }) - { - if export_symbols.contains(function.name()) { + for mut function in module.functions_iter() { + if export_symbols.contains(&function.name()) { continue; } @@ -370,7 +376,10 @@ impl DISanitizer { // replace retained nodes manually below. LLVMDIBuilderFinalizeSubprogram(self.builder, new_program); - DISubprogram::from_value_ref(LLVMMetadataAsValue(self.context, new_program)) + let new_program = LLVMMetadataAsValue(self.context, new_program); + let new_program = + NonNull::new(new_program).expect("new program should not be null"); + DISubprogram::from_ptr(new_program)? }; // Point the function to the new subprogram. @@ -396,23 +405,23 @@ impl DISanitizer { unsafe { LLVMMDNodeInContext2(self.context, core::ptr::null_mut(), 0) }; subprogram.set_retained_nodes(empty_node); - let ret = replace.insert(subprogram.value_ref as u64, unsafe { - LLVMValueAsMetadata(new_program.value_ref) + let ret = replace.insert(subprogram.as_ptr() as u64, unsafe { + LLVMValueAsMetadata(new_program.as_ptr()) }); assert!(ret.is_none()); } - replace + Ok(replace) } } #[derive(Clone, Debug, Eq, PartialEq)] -enum Item { - GlobalVariable(LLVMValueRef), - GlobalAlias(LLVMValueRef), - Function(LLVMValueRef), - FunctionParam(LLVMValueRef), - Instruction(LLVMValueRef), +enum Item<'ctx> { + GlobalVariable(GlobalVariable<'ctx>), + GlobalAlias(GlobalAlias<'ctx>), + Function(Function<'ctx>), + FunctionParam(Argument<'ctx>), + Instruction(Instruction<'ctx>), Operand(Operand), MetadataEntry(LLVMValueRef, u32, usize), } @@ -437,16 +446,16 @@ impl Operand { } } -impl Item { +impl Item<'_> { fn value_ref(&self) -> LLVMValueRef { match self { - Item::GlobalVariable(value) - | Item::GlobalAlias(value) - | Item::Function(value) - | Item::FunctionParam(value) - | Item::Instruction(value) - | Item::Operand(Operand { value, .. }) - | Item::MetadataEntry(value, _, _) => *value, + Item::GlobalVariable(global) => global.as_ptr(), + Item::GlobalAlias(global) => global.as_ptr(), + Item::Function(function) => function.as_ptr(), + Item::FunctionParam(function_param) => function_param.as_ptr(), + Item::Instruction(instruction) => instruction.as_ptr(), + Item::Operand(Operand { value, .. }) => *value, + Item::MetadataEntry(value, _, _) => *value, } } diff --git a/src/llvm/iter.rs b/src/llvm/iter.rs index e1d77198..21fc7720 100644 --- a/src/llvm/iter.rs +++ b/src/llvm/iter.rs @@ -3,55 +3,92 @@ use std::marker::PhantomData; use llvm_sys::{ core::{ LLVMGetFirstBasicBlock, LLVMGetFirstFunction, LLVMGetFirstGlobal, LLVMGetFirstGlobalAlias, - LLVMGetFirstInstruction, LLVMGetLastBasicBlock, LLVMGetLastFunction, LLVMGetLastGlobal, - LLVMGetLastGlobalAlias, LLVMGetLastInstruction, LLVMGetNextBasicBlock, LLVMGetNextFunction, - LLVMGetNextGlobal, LLVMGetNextGlobalAlias, LLVMGetNextInstruction, + LLVMGetFirstInstruction, LLVMGetNextBasicBlock, LLVMGetNextFunction, LLVMGetNextGlobal, + LLVMGetNextGlobalAlias, LLVMGetNextInstruction, LLVMGetPreviousBasicBlock, + LLVMGetPreviousFunction, LLVMGetPreviousGlobal, LLVMGetPreviousGlobalAlias, + LLVMGetPreviousInstruction, }, - prelude::{LLVMBasicBlockRef, LLVMModuleRef, LLVMValueRef}, + LLVMBasicBlock, LLVMValue, +}; + +use crate::llvm::types::ir::{ + BasicBlock, Function, GlobalAlias, GlobalVariable, Instruction, Module, }; macro_rules! llvm_iterator { - ($trait_name:ident, $iterator_name:ident, $iterable:ty, $method_name:ident, $item_ty:ty, $first:expr, $last:expr, $next:expr $(,)?) => { + ( + $trait_name:ident, + $iterator_name:ident, + $iterable:ident, + $method_name:ident, + $ptr_ty:ty, + $item_ty:ident, + $first:expr, + $last:expr, + $next:expr, + $prev:expr $(,)? + ) => { pub trait $trait_name { fn $method_name(&self) -> $iterator_name; } - pub struct $iterator_name<'a> { - lifetime: PhantomData<&'a $iterable>, - next: $item_ty, - last: $item_ty, + pub struct $iterator_name<'ctx> { + lifetime: PhantomData<&'ctx $iterable<'ctx>>, + current: Option<::std::ptr::NonNull<$ptr_ty>>, } - impl $trait_name for $iterable { + impl $trait_name for $iterable<'_> { fn $method_name(&self) -> $iterator_name { - let first = unsafe { $first(*self) }; - let last = unsafe { $last(*self) }; - assert_eq!(first.is_null(), last.is_null()); + #[allow(unused_imports)] + use $crate::llvm::types::LLVMTypeWrapper as _; + + let first = unsafe { $first(self.as_ptr()) }; + let first = ::std::ptr::NonNull::new(first); + $iterator_name { lifetime: PhantomData, - next: first, - last, + current: first, } } } - impl<'a> Iterator for $iterator_name<'a> { - type Item = $item_ty; + impl<'ctx> Iterator for $iterator_name<'ctx> { + type Item = $item_ty<'ctx>; fn next(&mut self) -> Option<Self::Item> { - let Self { - lifetime: _, - next, - last, - } = self; - if next.is_null() { - return None; + #[allow(unused_imports)] + use $crate::llvm::types::LLVMTypeWrapper as _; + + if let Some(item) = self.current { + let next = unsafe { $next(item.as_ptr()) }; + let next = std::ptr::NonNull::new(next); + self.current = next; + + let item = $item_ty::from_ptr(item).unwrap(); + + Some(item) + } else { + None + } + } + } + + impl<'ctx> DoubleEndedIterator for $iterator_name<'ctx> { + fn next_back(&mut self) -> Option<Self::Item> { + #[allow(unused_imports)] + use $crate::llvm::types::LLVMTypeWrapper as _; + + if let Some(item) = self.current { + let prev = unsafe { $prev(item.as_ptr()) }; + let prev = std::ptr::NonNull::new(prev); + self.current = prev; + + let item = $item_ty::from_ptr(item).unwrap(); + + Some(item) + } else { + None } - let last = *next == *last; - let item = *next; - *next = unsafe { $next(*next) }; - assert_eq!(next.is_null(), last); - Some(item) } } }; @@ -60,54 +97,64 @@ macro_rules! llvm_iterator { llvm_iterator! { IterModuleGlobals, GlobalsIter, - LLVMModuleRef, + Module, globals_iter, - LLVMValueRef, + LLVMValue, + GlobalVariable, LLVMGetFirstGlobal, LLVMGetLastGlobal, LLVMGetNextGlobal, + LLVMGetPreviousGlobal, } llvm_iterator! { IterModuleGlobalAliases, GlobalAliasesIter, - LLVMModuleRef, + Module, global_aliases_iter, - LLVMValueRef, + LLVMValue, + GlobalAlias, LLVMGetFirstGlobalAlias, LLVMGetLastGlobalAlias, LLVMGetNextGlobalAlias, + LLVMGetPreviousGlobalAlias, } llvm_iterator! { IterModuleFunctions, FunctionsIter, - LLVMModuleRef, + Module, functions_iter, - LLVMValueRef, + LLVMValue, + Function, LLVMGetFirstFunction, LLVMGetLastFunction, LLVMGetNextFunction, + LLVMGetPreviousFunction, } llvm_iterator!( IterBasicBlocks, BasicBlockIter, - LLVMValueRef, - basic_blocks_iter, - LLVMBasicBlockRef, + Function, + basic_blocks, + LLVMBasicBlock, + BasicBlock, LLVMGetFirstBasicBlock, LLVMGetLastBasicBlock, - LLVMGetNextBasicBlock + LLVMGetNextBasicBlock, + LLVMGetPreviousBasicBlock, ); llvm_iterator!( IterInstructions, InstructionsIter, - LLVMBasicBlockRef, + BasicBlock, instructions_iter, - LLVMValueRef, + LLVMValue, + Instruction, LLVMGetFirstInstruction, LLVMGetLastInstruction, - LLVMGetNextInstruction + LLVMGetNextInstruction, + LLVMGetPreviousInstruction ); diff --git a/src/llvm/mod.rs b/src/llvm/mod.rs index 86de1b04..c1c13a49 100644 --- a/src/llvm/mod.rs +++ b/src/llvm/mod.rs @@ -1,7 +1,3 @@ -mod di; -mod iter; -mod types; - use std::{ borrow::Cow, collections::HashSet, @@ -18,9 +14,7 @@ use llvm_sys::{ core::{ LLVMCreateMemoryBufferWithMemoryRange, LLVMDisposeMemoryBuffer, LLVMDisposeMessage, LLVMGetDiagInfoDescription, LLVMGetDiagInfoSeverity, LLVMGetEnumAttributeKindForName, - LLVMGetMDString, LLVMGetModuleInlineAsm, LLVMGetTarget, LLVMGetValueName2, - LLVMModuleCreateWithNameInContext, LLVMPrintModuleToFile, LLVMRemoveEnumAttributeAtIndex, - LLVMSetLinkage, LLVMSetModuleInlineAsm2, LLVMSetVisibility, + LLVMGetMDString, LLVMGetTarget, LLVMPrintModuleToFile, LLVMRemoveEnumAttributeAtIndex, }, debuginfo::LLVMStripModuleDebugInfo, error::{ @@ -49,9 +43,15 @@ use llvm_sys::{ LLVMAttributeFunctionIndex, LLVMLinkage, LLVMVisibility, }; use tracing::{debug, error}; +use types::ir::{Function, GlobalValue}; use crate::OptLevel; +mod di; +mod iter; +mod types; +pub use types::{ir::Module, LLVMTypeError, LLVMTypeWrapper}; + pub unsafe fn init<T: AsRef<str>>(args: &[T], overview: &str) { LLVMInitializeBPFTarget(); LLVMInitializeBPFTargetMC(); @@ -73,17 +73,6 @@ unsafe fn parse_command_line_options<T: AsRef<str>>(args: &[T], overview: &str) LLVMParseCommandLineOptions(c_ptrs.len() as i32, c_ptrs.as_ptr(), overview.as_ptr()); } -pub unsafe fn create_module(name: &str, context: LLVMContextRef) -> Option<LLVMModuleRef> { - let c_name = CString::new(name).unwrap(); - let module = LLVMModuleCreateWithNameInContext(c_name.as_ptr(), context); - - if module.is_null() { - return None; - } - - Some(module) -} - pub unsafe fn find_embedded_bitcode( context: LLVMContextRef, data: &[u8], @@ -192,29 +181,32 @@ pub unsafe fn create_target_machine( pub unsafe fn optimize( tm: LLVMTargetMachineRef, - module: LLVMModuleRef, + module: &mut Module<'_>, opt_level: OptLevel, ignore_inline_never: bool, export_symbols: &HashSet<Cow<'static, str>>, ) -> Result<(), String> { if module_asm_is_probestack(module) { - LLVMSetModuleInlineAsm2(module, ptr::null_mut(), 0); + module.set_inline_asm(""); } - for sym in module.globals_iter() { - internalize(sym, symbol_name(sym), export_symbols); + for mut sym in module.globals_iter() { + let name = sym.name(); + internalize(&mut sym, &name, export_symbols); } - for sym in module.global_aliases_iter() { - internalize(sym, symbol_name(sym), export_symbols); + for mut sym in module.global_aliases_iter() { + let name = sym.name(); + internalize(&mut sym, &name, export_symbols); } - for function in module.functions_iter() { - let name = symbol_name(function); - if !name.starts_with("llvm.") { + for mut function in module.functions_iter() { + let name = function.name().into_owned(); + let to_internalize = !name.starts_with("llvm."); + if to_internalize { if ignore_inline_never { - remove_attribute(function, "noinline"); + remove_attribute(&mut function, "noinline"); } - internalize(function, name, export_symbols); + internalize(&mut function, &name, export_symbols); } } @@ -239,7 +231,7 @@ pub unsafe fn optimize( debug!("running passes: {passes}"); let passes = CString::new(passes).unwrap(); let options = LLVMCreatePassBuilderOptions(); - let error = LLVMRunPasses(module, passes.as_ptr(), tm, options); + let error = LLVMRunPasses(module.as_ptr(), passes.as_ptr(), tm, options); LLVMDisposePassBuilderOptions(options); // Handle the error and print it to stderr. if !error.is_null() { @@ -260,26 +252,14 @@ pub unsafe fn strip_debug_info(module: LLVMModuleRef) -> bool { LLVMStripModuleDebugInfo(module) != 0 } -unsafe fn module_asm_is_probestack(module: LLVMModuleRef) -> bool { - let mut len = 0; - let ptr = LLVMGetModuleInlineAsm(module, &mut len); - if ptr.is_null() { - return false; - } - - let asm = String::from_utf8_lossy(slice::from_raw_parts(ptr as *const c_uchar, len)); +fn module_asm_is_probestack(module: &Module) -> bool { + let asm = module.inline_asm(); asm.contains("__rust_probestack") } -fn symbol_name<'a>(value: *mut llvm_sys::LLVMValue) -> &'a str { - let mut name_len = 0; - let ptr = unsafe { LLVMGetValueName2(value, &mut name_len) }; - unsafe { str::from_utf8(slice::from_raw_parts(ptr as *const c_uchar, name_len)).unwrap() } -} - -unsafe fn remove_attribute(function: *mut llvm_sys::LLVMValue, name: &str) { +unsafe fn remove_attribute(function: &mut Function, name: &str) { let attr_kind = LLVMGetEnumAttributeKindForName(name.as_ptr() as *const c_char, name.len()); - LLVMRemoveEnumAttributeAtIndex(function, LLVMAttributeFunctionIndex, attr_kind); + LLVMRemoveEnumAttributeAtIndex(function.as_ptr(), LLVMAttributeFunctionIndex, attr_kind); } pub unsafe fn write_ir(module: LLVMModuleRef, output: &CStr) -> Result<(), String> { @@ -308,14 +288,14 @@ pub unsafe fn codegen( } } -pub unsafe fn internalize( - value: LLVMValueRef, +pub unsafe fn internalize<T: GlobalValue>( + value: &mut T, name: &str, export_symbols: &HashSet<Cow<'static, str>>, ) { if !name.starts_with("llvm.") && !export_symbols.contains(name) { - LLVMSetLinkage(value, LLVMLinkage::LLVMInternalLinkage); - LLVMSetVisibility(value, LLVMVisibility::LLVMDefaultVisibility); + value.set_linkage(LLVMLinkage::LLVMInternalLinkage); + value.set_visibility(LLVMVisibility::LLVMDefaultVisibility); } } diff --git a/src/llvm/types/di.rs b/src/llvm/types/di.rs index 0edc4c61..e9fc651b 100644 --- a/src/llvm/types/di.rs +++ b/src/llvm/types/di.rs @@ -11,14 +11,18 @@ use llvm_sys::{ debuginfo::{ LLVMDIFileGetFilename, LLVMDIFlags, LLVMDIScopeGetFile, LLVMDISubprogramGetLine, LLVMDITypeGetFlags, LLVMDITypeGetLine, LLVMDITypeGetName, LLVMDITypeGetOffsetInBits, - LLVMGetDINodeTag, + LLVMGetDINodeTag, LLVMGetMetadataKind, LLVMMetadataKind, }, - prelude::{LLVMContextRef, LLVMMetadataRef, LLVMValueRef}, + prelude::{LLVMContextRef, LLVMMetadataRef}, + LLVMOpaqueMetadata, LLVMValue, }; use crate::llvm::{ mdstring_to_str, - types::ir::{MDNode, Metadata}, + types::{ + ir::{MDNode, Metadata}, + LLVMTypeError, LLVMTypeWrapper, + }, }; /// Returns a DWARF tag for the given debug info node. @@ -41,26 +45,30 @@ unsafe fn di_node_tag(metadata_ref: LLVMMetadataRef) -> DwTag { /// A `DIFile` debug info node, which represents a given file, is referenced by /// other debug info nodes which belong to the file. pub struct DIFile<'ctx> { - pub(super) metadata_ref: LLVMMetadataRef, + metadata: NonNull<LLVMOpaqueMetadata>, _marker: PhantomData<&'ctx ()>, } -impl DIFile<'_> { - /// Constructs a new [`DIFile`] from the given `metadata`. - /// - /// # Safety - /// - /// This method assumes that the given `metadata` corresponds to a valid - /// instance of [LLVM `DIFile`](https://llvm.org/doxygen/classllvm_1_1DIFile.html). - /// It's the caller's responsibility to ensure this invariant, as this - /// method doesn't perform any validation checks. - pub(crate) unsafe fn from_metadata_ref(metadata_ref: LLVMMetadataRef) -> Self { - Self { - metadata_ref, - _marker: PhantomData, +impl LLVMTypeWrapper for DIFile<'_> { + type Target = LLVMOpaqueMetadata; + + fn from_ptr(metadata: NonNull<Self::Target>) -> Result<Self, LLVMTypeError> { + let metadata_kind = unsafe { LLVMGetMetadataKind(metadata.as_ptr()) }; + if !matches!(metadata_kind, LLVMMetadataKind::LLVMDIFileMetadataKind) { + return Err(LLVMTypeError::InvalidPointerType("DIFile")); } + Ok(Self { + metadata, + _marker: PhantomData, + }) + } + + fn as_ptr(&self) -> *mut Self::Target { + self.metadata.as_ptr() } +} +impl DIFile<'_> { pub fn filename(&self) -> Option<&CStr> { let mut len = 0; // `LLVMDIFileGetName` doesn't allocate any memory, it just returns @@ -69,7 +77,7 @@ impl DIFile<'_> { // // Therefore, we don't need to call `LLVMDisposeMessage`. The memory // gets freed when calling `LLVMDisposeDIBuilder`. - let ptr = unsafe { LLVMDIFileGetFilename(self.metadata_ref, &mut len) }; + let ptr = unsafe { LLVMDIFileGetFilename(self.metadata.as_ptr(), &mut len) }; NonNull::new(ptr as *mut _).map(|ptr| unsafe { CStr::from_ptr(ptr.as_ptr()) }) } } @@ -109,39 +117,57 @@ unsafe fn di_type_name<'a>(metadata_ref: LLVMMetadataRef) -> Option<&'a CStr> { /// Represents the debug information for a primitive type in LLVM IR. pub struct DIType<'ctx> { - pub(super) metadata_ref: LLVMMetadataRef, - pub(super) value_ref: LLVMValueRef, + metadata: NonNull<LLVMOpaqueMetadata>, + value: NonNull<LLVMValue>, _marker: PhantomData<&'ctx ()>, } -impl DIType<'_> { - /// Constructs a new [`DIType`] from the given `value`. - /// - /// # Safety - /// - /// This method assumes that the given `value` corresponds to a valid - /// instance of [LLVM `DIType`](https://llvm.org/doxygen/classllvm_1_1DIType.html). - /// It's the caller's responsibility to ensure this invariant, as this - /// method doesn't perform any validation checks. - pub unsafe fn from_value_ref(value_ref: LLVMValueRef) -> Self { - let metadata_ref = unsafe { LLVMValueAsMetadata(value_ref) }; - Self { - metadata_ref, - value_ref, - _marker: PhantomData, +impl LLVMTypeWrapper for DIType<'_> { + type Target = LLVMValue; + + fn from_ptr(value: NonNull<Self::Target>) -> Result<Self, LLVMTypeError> { + let metadata = unsafe { LLVMValueAsMetadata(value.as_ptr()) }; + let metadata = NonNull::new(metadata).ok_or(LLVMTypeError::NullPointer)?; + let metadata_kind = unsafe { LLVMGetMetadataKind(metadata.as_ptr()) }; + // The children of `DIType` are: + // + // - `DIBasicType` + // - `DICompositeType` + // - `DIDerivedType` + // - `DIStringType` + // - `DISubroutimeType` + // + // https://llvm.org/doxygen/classllvm_1_1DIType.html + match metadata_kind { + LLVMMetadataKind::LLVMDIBasicTypeMetadataKind + | LLVMMetadataKind::LLVMDICompositeTypeMetadataKind + | LLVMMetadataKind::LLVMDIDerivedTypeMetadataKind + | LLVMMetadataKind::LLVMDIStringTypeMetadataKind + | LLVMMetadataKind::LLVMDISubroutineTypeMetadataKind => Ok(Self { + metadata, + value, + _marker: PhantomData, + }), + _ => Err(LLVMTypeError::InvalidPointerType("DIType")), } } + fn as_ptr(&self) -> *mut Self::Target { + self.value.as_ptr() + } +} + +impl DIType<'_> { /// Returns the offset of the type in bits. This offset is used in case the /// type is a member of a composite type. pub fn offset_in_bits(&self) -> usize { - unsafe { LLVMDITypeGetOffsetInBits(self.metadata_ref) as usize } + unsafe { LLVMDITypeGetOffsetInBits(self.metadata.as_ptr()) as usize } } } impl<'ctx> From<DIDerivedType<'ctx>> for DIType<'ctx> { fn from(di_derived_type: DIDerivedType) -> Self { - unsafe { Self::from_value_ref(di_derived_type.value_ref) } + Self::from_ptr(di_derived_type.value).unwrap() } } @@ -160,34 +186,44 @@ enum DIDerivedTypeOperand { /// alternative name. The examples of derived types are pointers, references, /// typedefs, etc. pub struct DIDerivedType<'ctx> { - metadata_ref: LLVMMetadataRef, - value_ref: LLVMValueRef, + metadata: NonNull<LLVMOpaqueMetadata>, + value: NonNull<LLVMValue>, _marker: PhantomData<&'ctx ()>, } -impl DIDerivedType<'_> { - /// Constructs a new [`DIDerivedType`] from the given `value`. - /// - /// # Safety - /// - /// This method assumes that the provided `value` corresponds to a valid - /// instance of [LLVM `DIDerivedType`](https://llvm.org/doxygen/classllvm_1_1DIDerivedType.html). - /// It's the caller's responsibility to ensure this invariant, as this - /// method doesn't perform any validation checks. - pub unsafe fn from_value_ref(value_ref: LLVMValueRef) -> Self { - let metadata_ref = LLVMValueAsMetadata(value_ref); - Self { - metadata_ref, - value_ref, - _marker: PhantomData, +impl LLVMTypeWrapper for DIDerivedType<'_> { + type Target = LLVMValue; + + fn from_ptr(value: NonNull<Self::Target>) -> Result<Self, LLVMTypeError> { + let metadata = unsafe { LLVMValueAsMetadata(value.as_ptr()) }; + let metadata = NonNull::new(metadata).ok_or(LLVMTypeError::NullPointer)?; + let metadata_kind = unsafe { LLVMGetMetadataKind(metadata.as_ptr()) }; + if !matches!( + metadata_kind, + LLVMMetadataKind::LLVMDIDerivedTypeMetadataKind, + ) { + return Err(LLVMTypeError::InvalidPointerType("DIDerivedType")); } + Ok(Self { + metadata, + value, + _marker: PhantomData, + }) + } + + fn as_ptr(&self) -> *mut Self::Target { + self.value.as_ptr() } +} +impl DIDerivedType<'_> { /// Returns the base type of this derived type. pub fn base_type(&self) -> Metadata { unsafe { - let value = LLVMGetOperand(self.value_ref, DIDerivedTypeOperand::BaseType as u32); - Metadata::from_value_ref(value) + let value = LLVMGetOperand(self.value.as_ptr(), DIDerivedTypeOperand::BaseType as u32); + let value = NonNull::new(value).expect("base type operand should not be null"); + Metadata::from_value(value) + .expect("base type pointer should be an instance of Metadata") } } @@ -198,12 +234,17 @@ impl DIDerivedType<'_> { /// Returns a `NulError` if the new name contains a NUL byte, as it cannot /// be converted into a `CString`. pub fn replace_name(&mut self, context: LLVMContextRef, name: &str) -> Result<(), NulError> { - super::ir::replace_name(self.value_ref, context, DITypeOperand::Name as u32, name) + super::ir::replace_name( + self.value.as_ptr(), + context, + DITypeOperand::Name as u32, + name, + ) } /// Returns a DWARF tag of the given derived type. pub fn tag(&self) -> DwTag { - unsafe { di_node_tag(self.metadata_ref) } + unsafe { di_node_tag(self.metadata.as_ptr()) } } } @@ -221,63 +262,75 @@ enum DICompositeTypeOperand { /// Composite type is a kind of type that can include other types, such as /// structures, enums, unions, etc. pub struct DICompositeType<'ctx> { - metadata_ref: LLVMMetadataRef, - value_ref: LLVMValueRef, + metadata: NonNull<LLVMOpaqueMetadata>, + value: NonNull<LLVMValue>, _marker: PhantomData<&'ctx ()>, } -impl DICompositeType<'_> { - /// Constructs a new [`DICompositeType`] from the given `value`. - /// - /// # Safety - /// - /// This method assumes that the provided `value` corresponds to a valid - /// instance of [LLVM `DICompositeType`](https://llvm.org/doxygen/classllvm_1_1DICompositeType.html). - /// It's the caller's responsibility to ensure this invariant, as this - /// method doesn't perform any validation checks. - pub unsafe fn from_value_ref(value_ref: LLVMValueRef) -> Self { - let metadata_ref = LLVMValueAsMetadata(value_ref); - Self { - metadata_ref, - value_ref, - _marker: PhantomData, +impl LLVMTypeWrapper for DICompositeType<'_> { + type Target = LLVMValue; + + fn from_ptr(value: NonNull<Self::Target>) -> Result<Self, LLVMTypeError> { + let metadata = unsafe { LLVMValueAsMetadata(value.as_ptr()) }; + let metadata = NonNull::new(metadata).expect("metadata pointer should not be null"); + let metadata_kind = unsafe { LLVMGetMetadataKind(metadata.as_ptr()) }; + if !matches!( + metadata_kind, + LLVMMetadataKind::LLVMDICompositeTypeMetadataKind, + ) { + return Err(LLVMTypeError::InvalidPointerType("DICompositeType")); } + Ok(Self { + metadata, + value, + _marker: PhantomData, + }) } + fn as_ptr(&self) -> *mut Self::Target { + self.value.as_ptr() + } +} + +impl DICompositeType<'_> { /// Returns an iterator over elements (struct fields, enum variants, etc.) /// of the composite type. pub fn elements(&self) -> impl Iterator<Item = Metadata> { let elements = - unsafe { LLVMGetOperand(self.value_ref, DICompositeTypeOperand::Elements as u32) }; + unsafe { LLVMGetOperand(self.value.as_ptr(), DICompositeTypeOperand::Elements as u32) }; let operands = NonNull::new(elements) .map(|elements| unsafe { LLVMGetNumOperands(elements.as_ptr()) }) .unwrap_or(0); - (0..operands) - .map(move |i| unsafe { Metadata::from_value_ref(LLVMGetOperand(elements, i as u32)) }) + (0..operands).map(move |i| { + let operand = unsafe { LLVMGetOperand(elements, i as u32) }; + let operand = NonNull::new(operand).expect("element operand should not be null"); + Metadata::from_value(operand).expect("operands should be instances of Metadata") + }) } /// Returns the name of the composite type. pub fn name(&self) -> Option<&CStr> { - unsafe { di_type_name(self.metadata_ref) } + unsafe { di_type_name(self.metadata.as_ptr()) } } /// Returns the file that the composite type belongs to. pub fn file(&self) -> DIFile { unsafe { - let metadata = LLVMDIScopeGetFile(self.metadata_ref); - DIFile::from_metadata_ref(metadata) + let metadata = LLVMDIScopeGetFile(self.metadata.as_ptr()); + let metadata = NonNull::new(metadata).expect("metadata pointer should not be null"); + DIFile::from_ptr(metadata).expect("the pointer should be of type DIFile") } } /// Returns the flags associated with the composity type. pub fn flags(&self) -> LLVMDIFlags { - unsafe { LLVMDITypeGetFlags(self.metadata_ref) } + unsafe { LLVMDITypeGetFlags(self.metadata.as_ptr()) } } /// Returns the line number in the source code where the type is defined. pub fn line(&self) -> u32 { - unsafe { LLVMDITypeGetLine(self.metadata_ref) } + unsafe { LLVMDITypeGetLine(self.metadata.as_ptr()) } } /// Replaces the elements of the composite type with a new metadata node. @@ -287,9 +340,9 @@ impl DICompositeType<'_> { pub fn replace_elements(&mut self, mdnode: MDNode) { unsafe { LLVMReplaceMDNodeOperandWith( - self.value_ref, + self.value.as_ptr(), DICompositeTypeOperand::Elements as u32, - LLVMValueAsMetadata(mdnode.value_ref), + LLVMValueAsMetadata(mdnode.as_ptr()), ) } } @@ -301,12 +354,17 @@ impl DICompositeType<'_> { /// Returns a `NulError` if the new name contains a NUL byte, as it cannot /// be converted into a `CString`. pub fn replace_name(&mut self, context: LLVMContextRef, name: &str) -> Result<(), NulError> { - super::ir::replace_name(self.value_ref, context, DITypeOperand::Name as u32, name) + super::ir::replace_name( + self.value.as_ptr(), + context, + DITypeOperand::Name as u32, + name, + ) } /// Returns a DWARF tag of the given composite type. pub fn tag(&self) -> DwTag { - unsafe { di_node_tag(self.metadata_ref) } + unsafe { di_node_tag(self.metadata.as_ptr()) } } } @@ -324,58 +382,70 @@ enum DISubprogramOperand { /// Represents the debug information for a subprogram (function) in LLVM IR. pub struct DISubprogram<'ctx> { - pub value_ref: LLVMValueRef, + value: NonNull<LLVMValue>, _marker: PhantomData<&'ctx ()>, } -impl DISubprogram<'_> { - /// Constructs a new [`DISubprogram`] from the given `value`. - /// - /// # Safety - /// - /// This method assumes that the provided `value` corresponds to a valid - /// instance of [LLVM `DISubprogram`](https://llvm.org/doxygen/classllvm_1_1DISubprogram.html). - /// It's the caller's responsibility to ensure this invariant, as this - /// method doesn't perform any validation checks. - pub(crate) unsafe fn from_value_ref(value_ref: LLVMValueRef) -> Self { - DISubprogram { - value_ref, - _marker: PhantomData, +impl LLVMTypeWrapper for DISubprogram<'_> { + type Target = LLVMValue; + + fn from_ptr(value: NonNull<Self::Target>) -> Result<Self, LLVMTypeError> { + let metadata_ref = unsafe { LLVMValueAsMetadata(value.as_ptr()) }; + if metadata_ref.is_null() { + return Err(LLVMTypeError::NullPointer); + } + let metadata_kind = unsafe { LLVMGetMetadataKind(metadata_ref) }; + if !matches!( + metadata_kind, + LLVMMetadataKind::LLVMDISubprogramMetadataKind, + ) { + return Err(LLVMTypeError::InvalidPointerType("DISubprogram")); } + Ok(DISubprogram { + value, + _marker: PhantomData, + }) } + fn as_ptr(&self) -> *mut Self::Target { + self.value.as_ptr() + } +} + +impl DISubprogram<'_> { /// Returns the name of the subprogram. pub fn name(&self) -> Option<&str> { - let operand = unsafe { LLVMGetOperand(self.value_ref, DISubprogramOperand::Name as u32) }; + let operand = + unsafe { LLVMGetOperand(self.value.as_ptr(), DISubprogramOperand::Name as u32) }; NonNull::new(operand).map(|_| mdstring_to_str(operand)) } /// Returns the linkage name of the subprogram. pub fn linkage_name(&self) -> Option<&str> { let operand = - unsafe { LLVMGetOperand(self.value_ref, DISubprogramOperand::LinkageName as u32) }; + unsafe { LLVMGetOperand(self.value.as_ptr(), DISubprogramOperand::LinkageName as u32) }; NonNull::new(operand).map(|_| mdstring_to_str(operand)) } pub fn ty(&self) -> LLVMMetadataRef { unsafe { LLVMValueAsMetadata(LLVMGetOperand( - self.value_ref, + self.value.as_ptr(), DISubprogramOperand::Ty as u32, )) } } pub fn file(&self) -> LLVMMetadataRef { - unsafe { LLVMDIScopeGetFile(LLVMValueAsMetadata(self.value_ref)) } + unsafe { LLVMDIScopeGetFile(LLVMValueAsMetadata(self.value.as_ptr())) } } pub fn line(&self) -> u32 { - unsafe { LLVMDISubprogramGetLine(LLVMValueAsMetadata(self.value_ref)) } + unsafe { LLVMDISubprogramGetLine(LLVMValueAsMetadata(self.value.as_ptr())) } } pub fn type_flags(&self) -> i32 { - unsafe { LLVMDITypeGetFlags(LLVMValueAsMetadata(self.value_ref)) } + unsafe { LLVMDITypeGetFlags(LLVMValueAsMetadata(self.value.as_ptr())) } } /// Replaces the name of the subprogram with a new name. @@ -386,7 +456,7 @@ impl DISubprogram<'_> { /// be converted into a `CString`. pub fn replace_name(&mut self, context: LLVMContextRef, name: &str) -> Result<(), NulError> { super::ir::replace_name( - self.value_ref, + self.value.as_ptr(), context, DISubprogramOperand::Name as u32, name, @@ -395,27 +465,34 @@ impl DISubprogram<'_> { pub fn scope(&self) -> Option<LLVMMetadataRef> { unsafe { - let operand = LLVMGetOperand(self.value_ref, DISubprogramOperand::Scope as u32); + let operand = LLVMGetOperand(self.value.as_ptr(), DISubprogramOperand::Scope as u32); NonNull::new(operand).map(|_| LLVMValueAsMetadata(operand)) } } pub fn unit(&self) -> Option<LLVMMetadataRef> { unsafe { - let operand = LLVMGetOperand(self.value_ref, DISubprogramOperand::Unit as u32); + let operand = LLVMGetOperand(self.value.as_ptr(), DISubprogramOperand::Unit as u32); NonNull::new(operand).map(|_| LLVMValueAsMetadata(operand)) } } pub fn set_unit(&mut self, unit: LLVMMetadataRef) { unsafe { - LLVMReplaceMDNodeOperandWith(self.value_ref, DISubprogramOperand::Unit as u32, unit) + LLVMReplaceMDNodeOperandWith( + self.value.as_ptr(), + DISubprogramOperand::Unit as u32, + unit, + ) }; } pub fn retained_nodes(&self) -> Option<LLVMMetadataRef> { unsafe { - let nodes = LLVMGetOperand(self.value_ref, DISubprogramOperand::RetainedNodes as u32); + let nodes = LLVMGetOperand( + self.value.as_ptr(), + DISubprogramOperand::RetainedNodes as u32, + ); NonNull::new(nodes).map(|_| LLVMValueAsMetadata(nodes)) } } @@ -423,7 +500,7 @@ impl DISubprogram<'_> { pub fn set_retained_nodes(&mut self, nodes: LLVMMetadataRef) { unsafe { LLVMReplaceMDNodeOperandWith( - self.value_ref, + self.value.as_ptr(), DISubprogramOperand::RetainedNodes as u32, nodes, ) diff --git a/src/llvm/types/ir.rs b/src/llvm/types/ir.rs index e68cd43a..8f42517a 100644 --- a/src/llvm/types/ir.rs +++ b/src/llvm/types/ir.rs @@ -1,31 +1,92 @@ use std::{ - ffi::{CString, NulError}, + borrow::Cow, + ffi::{c_uchar, CString, NulError}, marker::PhantomData, ptr::NonNull, + slice, }; use llvm_sys::{ core::{ - LLVMCountParams, LLVMDisposeValueMetadataEntries, LLVMGetNumOperands, LLVMGetOperand, - LLVMGetParam, LLVMGlobalCopyAllMetadata, LLVMIsAFunction, LLVMIsAGlobalObject, - LLVMIsAInstruction, LLVMIsAMDNode, LLVMIsAUser, LLVMMDNodeInContext2, - LLVMMDStringInContext2, LLVMMetadataAsValue, LLVMPrintValueToString, - LLVMReplaceMDNodeOperandWith, LLVMValueAsMetadata, LLVMValueMetadataEntriesGetKind, - LLVMValueMetadataEntriesGetMetadata, + LLVMCountParams, LLVMDisposeValueMetadataEntries, LLVMGetModuleInlineAsm, + LLVMGetNumOperands, LLVMGetOperand, LLVMGetParam, LLVMGetValueName2, + LLVMGlobalCopyAllMetadata, LLVMIsAArgument, LLVMIsAFunction, LLVMIsAGlobalAlias, + LLVMIsAGlobalObject, LLVMIsAGlobalVariable, LLVMIsAInstruction, LLVMIsAMDNode, LLVMIsAUser, + LLVMMDNodeInContext2, LLVMMDStringInContext2, LLVMMetadataAsValue, + LLVMModuleCreateWithNameInContext, LLVMPrintValueToString, LLVMReplaceMDNodeOperandWith, + LLVMSetLinkage, LLVMSetModuleInlineAsm2, LLVMSetVisibility, LLVMValueAsMetadata, + LLVMValueMetadataEntriesGetKind, LLVMValueMetadataEntriesGetMetadata, }, debuginfo::{LLVMGetMetadataKind, LLVMGetSubprogram, LLVMMetadataKind, LLVMSetSubprogram}, - prelude::{ - LLVMBasicBlockRef, LLVMContextRef, LLVMMetadataRef, LLVMValueMetadataEntry, LLVMValueRef, - }, + prelude::{LLVMContextRef, LLVMMetadataRef, LLVMValueMetadataEntry, LLVMValueRef}, + LLVMBasicBlock, LLVMLinkage, LLVMModule, LLVMValue, LLVMVisibility, }; use crate::llvm::{ - iter::IterBasicBlocks as _, - symbol_name, - types::di::{DICompositeType, DIDerivedType, DISubprogram, DIType}, + types::{ + di::{DICompositeType, DIDerivedType, DISubprogram, DIType}, + LLVMTypeError, LLVMTypeWrapper, + }, Message, }; +pub struct Module<'ctx> { + module: NonNull<LLVMModule>, + _marker: PhantomData<&'ctx ()>, +} + +impl LLVMTypeWrapper for Module<'_> { + type Target = LLVMModule; + + fn from_ptr(module: NonNull<Self::Target>) -> Result<Self, LLVMTypeError> + where + Self: Sized, + { + Ok(Self { + module, + _marker: PhantomData, + }) + } + + fn as_ptr(&self) -> *mut Self::Target { + self.module.as_ptr() + } +} + +impl Module<'_> { + pub fn new(name: &str, context: LLVMContextRef) -> Self { + let name = CString::new(name).unwrap(); + let module = unsafe { LLVMModuleCreateWithNameInContext(name.as_ptr(), context) }; + let module = NonNull::new(module).expect(""); + Self { + module, + _marker: PhantomData, + } + } + + pub fn inline_asm(&self) -> Cow<'_, str> { + let mut len = 0; + let ptr = unsafe { LLVMGetModuleInlineAsm(self.module.as_ptr(), &mut len) }; + let asm = unsafe { slice::from_raw_parts(ptr as *const c_uchar, len) }; + String::from_utf8_lossy(asm) + } + + pub fn set_inline_asm(&mut self, asm: &str) { + let len = asm.len(); + let asm = CString::new(asm).unwrap(); + unsafe { + LLVMSetModuleInlineAsm2(self.module.as_ptr(), asm.as_ptr(), len); + } + } +} + +pub(crate) fn symbol_name<'a>(value: LLVMValueRef) -> Cow<'a, str> { + let mut len = 0; + let ptr = unsafe { LLVMGetValueName2(value, &mut len) }; + let symbol_name = unsafe { slice::from_raw_parts(ptr as *const c_uchar, len) }; + String::from_utf8_lossy(symbol_name) +} + pub(crate) fn replace_name( value_ref: LLVMValueRef, context: LLVMContextRef, @@ -42,7 +103,7 @@ pub(crate) fn replace_name( pub enum Value<'ctx> { MDNode(MDNode<'ctx>), Function(Function<'ctx>), - Other(LLVMValueRef), + Other(NonNull<LLVMValue>), } impl std::fmt::Debug for Value<'_> { @@ -60,45 +121,60 @@ impl std::fmt::Debug for Value<'_> { match self { Self::MDNode(node) => f .debug_struct("MDNode") - .field("value", &value_to_string(node.value_ref)) + .field("value", &value_to_string(node.value.as_ptr())) .finish(), Self::Function(fun) => f .debug_struct("Function") - .field("value", &value_to_string(fun.value_ref)) + .field("value", &value_to_string(fun.value.as_ptr())) .finish(), Self::Other(value) => f .debug_struct("Other") - .field("value", &value_to_string(*value)) + .field("value", &value_to_string(value.as_ptr())) .finish(), } } } -impl Value<'_> { - pub fn new(value: LLVMValueRef) -> Self { - if unsafe { !LLVMIsAMDNode(value).is_null() } { - let mdnode = unsafe { MDNode::from_value_ref(value) }; - return Value::MDNode(mdnode); - } else if unsafe { !LLVMIsAFunction(value).is_null() } { - return Value::Function(unsafe { Function::from_value_ref(value) }); +impl LLVMTypeWrapper for Value<'_> { + type Target = LLVMValue; + + fn from_ptr(value_ref: NonNull<Self::Target>) -> Result<Self, LLVMTypeError> { + if unsafe { !LLVMIsAMDNode(value_ref.as_ptr()).is_null() } { + let mdnode = MDNode::from_ptr(value_ref)?; + Ok(Value::MDNode(mdnode)) + } else if unsafe { !LLVMIsAFunction(value_ref.as_ptr()).is_null() } { + Ok(Value::Function(Function::from_ptr(value_ref)?)) + } else { + Ok(Value::Other(value_ref)) } - Value::Other(value) } + fn as_ptr(&self) -> *mut Self::Target { + match self { + Value::MDNode(mdnode) => mdnode.as_ptr(), + Value::Function(f) => f.as_ptr(), + Value::Other(value) => value.as_ptr(), + } + } +} + +impl Value<'_> { pub fn metadata_entries(&self) -> Option<MetadataEntries> { let value = match self { - Value::MDNode(node) => node.value_ref, - Value::Function(f) => f.value_ref, - Value::Other(value) => *value, + Value::MDNode(node) => node.value.as_ptr(), + Value::Function(f) => f.value.as_ptr(), + Value::Other(value) => value.as_ptr(), }; MetadataEntries::new(value) } pub fn operands(&self) -> Option<impl Iterator<Item = LLVMValueRef>> { let value = match self { - Value::MDNode(node) => Some(node.value_ref), - Value::Function(f) => Some(f.value_ref), - Value::Other(value) if unsafe { !LLVMIsAUser(*value).is_null() } => Some(*value), + Value::MDNode(node) => Some(node.value.as_ptr()), + Value::Function(f) => Some(f.value.as_ptr()), + Value::Other(value) if unsafe { !LLVMIsAUser(value.as_ptr()).is_null() } => { + Some(value.as_ptr()) + } _ => None, }; @@ -112,7 +188,7 @@ pub enum Metadata<'ctx> { DICompositeType(DICompositeType<'ctx>), DIDerivedType(DIDerivedType<'ctx>), DISubprogram(DISubprogram<'ctx>), - Other(#[allow(dead_code)] LLVMValueRef), + Other(#[allow(dead_code)] NonNull<LLVMValue>), } impl Metadata<'_> { @@ -124,21 +200,21 @@ impl Metadata<'_> { /// instance of [LLVM `Metadata`](https://llvm.org/doxygen/classllvm_1_1Metadata.html). /// It's the caller's responsibility to ensure this invariant, as this /// method doesn't perform any valiation checks. - pub(crate) unsafe fn from_value_ref(value: LLVMValueRef) -> Self { - let metadata = LLVMValueAsMetadata(value); + pub(crate) fn from_value(value: NonNull<LLVMValue>) -> Result<Self, LLVMTypeError> { + let metadata = unsafe { LLVMValueAsMetadata(value.as_ptr()) }; match unsafe { LLVMGetMetadataKind(metadata) } { LLVMMetadataKind::LLVMDICompositeTypeMetadataKind => { - let di_composite_type = unsafe { DICompositeType::from_value_ref(value) }; - Metadata::DICompositeType(di_composite_type) + let di_composite_type = DICompositeType::from_ptr(value)?; + Ok(Metadata::DICompositeType(di_composite_type)) } LLVMMetadataKind::LLVMDIDerivedTypeMetadataKind => { - let di_derived_type = unsafe { DIDerivedType::from_value_ref(value) }; - Metadata::DIDerivedType(di_derived_type) + let di_derived_type = DIDerivedType::from_ptr(value)?; + Ok(Metadata::DIDerivedType(di_derived_type)) } LLVMMetadataKind::LLVMDISubprogramMetadataKind => { - let di_subprogram = unsafe { DISubprogram::from_value_ref(value) }; - Metadata::DISubprogram(di_subprogram) + let di_subprogram = DISubprogram::from_ptr(value)?; + Ok(Metadata::DISubprogram(di_subprogram)) } LLVMMetadataKind::LLVMDIGlobalVariableMetadataKind | LLVMMetadataKind::LLVMDICommonBlockMetadataKind @@ -172,62 +248,62 @@ impl Metadata<'_> { | LLVMMetadataKind::LLVMDIStringTypeMetadataKind | LLVMMetadataKind::LLVMDIGenericSubrangeMetadataKind | LLVMMetadataKind::LLVMDIArgListMetadataKind - | LLVMMetadataKind::LLVMDIAssignIDMetadataKind => Metadata::Other(value), + | LLVMMetadataKind::LLVMDIAssignIDMetadataKind => Ok(Metadata::Other(value)), } } } impl<'ctx> TryFrom<MDNode<'ctx>> for Metadata<'ctx> { - type Error = (); + type Error = LLVMTypeError; fn try_from(md_node: MDNode) -> Result<Self, Self::Error> { // FIXME: fail if md_node isn't a Metadata node - Ok(unsafe { Self::from_value_ref(md_node.value_ref) }) + Self::from_value(md_node.value) } } /// Represents a metadata node. #[derive(Clone)] pub struct MDNode<'ctx> { - pub(super) value_ref: LLVMValueRef, + value: NonNull<LLVMValue>, _marker: PhantomData<&'ctx ()>, } +impl LLVMTypeWrapper for MDNode<'_> { + type Target = LLVMValue; + + fn from_ptr(value: NonNull<Self::Target>) -> Result<Self, LLVMTypeError> { + if unsafe { LLVMIsAMDNode(value.as_ptr()).is_null() } { + return Err(LLVMTypeError::InvalidPointerType("MDNode")); + } + Ok(Self { + value, + _marker: PhantomData, + }) + } + + fn as_ptr(&self) -> *mut Self::Target { + self.value.as_ptr() + } +} + impl MDNode<'_> { /// Constructs a new [`MDNode`] from the given `metadata`. - /// - /// # Safety - /// - /// This method assumes that the given `metadata` corresponds to a valid - /// instance of [LLVM `MDNode`](https://llvm.org/doxygen/classllvm_1_1MDNode.html). - /// It's the caller's responsibility to ensure this invariant, as this - /// method doesn't perform any validation checks. - pub(crate) unsafe fn from_metadata_ref( + #[inline] + pub(crate) fn from_metadata_ref( context: LLVMContextRef, metadata: LLVMMetadataRef, - ) -> Self { - MDNode::from_value_ref(LLVMMetadataAsValue(context, metadata)) - } - - /// Constructs a new [`MDNode`] from the given `value`. - /// - /// # Safety - /// - /// This method assumes that the provided `value` corresponds to a valid - /// instance of [LLVM `MDNode`](https://llvm.org/doxygen/classllvm_1_1MDNode.html). - /// It's the caller's responsibility to ensure this invariant, as this - /// method doesn't perform any valiation checks. - pub(crate) unsafe fn from_value_ref(value_ref: LLVMValueRef) -> Self { - Self { - value_ref, - _marker: PhantomData, - } + ) -> Result<Self, LLVMTypeError> { + let value_ref = unsafe { LLVMMetadataAsValue(context, metadata) }; + let value = NonNull::new(value_ref).ok_or(LLVMTypeError::NullPointer)?; + MDNode::from_ptr(value) } /// Constructs an empty metadata node. + #[inline] pub fn empty(context: LLVMContextRef) -> Self { let metadata = unsafe { LLVMMDNodeInContext2(context, core::ptr::null_mut(), 0) }; - unsafe { Self::from_metadata_ref(context, metadata) } + Self::from_metadata_ref(context, metadata).expect("expected a valid MDNode") } /// Constructs a new metadata node from an array of [`DIType`] elements. @@ -239,7 +315,7 @@ impl MDNode<'_> { let metadata = unsafe { let mut elements: Vec<LLVMMetadataRef> = elements .iter() - .map(|di_type| LLVMValueAsMetadata(di_type.value_ref)) + .map(|di_type| LLVMValueAsMetadata(di_type.as_ptr())) .collect(); LLVMMDNodeInContext2( context, @@ -247,7 +323,7 @@ impl MDNode<'_> { elements.len(), ) }; - unsafe { Self::from_metadata_ref(context, metadata) } + Self::from_metadata_ref(context, metadata).expect("expected a valid MDNode") } } @@ -289,51 +365,212 @@ impl Drop for MetadataEntries { } } -/// Represents a metadata node. -#[derive(Clone)] -pub struct Function<'ctx> { - pub value_ref: LLVMValueRef, +pub struct BasicBlock<'ctx> { + value: NonNull<LLVMBasicBlock>, _marker: PhantomData<&'ctx ()>, } -impl<'ctx> Function<'ctx> { - /// Constructs a new [`Function`] from the given `value`. - /// - /// # Safety - /// - /// This method assumes that the provided `value` corresponds to a valid - /// instance of [LLVM `Function`](https://llvm.org/doxygen/classllvm_1_1Function.html). - /// It's the caller's responsibility to ensure this invariant, as this - /// method doesn't perform any valiation checks. - pub(crate) unsafe fn from_value_ref(value_ref: LLVMValueRef) -> Self { - Self { - value_ref, +impl LLVMTypeWrapper for BasicBlock<'_> { + type Target = LLVMBasicBlock; + + fn from_ptr(value: NonNull<Self::Target>) -> Result<Self, LLVMTypeError> { + Ok(Self { + value, + _marker: PhantomData, + }) + } + + fn as_ptr(&self) -> *mut Self::Target { + self.value.as_ptr() + } +} + +pub trait GlobalValue: LLVMTypeWrapper<Target = LLVMValue> { + fn set_linkage(&mut self, linkage: LLVMLinkage) { + unsafe { + LLVMSetLinkage(self.as_ptr(), linkage); + } + } + + fn set_visibility(&mut self, visibility: LLVMVisibility) { + unsafe { + LLVMSetVisibility(self.as_ptr(), visibility); + } + } +} + +/// Formal argument to a [`Function`]. +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct Argument<'ctx> { + value: NonNull<LLVMValue>, + _marker: PhantomData<&'ctx ()>, +} + +impl LLVMTypeWrapper for Argument<'_> { + type Target = LLVMValue; + + fn from_ptr(value: NonNull<Self::Target>) -> Result<Self, LLVMTypeError> + where + Self: Sized, + { + if unsafe { LLVMIsAArgument(value.as_ptr()).is_null() } { + return Err(LLVMTypeError::InvalidPointerType("Argument")); + } + Ok(Self { + value, _marker: PhantomData, + }) + } + + fn as_ptr(&self) -> *mut Self::Target { + self.value.as_ptr() + } +} + +/// Represents a function. +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct Function<'ctx> { + value: NonNull<LLVMValue>, + _marker: PhantomData<&'ctx ()>, +} + +impl LLVMTypeWrapper for Function<'_> { + type Target = LLVMValue; + + fn from_ptr(value: NonNull<Self::Target>) -> Result<Self, LLVMTypeError> { + if unsafe { LLVMIsAFunction(value.as_ptr()).is_null() } { + return Err(LLVMTypeError::InvalidPointerType("Function")); } + Ok(Self { + value, + _marker: PhantomData, + }) } - pub(crate) fn name(&self) -> &str { - symbol_name(self.value_ref) + fn as_ptr(&self) -> *mut Self::Target { + self.value.as_ptr() } +} - pub(crate) fn params(&self) -> impl Iterator<Item = LLVMValueRef> { - let params_count = unsafe { LLVMCountParams(self.value_ref) }; - let value = self.value_ref; - (0..params_count).map(move |i| unsafe { LLVMGetParam(value, i) }) +impl GlobalValue for Function<'_> {} + +impl<'ctx> Function<'ctx> { + pub(crate) fn name(&self) -> Cow<'_, str> { + symbol_name(self.value.as_ptr()) } - pub(crate) fn basic_blocks(&self) -> impl Iterator<Item = LLVMBasicBlockRef> + '_ { - self.value_ref.basic_blocks_iter() + pub(crate) fn params(&self) -> impl Iterator<Item = Argument> { + let params_count = unsafe { LLVMCountParams(self.value.as_ptr()) }; + let value = self.value.as_ptr(); + (0..params_count).map(move |i| { + let ptr = unsafe { LLVMGetParam(value, i) }; + Argument::from_ptr(NonNull::new(ptr).expect("an argument should not be null")).unwrap() + }) } pub(crate) fn subprogram(&self, context: LLVMContextRef) -> Option<DISubprogram<'ctx>> { - let subprogram = unsafe { LLVMGetSubprogram(self.value_ref) }; - NonNull::new(subprogram).map(|_| unsafe { - DISubprogram::from_value_ref(LLVMMetadataAsValue(context, subprogram)) - }) + let subprogram = unsafe { LLVMGetSubprogram(self.value.as_ptr()) }; + let subprogram = NonNull::new(subprogram)?; + let value = unsafe { LLVMMetadataAsValue(context, subprogram.as_ptr()) }; + let value = NonNull::new(value)?; + Some(DISubprogram::from_ptr(value).unwrap()) } pub(crate) fn set_subprogram(&mut self, subprogram: &DISubprogram) { - unsafe { LLVMSetSubprogram(self.value_ref, LLVMValueAsMetadata(subprogram.value_ref)) }; + unsafe { + LLVMSetSubprogram( + self.value.as_ptr(), + LLVMValueAsMetadata(subprogram.as_ptr()), + ) + }; + } +} + +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct GlobalAlias<'ctx> { + value: NonNull<LLVMValue>, + _marker: PhantomData<&'ctx ()>, +} + +impl LLVMTypeWrapper for GlobalAlias<'_> { + type Target = LLVMValue; + + fn from_ptr(value: NonNull<Self::Target>) -> Result<Self, LLVMTypeError> { + if unsafe { LLVMIsAGlobalAlias(value.as_ptr()).is_null() } { + return Err(LLVMTypeError::InvalidPointerType("GlobalAlias")); + } + Ok(Self { + value, + _marker: PhantomData, + }) + } + + fn as_ptr(&self) -> *mut Self::Target { + self.value.as_ptr() + } +} + +impl GlobalValue for GlobalAlias<'_> {} + +impl GlobalAlias<'_> { + pub fn name<'a>(&self) -> Cow<'a, str> { + symbol_name(self.value.as_ptr()) + } +} + +/// Represents a global variable. +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct GlobalVariable<'ctx> { + value: NonNull<LLVMValue>, + _marker: PhantomData<&'ctx ()>, +} + +impl LLVMTypeWrapper for GlobalVariable<'_> { + type Target = LLVMValue; + + fn from_ptr(value: NonNull<Self::Target>) -> Result<Self, LLVMTypeError> { + if unsafe { LLVMIsAGlobalVariable(value.as_ptr()).is_null() } { + return Err(LLVMTypeError::InvalidPointerType("GlobalVariable")); + } + Ok(Self { + value, + _marker: PhantomData, + }) + } + + fn as_ptr(&self) -> *mut Self::Target { + self.value.as_ptr() + } +} + +impl GlobalValue for GlobalVariable<'_> {} + +impl GlobalVariable<'_> { + pub fn name<'a>(&self) -> Cow<'a, str> { + symbol_name(self.value.as_ptr()) + } +} + +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct Instruction<'ctx> { + value: NonNull<LLVMValue>, + _marker: PhantomData<&'ctx ()>, +} + +impl LLVMTypeWrapper for Instruction<'_> { + type Target = LLVMValue; + + fn from_ptr(value: NonNull<Self::Target>) -> Result<Self, LLVMTypeError> + where + Self: Sized, + { + Ok(Self { + value, + _marker: PhantomData, + }) + } + + fn as_ptr(&self) -> *mut Self::Target { + self.value.as_ptr() } } diff --git a/src/llvm/types/mod.rs b/src/llvm/types/mod.rs index ac874bda..207c9a70 100644 --- a/src/llvm/types/mod.rs +++ b/src/llvm/types/mod.rs @@ -1,2 +1,24 @@ +use std::ptr::NonNull; + +use thiserror::Error; + pub mod di; pub mod ir; + +#[derive(Debug, Error)] +pub enum LLVMTypeError { + #[error("invalid pointer type, expected {0}")] + InvalidPointerType(&'static str), + #[error("null pointer")] + NullPointer, +} + +pub trait LLVMTypeWrapper { + type Target: Sized; + + /// Constructs a new [`Self`] from the given pointer `ptr`. + fn from_ptr(ptr: NonNull<Self::Target>) -> Result<Self, LLVMTypeError> + where + Self: Sized; + fn as_ptr(&self) -> *mut Self::Target; +}