From ff12197f75f54b15a6f10ba13784dd44b66d6e73 Mon Sep 17 00:00:00 2001 From: Michal Rostecki Date: Fri, 20 Dec 2024 17:47:25 +0100 Subject: [PATCH] =?UTF-8?q?=F0=9F=94=92=F0=9F=8C=AF=20Provide=20`LLVMTypeW?= =?UTF-8?q?rapper`=20trait,=20add=20more=20wrappers?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Instead of keeping raw pointer fields (of types like `LLVMValueRef`, `LLVMMetadataRef`) public and defining constructors with different names, provide the `LLVMTypeWrapper` trait with `from_ptr()` and `as_ptr()` methods. This will allow to convert all safe wrappers from and to raw pointers with one method, which is going to be helpful for building macros for them. On top of that, add more wrappers: * `Module` * `BasicBlock` * `GlobalAlias` * `GlobalVariable` * `Argument` Use these wrappers in iterators, make sure they don't expose the raw pointers to the callers. --- src/linker.rs | 69 ++++--- src/llvm/di.rs | 94 +++++---- src/llvm/iter.rs | 137 +++++++++---- src/llvm/mod.rs | 80 +++----- src/llvm/types/di.rs | 311 ++++++++++++++++++----------- src/llvm/types/ir.rs | 449 ++++++++++++++++++++++++++++++++---------- src/llvm/types/mod.rs | 22 +++ 7 files changed, 785 insertions(+), 377 deletions(-) diff --git a/src/linker.rs b/src/linker.rs index 8c183f68..a93459e9 100644 --- a/src/linker.rs +++ b/src/linker.rs @@ -14,18 +14,15 @@ use std::{ use ar::Archive; use llvm_sys::{ bit_writer::LLVMWriteBitcodeToFile, - core::{ - LLVMContextCreate, LLVMContextDispose, LLVMContextSetDiagnosticHandler, LLVMDisposeModule, - LLVMGetTarget, - }, + core::{LLVMContextCreate, LLVMContextDispose, LLVMContextSetDiagnosticHandler, LLVMGetTarget}, error_handling::{LLVMEnablePrettyStackTrace, LLVMInstallFatalErrorHandler}, - prelude::{LLVMContextRef, LLVMModuleRef}, + prelude::LLVMContextRef, target_machine::{LLVMCodeGenFileType, LLVMDisposeTargetMachine, LLVMTargetMachineRef}, }; use thiserror::Error; use tracing::{debug, error, info, warn}; -use crate::llvm; +use crate::llvm::{self, LLVMTypeError, LLVMTypeWrapper, Module}; /// Linker error #[derive(Debug, Error)] @@ -77,6 +74,10 @@ pub enum LinkerError { /// The input object file does not have embedded bitcode. #[error("no bitcode section found in {0}")] MissingBitcodeSection(PathBuf), + + /// Instantiating of an LLVM type failed. + #[error(transparent)] + LLVMType(#[from] LLVMTypeError), } /// BPF Cpu type @@ -224,21 +225,26 @@ pub struct LinkerOptions { } /// BPF Linker -pub struct Linker { +pub struct Linker<'ctx> { options: LinkerOptions, context: LLVMContextRef, - module: LLVMModuleRef, + module: Module<'ctx>, target_machine: LLVMTargetMachineRef, diagnostic_handler: DiagnosticHandler, } -impl Linker { +impl Linker<'_> { /// Create a new linker instance with the given options. pub fn new(options: LinkerOptions) -> Self { + let context = unsafe { LLVMContextCreate() }; + let module = Module::new( + options.output.file_stem().unwrap().to_str().unwrap(), + context, + ); Linker { options, - context: ptr::null_mut(), - module: ptr::null_mut(), + context, + module, target_machine: ptr::null_mut(), diagnostic_handler: DiagnosticHandler::new(), } @@ -365,7 +371,7 @@ impl Linker { Archive => panic!("nested archives not supported duh"), }; - if unsafe { !llvm::link_bitcode_buffer(self.context, self.module, &bitcode) } { + if unsafe { !llvm::link_bitcode_buffer(self.context, self.module.as_ptr(), &bitcode) } { return Err(LinkerError::LinkModuleError(path.to_owned())); } @@ -407,11 +413,11 @@ impl Linker { }) } None => { - let c_triple = unsafe { LLVMGetTarget(*module) }; + let c_triple = unsafe { LLVMGetTarget(module.as_ptr()) }; let triple = unsafe { CStr::from_ptr(c_triple) }.to_str().unwrap(); if triple.starts_with("bpf") { // case 2 - (triple, unsafe { llvm::target_from_module(*module) }) + (triple, unsafe { llvm::target_from_module(module.as_ptr()) }) } else { // case 3. info!("detected non-bpf input target {} and no explicit output --target specified, selecting `bpf'", triple); @@ -452,17 +458,18 @@ impl Linker { if self.options.btf { // if we want to emit BTF, we need to sanitize the debug information - llvm::DISanitizer::new(self.context, self.module).run(&self.options.export_symbols); + llvm::DISanitizer::new(self.context, &self.module) + .run(&mut self.module, &self.options.export_symbols)?; } else { // if we don't need BTF emission, we can strip DI - let ok = unsafe { llvm::strip_debug_info(self.module) }; + let ok = unsafe { llvm::strip_debug_info(self.module.as_ptr()) }; debug!("Stripping DI, changed={}", ok); } unsafe { llvm::optimize( self.target_machine, - self.module, + &mut self.module, self.options.optimize, self.options.ignore_inline_never, &self.options.export_symbols, @@ -486,7 +493,7 @@ impl Linker { fn write_bitcode(&mut self, output: &CStr) -> Result<(), LinkerError> { info!("writing bitcode to {:?}", output); - if unsafe { LLVMWriteBitcodeToFile(self.module, output.as_ptr()) } == 1 { + if unsafe { LLVMWriteBitcodeToFile(self.module.as_ptr(), output.as_ptr()) } == 1 { return Err(LinkerError::WriteBitcodeError); } @@ -496,14 +503,21 @@ impl Linker { fn write_ir(&mut self, output: &CStr) -> Result<(), LinkerError> { info!("writing IR to {:?}", output); - unsafe { llvm::write_ir(self.module, output) }.map_err(LinkerError::WriteIRError) + unsafe { llvm::write_ir(self.module.as_ptr(), output) }.map_err(LinkerError::WriteIRError) } fn emit(&mut self, output: &CStr, output_type: LLVMCodeGenFileType) -> Result<(), LinkerError> { info!("emitting {:?} to {:?}", output_type, output); - unsafe { llvm::codegen(self.target_machine, self.module, output, output_type) } - .map_err(LinkerError::EmitCodeError) + unsafe { + llvm::codegen( + self.target_machine, + self.module.as_ptr(), + output, + output_type, + ) + } + .map_err(LinkerError::EmitCodeError) } fn llvm_init(&mut self) { @@ -542,24 +556,23 @@ impl Linker { ); LLVMInstallFatalErrorHandler(Some(llvm::fatal_error)); LLVMEnablePrettyStackTrace(); - self.module = llvm::create_module( + self.module = Module::new( self.options.output.file_stem().unwrap().to_str().unwrap(), self.context, - ) - .unwrap(); + ); } } } -impl Drop for Linker { +impl Drop for Linker<'_> { fn drop(&mut self) { unsafe { if !self.target_machine.is_null() { LLVMDisposeTargetMachine(self.target_machine); } - if !self.module.is_null() { - LLVMDisposeModule(self.module); - } + // if !self.module.is_null() { + // LLVMDisposeModule(self.module); + // } if !self.context.is_null() { LLVMContextDispose(self.context); } diff --git a/src/llvm/di.rs b/src/llvm/di.rs index c4a939d4..2767b625 100644 --- a/src/llvm/di.rs +++ b/src/llvm/di.rs @@ -3,18 +3,23 @@ use std::{ collections::{hash_map::DefaultHasher, HashMap, HashSet}, ffi::c_char, hash::Hasher, - ptr, + ptr::{self, NonNull}, }; use gimli::{DW_TAG_pointer_type, DW_TAG_structure_type, DW_TAG_variant_part}; use llvm_sys::{core::*, debuginfo::*, prelude::*}; use tracing::{span, trace, warn, Level}; -use super::types::{ - di::DIType, - ir::{Function, MDNode, Metadata, Value}, +use crate::llvm::{ + iter::*, + types::{ + di::{DISubprogram, DIType}, + ir::{Function, Instruction, MDNode, Metadata, Module, Value}, + LLVMTypeError, LLVMTypeWrapper, + }, }; -use crate::llvm::{iter::*, types::di::DISubprogram}; + +use super::types::ir::{Argument, GlobalAlias, GlobalVariable}; // KSYM_NAME_LEN from linux kernel intentionally set // to lower value found accross kernel versions to ensure @@ -23,7 +28,6 @@ const MAX_KSYM_NAME_LEN: usize = 128; pub struct DISanitizer { context: LLVMContextRef, - module: LLVMModuleRef, builder: LLVMDIBuilderRef, visited_nodes: HashSet, replace_operands: HashMap, @@ -59,11 +63,11 @@ fn sanitize_type_name>(name: T) -> String { } impl DISanitizer { - pub fn new(context: LLVMContextRef, module: LLVMModuleRef) -> DISanitizer { + pub fn new(context: LLVMContextRef, module: &Module<'_>) -> Self { + let builder = unsafe { LLVMCreateDIBuilder(module.as_ptr()) }; DISanitizer { context, - module, - builder: unsafe { LLVMCreateDIBuilder(module) }, + builder, visited_nodes: HashSet::new(), replace_operands: HashMap::new(), skipped_types: Vec::new(), @@ -227,7 +231,7 @@ impl DISanitizer { // An operand with no value is valid and means that the operand is // not set (v, Item::Operand { .. }) if v.is_null() => return, - (v, _) if !v.is_null() => Value::new(v), + (v, _) if !v.is_null() => Value::from_ptr(NonNull::new(v).unwrap()).unwrap(), // All other items should have values (_, item) => panic!("{item:?} has no value"), }; @@ -283,16 +287,18 @@ impl DISanitizer { } } - pub fn run(mut self, exported_symbols: &HashSet>) { - let module = self.module; - - self.replace_operands = self.fix_subprogram_linkage(exported_symbols); + pub fn run( + mut self, + module: &mut Module<'_>, + exported_symbols: &HashSet>, + ) -> Result<(), LLVMTypeError> { + self.replace_operands = self.fix_subprogram_linkage(module, exported_symbols)?; - for value in module.globals_iter() { - self.visit_item(Item::GlobalVariable(value)); + for global in module.globals_iter() { + self.visit_item(Item::GlobalVariable(global)); } - for value in module.global_aliases_iter() { - self.visit_item(Item::GlobalAlias(value)); + for alias in module.global_aliases_iter() { + self.visit_item(Item::GlobalAlias(alias)); } for function in module.functions_iter() { @@ -307,6 +313,8 @@ impl DISanitizer { } unsafe { LLVMDisposeDIBuilder(self.builder) }; + + Ok(()) } // Make it so that only exported symbols (programs marked as #[no_mangle]) get BTF @@ -324,16 +332,13 @@ impl DISanitizer { // See tests/btf/assembly/exported-symbols.rs . fn fix_subprogram_linkage( &mut self, + module: &mut Module<'_>, export_symbols: &HashSet>, - ) -> HashMap { + ) -> Result, LLVMTypeError> { let mut replace = HashMap::new(); - for mut function in self - .module - .functions_iter() - .map(|value| unsafe { Function::from_value_ref(value) }) - { - if export_symbols.contains(function.name()) { + for mut function in module.functions_iter() { + if export_symbols.contains(&function.name()) { continue; } @@ -370,7 +375,10 @@ impl DISanitizer { // replace retained nodes manually below. LLVMDIBuilderFinalizeSubprogram(self.builder, new_program); - DISubprogram::from_value_ref(LLVMMetadataAsValue(self.context, new_program)) + let new_program = LLVMMetadataAsValue(self.context, new_program); + let new_program = + NonNull::new(new_program).expect("new program should not be null"); + DISubprogram::from_ptr(new_program)? }; // Point the function to the new subprogram. @@ -396,23 +404,23 @@ impl DISanitizer { unsafe { LLVMMDNodeInContext2(self.context, core::ptr::null_mut(), 0) }; subprogram.set_retained_nodes(empty_node); - let ret = replace.insert(subprogram.value_ref as u64, unsafe { - LLVMValueAsMetadata(new_program.value_ref) + let ret = replace.insert(subprogram.as_ptr() as u64, unsafe { + LLVMValueAsMetadata(new_program.as_ptr()) }); assert!(ret.is_none()); } - replace + Ok(replace) } } #[derive(Clone, Debug, Eq, PartialEq)] -enum Item { - GlobalVariable(LLVMValueRef), - GlobalAlias(LLVMValueRef), - Function(LLVMValueRef), - FunctionParam(LLVMValueRef), - Instruction(LLVMValueRef), +enum Item<'ctx> { + GlobalVariable(GlobalVariable<'ctx>), + GlobalAlias(GlobalAlias<'ctx>), + Function(Function<'ctx>), + FunctionParam(Argument<'ctx>), + Instruction(Instruction<'ctx>), Operand(Operand), MetadataEntry(LLVMValueRef, u32, usize), } @@ -437,16 +445,16 @@ impl Operand { } } -impl Item { +impl Item<'_> { fn value_ref(&self) -> LLVMValueRef { match self { - Item::GlobalVariable(value) - | Item::GlobalAlias(value) - | Item::Function(value) - | Item::FunctionParam(value) - | Item::Instruction(value) - | Item::Operand(Operand { value, .. }) - | Item::MetadataEntry(value, _, _) => *value, + Item::GlobalVariable(global) => global.as_ptr(), + Item::GlobalAlias(global) => global.as_ptr(), + Item::Function(function) => function.as_ptr(), + Item::FunctionParam(function_param) => function_param.as_ptr(), + Item::Instruction(instruction) => instruction.as_ptr(), + Item::Operand(Operand { value, .. }) => *value, + Item::MetadataEntry(value, _, _) => *value, } } diff --git a/src/llvm/iter.rs b/src/llvm/iter.rs index e1d77198..940b2da1 100644 --- a/src/llvm/iter.rs +++ b/src/llvm/iter.rs @@ -3,55 +3,98 @@ use std::marker::PhantomData; use llvm_sys::{ core::{ LLVMGetFirstBasicBlock, LLVMGetFirstFunction, LLVMGetFirstGlobal, LLVMGetFirstGlobalAlias, - LLVMGetFirstInstruction, LLVMGetLastBasicBlock, LLVMGetLastFunction, LLVMGetLastGlobal, - LLVMGetLastGlobalAlias, LLVMGetLastInstruction, LLVMGetNextBasicBlock, LLVMGetNextFunction, - LLVMGetNextGlobal, LLVMGetNextGlobalAlias, LLVMGetNextInstruction, + LLVMGetFirstInstruction, LLVMGetNextBasicBlock, LLVMGetNextFunction, LLVMGetNextGlobal, + LLVMGetNextGlobalAlias, LLVMGetNextInstruction, LLVMGetPreviousBasicBlock, + LLVMGetPreviousFunction, LLVMGetPreviousGlobal, LLVMGetPreviousGlobalAlias, + LLVMGetPreviousInstruction, }, - prelude::{LLVMBasicBlockRef, LLVMModuleRef, LLVMValueRef}, + LLVMBasicBlock, LLVMValue, +}; + +use crate::llvm::types::ir::{ + BasicBlock, Function, GlobalAlias, GlobalVariable, Instruction, Module, }; macro_rules! llvm_iterator { - ($trait_name:ident, $iterator_name:ident, $iterable:ty, $method_name:ident, $item_ty:ty, $first:expr, $last:expr, $next:expr $(,)?) => { + ( + $trait_name:ident, + $iterator_name:ident, + $iterable:ident, + $method_name:ident, + $ptr_ty:ty, + $item_ty:ident, + $first:expr, + $last:expr, + $next:expr, + $prev:expr $(,)? + ) => { pub trait $trait_name { fn $method_name(&self) -> $iterator_name; } - pub struct $iterator_name<'a> { - lifetime: PhantomData<&'a $iterable>, - next: $item_ty, - last: $item_ty, + pub struct $iterator_name<'ctx> { + lifetime: PhantomData<&'ctx $iterable<'ctx>>, + current: Option<::std::ptr::NonNull<$ptr_ty>>, } - impl $trait_name for $iterable { + impl $trait_name for $iterable<'_> { fn $method_name(&self) -> $iterator_name { - let first = unsafe { $first(*self) }; - let last = unsafe { $last(*self) }; - assert_eq!(first.is_null(), last.is_null()); + #[allow(unused_imports)] + use $crate::llvm::types::LLVMTypeWrapper as _; + + let first = unsafe { $first(self.as_ptr()) }; + let first = ::std::ptr::NonNull::new(first); + + // let last = unsafe { $last(*self) }; + // let last = std::ptr::NonNull::new(last).map(|ptr| $item_ty::from_ptr(ptr).unwrap()); + + // assert_eq!(first.is_none(), last.is_none()); + $iterator_name { lifetime: PhantomData, - next: first, - last, + current: first, + // last, } } } - impl<'a> Iterator for $iterator_name<'a> { - type Item = $item_ty; + impl<'ctx> Iterator for $iterator_name<'ctx> { + type Item = $item_ty<'ctx>; fn next(&mut self) -> Option { - let Self { - lifetime: _, - next, - last, - } = self; - if next.is_null() { - return None; + #[allow(unused_imports)] + use $crate::llvm::types::LLVMTypeWrapper as _; + + if let Some(item) = self.current { + let next = unsafe { $next(item.as_ptr()) }; + let next = std::ptr::NonNull::new(next); + self.current = next; + + let item = $item_ty::from_ptr(item).unwrap(); + + Some(item) + } else { + None + } + } + } + + impl<'ctx> DoubleEndedIterator for $iterator_name<'ctx> { + fn next_back(&mut self) -> Option { + #[allow(unused_imports)] + use $crate::llvm::types::LLVMTypeWrapper as _; + + if let Some(item) = self.current { + let prev = unsafe { $prev(item.as_ptr()) }; + let prev = std::ptr::NonNull::new(prev); + self.current = prev; + + let item = $item_ty::from_ptr(item).unwrap(); + + Some(item) + } else { + None } - let last = *next == *last; - let item = *next; - *next = unsafe { $next(*next) }; - assert_eq!(next.is_null(), last); - Some(item) } } }; @@ -60,54 +103,64 @@ macro_rules! llvm_iterator { llvm_iterator! { IterModuleGlobals, GlobalsIter, - LLVMModuleRef, + Module, globals_iter, - LLVMValueRef, + LLVMValue, + GlobalVariable, LLVMGetFirstGlobal, LLVMGetLastGlobal, LLVMGetNextGlobal, + LLVMGetPreviousGlobal, } llvm_iterator! { IterModuleGlobalAliases, GlobalAliasesIter, - LLVMModuleRef, + Module, global_aliases_iter, - LLVMValueRef, + LLVMValue, + GlobalAlias, LLVMGetFirstGlobalAlias, LLVMGetLastGlobalAlias, LLVMGetNextGlobalAlias, + LLVMGetPreviousGlobalAlias, } llvm_iterator! { IterModuleFunctions, FunctionsIter, - LLVMModuleRef, + Module, functions_iter, - LLVMValueRef, + LLVMValue, + Function, LLVMGetFirstFunction, LLVMGetLastFunction, LLVMGetNextFunction, + LLVMGetPreviousFunction, } llvm_iterator!( IterBasicBlocks, BasicBlockIter, - LLVMValueRef, - basic_blocks_iter, - LLVMBasicBlockRef, + Function, + basic_blocks, + LLVMBasicBlock, + BasicBlock, LLVMGetFirstBasicBlock, LLVMGetLastBasicBlock, - LLVMGetNextBasicBlock + LLVMGetNextBasicBlock, + LLVMGetPreviousBasicBlock, ); llvm_iterator!( IterInstructions, InstructionsIter, - LLVMBasicBlockRef, + BasicBlock, instructions_iter, - LLVMValueRef, + LLVMValue, + Instruction, LLVMGetFirstInstruction, LLVMGetLastInstruction, - LLVMGetNextInstruction + LLVMGetNextInstruction, + LLVMGetPreviousInstruction ); diff --git a/src/llvm/mod.rs b/src/llvm/mod.rs index 86de1b04..c1c13a49 100644 --- a/src/llvm/mod.rs +++ b/src/llvm/mod.rs @@ -1,7 +1,3 @@ -mod di; -mod iter; -mod types; - use std::{ borrow::Cow, collections::HashSet, @@ -18,9 +14,7 @@ use llvm_sys::{ core::{ LLVMCreateMemoryBufferWithMemoryRange, LLVMDisposeMemoryBuffer, LLVMDisposeMessage, LLVMGetDiagInfoDescription, LLVMGetDiagInfoSeverity, LLVMGetEnumAttributeKindForName, - LLVMGetMDString, LLVMGetModuleInlineAsm, LLVMGetTarget, LLVMGetValueName2, - LLVMModuleCreateWithNameInContext, LLVMPrintModuleToFile, LLVMRemoveEnumAttributeAtIndex, - LLVMSetLinkage, LLVMSetModuleInlineAsm2, LLVMSetVisibility, + LLVMGetMDString, LLVMGetTarget, LLVMPrintModuleToFile, LLVMRemoveEnumAttributeAtIndex, }, debuginfo::LLVMStripModuleDebugInfo, error::{ @@ -49,9 +43,15 @@ use llvm_sys::{ LLVMAttributeFunctionIndex, LLVMLinkage, LLVMVisibility, }; use tracing::{debug, error}; +use types::ir::{Function, GlobalValue}; use crate::OptLevel; +mod di; +mod iter; +mod types; +pub use types::{ir::Module, LLVMTypeError, LLVMTypeWrapper}; + pub unsafe fn init>(args: &[T], overview: &str) { LLVMInitializeBPFTarget(); LLVMInitializeBPFTargetMC(); @@ -73,17 +73,6 @@ unsafe fn parse_command_line_options>(args: &[T], overview: &str) LLVMParseCommandLineOptions(c_ptrs.len() as i32, c_ptrs.as_ptr(), overview.as_ptr()); } -pub unsafe fn create_module(name: &str, context: LLVMContextRef) -> Option { - let c_name = CString::new(name).unwrap(); - let module = LLVMModuleCreateWithNameInContext(c_name.as_ptr(), context); - - if module.is_null() { - return None; - } - - Some(module) -} - pub unsafe fn find_embedded_bitcode( context: LLVMContextRef, data: &[u8], @@ -192,29 +181,32 @@ pub unsafe fn create_target_machine( pub unsafe fn optimize( tm: LLVMTargetMachineRef, - module: LLVMModuleRef, + module: &mut Module<'_>, opt_level: OptLevel, ignore_inline_never: bool, export_symbols: &HashSet>, ) -> Result<(), String> { if module_asm_is_probestack(module) { - LLVMSetModuleInlineAsm2(module, ptr::null_mut(), 0); + module.set_inline_asm(""); } - for sym in module.globals_iter() { - internalize(sym, symbol_name(sym), export_symbols); + for mut sym in module.globals_iter() { + let name = sym.name(); + internalize(&mut sym, &name, export_symbols); } - for sym in module.global_aliases_iter() { - internalize(sym, symbol_name(sym), export_symbols); + for mut sym in module.global_aliases_iter() { + let name = sym.name(); + internalize(&mut sym, &name, export_symbols); } - for function in module.functions_iter() { - let name = symbol_name(function); - if !name.starts_with("llvm.") { + for mut function in module.functions_iter() { + let name = function.name().into_owned(); + let to_internalize = !name.starts_with("llvm."); + if to_internalize { if ignore_inline_never { - remove_attribute(function, "noinline"); + remove_attribute(&mut function, "noinline"); } - internalize(function, name, export_symbols); + internalize(&mut function, &name, export_symbols); } } @@ -239,7 +231,7 @@ pub unsafe fn optimize( debug!("running passes: {passes}"); let passes = CString::new(passes).unwrap(); let options = LLVMCreatePassBuilderOptions(); - let error = LLVMRunPasses(module, passes.as_ptr(), tm, options); + let error = LLVMRunPasses(module.as_ptr(), passes.as_ptr(), tm, options); LLVMDisposePassBuilderOptions(options); // Handle the error and print it to stderr. if !error.is_null() { @@ -260,26 +252,14 @@ pub unsafe fn strip_debug_info(module: LLVMModuleRef) -> bool { LLVMStripModuleDebugInfo(module) != 0 } -unsafe fn module_asm_is_probestack(module: LLVMModuleRef) -> bool { - let mut len = 0; - let ptr = LLVMGetModuleInlineAsm(module, &mut len); - if ptr.is_null() { - return false; - } - - let asm = String::from_utf8_lossy(slice::from_raw_parts(ptr as *const c_uchar, len)); +fn module_asm_is_probestack(module: &Module) -> bool { + let asm = module.inline_asm(); asm.contains("__rust_probestack") } -fn symbol_name<'a>(value: *mut llvm_sys::LLVMValue) -> &'a str { - let mut name_len = 0; - let ptr = unsafe { LLVMGetValueName2(value, &mut name_len) }; - unsafe { str::from_utf8(slice::from_raw_parts(ptr as *const c_uchar, name_len)).unwrap() } -} - -unsafe fn remove_attribute(function: *mut llvm_sys::LLVMValue, name: &str) { +unsafe fn remove_attribute(function: &mut Function, name: &str) { let attr_kind = LLVMGetEnumAttributeKindForName(name.as_ptr() as *const c_char, name.len()); - LLVMRemoveEnumAttributeAtIndex(function, LLVMAttributeFunctionIndex, attr_kind); + LLVMRemoveEnumAttributeAtIndex(function.as_ptr(), LLVMAttributeFunctionIndex, attr_kind); } pub unsafe fn write_ir(module: LLVMModuleRef, output: &CStr) -> Result<(), String> { @@ -308,14 +288,14 @@ pub unsafe fn codegen( } } -pub unsafe fn internalize( - value: LLVMValueRef, +pub unsafe fn internalize( + value: &mut T, name: &str, export_symbols: &HashSet>, ) { if !name.starts_with("llvm.") && !export_symbols.contains(name) { - LLVMSetLinkage(value, LLVMLinkage::LLVMInternalLinkage); - LLVMSetVisibility(value, LLVMVisibility::LLVMDefaultVisibility); + value.set_linkage(LLVMLinkage::LLVMInternalLinkage); + value.set_visibility(LLVMVisibility::LLVMDefaultVisibility); } } diff --git a/src/llvm/types/di.rs b/src/llvm/types/di.rs index 0edc4c61..8c804211 100644 --- a/src/llvm/types/di.rs +++ b/src/llvm/types/di.rs @@ -11,14 +11,18 @@ use llvm_sys::{ debuginfo::{ LLVMDIFileGetFilename, LLVMDIFlags, LLVMDIScopeGetFile, LLVMDISubprogramGetLine, LLVMDITypeGetFlags, LLVMDITypeGetLine, LLVMDITypeGetName, LLVMDITypeGetOffsetInBits, - LLVMGetDINodeTag, + LLVMGetDINodeTag, LLVMGetMetadataKind, LLVMMetadataKind, }, - prelude::{LLVMContextRef, LLVMMetadataRef, LLVMValueRef}, + prelude::{LLVMContextRef, LLVMMetadataRef}, + LLVMOpaqueMetadata, LLVMValue, }; use crate::llvm::{ mdstring_to_str, - types::ir::{MDNode, Metadata}, + types::{ + ir::{MDNode, Metadata}, + LLVMTypeError, LLVMTypeWrapper, + }, }; /// Returns a DWARF tag for the given debug info node. @@ -41,26 +45,30 @@ unsafe fn di_node_tag(metadata_ref: LLVMMetadataRef) -> DwTag { /// A `DIFile` debug info node, which represents a given file, is referenced by /// other debug info nodes which belong to the file. pub struct DIFile<'ctx> { - pub(super) metadata_ref: LLVMMetadataRef, + metadata: NonNull, _marker: PhantomData<&'ctx ()>, } -impl DIFile<'_> { - /// Constructs a new [`DIFile`] from the given `metadata`. - /// - /// # Safety - /// - /// This method assumes that the given `metadata` corresponds to a valid - /// instance of [LLVM `DIFile`](https://llvm.org/doxygen/classllvm_1_1DIFile.html). - /// It's the caller's responsibility to ensure this invariant, as this - /// method doesn't perform any validation checks. - pub(crate) unsafe fn from_metadata_ref(metadata_ref: LLVMMetadataRef) -> Self { - Self { - metadata_ref, - _marker: PhantomData, +impl LLVMTypeWrapper for DIFile<'_> { + type Target = LLVMOpaqueMetadata; + + fn from_ptr(metadata: NonNull) -> Result { + let metadata_kind = unsafe { LLVMGetMetadataKind(metadata.as_ptr()) }; + if !matches!(metadata_kind, LLVMMetadataKind::LLVMDIFileMetadataKind) { + return Err(LLVMTypeError::InvalidPointerType("DIFile")); } + Ok(Self { + metadata, + _marker: PhantomData, + }) + } + + fn as_ptr(&self) -> *mut Self::Target { + self.metadata.as_ptr() } +} +impl DIFile<'_> { pub fn filename(&self) -> Option<&CStr> { let mut len = 0; // `LLVMDIFileGetName` doesn't allocate any memory, it just returns @@ -69,7 +77,7 @@ impl DIFile<'_> { // // Therefore, we don't need to call `LLVMDisposeMessage`. The memory // gets freed when calling `LLVMDisposeDIBuilder`. - let ptr = unsafe { LLVMDIFileGetFilename(self.metadata_ref, &mut len) }; + let ptr = unsafe { LLVMDIFileGetFilename(self.metadata.as_ptr(), &mut len) }; NonNull::new(ptr as *mut _).map(|ptr| unsafe { CStr::from_ptr(ptr.as_ptr()) }) } } @@ -109,39 +117,57 @@ unsafe fn di_type_name<'a>(metadata_ref: LLVMMetadataRef) -> Option<&'a CStr> { /// Represents the debug information for a primitive type in LLVM IR. pub struct DIType<'ctx> { - pub(super) metadata_ref: LLVMMetadataRef, - pub(super) value_ref: LLVMValueRef, + metadata: NonNull, + value: NonNull, _marker: PhantomData<&'ctx ()>, } -impl DIType<'_> { - /// Constructs a new [`DIType`] from the given `value`. - /// - /// # Safety - /// - /// This method assumes that the given `value` corresponds to a valid - /// instance of [LLVM `DIType`](https://llvm.org/doxygen/classllvm_1_1DIType.html). - /// It's the caller's responsibility to ensure this invariant, as this - /// method doesn't perform any validation checks. - pub unsafe fn from_value_ref(value_ref: LLVMValueRef) -> Self { - let metadata_ref = unsafe { LLVMValueAsMetadata(value_ref) }; - Self { - metadata_ref, - value_ref, - _marker: PhantomData, +impl LLVMTypeWrapper for DIType<'_> { + type Target = LLVMValue; + + fn from_ptr(value: NonNull) -> Result { + let metadata = unsafe { LLVMValueAsMetadata(value.as_ptr()) }; + let metadata = NonNull::new(metadata).ok_or(LLVMTypeError::NullPointer)?; + let metadata_kind = unsafe { LLVMGetMetadataKind(metadata.as_ptr()) }; + // The children of `DIType` are: + // + // - `DIBasicType` + // - `DICompositeType` + // - `DIDerivedType` + // - `DIStringType` + // - `DISubroutimeType` + // + // https://llvm.org/doxygen/classllvm_1_1DIType.html + match metadata_kind { + LLVMMetadataKind::LLVMDIBasicTypeMetadataKind + | LLVMMetadataKind::LLVMDICompositeTypeMetadataKind + | LLVMMetadataKind::LLVMDIDerivedTypeMetadataKind + | LLVMMetadataKind::LLVMDIStringTypeMetadataKind + | LLVMMetadataKind::LLVMDISubroutineTypeMetadataKind => Ok(Self { + metadata, + value, + _marker: PhantomData, + }), + _ => Err(LLVMTypeError::InvalidPointerType("DIType")), } } + fn as_ptr(&self) -> *mut Self::Target { + self.value.as_ptr() + } +} + +impl DIType<'_> { /// Returns the offset of the type in bits. This offset is used in case the /// type is a member of a composite type. pub fn offset_in_bits(&self) -> usize { - unsafe { LLVMDITypeGetOffsetInBits(self.metadata_ref) as usize } + unsafe { LLVMDITypeGetOffsetInBits(self.metadata.as_ptr()) as usize } } } impl<'ctx> From> for DIType<'ctx> { fn from(di_derived_type: DIDerivedType) -> Self { - unsafe { Self::from_value_ref(di_derived_type.value_ref) } + Self::from_ptr(di_derived_type.value).unwrap() } } @@ -160,34 +186,45 @@ enum DIDerivedTypeOperand { /// alternative name. The examples of derived types are pointers, references, /// typedefs, etc. pub struct DIDerivedType<'ctx> { - metadata_ref: LLVMMetadataRef, - value_ref: LLVMValueRef, + metadata: NonNull, + value: NonNull, _marker: PhantomData<&'ctx ()>, } -impl DIDerivedType<'_> { - /// Constructs a new [`DIDerivedType`] from the given `value`. - /// - /// # Safety - /// - /// This method assumes that the provided `value` corresponds to a valid - /// instance of [LLVM `DIDerivedType`](https://llvm.org/doxygen/classllvm_1_1DIDerivedType.html). - /// It's the caller's responsibility to ensure this invariant, as this - /// method doesn't perform any validation checks. - pub unsafe fn from_value_ref(value_ref: LLVMValueRef) -> Self { - let metadata_ref = LLVMValueAsMetadata(value_ref); - Self { - metadata_ref, - value_ref, - _marker: PhantomData, +impl LLVMTypeWrapper for DIDerivedType<'_> { + type Target = LLVMValue; + + fn from_ptr(value: NonNull) -> Result { + let metadata = unsafe { LLVMValueAsMetadata(value.as_ptr()) }; + let metadata = NonNull::new(metadata).ok_or(LLVMTypeError::NullPointer)?; + let metadata_kind = unsafe { LLVMGetMetadataKind(metadata.as_ptr()) }; + if !matches!( + metadata_kind, + LLVMMetadataKind::LLVMDIDerivedTypeMetadataKind, + ) { + return Err(LLVMTypeError::InvalidPointerType("DIDerivedType")); } + Ok(Self { + metadata, + value, + _marker: PhantomData, + }) + } + + fn as_ptr(&self) -> *mut Self::Target { + self.value.as_ptr() } +} +impl DIDerivedType<'_> { /// Returns the base type of this derived type. pub fn base_type(&self) -> Metadata { unsafe { - let value = LLVMGetOperand(self.value_ref, DIDerivedTypeOperand::BaseType as u32); - Metadata::from_value_ref(value) + let value = LLVMGetOperand(self.value.as_ptr(), DIDerivedTypeOperand::BaseType as u32); + let value = NonNull::new(value).expect("base type operand should not be null"); + // PANICS: We are sure that the pointer type is correct. There is + // no need to leak the error. + Metadata::from_value(value).unwrap() } } @@ -198,12 +235,17 @@ impl DIDerivedType<'_> { /// Returns a `NulError` if the new name contains a NUL byte, as it cannot /// be converted into a `CString`. pub fn replace_name(&mut self, context: LLVMContextRef, name: &str) -> Result<(), NulError> { - super::ir::replace_name(self.value_ref, context, DITypeOperand::Name as u32, name) + super::ir::replace_name( + self.value.as_ptr(), + context, + DITypeOperand::Name as u32, + name, + ) } /// Returns a DWARF tag of the given derived type. pub fn tag(&self) -> DwTag { - unsafe { di_node_tag(self.metadata_ref) } + unsafe { di_node_tag(self.metadata.as_ptr()) } } } @@ -221,63 +263,82 @@ enum DICompositeTypeOperand { /// Composite type is a kind of type that can include other types, such as /// structures, enums, unions, etc. pub struct DICompositeType<'ctx> { - metadata_ref: LLVMMetadataRef, - value_ref: LLVMValueRef, + metadata: NonNull, + value: NonNull, _marker: PhantomData<&'ctx ()>, } -impl DICompositeType<'_> { - /// Constructs a new [`DICompositeType`] from the given `value`. - /// - /// # Safety - /// - /// This method assumes that the provided `value` corresponds to a valid - /// instance of [LLVM `DICompositeType`](https://llvm.org/doxygen/classllvm_1_1DICompositeType.html). - /// It's the caller's responsibility to ensure this invariant, as this - /// method doesn't perform any validation checks. - pub unsafe fn from_value_ref(value_ref: LLVMValueRef) -> Self { - let metadata_ref = LLVMValueAsMetadata(value_ref); - Self { - metadata_ref, - value_ref, - _marker: PhantomData, +impl LLVMTypeWrapper for DICompositeType<'_> { + type Target = LLVMValue; + + fn from_ptr(value: NonNull) -> Result { + let metadata = unsafe { LLVMValueAsMetadata(value.as_ptr()) }; + // PANICS: value is NonNull. Hence, the result of LLVMValueAsMetadata + // should not be null, unless LLVM is severly broken. + let metadata = NonNull::new(metadata).expect("metadata pointer should not be null"); + let metadata_kind = unsafe { LLVMGetMetadataKind(metadata.as_ptr()) }; + if !matches!( + metadata_kind, + LLVMMetadataKind::LLVMDICompositeTypeMetadataKind, + ) { + return Err(LLVMTypeError::InvalidPointerType("DICompositeType")); } + Ok(Self { + metadata, + value, + _marker: PhantomData, + }) } + fn as_ptr(&self) -> *mut Self::Target { + self.value.as_ptr() + } +} + +impl DICompositeType<'_> { /// Returns an iterator over elements (struct fields, enum variants, etc.) /// of the composite type. pub fn elements(&self) -> impl Iterator { let elements = - unsafe { LLVMGetOperand(self.value_ref, DICompositeTypeOperand::Elements as u32) }; + unsafe { LLVMGetOperand(self.value.as_ptr(), DICompositeTypeOperand::Elements as u32) }; let operands = NonNull::new(elements) .map(|elements| unsafe { LLVMGetNumOperands(elements.as_ptr()) }) .unwrap_or(0); - (0..operands) - .map(move |i| unsafe { Metadata::from_value_ref(LLVMGetOperand(elements, i as u32)) }) + (0..operands).map(move |i| { + let operand = unsafe { LLVMGetOperand(elements, i as u32) }; + let operand = NonNull::new(operand).expect("element operand should not be null"); + // PANICS: We are sure that the pointer type is correct. There is + // no need to leak the error. + Metadata::from_value(operand).unwrap() + }) } /// Returns the name of the composite type. pub fn name(&self) -> Option<&CStr> { - unsafe { di_type_name(self.metadata_ref) } + unsafe { di_type_name(self.metadata.as_ptr()) } } /// Returns the file that the composite type belongs to. pub fn file(&self) -> DIFile { unsafe { - let metadata = LLVMDIScopeGetFile(self.metadata_ref); - DIFile::from_metadata_ref(metadata) + let metadata = LLVMDIScopeGetFile(self.metadata.as_ptr()); + // PANICS: + let metadata = NonNull::new(metadata).expect("metadata pointer should not be null"); + // PANICS: We are sure that the pointer type is correct. There is + // no need to leak the error. + DIFile::from_ptr(metadata).unwrap() } } /// Returns the flags associated with the composity type. pub fn flags(&self) -> LLVMDIFlags { - unsafe { LLVMDITypeGetFlags(self.metadata_ref) } + unsafe { LLVMDITypeGetFlags(self.metadata.as_ptr()) } } /// Returns the line number in the source code where the type is defined. pub fn line(&self) -> u32 { - unsafe { LLVMDITypeGetLine(self.metadata_ref) } + unsafe { LLVMDITypeGetLine(self.metadata.as_ptr()) } } /// Replaces the elements of the composite type with a new metadata node. @@ -287,9 +348,9 @@ impl DICompositeType<'_> { pub fn replace_elements(&mut self, mdnode: MDNode) { unsafe { LLVMReplaceMDNodeOperandWith( - self.value_ref, + self.value.as_ptr(), DICompositeTypeOperand::Elements as u32, - LLVMValueAsMetadata(mdnode.value_ref), + LLVMValueAsMetadata(mdnode.as_ptr()), ) } } @@ -301,12 +362,17 @@ impl DICompositeType<'_> { /// Returns a `NulError` if the new name contains a NUL byte, as it cannot /// be converted into a `CString`. pub fn replace_name(&mut self, context: LLVMContextRef, name: &str) -> Result<(), NulError> { - super::ir::replace_name(self.value_ref, context, DITypeOperand::Name as u32, name) + super::ir::replace_name( + self.value.as_ptr(), + context, + DITypeOperand::Name as u32, + name, + ) } /// Returns a DWARF tag of the given composite type. pub fn tag(&self) -> DwTag { - unsafe { di_node_tag(self.metadata_ref) } + unsafe { di_node_tag(self.metadata.as_ptr()) } } } @@ -324,58 +390,70 @@ enum DISubprogramOperand { /// Represents the debug information for a subprogram (function) in LLVM IR. pub struct DISubprogram<'ctx> { - pub value_ref: LLVMValueRef, + value: NonNull, _marker: PhantomData<&'ctx ()>, } -impl DISubprogram<'_> { - /// Constructs a new [`DISubprogram`] from the given `value`. - /// - /// # Safety - /// - /// This method assumes that the provided `value` corresponds to a valid - /// instance of [LLVM `DISubprogram`](https://llvm.org/doxygen/classllvm_1_1DISubprogram.html). - /// It's the caller's responsibility to ensure this invariant, as this - /// method doesn't perform any validation checks. - pub(crate) unsafe fn from_value_ref(value_ref: LLVMValueRef) -> Self { - DISubprogram { - value_ref, - _marker: PhantomData, +impl LLVMTypeWrapper for DISubprogram<'_> { + type Target = LLVMValue; + + fn from_ptr(value: NonNull) -> Result { + let metadata_ref = unsafe { LLVMValueAsMetadata(value.as_ptr()) }; + if metadata_ref.is_null() { + return Err(LLVMTypeError::NullPointer); + } + let metadata_kind = unsafe { LLVMGetMetadataKind(metadata_ref) }; + if !matches!( + metadata_kind, + LLVMMetadataKind::LLVMDISubprogramMetadataKind, + ) { + return Err(LLVMTypeError::InvalidPointerType("DISubprogram")); } + Ok(DISubprogram { + value, + _marker: PhantomData, + }) } + fn as_ptr(&self) -> *mut Self::Target { + self.value.as_ptr() + } +} + +impl DISubprogram<'_> { /// Returns the name of the subprogram. pub fn name(&self) -> Option<&str> { - let operand = unsafe { LLVMGetOperand(self.value_ref, DISubprogramOperand::Name as u32) }; + let operand = + unsafe { LLVMGetOperand(self.value.as_ptr(), DISubprogramOperand::Name as u32) }; NonNull::new(operand).map(|_| mdstring_to_str(operand)) } /// Returns the linkage name of the subprogram. pub fn linkage_name(&self) -> Option<&str> { let operand = - unsafe { LLVMGetOperand(self.value_ref, DISubprogramOperand::LinkageName as u32) }; + unsafe { LLVMGetOperand(self.value.as_ptr(), DISubprogramOperand::LinkageName as u32) }; NonNull::new(operand).map(|_| mdstring_to_str(operand)) } pub fn ty(&self) -> LLVMMetadataRef { unsafe { LLVMValueAsMetadata(LLVMGetOperand( - self.value_ref, + self.value.as_ptr(), DISubprogramOperand::Ty as u32, )) } } pub fn file(&self) -> LLVMMetadataRef { - unsafe { LLVMDIScopeGetFile(LLVMValueAsMetadata(self.value_ref)) } + unsafe { LLVMDIScopeGetFile(LLVMValueAsMetadata(self.value.as_ptr())) } } pub fn line(&self) -> u32 { - unsafe { LLVMDISubprogramGetLine(LLVMValueAsMetadata(self.value_ref)) } + unsafe { LLVMDISubprogramGetLine(LLVMValueAsMetadata(self.value.as_ptr())) } } pub fn type_flags(&self) -> i32 { - unsafe { LLVMDITypeGetFlags(LLVMValueAsMetadata(self.value_ref)) } + unsafe { LLVMDITypeGetFlags(LLVMValueAsMetadata(self.value.as_ptr())) } } /// Replaces the name of the subprogram with a new name. @@ -386,7 +464,7 @@ impl DISubprogram<'_> { /// be converted into a `CString`. pub fn replace_name(&mut self, context: LLVMContextRef, name: &str) -> Result<(), NulError> { super::ir::replace_name( - self.value_ref, + self.value.as_ptr(), context, DISubprogramOperand::Name as u32, name, @@ -395,27 +473,34 @@ impl DISubprogram<'_> { pub fn scope(&self) -> Option { unsafe { - let operand = LLVMGetOperand(self.value_ref, DISubprogramOperand::Scope as u32); + let operand = LLVMGetOperand(self.value.as_ptr(), DISubprogramOperand::Scope as u32); NonNull::new(operand).map(|_| LLVMValueAsMetadata(operand)) } } pub fn unit(&self) -> Option { unsafe { - let operand = LLVMGetOperand(self.value_ref, DISubprogramOperand::Unit as u32); + let operand = LLVMGetOperand(self.value.as_ptr(), DISubprogramOperand::Unit as u32); NonNull::new(operand).map(|_| LLVMValueAsMetadata(operand)) } } pub fn set_unit(&mut self, unit: LLVMMetadataRef) { unsafe { - LLVMReplaceMDNodeOperandWith(self.value_ref, DISubprogramOperand::Unit as u32, unit) + LLVMReplaceMDNodeOperandWith( + self.value.as_ptr(), + DISubprogramOperand::Unit as u32, + unit, + ) }; } pub fn retained_nodes(&self) -> Option { unsafe { - let nodes = LLVMGetOperand(self.value_ref, DISubprogramOperand::RetainedNodes as u32); + let nodes = LLVMGetOperand( + self.value.as_ptr(), + DISubprogramOperand::RetainedNodes as u32, + ); NonNull::new(nodes).map(|_| LLVMValueAsMetadata(nodes)) } } @@ -423,7 +508,7 @@ impl DISubprogram<'_> { pub fn set_retained_nodes(&mut self, nodes: LLVMMetadataRef) { unsafe { LLVMReplaceMDNodeOperandWith( - self.value_ref, + self.value.as_ptr(), DISubprogramOperand::RetainedNodes as u32, nodes, ) diff --git a/src/llvm/types/ir.rs b/src/llvm/types/ir.rs index e68cd43a..145ae829 100644 --- a/src/llvm/types/ir.rs +++ b/src/llvm/types/ir.rs @@ -1,31 +1,93 @@ use std::{ - ffi::{CString, NulError}, + borrow::Cow, + ffi::{c_uchar, CString, NulError}, marker::PhantomData, ptr::NonNull, + slice, }; use llvm_sys::{ core::{ - LLVMCountParams, LLVMDisposeValueMetadataEntries, LLVMGetNumOperands, LLVMGetOperand, - LLVMGetParam, LLVMGlobalCopyAllMetadata, LLVMIsAFunction, LLVMIsAGlobalObject, - LLVMIsAInstruction, LLVMIsAMDNode, LLVMIsAUser, LLVMMDNodeInContext2, - LLVMMDStringInContext2, LLVMMetadataAsValue, LLVMPrintValueToString, - LLVMReplaceMDNodeOperandWith, LLVMValueAsMetadata, LLVMValueMetadataEntriesGetKind, - LLVMValueMetadataEntriesGetMetadata, + LLVMCountParams, LLVMDisposeValueMetadataEntries, LLVMGetModuleInlineAsm, + LLVMGetNumOperands, LLVMGetOperand, LLVMGetParam, LLVMGetValueName2, + LLVMGlobalCopyAllMetadata, LLVMIsAArgument, LLVMIsAFunction, LLVMIsAGlobalAlias, + LLVMIsAGlobalObject, LLVMIsAGlobalVariable, LLVMIsAInstruction, LLVMIsAMDNode, LLVMIsAUser, + LLVMMDNodeInContext2, LLVMMDStringInContext2, LLVMMetadataAsValue, + LLVMModuleCreateWithNameInContext, LLVMPrintValueToString, LLVMReplaceMDNodeOperandWith, + LLVMSetLinkage, LLVMSetModuleInlineAsm2, LLVMSetVisibility, LLVMValueAsMetadata, + LLVMValueMetadataEntriesGetKind, LLVMValueMetadataEntriesGetMetadata, }, debuginfo::{LLVMGetMetadataKind, LLVMGetSubprogram, LLVMMetadataKind, LLVMSetSubprogram}, - prelude::{ - LLVMBasicBlockRef, LLVMContextRef, LLVMMetadataRef, LLVMValueMetadataEntry, LLVMValueRef, - }, + prelude::{LLVMContextRef, LLVMMetadataRef, LLVMValueMetadataEntry, LLVMValueRef}, + LLVMBasicBlock, LLVMLinkage, LLVMModule, LLVMValue, LLVMVisibility, }; use crate::llvm::{ - iter::IterBasicBlocks as _, - symbol_name, - types::di::{DICompositeType, DIDerivedType, DISubprogram, DIType}, + types::{ + di::{DICompositeType, DIDerivedType, DISubprogram, DIType}, + LLVMTypeError, LLVMTypeWrapper, + }, Message, }; +pub struct Module<'ctx> { + module: NonNull, + _marker: PhantomData<&'ctx ()>, +} + +impl LLVMTypeWrapper for Module<'_> { + type Target = LLVMModule; + + fn from_ptr(module: NonNull) -> Result + where + Self: Sized, + { + Ok(Self { + module, + _marker: PhantomData, + }) + } + + fn as_ptr(&self) -> *mut Self::Target { + self.module.as_ptr() + } +} + +impl Module<'_> { + pub fn new(name: &str, context: LLVMContextRef) -> Self { + let name = CString::new(name).unwrap(); + let module = unsafe { LLVMModuleCreateWithNameInContext(name.as_ptr(), context) }; + // PANICS: + let module = NonNull::new(module).unwrap(); + Self { + module, + _marker: PhantomData, + } + } + + pub fn inline_asm(&self) -> Cow<'_, str> { + let mut len = 0; + let ptr = unsafe { LLVMGetModuleInlineAsm(self.module.as_ptr(), &mut len) }; + let asm = unsafe { slice::from_raw_parts(ptr as *const c_uchar, len) }; + String::from_utf8_lossy(asm) + } + + pub fn set_inline_asm(&mut self, asm: &str) { + let len = asm.len(); + let asm = CString::new(asm).unwrap(); + unsafe { + LLVMSetModuleInlineAsm2(self.module.as_ptr(), asm.as_ptr(), len); + } + } +} + +pub(crate) fn symbol_name<'a>(value: LLVMValueRef) -> Cow<'a, str> { + let mut len = 0; + let ptr = unsafe { LLVMGetValueName2(value, &mut len) }; + let symbol_name = unsafe { slice::from_raw_parts(ptr as *const c_uchar, len) }; + String::from_utf8_lossy(symbol_name) +} + pub(crate) fn replace_name( value_ref: LLVMValueRef, context: LLVMContextRef, @@ -42,7 +104,7 @@ pub(crate) fn replace_name( pub enum Value<'ctx> { MDNode(MDNode<'ctx>), Function(Function<'ctx>), - Other(LLVMValueRef), + Other(NonNull), } impl std::fmt::Debug for Value<'_> { @@ -60,45 +122,59 @@ impl std::fmt::Debug for Value<'_> { match self { Self::MDNode(node) => f .debug_struct("MDNode") - .field("value", &value_to_string(node.value_ref)) + .field("value", &value_to_string(node.value.as_ptr())) .finish(), Self::Function(fun) => f .debug_struct("Function") - .field("value", &value_to_string(fun.value_ref)) + .field("value", &value_to_string(fun.value.as_ptr())) .finish(), Self::Other(value) => f .debug_struct("Other") - .field("value", &value_to_string(*value)) + .field("value", &value_to_string(value.as_ptr())) .finish(), } } } -impl Value<'_> { - pub fn new(value: LLVMValueRef) -> Self { - if unsafe { !LLVMIsAMDNode(value).is_null() } { - let mdnode = unsafe { MDNode::from_value_ref(value) }; - return Value::MDNode(mdnode); - } else if unsafe { !LLVMIsAFunction(value).is_null() } { - return Value::Function(unsafe { Function::from_value_ref(value) }); +impl LLVMTypeWrapper for Value<'_> { + type Target = LLVMValue; + + fn from_ptr(value_ref: NonNull) -> Result { + if unsafe { !LLVMIsAMDNode(value_ref.as_ptr()).is_null() } { + let mdnode = MDNode::from_ptr(value_ref)?; + return Ok(Value::MDNode(mdnode)); + } else if unsafe { !LLVMIsAFunction(value_ref.as_ptr()).is_null() } { + return Ok(Value::Function(Function::from_ptr(value_ref)?)); + } + Ok(Value::Other(value_ref)) + } + + fn as_ptr(&self) -> *mut Self::Target { + match self { + Value::MDNode(mdnode) => mdnode.as_ptr(), + Value::Function(f) => f.as_ptr(), + Value::Other(value) => value.as_ptr(), } - Value::Other(value) } +} +impl Value<'_> { pub fn metadata_entries(&self) -> Option { let value = match self { - Value::MDNode(node) => node.value_ref, - Value::Function(f) => f.value_ref, - Value::Other(value) => *value, + Value::MDNode(node) => node.value.as_ptr(), + Value::Function(f) => f.value.as_ptr(), + Value::Other(value) => value.as_ptr(), }; MetadataEntries::new(value) } pub fn operands(&self) -> Option> { let value = match self { - Value::MDNode(node) => Some(node.value_ref), - Value::Function(f) => Some(f.value_ref), - Value::Other(value) if unsafe { !LLVMIsAUser(*value).is_null() } => Some(*value), + Value::MDNode(node) => Some(node.value.as_ptr()), + Value::Function(f) => Some(f.value.as_ptr()), + Value::Other(value) if unsafe { !LLVMIsAUser(value.as_ptr()).is_null() } => { + Some(value.as_ptr()) + } _ => None, }; @@ -112,7 +188,7 @@ pub enum Metadata<'ctx> { DICompositeType(DICompositeType<'ctx>), DIDerivedType(DIDerivedType<'ctx>), DISubprogram(DISubprogram<'ctx>), - Other(#[allow(dead_code)] LLVMValueRef), + Other(#[allow(dead_code)] NonNull), } impl Metadata<'_> { @@ -124,21 +200,21 @@ impl Metadata<'_> { /// instance of [LLVM `Metadata`](https://llvm.org/doxygen/classllvm_1_1Metadata.html). /// It's the caller's responsibility to ensure this invariant, as this /// method doesn't perform any valiation checks. - pub(crate) unsafe fn from_value_ref(value: LLVMValueRef) -> Self { - let metadata = LLVMValueAsMetadata(value); + pub(crate) fn from_value(value: NonNull) -> Result { + let metadata = unsafe { LLVMValueAsMetadata(value.as_ptr()) }; match unsafe { LLVMGetMetadataKind(metadata) } { LLVMMetadataKind::LLVMDICompositeTypeMetadataKind => { - let di_composite_type = unsafe { DICompositeType::from_value_ref(value) }; - Metadata::DICompositeType(di_composite_type) + let di_composite_type = DICompositeType::from_ptr(value)?; + Ok(Metadata::DICompositeType(di_composite_type)) } LLVMMetadataKind::LLVMDIDerivedTypeMetadataKind => { - let di_derived_type = unsafe { DIDerivedType::from_value_ref(value) }; - Metadata::DIDerivedType(di_derived_type) + let di_derived_type = DIDerivedType::from_ptr(value)?; + Ok(Metadata::DIDerivedType(di_derived_type)) } LLVMMetadataKind::LLVMDISubprogramMetadataKind => { - let di_subprogram = unsafe { DISubprogram::from_value_ref(value) }; - Metadata::DISubprogram(di_subprogram) + let di_subprogram = DISubprogram::from_ptr(value)?; + Ok(Metadata::DISubprogram(di_subprogram)) } LLVMMetadataKind::LLVMDIGlobalVariableMetadataKind | LLVMMetadataKind::LLVMDICommonBlockMetadataKind @@ -172,62 +248,64 @@ impl Metadata<'_> { | LLVMMetadataKind::LLVMDIStringTypeMetadataKind | LLVMMetadataKind::LLVMDIGenericSubrangeMetadataKind | LLVMMetadataKind::LLVMDIArgListMetadataKind - | LLVMMetadataKind::LLVMDIAssignIDMetadataKind => Metadata::Other(value), + | LLVMMetadataKind::LLVMDIAssignIDMetadataKind => Ok(Metadata::Other(value)), } } } impl<'ctx> TryFrom> for Metadata<'ctx> { - type Error = (); + type Error = LLVMTypeError; fn try_from(md_node: MDNode) -> Result { // FIXME: fail if md_node isn't a Metadata node - Ok(unsafe { Self::from_value_ref(md_node.value_ref) }) + Self::from_value(md_node.value) } } /// Represents a metadata node. #[derive(Clone)] pub struct MDNode<'ctx> { - pub(super) value_ref: LLVMValueRef, + value: NonNull, _marker: PhantomData<&'ctx ()>, } +impl LLVMTypeWrapper for MDNode<'_> { + type Target = LLVMValue; + + fn from_ptr(value: NonNull) -> Result { + if unsafe { LLVMIsAMDNode(value.as_ptr()).is_null() } { + return Err(LLVMTypeError::InvalidPointerType("MDNode")); + } + Ok(Self { + value, + _marker: PhantomData, + }) + } + + fn as_ptr(&self) -> *mut Self::Target { + self.value.as_ptr() + } +} + impl MDNode<'_> { /// Constructs a new [`MDNode`] from the given `metadata`. - /// - /// # Safety - /// - /// This method assumes that the given `metadata` corresponds to a valid - /// instance of [LLVM `MDNode`](https://llvm.org/doxygen/classllvm_1_1MDNode.html). - /// It's the caller's responsibility to ensure this invariant, as this - /// method doesn't perform any validation checks. - pub(crate) unsafe fn from_metadata_ref( + #[inline] + pub(crate) fn from_metadata_ref( context: LLVMContextRef, metadata: LLVMMetadataRef, - ) -> Self { - MDNode::from_value_ref(LLVMMetadataAsValue(context, metadata)) - } - - /// Constructs a new [`MDNode`] from the given `value`. - /// - /// # Safety - /// - /// This method assumes that the provided `value` corresponds to a valid - /// instance of [LLVM `MDNode`](https://llvm.org/doxygen/classllvm_1_1MDNode.html). - /// It's the caller's responsibility to ensure this invariant, as this - /// method doesn't perform any valiation checks. - pub(crate) unsafe fn from_value_ref(value_ref: LLVMValueRef) -> Self { - Self { - value_ref, - _marker: PhantomData, - } + ) -> Result { + let value_ref = unsafe { LLVMMetadataAsValue(context, metadata) }; + let value = NonNull::new(value_ref).ok_or(LLVMTypeError::NullPointer)?; + MDNode::from_ptr(value) } /// Constructs an empty metadata node. + #[inline] pub fn empty(context: LLVMContextRef) -> Self { let metadata = unsafe { LLVMMDNodeInContext2(context, core::ptr::null_mut(), 0) }; - unsafe { Self::from_metadata_ref(context, metadata) } + // PANICS: We are sure that the pointer type is correct. There is + // no need to leak the error. + Self::from_metadata_ref(context, metadata).unwrap() } /// Constructs a new metadata node from an array of [`DIType`] elements. @@ -239,7 +317,7 @@ impl MDNode<'_> { let metadata = unsafe { let mut elements: Vec = elements .iter() - .map(|di_type| LLVMValueAsMetadata(di_type.value_ref)) + .map(|di_type| LLVMValueAsMetadata(di_type.as_ptr())) .collect(); LLVMMDNodeInContext2( context, @@ -247,7 +325,9 @@ impl MDNode<'_> { elements.len(), ) }; - unsafe { Self::from_metadata_ref(context, metadata) } + // PANICS: We are sure that the pointer type is correct. There is + // no need to leak the error. + Self::from_metadata_ref(context, metadata).unwrap() } } @@ -289,51 +369,218 @@ impl Drop for MetadataEntries { } } -/// Represents a metadata node. -#[derive(Clone)] -pub struct Function<'ctx> { - pub value_ref: LLVMValueRef, +pub struct BasicBlock<'ctx> { + value: NonNull, _marker: PhantomData<&'ctx ()>, } -impl<'ctx> Function<'ctx> { - /// Constructs a new [`Function`] from the given `value`. - /// - /// # Safety - /// - /// This method assumes that the provided `value` corresponds to a valid - /// instance of [LLVM `Function`](https://llvm.org/doxygen/classllvm_1_1Function.html). - /// It's the caller's responsibility to ensure this invariant, as this - /// method doesn't perform any valiation checks. - pub(crate) unsafe fn from_value_ref(value_ref: LLVMValueRef) -> Self { - Self { - value_ref, +impl LLVMTypeWrapper for BasicBlock<'_> { + type Target = LLVMBasicBlock; + + fn from_ptr(value: NonNull) -> Result { + Ok(Self { + value, + _marker: PhantomData, + }) + } + + fn as_ptr(&self) -> *mut Self::Target { + self.value.as_ptr() + } +} + +pub trait GlobalValue: LLVMTypeWrapper { + fn set_linkage(&mut self, linkage: LLVMLinkage) { + unsafe { + LLVMSetLinkage(self.as_ptr(), linkage); + } + } + + fn set_visibility(&mut self, visibility: LLVMVisibility) { + unsafe { + LLVMSetVisibility(self.as_ptr(), visibility); + } + } +} + +/// Formal argument to a [`Function`]. +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct Argument<'ctx> { + value: NonNull, + _marker: PhantomData<&'ctx ()>, +} + +impl LLVMTypeWrapper for Argument<'_> { + type Target = LLVMValue; + + fn from_ptr(value: NonNull) -> Result + where + Self: Sized, + { + if unsafe { LLVMIsAArgument(value.as_ptr()).is_null() } { + return Err(LLVMTypeError::InvalidPointerType("Argument")); + } + Ok(Self { + value, _marker: PhantomData, + }) + } + + fn as_ptr(&self) -> *mut Self::Target { + self.value.as_ptr() + } +} + +/// Represents a function. +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct Function<'ctx> { + value: NonNull, + _marker: PhantomData<&'ctx ()>, +} + +impl LLVMTypeWrapper for Function<'_> { + type Target = LLVMValue; + + fn from_ptr(value: NonNull) -> Result { + if unsafe { LLVMIsAFunction(value.as_ptr()).is_null() } { + return Err(LLVMTypeError::InvalidPointerType("Function")); } + Ok(Self { + value, + _marker: PhantomData, + }) } - pub(crate) fn name(&self) -> &str { - symbol_name(self.value_ref) + fn as_ptr(&self) -> *mut Self::Target { + self.value.as_ptr() } +} - pub(crate) fn params(&self) -> impl Iterator { - let params_count = unsafe { LLVMCountParams(self.value_ref) }; - let value = self.value_ref; - (0..params_count).map(move |i| unsafe { LLVMGetParam(value, i) }) +impl GlobalValue for Function<'_> {} + +impl<'ctx> Function<'ctx> { + pub(crate) fn name(&self) -> Cow<'_, str> { + symbol_name(self.value.as_ptr()) } - pub(crate) fn basic_blocks(&self) -> impl Iterator + '_ { - self.value_ref.basic_blocks_iter() + pub(crate) fn params(&self) -> impl Iterator { + let params_count = unsafe { LLVMCountParams(self.value.as_ptr()) }; + let value = self.value.as_ptr(); + (0..params_count).map(move |i| { + let ptr = unsafe { LLVMGetParam(value, i) }; + // PANICS: We are sure that the pointer type is correct. There is + // no need to leak the error. + Argument::from_ptr(NonNull::new(ptr).unwrap()).unwrap() + }) } + // pub(crate) fn basic_blocks(&self) -> impl Iterator + '_ { + // self.value.as_ptr().basic_blocks_iter() + // } + pub(crate) fn subprogram(&self, context: LLVMContextRef) -> Option> { - let subprogram = unsafe { LLVMGetSubprogram(self.value_ref) }; - NonNull::new(subprogram).map(|_| unsafe { - DISubprogram::from_value_ref(LLVMMetadataAsValue(context, subprogram)) - }) + let subprogram = unsafe { LLVMGetSubprogram(self.value.as_ptr()) }; + let subprogram = NonNull::new(subprogram)?; + let value = unsafe { LLVMMetadataAsValue(context, subprogram.as_ptr()) }; + let value = NonNull::new(value)?; + Some(DISubprogram::from_ptr(value).unwrap()) } pub(crate) fn set_subprogram(&mut self, subprogram: &DISubprogram) { - unsafe { LLVMSetSubprogram(self.value_ref, LLVMValueAsMetadata(subprogram.value_ref)) }; + unsafe { + LLVMSetSubprogram( + self.value.as_ptr(), + LLVMValueAsMetadata(subprogram.as_ptr()), + ) + }; + } +} + +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct GlobalAlias<'ctx> { + value: NonNull, + _marker: PhantomData<&'ctx ()>, +} + +impl LLVMTypeWrapper for GlobalAlias<'_> { + type Target = LLVMValue; + + fn from_ptr(value: NonNull) -> Result { + if unsafe { LLVMIsAGlobalAlias(value.as_ptr()).is_null() } { + return Err(LLVMTypeError::InvalidPointerType("GlobalAlias")); + } + Ok(Self { + value, + _marker: PhantomData, + }) + } + + fn as_ptr(&self) -> *mut Self::Target { + self.value.as_ptr() + } +} + +impl GlobalValue for GlobalAlias<'_> {} + +impl GlobalAlias<'_> { + pub fn name<'a>(&self) -> Cow<'a, str> { + symbol_name(self.value.as_ptr()) + } +} + +/// Represents a global variable. +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct GlobalVariable<'ctx> { + value: NonNull, + _marker: PhantomData<&'ctx ()>, +} + +impl LLVMTypeWrapper for GlobalVariable<'_> { + type Target = LLVMValue; + + fn from_ptr(value: NonNull) -> Result { + if unsafe { LLVMIsAGlobalVariable(value.as_ptr()).is_null() } { + return Err(LLVMTypeError::InvalidPointerType("GlobalVariable")); + } + Ok(Self { + value, + _marker: PhantomData, + }) + } + + fn as_ptr(&self) -> *mut Self::Target { + self.value.as_ptr() + } +} + +impl GlobalValue for GlobalVariable<'_> {} + +impl GlobalVariable<'_> { + pub fn name<'a>(&self) -> Cow<'a, str> { + symbol_name(self.value.as_ptr()) + } +} + +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct Instruction<'ctx> { + value: NonNull, + _marker: PhantomData<&'ctx ()>, +} + +impl LLVMTypeWrapper for Instruction<'_> { + type Target = LLVMValue; + + fn from_ptr(value: NonNull) -> Result + where + Self: Sized, + { + Ok(Self { + value, + _marker: PhantomData, + }) + } + + fn as_ptr(&self) -> *mut Self::Target { + self.value.as_ptr() } } diff --git a/src/llvm/types/mod.rs b/src/llvm/types/mod.rs index ac874bda..207c9a70 100644 --- a/src/llvm/types/mod.rs +++ b/src/llvm/types/mod.rs @@ -1,2 +1,24 @@ +use std::ptr::NonNull; + +use thiserror::Error; + pub mod di; pub mod ir; + +#[derive(Debug, Error)] +pub enum LLVMTypeError { + #[error("invalid pointer type, expected {0}")] + InvalidPointerType(&'static str), + #[error("null pointer")] + NullPointer, +} + +pub trait LLVMTypeWrapper { + type Target: Sized; + + /// Constructs a new [`Self`] from the given pointer `ptr`. + fn from_ptr(ptr: NonNull) -> Result + where + Self: Sized; + fn as_ptr(&self) -> *mut Self::Target; +}