Skip to content

Commit

Permalink
call graph construction
Browse files Browse the repository at this point in the history
  • Loading branch information
Medowhill committed Jan 19, 2024
1 parent 0613409 commit 19f5ba6
Show file tree
Hide file tree
Showing 3 changed files with 195 additions and 23 deletions.
5 changes: 1 addition & 4 deletions src/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,12 +104,9 @@ pub fn inverse<T: Clone + Eq + std::hash::Hash>(
map: &HashMap<T, HashSet<T>>,
) -> HashMap<T, HashSet<T>> {
let mut inv: HashMap<_, HashSet<_>> = HashMap::new();
for node in map.keys() {
inv.insert(node.clone(), HashSet::new());
}
for (node, succs) in map {
for succ in succs {
inv.get_mut(succ).unwrap().insert(node.clone());
inv.entry(succ.clone()).or_default().insert(node.clone());
}
}
inv
Expand Down
98 changes: 82 additions & 16 deletions src/relational/analysis.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@ use rustc_data_structures::graph::WithSuccessors;
use rustc_hir::{def_id::DefId, ItemKind};
use rustc_index::bit_set::BitSet;
use rustc_middle::{
mir::{BasicBlock, Body, Local, Location, Operand, TerminatorKind},
mir::{
interpret::ConstValue, visit::Visitor, BasicBlock, Body, ConstantKind, Local, Location,
Operand, Terminator, TerminatorKind,
},
ty::{AdtKind, Ty, TyCtxt, TyKind, TypeAndMut},
};
use rustc_mir_dataflow::Analysis;
Expand All @@ -34,18 +37,29 @@ fn analyze_input(input: Input) -> AnalysisResults {
pub fn analyze(tcx: TyCtxt<'_>) -> AnalysisResults {
let hir = tcx.hir();
let may_aliases = steensgaard::analyze(tcx);
let var_graph = may_aliases.get_alias_graph();

let func_ids: Vec<_> = hir
.items()
.filter_map(|item_id| {
let item = hir.item(item_id);
if item.ident.name.as_str() == "main" || !matches!(item.kind, ItemKind::Fn(_, _, _)) {
return None;
}
Some(item_id.owner_id.def_id)
})
.collect();

let mut functions = HashMap::new();
for item_id in hir.items() {
let item = hir.item(item_id);
if item.ident.name.as_str() == "main" {
continue;
}
if !matches!(item.kind, ItemKind::Fn(_, _, _)) {
continue;
}
let _call_graph: HashMap<_, _> = func_ids
.iter()
.map(|def_id| {
let body = tcx.optimized_mir(def_id.to_def_id());
(*def_id, get_callees(*def_id, body, &var_graph, tcx))
})
.collect();

let local_def_id = item_id.owner_id.def_id;
let mut functions = HashMap::new();
for local_def_id in func_ids.iter().copied() {
let def_id = local_def_id.to_def_id();
let body = tcx.optimized_mir(def_id);

Expand Down Expand Up @@ -203,10 +217,6 @@ impl<'tcx> Analyzer<'tcx, '_> {

aliases
}

// fn def_id_to_string(&self, def_id: DefId) -> String {
// self.tcx.def_path(def_id).to_string_no_crate_verbose()
// }
}

#[derive(Debug)]
Expand Down Expand Up @@ -249,7 +259,7 @@ impl TyStructure {
if adt_def.adt_kind() == AdtKind::Enum {
let def_id = adt_def.did();
let name = tcx.def_path(def_id).to_string_no_crate_verbose();
assert!(name.contains("::Option") && def_id.is_local(), "{name}");
assert!(name.contains("::Option") && !def_id.is_local(), "{name}");
Self::Adt(vec![Self::Leaf])
} else {
let variant = &adt_def.variants()[VariantIdx::from_usize(0)];
Expand Down Expand Up @@ -428,3 +438,59 @@ fn get_dead_locals<'tcx>(body: &Body<'tcx>, tcx: TyCtxt<'tcx>) -> Vec<BitSet<Loc
})
.collect()
}

struct CallVisitor<'tcx, 'a> {
tcx: TyCtxt<'tcx>,
callees: HashSet<LocalDefId>,
current_fn: LocalDefId,
var_graph: &'a steensgaard::AliasGraph,
}

impl CallVisitor<'_, '_> {
fn def_id_to_string(&self, def_id: DefId) -> String {
self.tcx.def_path(def_id).to_string_no_crate_verbose()
}
}

impl<'tcx> Visitor<'tcx> for CallVisitor<'tcx, '_> {
fn visit_terminator(&mut self, terminator: &Terminator<'tcx>, location: Location) {
if let TerminatorKind::Call { func, .. } = &terminator.kind {
match func {
Operand::Copy(f) | Operand::Move(f) => {
assert!(f.projection.is_empty());
let callees = self.var_graph.find_fn_may_aliases(self.current_fn, f.local);
self.callees.extend(callees);
}
Operand::Constant(box constant) => {
if let ConstantKind::Val(value, ty) = constant.literal {
assert_eq!(value, ConstValue::ZeroSized);
let TyKind::FnDef(def_id, _) = ty.kind() else { unreachable!() };
let name = self.def_id_to_string(*def_id);
if let Some(def_id) = def_id.as_local() {
if !name.contains("{extern#") {
self.callees.insert(def_id);
}
}
}
}
}
}
self.super_terminator(terminator, location);
}
}

fn get_callees<'tcx>(
current_fn: LocalDefId,
body: &Body<'tcx>,
var_graph: &steensgaard::AliasGraph,
tcx: TyCtxt<'tcx>,
) -> HashSet<LocalDefId> {
let mut visitor = CallVisitor {
tcx,
callees: HashSet::new(),
current_fn,
var_graph,
};
visitor.visit_body(body);
visitor.callees
}
115 changes: 112 additions & 3 deletions src/steensgaard/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ use rustc_hir::{ForeignItemKind, ItemKind};
use rustc_middle::{
mir::{
interpret::{ConstValue, GlobalAlloc, Scalar},
BinOp, ConstantKind, Operand, Rvalue, Statement, StatementKind, Terminator, TerminatorKind,
BinOp, ConstantKind, Local, Operand, Rvalue, Statement, StatementKind, Terminator,
TerminatorKind,
},
ty::{Ty, TyCtxt, TyKind},
};
Expand Down Expand Up @@ -404,8 +405,8 @@ struct Analyzer<'tcx, 'a> {
pub struct AnalysisResults {
pub vars: HashMap<VarId, VarId>,
pub var_tys: HashMap<VarId, VarType>,
pub fns: HashMap<FnId, FnId>,
pub fn_tys: HashMap<FnId, FnType>,
fns: HashMap<FnId, FnId>,
fn_tys: HashMap<FnId, FnType>,
}

impl std::fmt::Debug for AnalysisResults {
Expand Down Expand Up @@ -448,6 +449,114 @@ impl AnalysisResults {
let VarType::Ref(ty) = self.var_tys[&id] else { panic!() };
ty
}

pub fn get_alias_graph(&self) -> AliasGraph {
let id_to_node = self.vars.clone();
let node_to_ids = graph::inverse(&to_graph(&id_to_node));
let points_to: HashMap<_, _> = self
.var_tys
.iter()
.filter_map(|(k, v)| {
let VarType::Ref(ty) = v else { return None };
Some((*k, ty.var_ty))
})
.collect();
let pointed_by = graph::inverse(&to_graph(&points_to));
let points_to_fn: HashMap<_, _> = self
.var_tys
.iter()
.filter_map(|(k, v)| {
let VarType::Ref(ty) = v else { return None };
Some((*k, ty.fn_ty))
})
.collect();
let mut pointed_by_fn: HashMap<_, HashSet<_>> = HashMap::new();
for (var_id, fn_id) in &points_to_fn {
pointed_by_fn.entry(*fn_id).or_default().insert(*var_id);
}
AliasGraph {
id_to_node,
node_to_ids,
points_to,
pointed_by,
points_to_fn,
fn_pointed_by: pointed_by_fn,
}
}
}

fn to_graph<T: Copy + Eq + std::hash::Hash>(map: &HashMap<T, T>) -> HashMap<T, HashSet<T>> {
map.iter()
.map(|(k, v)| (*k, HashSet::from_iter([*v])))
.collect()
}

#[derive(Debug)]
pub struct AliasGraph {
id_to_node: HashMap<VarId, VarId>,
node_to_ids: HashMap<VarId, HashSet<VarId>>,
points_to: HashMap<VarId, VarId>,
pointed_by: HashMap<VarId, HashSet<VarId>>,
points_to_fn: HashMap<VarId, FnId>,
fn_pointed_by: HashMap<FnId, HashSet<VarId>>,
}

#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct MayAlias {
pub function: LocalDefId,
pub local: Local,
pub depth: usize,
}

impl AliasGraph {
pub fn find_may_aliases(&self, f: LocalDefId, l: Local) -> HashSet<MayAlias> {
let id = VarId::Local(f, l.as_u32());
let node = self.id_to_node[&id];
let pointed_node = self.points_to[&node];

let mut aliases = HashSet::new();
let mut done = HashSet::new();
let mut remainings = HashSet::new();
remainings.insert((pointed_node, 0));

while !remainings.is_empty() {
let mut new_remainings = HashSet::new();
for (node, depth) in remainings {
done.insert(node);
for node in &self.node_to_ids[&node] {
let VarId::Local(f, l) = *node else { continue };
let alias = MayAlias {
function: f,
local: Local::from_u32(l),
depth,
};
aliases.insert(alias);
}
for node in &self.pointed_by[&node] {
if !done.contains(node) {
new_remainings.insert((*node, depth + 1));
}
}
}
remainings = new_remainings;
}

aliases
}

pub fn find_fn_may_aliases(&self, f: LocalDefId, l: Local) -> HashSet<LocalDefId> {
let id = VarId::Local(f, l.as_u32());
let node = self.id_to_node[&id];
let pointed_fn = self.points_to_fn[&node];
let nodes = &self.fn_pointed_by[&pointed_fn];
nodes
.iter()
.filter_map(|node| {
let VarId::Global(f) = node else { return None };
Some(*f)
})
.collect()
}
}

impl<'tcx, 'a> Analyzer<'tcx, 'a> {
Expand Down

0 comments on commit 19f5ba6

Please sign in to comment.