From bee14ad42839e7abcf10ffdc6f313896d79f3499 Mon Sep 17 00:00:00 2001 From: Jaemin Hong Date: Mon, 22 Jan 2024 08:24:10 +0000 Subject: [PATCH] intra function call --- src/graph.rs | 28 +++++--- src/relational/analysis.rs | 133 ++++++++++++++++++++++-------------- src/relational/domains.rs | 2 +- src/relational/semantics.rs | 95 ++++++++++++++++++++++++-- src/relational/tests.rs | 72 +++++++++++++++++++ src/steensgaard/mod.rs | 49 +++++++------ 6 files changed, 287 insertions(+), 92 deletions(-) diff --git a/src/graph.rs b/src/graph.rs index ffb33fa..89780fa 100644 --- a/src/graph.rs +++ b/src/graph.rs @@ -4,15 +4,22 @@ use etrace::some_or; use rustc_data_structures::graph::{scc::Sccs, vec_graph::VecGraph}; use rustc_index::Idx; -pub fn transitive_closure( - graph: &BTreeMap>, -) -> BTreeMap> { - let len = graph.len(); +pub fn transitive_closure( + graph: &HashMap>, +) -> HashMap> { + let id_to_v: Vec<_> = graph.keys().cloned().collect(); + let v_to_id: HashMap<_, _> = id_to_v + .iter() + .cloned() + .enumerate() + .map(|(k, v)| (v, k)) + .collect(); + let len = id_to_v.len(); let mut reachability = vec![vec![false; len]; len]; for (v, succs) in graph.iter() { for succ in succs { - reachability[v.index()][succ.index()] = true; + reachability[v_to_id[v]][v_to_id[succ]] = true; } } @@ -25,20 +32,20 @@ pub fn transitive_closure( } } - let mut new_graph = BTreeMap::new(); + let mut new_graph = HashMap::new(); for (i, reachability) in reachability.iter().enumerate() { let neighbors = reachability .iter() .enumerate() .filter_map(|(to, is_reachable)| { if *is_reachable { - Some(T::new(to)) + Some(id_to_v[to].clone()) } else { None } }) .collect(); - new_graph.insert(T::new(i), neighbors); + new_graph.insert(id_to_v[i].clone(), neighbors); } new_graph } @@ -104,9 +111,12 @@ pub fn inverse( map: &HashMap>, ) -> HashMap> { let mut inv: HashMap<_, HashSet<_>> = HashMap::new(); + for node in map.keys() { + inv.insert(node.clone(), HashSet::new()); + } for (node, succs) in map { for succ in succs { - inv.entry(succ.clone()).or_default().insert(node.clone()); + inv.get_mut(succ).unwrap().insert(node.clone()); } } inv diff --git a/src/relational/analysis.rs b/src/relational/analysis.rs index b8e7801..5d3a17b 100644 --- a/src/relational/analysis.rs +++ b/src/relational/analysis.rs @@ -10,7 +10,7 @@ use rustc_index::bit_set::BitSet; use rustc_middle::{ mir::{ interpret::ConstValue, visit::Visitor, BasicBlock, Body, ConstantKind, Local, Location, - Operand, Terminator, TerminatorKind, + Operand, Place, Rvalue, Terminator, TerminatorKind, }, ty::{AdtKind, Ty, TyCtxt, TyKind, TypeAndMut}, }; @@ -37,7 +37,7 @@ fn analyze_input(input: Input) -> AnalysisResults { pub fn analyze(tcx: TyCtxt<'_>) -> AnalysisResults { let hir = tcx.hir(); let may_aliases = steensgaard::analyze(tcx); - let var_graph = may_aliases.get_alias_graph(); + let alias_graph = may_aliases.get_alias_graph(); let func_ids: Vec<_> = hir .items() @@ -50,13 +50,18 @@ pub fn analyze(tcx: TyCtxt<'_>) -> AnalysisResults { }) .collect(); - let _call_graph: HashMap<_, _> = func_ids + let (indirect_assigns, call_graph): (HashMap<_, _>, HashMap<_, _>) = func_ids .iter() .map(|def_id| { let body = tcx.optimized_mir(def_id.to_def_id()); - (*def_id, get_callees(*def_id, body, &var_graph, tcx)) + let (indirect_assigns, callees) = visit_body(*def_id, body, &alias_graph, tcx); + ((*def_id, indirect_assigns), (*def_id, callees)) }) - .collect(); + .unzip(); + let mut reachability = graph::transitive_closure(&call_graph); + for (caller, callees) in &mut reachability { + callees.insert(*caller); + } let mut functions = HashMap::new(); for local_def_id in func_ids.iter().copied() { @@ -104,7 +109,9 @@ pub fn analyze(tcx: TyCtxt<'_>) -> AnalysisResults { local_tys, local_ptr_tys, local_def_id, - may_aliases: &may_aliases, + indirect_assigns: &indirect_assigns, + reachability: &reachability, + alias_graph: &alias_graph, }; functions.insert(def_id, analyzer.analyze()); } @@ -119,14 +126,16 @@ pub struct AnalysisResults { #[allow(missing_debug_implementations)] pub struct Analyzer<'tcx, 'a> { - tcx: TyCtxt<'tcx>, + pub tcx: TyCtxt<'tcx>, body: &'tcx Body<'tcx>, rpo_map: HashMap, dead_locals: Vec>, local_tys: Vec, local_ptr_tys: HashMap, local_def_id: LocalDefId, - may_aliases: &'a steensgaard::AnalysisResults, + indirect_assigns: &'a HashMap>, + reachability: &'a HashMap>, + alias_graph: &'a steensgaard::AliasGraph, } impl<'tcx> Analyzer<'tcx, '_> { @@ -183,39 +192,47 @@ impl<'tcx> Analyzer<'tcx, '_> { operand.ty(&self.body.local_decls, self.tcx) } - pub fn find_may_aliases(&self, local: Local) -> HashSet<(Local, u32)> { - let id = self.may_aliases.local(self.local_def_id, local.as_u32()); - let ty = self.may_aliases.var_ty(id); - - let mut aliases = HashSet::new(); - let mut done = HashSet::new(); - let mut remainings = HashSet::new(); - remainings.insert((ty.var_ty, 0u32)); - - while !remainings.is_empty() { - let mut new_remainings = HashSet::new(); - for (id, depth) in remainings { - done.insert(id); - for (id1, id2) in &self.may_aliases.vars { - if id != *id2 { - continue; - } - let steensgaard::VarId::Local(f, l) = *id1 else { continue }; - if f == self.local_def_id { - aliases.insert((Local::from_u32(l), depth)); - } - } - for (v, t) in &self.may_aliases.var_tys { - let steensgaard::VarType::Ref(t) = t else { continue }; - if t.var_ty == id && !done.contains(v) { - new_remainings.insert((*v, depth + 1)); - } + pub fn resolve_indirect_calls(&self, local: Local) -> HashSet { + self.alias_graph + .find_fn_may_aliases(self.local_def_id, local) + } + + pub fn locals_invalidated_by_call(&self, callee: LocalDefId) -> HashSet<(Local, usize)> { + self.reachability[&callee] + .iter() + .flat_map(|func| { + self.indirect_assigns[func].iter().flat_map(|local| { + self.alias_graph + .find_may_aliases(*func, *local) + .into_iter() + .filter_map(|alias| { + if alias.function == self.local_def_id { + Some((alias.local, alias.depth)) + } else { + None + } + }) + }) + }) + .collect() + } + + pub fn find_may_aliases(&self, local: Local) -> HashSet<(Local, usize)> { + self.alias_graph + .find_may_aliases(self.local_def_id, local) + .into_iter() + .filter_map(|alias| { + if alias.function == self.local_def_id { + Some((alias.local, alias.depth)) + } else { + None } - } - remainings = new_remainings; - } + }) + .collect() + } - aliases + pub fn def_id_to_string(&self, def_id: DefId) -> String { + self.tcx.def_path(def_id).to_string_no_crate_verbose() } } @@ -439,22 +456,35 @@ fn get_dead_locals<'tcx>(body: &Body<'tcx>, tcx: TyCtxt<'tcx>) -> Vec { +struct BodyVisitor<'tcx, 'a> { tcx: TyCtxt<'tcx>, - callees: HashSet, current_fn: LocalDefId, var_graph: &'a steensgaard::AliasGraph, + + indirect_assigns: HashSet, + callees: HashSet, } -impl CallVisitor<'_, '_> { +impl BodyVisitor<'_, '_> { fn def_id_to_string(&self, def_id: DefId) -> String { self.tcx.def_path(def_id).to_string_no_crate_verbose() } } -impl<'tcx> Visitor<'tcx> for CallVisitor<'tcx, '_> { +impl<'tcx> Visitor<'tcx> for BodyVisitor<'tcx, '_> { + fn visit_assign(&mut self, place: &Place<'tcx>, rvalue: &Rvalue<'tcx>, location: Location) { + if place.is_indirect_first_projection() { + self.indirect_assigns.insert(place.local); + } + self.super_assign(place, rvalue, location); + } + fn visit_terminator(&mut self, terminator: &Terminator<'tcx>, location: Location) { - if let TerminatorKind::Call { func, .. } = &terminator.kind { + if let TerminatorKind::Call { + func, destination, .. + } = &terminator.kind + { + assert!(destination.projection.is_empty()); match func { Operand::Copy(f) | Operand::Move(f) => { assert!(f.projection.is_empty()); @@ -479,18 +509,19 @@ impl<'tcx> Visitor<'tcx> for CallVisitor<'tcx, '_> { } } -fn get_callees<'tcx>( +fn visit_body<'tcx>( current_fn: LocalDefId, body: &Body<'tcx>, - var_graph: &steensgaard::AliasGraph, + alias_graph: &steensgaard::AliasGraph, tcx: TyCtxt<'tcx>, -) -> HashSet { - let mut visitor = CallVisitor { +) -> (HashSet, HashSet) { + let mut visitor = BodyVisitor { tcx, - callees: HashSet::new(), current_fn, - var_graph, + var_graph: alias_graph, + indirect_assigns: HashSet::new(), + callees: HashSet::new(), }; visitor.visit_body(body); - visitor.callees + (visitor.indirect_assigns, visitor.callees) } diff --git a/src/relational/domains.rs b/src/relational/domains.rs index 95280aa..01e10a3 100644 --- a/src/relational/domains.rs +++ b/src/relational/domains.rs @@ -588,7 +588,7 @@ impl Graph { self.nodes[loc.root].obj.project_mut(&loc.projection) } - pub fn invalidate_deref(&mut self, local: Local, mut depth: u32, opt_id: Option) { + pub fn invalidate_deref(&mut self, local: Local, mut depth: usize, opt_id: Option) { let id = *some_or!(self.locals.get(&local), return); let mut locs = vec![AbsLoc::new(id, vec![])]; while !locs.is_empty() { diff --git a/src/relational/semantics.rs b/src/relational/semantics.rs index 3ecd16a..c3793eb 100644 --- a/src/relational/semantics.rs +++ b/src/relational/semantics.rs @@ -9,6 +9,7 @@ use rustc_middle::{ }, ty::{adjustment::PointerCoercion, IntTy, Ty, TyKind, UintTy}, }; +use rustc_span::def_id::LocalDefId; use super::*; use crate::*; @@ -319,17 +320,101 @@ impl<'tcx> Analyzer<'tcx, '_> { .collect(), }, TerminatorKind::Call { - func: _func, - args: _args, - destination: _destination, - target: Some(_), + func, + args, + destination, + target: Some(target), .. } => { - todo!() + assert!(destination.projection.is_empty()); + let mut state = state.clone(); + state.gm().invalidate_symbolic(destination.local); + let location = Location { + block: *target, + statement_index: 0, + }; + match func { + Operand::Copy(func) | Operand::Move(func) => { + assert!(func.projection.is_empty()); + let callees = self.resolve_indirect_calls(func.local); + self.transfer_intra_call(&callees, &mut state); + } + Operand::Constant(box constant) => { + let ConstantKind::Val(value, ty) = constant.literal else { unreachable!() }; + assert!(matches!(value, ConstValue::ZeroSized)); + let TyKind::FnDef(def_id, _) = ty.kind() else { unreachable!() }; + let name = self.def_id_to_string(*def_id); + let mut segs: Vec<_> = name.split("::").collect(); + let seg0 = segs.pop().unwrap_or_default(); + let seg1 = segs.pop().unwrap_or_default(); + let seg2 = segs.pop().unwrap_or_default(); + let seg3 = segs.pop().unwrap_or_default(); + let sig = self.tcx.fn_sig(def_id).skip_binder(); + let inputs = sig.inputs().skip_binder(); + if let Some(local_def_id) = def_id.as_local() { + if let Some(impl_def_id) = self.tcx.impl_of_method(*def_id) { + let span = self.tcx.span_of_impl(impl_def_id).unwrap(); + let code = + self.tcx.sess.source_map().span_to_snippet(span).unwrap(); + assert_eq!(code, "BitfieldStruct"); + } else if seg1.contains("{extern#") { + self.transfer_c_call(seg0, inputs, args, &mut state); + } else { + self.transfer_intra_call( + &HashSet::from_iter([local_def_id]), + &mut state, + ); + } + } else { + self.transfer_rust_call( + (seg3, seg2, seg1, seg0), + inputs, + args, + &mut state, + ); + } + } + } + vec![(location, state)] } _ => unreachable!(), } } + + fn transfer_intra_call(&self, callees: &HashSet, state: &mut AbsMem) { + let graph = state.gm(); + for callee in callees { + for (local, depth) in self.locals_invalidated_by_call(*callee) { + graph.invalidate_deref(local, depth, None); + } + } + } + + fn transfer_c_call( + &self, + _name: &str, + _inputs: &[Ty<'_>], + _args: &[Operand<'_>], + _state: &mut AbsMem, + ) { + todo!() + } + + fn transfer_rust_call( + &self, + name: (&str, &str, &str, &str), + inputs: &[Ty<'_>], + _args: &[Operand<'_>], + _state: &mut AbsMem, + ) { + if inputs.iter().all(|t| t.is_primitive()) { + return; + } + match name { + ("", "option", _, "unwrap") => {} + _ => todo!("{:?}", name), + } + } } #[derive(Debug, Clone)] diff --git a/src/relational/tests.rs b/src/relational/tests.rs index fd4a856..039c695 100644 --- a/src/relational/tests.rs +++ b/src/relational/tests.rs @@ -817,3 +817,75 @@ fn test_deref_eq_invalidate() { }, ); } + +#[test] +fn test_call_invalidate() { + // _2 = const 0_i32 + // (*_1) = move _2 + // + // _1 = const 0_i32 + // _3 = &mut _1 + // _2 = &raw mut (*_3) + // _4 = foo::f(_2) -> [return: bb1, unwind continue] + analyze_fn( + " + unsafe fn f(mut x: *mut libc::c_int) { + *x = 0 as libc::c_int; + } + let mut x: libc::c_int = 0 as libc::c_int; + let mut y: *mut libc::c_int = &mut x; + f(y); + ", + |g, _, _| { + let n = get_nodes(&g, 1..=3); + let i = get_ids(&g, 1..=3); + assert_eq!(n[&2].as_ptr(), &AbsLoc::new_root(i[&1])); + assert_eq!(n[&3].as_ptr(), &AbsLoc::new_root(i[&1])); + + assert_eq!(g.get_local_as_int(1), None); + }, + ); +} + +#[test] +fn test_indirect_call_invalidate() { + // switchInt(move _1) -> [0: bb2, otherwise: bb1] + // _3 = foo::f as unsafe fn(*mut i32) (PointerCoercion(ReifyFnPointer)) + // _2 = std::option::Option::::Some(move _3) + // goto -> bb3 + // _4 = foo::g as unsafe fn(*mut i32) (PointerCoercion(ReifyFnPointer)) + // _2 = std::option::Option::::Some(move _4) + // goto -> bb3 + // _5 = const 0_i32 + // _7 = &mut _5 + // _6 = &raw mut (*_7) + // _10 = _2 + // _9 = std::option::Option::::unwrap(move _10) + // _8 = move _9(_6) -> [return: bb5, unwind continue] + analyze_fn_with( + "", + "mut x: libc::c_int", + " + unsafe fn f(mut x: *mut libc::c_int) { + *x = 0 as libc::c_int; + } + unsafe fn g(mut x: *mut libc::c_int) {} + let mut h: Option:: ()> = if x != 0 { + Some(f as unsafe fn(*mut libc::c_int) -> ()) + } else { + Some(g as unsafe fn(*mut libc::c_int) -> ()) + }; + let mut y: libc::c_int = 0 as libc::c_int; + let mut z: *mut libc::c_int = &mut y; + h.unwrap()(z); + ", + |g, _, _| { + let n = get_nodes(&g, 5..=7); + let i = get_ids(&g, 5..=7); + assert_eq!(n[&6].as_ptr(), &AbsLoc::new_root(i[&5])); + assert_eq!(n[&7].as_ptr(), &AbsLoc::new_root(i[&5])); + + assert_eq!(g.get_local_as_int(5), None); + }, + ); +} diff --git a/src/steensgaard/mod.rs b/src/steensgaard/mod.rs index fd49e2b..1c34ba8 100644 --- a/src/steensgaard/mod.rs +++ b/src/steensgaard/mod.rs @@ -403,8 +403,8 @@ struct Analyzer<'tcx, 'a> { } pub struct AnalysisResults { - pub vars: HashMap, - pub var_tys: HashMap, + vars: HashMap, + var_tys: HashMap, fns: HashMap, fn_tys: HashMap, } @@ -452,28 +452,17 @@ impl AnalysisResults { pub fn get_alias_graph(&self) -> AliasGraph { let id_to_node = self.vars.clone(); - let node_to_ids = graph::inverse(&to_graph(&id_to_node)); - let points_to: HashMap<_, _> = self + let node_to_ids = inv(&id_to_node); + let (points_to, points_to_fn): (HashMap<_, _>, HashMap<_, _>) = self .var_tys .iter() .filter_map(|(k, v)| { let VarType::Ref(ty) = v else { return None }; - Some((*k, ty.var_ty)) + Some(((*k, ty.var_ty), (*k, ty.fn_ty))) }) - .collect(); - let pointed_by = graph::inverse(&to_graph(&points_to)); - let points_to_fn: HashMap<_, _> = self - .var_tys - .iter() - .filter_map(|(k, v)| { - let VarType::Ref(ty) = v else { return None }; - Some((*k, ty.fn_ty)) - }) - .collect(); - let mut pointed_by_fn: HashMap<_, HashSet<_>> = HashMap::new(); - for (var_id, fn_id) in &points_to_fn { - pointed_by_fn.entry(*fn_id).or_default().insert(*var_id); - } + .unzip(); + let pointed_by = inv(&points_to); + let pointed_by_fn = inv(&points_to_fn); AliasGraph { id_to_node, node_to_ids, @@ -485,10 +474,16 @@ impl AnalysisResults { } } -fn to_graph(map: &HashMap) -> HashMap> { - map.iter() - .map(|(k, v)| (*k, HashSet::from_iter([*v]))) - .collect() +fn inv(map: &HashMap) -> HashMap> +where + T: Eq + std::hash::Hash + Copy, + S: Eq + std::hash::Hash + Copy, +{ + let mut new_map: HashMap<_, HashSet<_>> = HashMap::new(); + for (k, v) in map { + new_map.entry(*v).or_default().insert(*k); + } + new_map } #[derive(Debug)] @@ -532,9 +527,11 @@ impl AliasGraph { }; aliases.insert(alias); } - for node in &self.pointed_by[&node] { - if !done.contains(node) { - new_remainings.insert((*node, depth + 1)); + if let Some(nodes) = self.pointed_by.get(&node) { + for node in nodes { + if !done.contains(node) { + new_remainings.insert((*node, depth + 1)); + } } } }