Skip to content

Commit

Permalink
transformation: enum & helper methods
Browse files Browse the repository at this point in the history
  • Loading branch information
Medowhill committed May 13, 2024
1 parent ccc110d commit 95ad219
Show file tree
Hide file tree
Showing 2 changed files with 245 additions and 15 deletions.
208 changes: 195 additions & 13 deletions src/analysis.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,17 @@ use std::{
use etrace::some_or;
use relational::Obj;
use rustc_abi::{FieldIdx, VariantIdx};
use rustc_hir::{definitions::DefPathDataName, ItemKind};
use rustc_hir::{
def::Res,
definitions::DefPathDataName,
intravisit::{self, Visitor as HVisitor},
Expr, ExprKind, ItemKind, Node, QPath, VariantData,
};
use rustc_index::IndexVec;
use rustc_middle::{
hir::nested_filter,
mir::{
visit::{MutatingUseContext, PlaceContext, Visitor},
visit::{MutatingUseContext, PlaceContext, Visitor as MVisitor},
AggregateKind, ConstantKind, HasLocalDecls, Local, LocalDecl, Location, Place, PlaceElem,
ProjectionElem, Rvalue, Terminator, TerminatorKind,
},
Expand Down Expand Up @@ -58,6 +64,7 @@ pub fn analyze(tcx: TyCtxt<'_>, conf: &Config) {

let mut structs = vec![];
let mut unions = vec![];
let mut union_to_struct = HashMap::new();
let mut ty_graph: HashMap<_, Vec<_>> = HashMap::new();
for item_id in hir.items() {
let item = hir.item(item_id);
Expand Down Expand Up @@ -94,7 +101,7 @@ pub fn analyze(tcx: TyCtxt<'_>, conf: &Config) {
continue;
}
let mut has_union = false;
for f in &variant.fields {
for (i, f) in variant.fields.iter_enumerated() {
let TyKind::Adt(adt_def, _) = f.ty(tcx, List::empty()).kind() else { continue };
if !adt_def.is_union() {
continue;
Expand All @@ -112,6 +119,7 @@ pub fn analyze(tcx: TyCtxt<'_>, conf: &Config) {
.contains(&tcx.def_path(u.to_def_id()).to_string_no_crate_verbose())
{
unions.push(u);
union_to_struct.insert(u, (i, local_def_id));
has_union = true;
}
}
Expand All @@ -129,6 +137,7 @@ pub fn analyze(tcx: TyCtxt<'_>, conf: &Config) {
})
.collect();

let mut fields: HashMap<_, BTreeSet<_>> = HashMap::new();
let mut tag_values: HashMap<_, BTreeMap<_, BTreeSet<u128>>> = HashMap::new();
let mut variant_tag_values: HashMap<_, BTreeMap<_, BTreeMap<_, BTreeMap<_, BTreeSet<_>>>>> =
HashMap::new();
Expand All @@ -140,20 +149,29 @@ pub fn analyze(tcx: TyCtxt<'_>, conf: &Config) {
ItemKind::Static(_, _, _) => tcx.mir_for_ctfe(local_def_id),
_ => continue,
};
let mut visitor = BodyVisitor::new(tcx, &body.local_decls, &structs, &unions);
let mut visitor = MBodyVisitor::new(tcx, &body.local_decls, &structs, &unions);
visitor.visit_body(body);
if visitor.accesses.is_empty()
&& visitor.struct_accesses.is_empty()
&& visitor.aggregates.is_empty()
{
continue;
}
for a in &visitor.accesses {
fields.entry(a.ty).or_default().insert(a.field);
}
let locals: HashSet<_> = visitor
.accesses
.iter()
.map(|a| a.local)
.chain(visitor.struct_accesses.iter().copied())
.chain(visitor.aggregates.iter().copied())
.chain(
visitor
.aggregates
.values()
.flatten()
.map(|(local, _)| *local),
)
.collect();
let mut local_to_unions: HashMap<_, Vec<_>> = HashMap::new();
for i in locals {
Expand Down Expand Up @@ -259,10 +277,10 @@ pub fn analyze(tcx: TyCtxt<'_>, conf: &Config) {
}
}

let mut tag_fields = HashMap::new();
for (u, vs) in &variant_tag_values {
println!("Union {:?}", u);
let adt_def = tcx.adt_def(*u);
println!("{}", adt_def.variant(VariantIdx::from_u32(0)).fields.len());
println!("Fields {:?}", fields[u]);
let tag_field = vs
.iter()
.filter_map(|(f, vs)| {
Expand All @@ -281,6 +299,7 @@ pub fn analyze(tcx: TyCtxt<'_>, conf: &Config) {
})
.max_by_key(|(_, n)| *n);
if let Some((tag_field, _)) = tag_field {
tag_fields.insert(*u, tag_field);
println!("Field: {:?}", tag_field);
let mut all_tags = tag_values[&u][&tag_field].clone();
for (variant, vs) in &vs[&tag_field] {
Expand Down Expand Up @@ -309,6 +328,125 @@ pub fn analyze(tcx: TyCtxt<'_>, conf: &Config) {
}
}
}

let mut rewrite_structs: HashMap<_, Vec<_>> = HashMap::new();
for (u, tag_field) in tag_fields {
let (i, s) = union_to_struct[&u];
rewrite_structs
.entry(s)
.or_default()
.push((u, tag_field, i));
}

let source_map = tcx.sess.source_map();

let mut struct_tag_fields = HashMap::new();
for (s, us) in rewrite_structs {
let fs = us.iter().map(|(_, f, _)| *f).collect::<HashSet<_>>();
assert_eq!(fs.len(), 1);
let field_idx = fs.into_iter().next().unwrap();
struct_tag_fields.insert(s, field_idx);
let Node::Item(item) = hir.get_if_local(s.to_def_id()).unwrap() else { unreachable!() };
let struct_name = item.ident.name.to_ident_string();
let ItemKind::Struct(VariantData::Struct(sfs, _), _) = item.kind else { unreachable!() };
let field = &sfs[field_idx as usize];
let span = source_map.span_extend_to_line(field.span);
println!("{}", source_map.span_to_snippet(span).unwrap());
let field_name = field.ident.name.to_ident_string();
let field_ty = source_map.span_to_snippet(field.ty.span).unwrap();

for (u, _, i) in us {
let struct_field_name = sfs[i.as_usize()].ident.name.to_ident_string();
let Node::Item(item) = hir.get_if_local(u.to_def_id()).unwrap() else { unreachable!() };
let union_name = item.ident.name.to_ident_string();
println!("{}", source_map.span_to_snippet(item.span).unwrap());
let ItemKind::Union(VariantData::Struct(ufs, _), _) = item.kind else { unreachable!() };
let tys: Vec<_> = ufs
.iter()
.map(|f| source_map.span_to_snippet(f.ty.span).unwrap().to_string())
.collect();
let mut all_fields = fields[&u].clone();
let mut all_tags = tag_values[&u][&field_idx].clone();
let mut enum_variants = vec![];
let mut field_methods = String::new();
for (field, tags) in &variant_tag_values[&u][&field_idx] {
all_fields.remove(&FieldIdx::from_u32(*field));
for t in tags.keys() {
all_tags.remove(t);
}
for tag in tags.keys() {
enum_variants.push((Some(*field), *tag));
}
let field_name = ufs[*field as usize].ident.name.to_ident_string();
let ty = &tys[*field as usize];
let pattern: String = tags
.keys()
.map(|tag| format!("Self::{}{}(v)", field_name, tag))
.intersperse("|".to_string())
.collect();
let pattern = if tags.len() == 1 {
pattern
} else {
format!("({})", pattern)
};
let method = format!(
"impl {}{{pub fn get_{}(&mut self)->*mut {}{{let {}=self else{{panic!()}};v as _}}}}",
union_name, field_name, ty, pattern,
);
field_methods.push_str(&method);
field_methods.push('\n');
}
if all_fields.is_empty() {
for tag in all_tags {
enum_variants.push((None, tag));
}
} else {
assert_eq!(all_fields.len(), 1);
assert_eq!(all_tags.len(), 1);
let field = all_fields.into_iter().next().unwrap().as_u32();
let tag = all_tags.into_iter().next().unwrap();
enum_variants.push((Some(field), tag));
}
let mut enum_str = format!("pub enum {}{{", union_name);
let mut get_tag_method = format!(
"impl {}{{pub fn get_{}(self)->{}{{match self.{}{{",
struct_name, field_name, field_ty, struct_field_name
);
for (f, t) in enum_variants {
if let Some(f) = f {
let field_name = ufs[f as usize].ident.name.to_ident_string();
let ty = &tys[f as usize];
let variant_name = format!("{}{}", field_name, t);
enum_str.push_str(&format!("{}({})", variant_name, ty));
get_tag_method.push_str(&format!("{}::{}(_)=>{}", union_name, variant_name, t));
} else {
let variant_name = format!("Empty{}", t);
enum_str.push_str(&variant_name);
get_tag_method.push_str(&format!("{}::{}=>{}", union_name, variant_name, t));
};
enum_str.push(',');
get_tag_method.push(',');
}
enum_str.push('}');
get_tag_method.push_str("}}}");
println!("{}", enum_str);
println!("{}", get_tag_method);
println!("{}", field_methods);
}
}

for item_id in hir.items() {
let item = hir.item(item_id);
let (ItemKind::Fn(_, _, body_id) | ItemKind::Static(_, _, body_id)) = item.kind else {
continue;
};
let body = hir.body(body_id);
let mut visitor = HBodyVisitor {
tcx,
struct_tag_fields: &struct_tag_fields,
};
visitor.visit_body(body);
}
}

#[derive(Debug, Clone, Copy)]
Expand Down Expand Up @@ -390,17 +528,17 @@ fn compute_tags<'tcx, D: HasLocalDecls<'tcx> + ?Sized>(
v
}

struct BodyVisitor<'tcx, 'a> {
struct MBodyVisitor<'tcx, 'a> {
tcx: TyCtxt<'tcx>,
local_decls: &'a IndexVec<Local, LocalDecl<'tcx>>,
structs: &'a Vec<LocalDefId>,
unions: &'a Vec<LocalDefId>,
accesses: Vec<FieldAccess<'tcx>>,
struct_accesses: HashSet<Local>,
aggregates: HashSet<Local>,
aggregates: HashMap<LocalDefId, Vec<(Local, Location)>>,
}

impl<'tcx, 'a> BodyVisitor<'tcx, 'a> {
impl<'tcx, 'a> MBodyVisitor<'tcx, 'a> {
fn new(
tcx: TyCtxt<'tcx>,
local_decls: &'a IndexVec<Local, LocalDecl<'tcx>>,
Expand All @@ -414,12 +552,12 @@ impl<'tcx, 'a> BodyVisitor<'tcx, 'a> {
unions,
accesses: vec![],
struct_accesses: HashSet::new(),
aggregates: HashSet::new(),
aggregates: HashMap::new(),
}
}
}

impl<'tcx> Visitor<'tcx> for BodyVisitor<'tcx, '_> {
impl<'tcx> MVisitor<'tcx> for MBodyVisitor<'tcx, '_> {
fn visit_place(&mut self, place: &Place<'tcx>, context: PlaceContext, location: Location) {
if place.projection.len() > 0 {
for i in 0..(place.projection.len() - 1) {
Expand Down Expand Up @@ -457,7 +595,10 @@ impl<'tcx> Visitor<'tcx> for BodyVisitor<'tcx, '_> {
if let Rvalue::Aggregate(box AggregateKind::Adt(def_id, _, _, _, _), _) = rvalue {
if let Some(def_id) = def_id.as_local() {
if self.structs.contains(&def_id) || self.unions.contains(&def_id) {
self.aggregates.insert(place.local);
self.aggregates
.entry(def_id)
.or_default()
.push((place.local, location));
}
}
}
Expand Down Expand Up @@ -511,3 +652,44 @@ impl<'tcx> FieldAccess<'tcx> {
)
}
}

struct HBodyVisitor<'a, 'tcx> {
tcx: TyCtxt<'tcx>,
struct_tag_fields: &'a HashMap<LocalDefId, u32>,
}

impl<'tcx> HBodyVisitor<'_, 'tcx> {
fn handle_expr(&mut self, expr: &'tcx Expr<'tcx>) {
match expr.kind {
ExprKind::Struct(path, _fs, _) => {
let QPath::Resolved(_, path) = path else { return };
let Res::Def(_, def_id) = path.res else { return };
let def_id = some_or!(def_id.as_local(), return);
let _tag_field_idx = *some_or!(self.struct_tag_fields.get(&def_id), return);
println!(
"{}",
self.tcx
.sess
.source_map()
.span_to_snippet(expr.span)
.unwrap()
);
}
ExprKind::Field(_, _) => {}
_ => {}
}
}
}

impl<'tcx> HVisitor<'tcx> for HBodyVisitor<'_, 'tcx> {
type NestedFilter = nested_filter::OnlyBodies;

fn nested_visit_map(&mut self) -> Self::Map {
self.tcx.hir()
}

fn visit_expr(&mut self, expr: &'tcx Expr<'tcx>) {
self.handle_expr(expr);
intravisit::walk_expr(self, expr);
}
}
Loading

0 comments on commit 95ad219

Please sign in to comment.