Skip to content

Commit

Permalink
avoid computing whole reachability
Browse files Browse the repository at this point in the history
  • Loading branch information
Medowhill committed May 6, 2024
1 parent 08ddad2 commit 59c5468
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 114 deletions.
95 changes: 14 additions & 81 deletions src/graph.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet};
use std::collections::{HashMap, HashSet};

use etrace::some_or;
use rustc_data_structures::graph::{scc::Sccs, vec_graph::VecGraph};
Expand Down Expand Up @@ -100,18 +100,6 @@ pub fn reachable_vertices<T: Idx + std::hash::Hash>(
.collect()
}

fn symmetric_closure<T: Clone + Eq + PartialOrd + Ord>(
map: &BTreeMap<T, BTreeSet<T>>,
) -> BTreeMap<T, BTreeSet<T>> {
let mut clo = map.clone();
for (node, succs) in map {
for succ in succs {
clo.get_mut(succ).unwrap().insert(node.clone());
}
}
clo
}

pub fn inverse<T: Clone + Eq + std::hash::Hash>(
map: &HashMap<T, HashSet<T>>,
) -> HashMap<T, HashSet<T>> {
Expand All @@ -127,95 +115,40 @@ pub fn inverse<T: Clone + Eq + std::hash::Hash>(
inv
}

/// `map` must not have a cycle.
pub fn post_order<T: Clone + Eq + PartialOrd + Ord>(
map: &BTreeMap<T, BTreeSet<T>>,
inv_map: &BTreeMap<T, BTreeSet<T>>,
) -> Vec<Vec<T>> {
let mut res = vec![];
let clo = symmetric_closure(map);
let (_, components) = compute_sccs(&clo);

for (_, component) in components {
let mut v = vec![];
let mut reached = BTreeSet::new();
for node in component {
if inv_map.get(&node).unwrap().is_empty() {
dfs_walk(&node, &mut v, &mut reached, map);
}
}
res.push(v);
}

res
}

fn dfs_walk<T: Clone + Eq + PartialOrd + Ord>(
node: &T,
v: &mut Vec<T>,
reached: &mut BTreeSet<T>,
map: &BTreeMap<T, BTreeSet<T>>,
) {
reached.insert(node.clone());
for succ in map.get(node).unwrap() {
if !reached.contains(succ) {
dfs_walk(succ, v, reached, map);
}
}
v.push(node.clone());
}

pub fn compute_sccs<T: Clone + Eq + PartialOrd + Ord>(
map: &BTreeMap<T, BTreeSet<T>>,
) -> (BTreeMap<Id, BTreeSet<Id>>, BTreeMap<Id, BTreeSet<T>>) {
let id_map: BTreeMap<_, _> = map
pub fn compute_sccs<T: Clone + Eq + std::hash::Hash>(
map: &HashMap<T, HashSet<T>>,
) -> (HashMap<usize, HashSet<usize>>, HashMap<usize, HashSet<T>>) {
let id_map: HashMap<_, _> = map
.keys()
.enumerate()
.map(|(i, f)| (i, f.clone()))
.collect();
let inv_id_map: BTreeMap<_, _> = id_map.iter().map(|(i, f)| (f.clone(), *i)).collect();
let inv_id_map: HashMap<_, _> = id_map.iter().map(|(i, f)| (f.clone(), *i)).collect();
let edges = map
.iter()
.flat_map(|(node, succs)| {
succs.iter().map(|succ| {
(
Id::new(*inv_id_map.get(node).unwrap()),
Id::new(*inv_id_map.get(succ).unwrap()),
)
})
succs
.iter()
.map(|succ| (inv_id_map[node], inv_id_map[succ]))
})
.collect();
let sccs: Sccs<Id, Id> = Sccs::new(&VecGraph::new(map.len(), edges));
let sccs: Sccs<usize, usize> = Sccs::new(&VecGraph::new(map.len(), edges));

let component_graph: BTreeMap<_, _> = sccs
let component_graph: HashMap<_, _> = sccs
.all_sccs()
.map(|node| (node, sccs.successors(node).iter().cloned().collect()))
.collect();

let mut component_elems: BTreeMap<_, BTreeSet<_>> = BTreeMap::new();
let mut component_elems: HashMap<_, HashSet<_>> = HashMap::new();
for i in 0..(map.len()) {
let scc = sccs.scc(Id::new(i));
let node = id_map.get(&i).unwrap().clone();
let scc = sccs.scc(i);
let node = id_map[&i].clone();
component_elems.entry(scc).or_default().insert(node);
}

(component_graph, component_elems)
}

#[derive(Clone, Copy, Eq, PartialEq, Debug, Hash, PartialOrd, Ord)]
#[repr(transparent)]
pub struct Id(usize);

impl Idx for Id {
fn new(idx: usize) -> Self {
Self(idx)
}

fn index(self) -> usize {
self.0
}
}

struct Node<V, K> {
v: V,
k: K,
Expand Down
69 changes: 53 additions & 16 deletions src/points_to/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::{
cell::RefCell,
collections::{hash_map::Entry, HashMap, HashSet},
path::Path,
};
Expand Down Expand Up @@ -83,12 +84,54 @@ pub struct AnalysisResults {

pub solutions: Solutions,

pub call_graph: HashMap<LocalDefId, HashSet<LocalDefId>>,
pub indirect_calls: HashMap<LocalDefId, HashMap<BasicBlock, Vec<LocalDefId>>>,
pub call_graph_sccs: HashMap<usize, HashSet<usize>>,
pub scc_elems: HashMap<usize, HashSet<LocalDefId>>,
pub fn_sccs: HashMap<LocalDefId, usize>,
pub reachables: RefCell<HashMap<usize, HashSet<usize>>>,

pub writes: HashMap<LocalDefId, HashMap<Location, HybridBitSet<usize>>>,
pub bitfield_writes: HashMap<LocalDefId, HashMap<Location, HybridBitSet<usize>>>,
pub call_writes: HashMap<LocalDefId, HybridBitSet<usize>>,
pub fn_writes: HashMap<LocalDefId, HybridBitSet<usize>>,
}

impl AnalysisResults {
pub fn call_writes(&self, def_id: LocalDefId) -> HybridBitSet<usize> {
self.with_reachables(self.fn_sccs[&def_id], |sccs| {
let mut writes = HybridBitSet::new_empty(self.ends.len());
for scc in sccs {
for f in &self.scc_elems[scc] {
writes.union(&self.fn_writes[f]);
}
}
writes
})
}

#[inline]
fn with_reachables<R, F: FnOnce(&HashSet<usize>) -> R>(&self, scc: usize, f: F) -> R {
if let Some(rs) = self.reachables.borrow().get(&scc) {
return f(rs);
}
let mut reachables = HashSet::new();
self.reachables_from_scc(scc, &mut reachables);
let r = f(&reachables);
self.reachables.borrow_mut().insert(scc, reachables.clone());
r
}

fn reachables_from_scc(&self, scc: usize, reachables: &mut HashSet<usize>) {
if let Some(rs) = self.reachables.borrow().get(&scc) {
reachables.extend(rs);
return;
}
let mut this_reachables: HashSet<_> = [scc].into_iter().collect();
for succ in &self.call_graph_sccs[&scc] {
self.reachables_from_scc(*succ, &mut this_reachables);
}
reachables.extend(this_reachables.iter());
self.reachables.borrow_mut().insert(scc, this_reachables);
}
}

pub fn pre_analyze<'a, 'tcx>(
Expand Down Expand Up @@ -441,19 +484,10 @@ pub fn post_analyze<'a, 'tcx>(
let callees = pre.call_graph.get_mut(caller).unwrap();
callees.extend(calls.values().flatten());
}
let mut reachability = graph::transitive_closure(&pre.call_graph);
for (func, reachables) in &mut reachability {
reachables.insert(*func);
}
let call_writes: HashMap<_, _> = reachability
let (call_graph_sccs, scc_elems) = graph::compute_sccs(&pre.call_graph);
let fn_sccs = scc_elems
.iter()
.map(|(func, reachables)| {
let mut writes = HybridBitSet::new_empty(pre.ends.len());
for reachable in reachables {
writes.union(&fn_writes[reachable]);
}
(*func, writes)
})
.flat_map(|(id, fs)| fs.iter().map(|f| (*f, *id)))
.collect();

AnalysisResults {
Expand All @@ -462,11 +496,14 @@ pub fn post_analyze<'a, 'tcx>(
graph: pre.graph,
var_nodes: pre.var_nodes,
solutions,
call_graph: pre.call_graph,
indirect_calls,
call_graph_sccs,
scc_elems,
fn_sccs,
reachables: RefCell::new(HashMap::new()),
writes,
bitfield_writes,
call_writes,
fn_writes,
}
}

Expand Down
5 changes: 0 additions & 5 deletions src/points_to/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2688,11 +2688,6 @@ fn test_call_graph() {
",
|res, tcx| {
let def_id = find("f", tcx);
let callees = &res.call_graph[&def_id];
assert_eq!(callees.len(), 3);
assert!(callees.contains(&find("g", tcx)));
assert!(callees.contains(&find("h", tcx)));
assert!(callees.contains(&find("i", tcx)));
let indirect_calls = &res.indirect_calls[&def_id];
assert_eq!(indirect_calls.len(), 1);
let callees = &indirect_calls[&BasicBlock::new(5)];
Expand Down
30 changes: 18 additions & 12 deletions src/relational/analysis.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::{
};

use rustc_data_structures::graph::WithSuccessors;
use rustc_hir::def_id::DefId;
use rustc_hir::{def_id::DefId, ItemKind};
use rustc_index::bit_set::{BitSet, HybridBitSet};
use rustc_middle::{
mir::{BasicBlock, BinOp, Body, Local, Location, Rvalue, StatementKind, TerminatorKind},
Expand Down Expand Up @@ -38,15 +38,21 @@ pub fn analyze(tcx: TyCtxt<'_>, gc: bool) -> AnalysisResults {
let pre = points_to::pre_analyze(&tss, tcx);
let solutions = points_to::analyze(&pre, &tss, tcx);
let may_points_to = points_to::post_analyze(pre, solutions, &tss, tcx);
let functions = may_points_to
.call_graph
.keys()
.map(|def_id| {
let body = tcx.optimized_mir(*def_id);
(
*def_id,
analyze_fn(*def_id, body, &tss, &may_points_to, gc, tcx).in_states,
)
let hir = tcx.hir();
let functions = hir
.items()
.filter_map(|item_id| {
let item = hir.item(item_id);
let local_def_id = item.owner_id.def_id;
let body = match item.kind {
ItemKind::Fn(_, _, _) if item.ident.name.as_str() != "main" => {
tcx.optimized_mir(local_def_id)
}
ItemKind::Static(_, _, _) => tcx.mir_for_ctfe(local_def_id),
_ => return None,
};
let res = analyze_fn(local_def_id, body, &tss, &may_points_to, gc, tcx).in_states;
Some((local_def_id, res))
})
.collect();
AnalysisResults { functions }
Expand Down Expand Up @@ -189,9 +195,9 @@ impl Analyzer<'_, '_> {

pub fn get_call_writes(&self, callees: &[LocalDefId]) -> Option<HybridBitSet<usize>> {
let c0 = callees.get(0)?;
let mut writes = self.may_points_to.call_writes[c0].clone();
let mut writes = self.may_points_to.call_writes(*c0);
for c in &callees[1..] {
writes.union(&self.may_points_to.call_writes[c]);
writes.union(&self.may_points_to.call_writes(*c));
}
if writes.is_empty() {
None
Expand Down

0 comments on commit 59c5468

Please sign in to comment.