From 95ad2199e49ff24d81411899589c1cdf786999a1 Mon Sep 17 00:00:00 2001 From: Jaemin Hong Date: Mon, 13 May 2024 14:58:54 +0000 Subject: [PATCH] transformation: enum & helper methods --- src/analysis.rs | 208 +++++++++++++++++++++++++++++++++++++++++++--- src/bin/urcrat.rs | 52 +++++++++++- 2 files changed, 245 insertions(+), 15 deletions(-) diff --git a/src/analysis.rs b/src/analysis.rs index 5935bbb..7f9bff1 100644 --- a/src/analysis.rs +++ b/src/analysis.rs @@ -6,11 +6,17 @@ use std::{ use etrace::some_or; use relational::Obj; use rustc_abi::{FieldIdx, VariantIdx}; -use rustc_hir::{definitions::DefPathDataName, ItemKind}; +use rustc_hir::{ + def::Res, + definitions::DefPathDataName, + intravisit::{self, Visitor as HVisitor}, + Expr, ExprKind, ItemKind, Node, QPath, VariantData, +}; use rustc_index::IndexVec; use rustc_middle::{ + hir::nested_filter, mir::{ - visit::{MutatingUseContext, PlaceContext, Visitor}, + visit::{MutatingUseContext, PlaceContext, Visitor as MVisitor}, AggregateKind, ConstantKind, HasLocalDecls, Local, LocalDecl, Location, Place, PlaceElem, ProjectionElem, Rvalue, Terminator, TerminatorKind, }, @@ -58,6 +64,7 @@ pub fn analyze(tcx: TyCtxt<'_>, conf: &Config) { let mut structs = vec![]; let mut unions = vec![]; + let mut union_to_struct = HashMap::new(); let mut ty_graph: HashMap<_, Vec<_>> = HashMap::new(); for item_id in hir.items() { let item = hir.item(item_id); @@ -94,7 +101,7 @@ pub fn analyze(tcx: TyCtxt<'_>, conf: &Config) { continue; } let mut has_union = false; - for f in &variant.fields { + for (i, f) in variant.fields.iter_enumerated() { let TyKind::Adt(adt_def, _) = f.ty(tcx, List::empty()).kind() else { continue }; if !adt_def.is_union() { continue; @@ -112,6 +119,7 @@ pub fn analyze(tcx: TyCtxt<'_>, conf: &Config) { .contains(&tcx.def_path(u.to_def_id()).to_string_no_crate_verbose()) { unions.push(u); + union_to_struct.insert(u, (i, local_def_id)); has_union = true; } } @@ -129,6 +137,7 @@ pub fn analyze(tcx: TyCtxt<'_>, conf: &Config) { }) .collect(); + let mut fields: HashMap<_, BTreeSet<_>> = HashMap::new(); let mut tag_values: HashMap<_, BTreeMap<_, BTreeSet>> = HashMap::new(); let mut variant_tag_values: HashMap<_, BTreeMap<_, BTreeMap<_, BTreeMap<_, BTreeSet<_>>>>> = HashMap::new(); @@ -140,7 +149,7 @@ pub fn analyze(tcx: TyCtxt<'_>, conf: &Config) { ItemKind::Static(_, _, _) => tcx.mir_for_ctfe(local_def_id), _ => continue, }; - let mut visitor = BodyVisitor::new(tcx, &body.local_decls, &structs, &unions); + let mut visitor = MBodyVisitor::new(tcx, &body.local_decls, &structs, &unions); visitor.visit_body(body); if visitor.accesses.is_empty() && visitor.struct_accesses.is_empty() @@ -148,12 +157,21 @@ pub fn analyze(tcx: TyCtxt<'_>, conf: &Config) { { continue; } + for a in &visitor.accesses { + fields.entry(a.ty).or_default().insert(a.field); + } let locals: HashSet<_> = visitor .accesses .iter() .map(|a| a.local) .chain(visitor.struct_accesses.iter().copied()) - .chain(visitor.aggregates.iter().copied()) + .chain( + visitor + .aggregates + .values() + .flatten() + .map(|(local, _)| *local), + ) .collect(); let mut local_to_unions: HashMap<_, Vec<_>> = HashMap::new(); for i in locals { @@ -259,10 +277,10 @@ pub fn analyze(tcx: TyCtxt<'_>, conf: &Config) { } } + let mut tag_fields = HashMap::new(); for (u, vs) in &variant_tag_values { println!("Union {:?}", u); - let adt_def = tcx.adt_def(*u); - println!("{}", adt_def.variant(VariantIdx::from_u32(0)).fields.len()); + println!("Fields {:?}", fields[u]); let tag_field = vs .iter() .filter_map(|(f, vs)| { @@ -281,6 +299,7 @@ pub fn analyze(tcx: TyCtxt<'_>, conf: &Config) { }) .max_by_key(|(_, n)| *n); if let Some((tag_field, _)) = tag_field { + tag_fields.insert(*u, tag_field); println!("Field: {:?}", tag_field); let mut all_tags = tag_values[&u][&tag_field].clone(); for (variant, vs) in &vs[&tag_field] { @@ -309,6 +328,125 @@ pub fn analyze(tcx: TyCtxt<'_>, conf: &Config) { } } } + + let mut rewrite_structs: HashMap<_, Vec<_>> = HashMap::new(); + for (u, tag_field) in tag_fields { + let (i, s) = union_to_struct[&u]; + rewrite_structs + .entry(s) + .or_default() + .push((u, tag_field, i)); + } + + let source_map = tcx.sess.source_map(); + + let mut struct_tag_fields = HashMap::new(); + for (s, us) in rewrite_structs { + let fs = us.iter().map(|(_, f, _)| *f).collect::>(); + assert_eq!(fs.len(), 1); + let field_idx = fs.into_iter().next().unwrap(); + struct_tag_fields.insert(s, field_idx); + let Node::Item(item) = hir.get_if_local(s.to_def_id()).unwrap() else { unreachable!() }; + let struct_name = item.ident.name.to_ident_string(); + let ItemKind::Struct(VariantData::Struct(sfs, _), _) = item.kind else { unreachable!() }; + let field = &sfs[field_idx as usize]; + let span = source_map.span_extend_to_line(field.span); + println!("{}", source_map.span_to_snippet(span).unwrap()); + let field_name = field.ident.name.to_ident_string(); + let field_ty = source_map.span_to_snippet(field.ty.span).unwrap(); + + for (u, _, i) in us { + let struct_field_name = sfs[i.as_usize()].ident.name.to_ident_string(); + let Node::Item(item) = hir.get_if_local(u.to_def_id()).unwrap() else { unreachable!() }; + let union_name = item.ident.name.to_ident_string(); + println!("{}", source_map.span_to_snippet(item.span).unwrap()); + let ItemKind::Union(VariantData::Struct(ufs, _), _) = item.kind else { unreachable!() }; + let tys: Vec<_> = ufs + .iter() + .map(|f| source_map.span_to_snippet(f.ty.span).unwrap().to_string()) + .collect(); + let mut all_fields = fields[&u].clone(); + let mut all_tags = tag_values[&u][&field_idx].clone(); + let mut enum_variants = vec![]; + let mut field_methods = String::new(); + for (field, tags) in &variant_tag_values[&u][&field_idx] { + all_fields.remove(&FieldIdx::from_u32(*field)); + for t in tags.keys() { + all_tags.remove(t); + } + for tag in tags.keys() { + enum_variants.push((Some(*field), *tag)); + } + let field_name = ufs[*field as usize].ident.name.to_ident_string(); + let ty = &tys[*field as usize]; + let pattern: String = tags + .keys() + .map(|tag| format!("Self::{}{}(v)", field_name, tag)) + .intersperse("|".to_string()) + .collect(); + let pattern = if tags.len() == 1 { + pattern + } else { + format!("({})", pattern) + }; + let method = format!( + "impl {}{{pub fn get_{}(&mut self)->*mut {}{{let {}=self else{{panic!()}};v as _}}}}", + union_name, field_name, ty, pattern, + ); + field_methods.push_str(&method); + field_methods.push('\n'); + } + if all_fields.is_empty() { + for tag in all_tags { + enum_variants.push((None, tag)); + } + } else { + assert_eq!(all_fields.len(), 1); + assert_eq!(all_tags.len(), 1); + let field = all_fields.into_iter().next().unwrap().as_u32(); + let tag = all_tags.into_iter().next().unwrap(); + enum_variants.push((Some(field), tag)); + } + let mut enum_str = format!("pub enum {}{{", union_name); + let mut get_tag_method = format!( + "impl {}{{pub fn get_{}(self)->{}{{match self.{}{{", + struct_name, field_name, field_ty, struct_field_name + ); + for (f, t) in enum_variants { + if let Some(f) = f { + let field_name = ufs[f as usize].ident.name.to_ident_string(); + let ty = &tys[f as usize]; + let variant_name = format!("{}{}", field_name, t); + enum_str.push_str(&format!("{}({})", variant_name, ty)); + get_tag_method.push_str(&format!("{}::{}(_)=>{}", union_name, variant_name, t)); + } else { + let variant_name = format!("Empty{}", t); + enum_str.push_str(&variant_name); + get_tag_method.push_str(&format!("{}::{}=>{}", union_name, variant_name, t)); + }; + enum_str.push(','); + get_tag_method.push(','); + } + enum_str.push('}'); + get_tag_method.push_str("}}}"); + println!("{}", enum_str); + println!("{}", get_tag_method); + println!("{}", field_methods); + } + } + + for item_id in hir.items() { + let item = hir.item(item_id); + let (ItemKind::Fn(_, _, body_id) | ItemKind::Static(_, _, body_id)) = item.kind else { + continue; + }; + let body = hir.body(body_id); + let mut visitor = HBodyVisitor { + tcx, + struct_tag_fields: &struct_tag_fields, + }; + visitor.visit_body(body); + } } #[derive(Debug, Clone, Copy)] @@ -390,17 +528,17 @@ fn compute_tags<'tcx, D: HasLocalDecls<'tcx> + ?Sized>( v } -struct BodyVisitor<'tcx, 'a> { +struct MBodyVisitor<'tcx, 'a> { tcx: TyCtxt<'tcx>, local_decls: &'a IndexVec>, structs: &'a Vec, unions: &'a Vec, accesses: Vec>, struct_accesses: HashSet, - aggregates: HashSet, + aggregates: HashMap>, } -impl<'tcx, 'a> BodyVisitor<'tcx, 'a> { +impl<'tcx, 'a> MBodyVisitor<'tcx, 'a> { fn new( tcx: TyCtxt<'tcx>, local_decls: &'a IndexVec>, @@ -414,12 +552,12 @@ impl<'tcx, 'a> BodyVisitor<'tcx, 'a> { unions, accesses: vec![], struct_accesses: HashSet::new(), - aggregates: HashSet::new(), + aggregates: HashMap::new(), } } } -impl<'tcx> Visitor<'tcx> for BodyVisitor<'tcx, '_> { +impl<'tcx> MVisitor<'tcx> for MBodyVisitor<'tcx, '_> { fn visit_place(&mut self, place: &Place<'tcx>, context: PlaceContext, location: Location) { if place.projection.len() > 0 { for i in 0..(place.projection.len() - 1) { @@ -457,7 +595,10 @@ impl<'tcx> Visitor<'tcx> for BodyVisitor<'tcx, '_> { if let Rvalue::Aggregate(box AggregateKind::Adt(def_id, _, _, _, _), _) = rvalue { if let Some(def_id) = def_id.as_local() { if self.structs.contains(&def_id) || self.unions.contains(&def_id) { - self.aggregates.insert(place.local); + self.aggregates + .entry(def_id) + .or_default() + .push((place.local, location)); } } } @@ -511,3 +652,44 @@ impl<'tcx> FieldAccess<'tcx> { ) } } + +struct HBodyVisitor<'a, 'tcx> { + tcx: TyCtxt<'tcx>, + struct_tag_fields: &'a HashMap, +} + +impl<'tcx> HBodyVisitor<'_, 'tcx> { + fn handle_expr(&mut self, expr: &'tcx Expr<'tcx>) { + match expr.kind { + ExprKind::Struct(path, _fs, _) => { + let QPath::Resolved(_, path) = path else { return }; + let Res::Def(_, def_id) = path.res else { return }; + let def_id = some_or!(def_id.as_local(), return); + let _tag_field_idx = *some_or!(self.struct_tag_fields.get(&def_id), return); + println!( + "{}", + self.tcx + .sess + .source_map() + .span_to_snippet(expr.span) + .unwrap() + ); + } + ExprKind::Field(_, _) => {} + _ => {} + } + } +} + +impl<'tcx> HVisitor<'tcx> for HBodyVisitor<'_, 'tcx> { + type NestedFilter = nested_filter::OnlyBodies; + + fn nested_visit_map(&mut self) -> Self::Map { + self.tcx.hir() + } + + fn visit_expr(&mut self, expr: &'tcx Expr<'tcx>) { + self.handle_expr(expr); + intravisit::walk_expr(self, expr); + } +} diff --git a/src/bin/urcrat.rs b/src/bin/urcrat.rs index 034940a..9541714 100644 --- a/src/bin/urcrat.rs +++ b/src/bin/urcrat.rs @@ -1,4 +1,7 @@ -use std::{fs::File, path::PathBuf}; +use std::{ + fs::{self, File}, + path::{Path, PathBuf}, +}; use clap::{Parser, Subcommand}; use urcrat::*; @@ -14,6 +17,8 @@ enum Command { may: Option, #[arg(short, long)] r#union: Vec, + #[arg(short, long)] + output: Option, }, } @@ -52,7 +57,11 @@ fn main() { } elapsed } - Command::Must { may, r#union } => { + Command::Must { + may, + r#union, + output, + } => { let solutions = may.map(|file| { let arr = std::fs::read(file).unwrap(); points_to::deserialize_solutions(&arr) @@ -63,8 +72,47 @@ fn main() { }; let start = std::time::Instant::now(); analysis::analyze_path(&file, &conf); + + if let Some(mut output) = output { + output.push(args.input.file_name().unwrap()); + if output.exists() { + assert!(output.is_dir()); + clear_dir(&output); + } else { + fs::create_dir(&output).unwrap(); + } + copy_dir(&args.input, &output, true); + } start.elapsed() } }; println!("{}", elapsed.as_millis()); } + +fn clear_dir(path: &Path) { + for entry in fs::read_dir(path).unwrap() { + let entry_path = entry.unwrap().path(); + if entry_path.is_dir() { + let name = entry_path.file_name().unwrap(); + if name != "target" { + fs::remove_dir_all(entry_path).unwrap(); + } + } else { + fs::remove_file(entry_path).unwrap(); + } + } +} + +fn copy_dir(src: &Path, dst: &Path, root: bool) { + for entry in fs::read_dir(src).unwrap() { + let src_path = entry.unwrap().path(); + let name = src_path.file_name().unwrap(); + let dst_path = dst.join(name); + if src_path.is_file() { + fs::copy(src_path, dst_path).unwrap(); + } else if src_path.is_dir() && (!root || name != "target") { + fs::create_dir(&dst_path).unwrap(); + copy_dir(&src_path, &dst_path, false); + } + } +}