From 372f4cf39a62783db5c806bb32657c56cc0397ab Mon Sep 17 00:00:00 2001 From: Jaemin Hong Date: Thu, 30 May 2024 08:19:47 +0000 Subject: [PATCH] multiple assigns to single assign --- src/tag_analysis.rs | 499 ++++++++++++++++++++++++++++++++++++++------ 1 file changed, 434 insertions(+), 65 deletions(-) diff --git a/src/tag_analysis.rs b/src/tag_analysis.rs index 379dc39..3ae9d09 100644 --- a/src/tag_analysis.rs +++ b/src/tag_analysis.rs @@ -13,16 +13,17 @@ use rustc_hir::{ def::Res, definitions::DefPathDataName, intravisit::{self, Visitor as HVisitor}, - BinOpKind, ByRef, Expr, ExprKind, HirId, ItemKind, MatchSource, Node, Pat, PatKind, QPath, - StmtKind, UnOp, VariantData, + BinOpKind, Block, ByRef, Expr, ExprKind, HirId, ItemKind, MatchSource, Node, Pat, PatKind, + QPath, StmtKind, UnOp, VariantData, }; use rustc_index::{bit_set::BitSet, IndexVec}; use rustc_middle::{ hir::nested_filter, mir::{ visit::{PlaceContext, Visitor as MVisitor}, - AggregateKind, BasicBlock, Body, ConstantKind, HasLocalDecls, Local, LocalDecl, Location, - Operand, Place, PlaceElem, ProjectionElem, Rvalue, Terminator, TerminatorKind, + AggregateKind, BasicBlock, BasicBlockData, Body, ConstantKind, HasLocalDecls, Local, + LocalDecl, Location, Operand, Place, PlaceElem, ProjectionElem, Rvalue, Terminator, + TerminatorKind, }, ty::{List, Ty, TyCtxt, TyKind, TypeAndMut, TypeckResults}, }; @@ -211,6 +212,8 @@ pub fn analyze(tcx: TyCtxt<'_>, conf: &Config) { let mut field_values: HashMap> = HashMap::new(); let mut access_in_matches: HashMap<_, Vec<_>> = HashMap::new(); let mut access_in_ifs: HashMap<_, Vec<_>> = HashMap::new(); + let mut basic_blocks = HashMap::new(); + let mut locals: HashMap<_, HashMap<_, _>> = HashMap::new(); println!("Start analysis"); for item_id in hir.items() { let item = hir.item(item_id); @@ -225,6 +228,18 @@ pub fn analyze(tcx: TyCtxt<'_>, conf: &Config) { visitor.visit_body(body); let mut hvisitor = HBodyVisitor::new(tcx); hvisitor.visit_body(hbody); + if visitor.accesses.is_empty() + && visitor.struct_accesses.is_empty() + && visitor.aggregates.is_empty() + { + continue; + } + basic_blocks.insert(local_def_id, visitor.basic_blocks); + let locals = locals.entry(local_def_id).or_default(); + for (local, local_def) in body.local_decls.iter_enumerated() { + let hir_id = some_or!(hvisitor.bindings.get(&local_def.source_info.span), continue); + locals.entry(*hir_id).or_insert(local); + } if !hvisitor.inits.is_empty() { for (local, location) in &visitor.inits { let span = body @@ -234,12 +249,6 @@ pub fn analyze(tcx: TyCtxt<'_>, conf: &Config) { aggregates.insert(*aggregate_span, (*local, *location)); } } - if visitor.accesses.is_empty() - && visitor.struct_accesses.is_empty() - && visitor.aggregates.is_empty() - { - continue; - } for (local, location) in visitor.aggregates.values().flatten() { let span = body .stmt_at(*location) @@ -808,6 +817,7 @@ impl {} {{ }; let hir_body = hir.body(body_id); let local_def_id = item_id.owner_id.def_id; + let basic_blocks = some_or!(basic_blocks.get(&local_def_id), continue); let typeck = tcx.typeck(local_def_id); let mut visitor = SuggestingVisitor { tcx, @@ -819,11 +829,14 @@ impl {} {{ unions: &tagged_unions, access_in_matches: &access_in_matches, access_in_ifs: &access_in_ifs, + basic_blocks, + hir_id_to_locals: &locals[&local_def_id], suggestions: &mut suggestions, locals: HashMap::new(), match_targets: HashMap::new(), if_targets: HashMap::new(), + aggregate_spans: vec![], }; visitor.visit_body(hir_body); @@ -1087,6 +1100,7 @@ struct MBodyVisitor<'tcx, 'a> { inits: Vec<(Local, Location)>, switches: HashMap)>, ifs: Vec, + basic_blocks: Vec<(Span, Location)>, } impl<'tcx, 'a> MBodyVisitor<'tcx, 'a> { @@ -1107,6 +1121,7 @@ impl<'tcx, 'a> MBodyVisitor<'tcx, 'a> { inits: vec![], switches: HashMap::new(), ifs: vec![], + basic_blocks: vec![], } } } @@ -1220,6 +1235,26 @@ impl<'tcx> MVisitor<'tcx> for MBodyVisitor<'tcx, '_> { } self.super_terminator(terminator, location); } + + fn visit_basic_block_data(&mut self, block: BasicBlock, data: &BasicBlockData<'tcx>) { + if let Some(stmt) = data.statements.get(0) { + let span = stmt.source_info.span; + let mut lo = span.lo(); + let mut hi = span.hi(); + for stmt in &data.statements[1..] { + let span = stmt.source_info.span; + lo = lo.min(span.lo()); + hi = hi.max(span.hi()); + } + let span = span.with_lo(lo).with_hi(hi); + let location = Location { + block, + statement_index: data.statements.len(), + }; + self.basic_blocks.push((span, location)); + } + self.super_basic_block_data(block, data); + } } #[derive(Debug, Clone, Copy)] @@ -1297,11 +1332,14 @@ struct SuggestingVisitor<'a, 'tcx> { unions: &'a HashMap, access_in_matches: &'a HashMap>>, access_in_ifs: &'a HashMap>>, + basic_blocks: &'a [(Span, Location)], + hir_id_to_locals: &'a HashMap, suggestions: &'a mut Suggestions<'tcx>, locals: HashMap>, match_targets: HashMap, if_targets: HashMap, + aggregate_spans: Vec, } impl<'tcx> SuggestingVisitor<'_, 'tcx> { @@ -1488,15 +1526,21 @@ impl<'tcx> SuggestingVisitor<'_, 'tcx> { ExprContext::Store(op) => { assert!(!op); - let span = field.span.shrink_to_lo(); - self.suggestions.add(span, "set_".to_string()); + if !self + .aggregate_spans + .iter() + .any(|span| span.contains(expr.span)) + { + let span = field.span.shrink_to_lo(); + self.suggestions.add(span, "set_".to_string()); - let span = field.span.shrink_to_hi(); - let span = span.with_hi(span.hi() + BytePos(2)); - self.suggestions.add(span, "(".to_string()); + let span = field.span.shrink_to_hi(); + let span = span.with_hi(span.hi() + BytePos(2)); + self.suggestions.add(span, "(".to_string()); - let span = e2.span.shrink_to_hi(); - self.suggestions.add(span, ")".to_string()); + let span = e2.span.shrink_to_hi(); + self.suggestions.add(span, ")".to_string()); + } } ExprContext::Address => panic!(), } @@ -1543,50 +1587,56 @@ impl<'tcx> SuggestingVisitor<'_, 'tcx> { self.suggestions.add(field.span, call); } ExprContext::Store(_) | ExprContext::Address => { - let span = expr.span.shrink_to_lo(); - self.suggestions.add(span, "(*".to_string()); - - let ItemKind::Union(VariantData::Struct(fs, _), _) = - self.tcx.hir().expect_item(did).kind - else { - unreachable!() - }; - let (i, _) = fs + if !self + .aggregate_spans .iter() - .enumerate() - .find(|(_, f)| f.ident.name == field.name) - .unwrap(); - let tags = &tu.variant_tags[&FieldIdx::from(i)]; - - let call = if tags.len() == 1 { - format!("deref_{}_mut())", field.name) - } else { - let ts = &self.structs[&tu.struct_local_def_id]; - let field_name = &ts.field_names[ts.tag_index]; - let ExprKind::Field(e2, _) = e.kind else { unreachable!() }; - let tag = format!( - "{}.{}", - source_map.span_to_snippet(e2.span).unwrap(), - field_name, - ); - format!("deref_{}_mut({}()))", field.name, tag) - }; - self.suggestions.add(field.span, call); - - let root = get_root(expr); - if let ExprKind::Unary(UnOp::Deref, e) = root.kind { - let ty = self.typeck.expr_ty(e); - if let TyKind::RawPtr(TypeAndMut { - mutbl: Mutability::Not, - ty, - }) = ty.kind() - { - let span = e.span.shrink_to_lo(); - self.suggestions.add(span, "(".to_string()); - - let span = e.span.shrink_to_hi(); - let cast = format!(" as *mut crate::{:?})", ty); - self.suggestions.add(span, cast); + .any(|span| span.contains(expr.span)) + { + let span = expr.span.shrink_to_lo(); + self.suggestions.add(span, "(*".to_string()); + + let ItemKind::Union(VariantData::Struct(fs, _), _) = + self.tcx.hir().expect_item(did).kind + else { + unreachable!() + }; + let (i, _) = fs + .iter() + .enumerate() + .find(|(_, f)| f.ident.name == field.name) + .unwrap(); + let tags = &tu.variant_tags[&FieldIdx::from(i)]; + + let call = if tags.len() == 1 { + format!("deref_{}_mut())", field.name) + } else { + let ts = &self.structs[&tu.struct_local_def_id]; + let field_name = &ts.field_names[ts.tag_index]; + let ExprKind::Field(e2, _) = e.kind else { unreachable!() }; + let tag = format!( + "{}.{}", + source_map.span_to_snippet(e2.span).unwrap(), + field_name, + ); + format!("deref_{}_mut({}()))", field.name, tag) + }; + self.suggestions.add(field.span, call); + + let root = unwrap_projection(expr); + if let ExprKind::Unary(UnOp::Deref, e) = root.kind { + let ty = self.typeck.expr_ty(e); + if let TyKind::RawPtr(TypeAndMut { + mutbl: Mutability::Not, + ty, + }) = ty.kind() + { + let span = e.span.shrink_to_lo(); + self.suggestions.add(span, "(".to_string()); + + let span = e.span.shrink_to_hi(); + let cast = format!(" as *mut crate::{:?})", ty); + self.suggestions.add(span, cast); + } } } } @@ -1655,6 +1705,305 @@ impl<'tcx> SuggestingVisitor<'_, 'tcx> { _ => None, } } + + fn handle_block(&mut self, block: &'tcx Block<'tcx>) { + let mut blocks = AssignBlocks::default(); + for stmt in block.stmts { + if let StmtKind::Semi(expr) = stmt.kind { + match expr.kind { + ExprKind::Assign(lhs, rhs, _) => { + if let Some((struct_ty, struct_expr, fields)) = self.find_struct(lhs) { + let s = self + .tcx + .sess + .source_map() + .span_to_snippet(struct_expr.span) + .unwrap(); + let struct_expr_string = normalize_expr_str(&s); + let bs = AssignBlockStmt { + struct_ty, + struct_expr, + struct_expr_string, + fields, + rhs, + span: stmt.span, + }; + blocks.add_stmt(bs); + continue; + } + } + ExprKind::MethodCall(method, struct_expr, v, _) => { + if let Some(struct_ty) = self.as_tagged_struct(struct_expr) { + if v.len() == 1 { + let method = method.ident.name.to_ident_string(); + let field = method.strip_prefix("set_").unwrap(); + let s = self + .tcx + .sess + .source_map() + .span_to_snippet(struct_expr.span) + .unwrap(); + let struct_expr_string = normalize_expr_str(&s); + let bs = AssignBlockStmt { + struct_ty, + struct_expr, + struct_expr_string, + fields: vec![field.to_string()], + rhs: &v[0], + span: stmt.span, + }; + blocks.add_stmt(bs); + continue; + } + } + } + _ => {} + } + } + blocks.finish_block(); + } + blocks.finish_block(); + + for block in blocks.blocks { + let ty = block[0].struct_ty; + let expr = block[0].struct_expr; + let root = unwrap_deref(unwrap_projection(expr)); + let ts = self.structs.get(&ty).unwrap(); + let tag_field_name = &ts.field_names[ts.tag_index]; + + let mut assigns = HashMap::new(); + for bs in &block { + assigns.insert(bs.fields.clone(), bs); + } + + let tag_assign = some_or!(assigns.get(&vec![tag_field_name.to_string()]), continue); + let (_, location) = self + .basic_blocks + .iter() + .find(|(span, _)| span.contains(tag_assign.span)) + .copied() + .unwrap(); + + let mut removed = false; + for (u, field_idx) in &ts.unions { + let union_field_name = &ts.field_names[*field_idx]; + let mut fields = vec![union_field_name.clone()]; + let v = make_aggregate(Some(*u), &mut fields, &assigns, self.tcx); + let Some(AssignedValue::Compound(union_name, vs)) = v else { continue }; + if vs.len() != 1 { + continue; + } + let (variant, value) = vs.into_iter().next().unwrap(); + if let ExprKind::Path(QPath::Resolved(_, path)) = root.kind { + let Res::Local(hir_id) = path.res else { continue }; + let local = self.hir_id_to_locals[&hir_id]; + let field_at = FieldAt { + func: self.func, + location, + local, + field: ts.tag_index, + }; + let tags = some_or!(self.field_values.get(&field_at), continue); + if tags.len() != 1 { + continue; + } + let tag = *tags.iter().next().unwrap(); + + let mut bss: Vec<_> = block + .iter() + .filter(|bs| &bs.fields[0] == union_field_name) + .collect(); + let bs = bss.pop().unwrap(); + let code = format!( + "{}.{} = {}::{}{}({});", + self.tcx + .sess + .source_map() + .span_to_snippet(expr.span) + .unwrap(), + union_field_name, + union_name, + variant, + tag, + value, + ); + self.suggestions.add(bs.span, code); + self.aggregate_spans.push(bs.span); + for bs in bss { + self.suggestions.add(bs.span, "".to_string()); + self.aggregate_spans.push(bs.span); + } + removed = true; + } else { + println!( + "{}", + self.tcx + .sess + .source_map() + .span_to_snippet(root.span) + .unwrap() + ); + } + } + if removed { + self.suggestions.add(tag_assign.span, "".to_string()); + self.aggregate_spans.push(tag_assign.span); + } + } + } + + fn find_struct( + &self, + expr: &'tcx Expr<'tcx>, + ) -> Option<(LocalDefId, &'tcx Expr<'tcx>, Vec)> { + let ExprKind::Field(e, f) = expr.kind else { return None }; + let res = if let Some(did) = self.as_tagged_struct(e) { + (did, e, vec![f.name.to_ident_string()]) + } else { + let (did, sexpr, mut fs) = self.find_struct(e)?; + fs.push(f.name.to_ident_string()); + (did, sexpr, fs) + }; + Some(res) + } + + fn as_tagged_struct(&self, expr: &'tcx Expr<'tcx>) -> Option { + let ty = self.typeck.expr_ty(expr); + let TyKind::Adt(adt_def, _) = ty.kind() else { return None }; + let did = adt_def.did().as_local()?; + if self.structs.contains_key(&did) { + Some(did) + } else { + None + } + } +} + +fn make_aggregate<'tcx>( + ty: Option, + fields: &mut Vec, + assigns: &HashMap, &AssignBlockStmt<'tcx>>, + tcx: TyCtxt<'tcx>, +) -> Option { + if let Some(bs) = assigns.get(fields) { + tcx.sess + .source_map() + .span_to_snippet(bs.rhs.span) + .ok() + .map(AssignedValue::Primitive) + } else { + let ty = ty?; + let def_path = tcx.def_path(ty.to_def_id()); + let mut name = "crate".to_string(); + for data in def_path.data { + write!(name, "::{}", data).unwrap(); + } + let adt_def = tcx.adt_def(ty); + let (ItemKind::Struct(vd, _) | ItemKind::Union(vd, _)) = tcx.hir().expect_item(ty).kind + else { + unreachable!() + }; + let variant = adt_def.variant(VariantIdx::from_u32(0)); + let is_struct = adt_def.is_struct(); + + let mut field_values = HashMap::new(); + for (mfd, hfd) in variant.fields.iter().zip(vd.fields()) { + let field_name = hfd.ident.name.to_ident_string(); + let ty = mfd.ty(tcx, List::empty()); + let ty = if let TyKind::Adt(adt_def, _) = ty.kind() { + adt_def.did().as_local() + } else { + None + }; + fields.push(field_name.clone()); + let res = make_aggregate(ty, fields, assigns, tcx); + fields.pop(); + if let Some(v) = res { + field_values.insert(field_name, v); + if !is_struct { + break; + } + } else if is_struct { + return None; + } + } + if field_values.is_empty() { + None + } else { + Some(AssignedValue::Compound(name, field_values)) + } + } +} + +enum AssignedValue { + Compound(String, HashMap), + Primitive(String), +} + +impl std::fmt::Debug for AssignedValue { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + AssignedValue::Compound(name, fields) => { + write!(f, "{} {{ ", name)?; + for (i, (field, value)) in fields.iter().enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{}", field)?; + let value = value.to_string(); + if field != &value { + write!(f, ": {}", value)?; + } + } + write!(f, " }}") + } + AssignedValue::Primitive(value) => write!(f, "{}", value), + } + } +} + +impl std::fmt::Display for AssignedValue { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{:?}", self) + } +} + +#[derive(Debug, Default)] +struct AssignBlocks<'tcx> { + blocks: Vec>, + block: AssignBlock<'tcx>, +} + +impl<'tcx> AssignBlocks<'tcx> { + fn add_stmt(&mut self, stmt: AssignBlockStmt<'tcx>) { + if let Some(curr) = self.block.get(0) { + if curr.struct_ty != stmt.struct_ty + || curr.struct_expr_string != stmt.struct_expr_string + { + self.finish_block(); + } + } + self.block.push(stmt); + } + + fn finish_block(&mut self) { + let block = std::mem::take(&mut self.block); + if block.len() > 1 { + self.blocks.push(block); + } + } +} + +type AssignBlock<'tcx> = Vec>; + +#[derive(Debug)] +struct AssignBlockStmt<'tcx> { + struct_ty: LocalDefId, + struct_expr: &'tcx Expr<'tcx>, + struct_expr_string: String, + fields: Vec, + rhs: &'tcx Expr<'tcx>, + span: Span, } impl<'tcx> HVisitor<'tcx> for SuggestingVisitor<'_, 'tcx> { @@ -1673,6 +2022,11 @@ impl<'tcx> HVisitor<'tcx> for SuggestingVisitor<'_, 'tcx> { self.handle_local(local); intravisit::walk_local(self, local); } + + fn visit_block(&mut self, block: &'tcx Block<'tcx>) { + self.handle_block(block); + intravisit::walk_block(self, block); + } } struct HIf { @@ -1687,6 +2041,7 @@ struct HBodyVisitor<'tcx> { set_exprs: HashSet, matches: Vec<(Span, Vec<(Span, ArmTags)>)>, ifs: Vec, + bindings: HashMap, } impl<'tcx> HBodyVisitor<'tcx> { @@ -1697,6 +2052,7 @@ impl<'tcx> HBodyVisitor<'tcx> { set_exprs: HashSet::new(), matches: vec![], ifs: vec![], + bindings: HashMap::new(), } } @@ -1743,6 +2099,11 @@ impl<'tcx> HBodyVisitor<'tcx> { _ => {} } } + + fn handle_pat(&mut self, pat: &'tcx Pat<'tcx>) { + let PatKind::Binding(_, hir_id, _, _) = pat.kind else { return }; + self.bindings.insert(pat.span, hir_id); + } } impl<'tcx> HVisitor<'tcx> for HBodyVisitor<'tcx> { @@ -1756,6 +2117,11 @@ impl<'tcx> HVisitor<'tcx> for HBodyVisitor<'tcx> { self.handle_expr(expr); intravisit::walk_expr(self, expr); } + + fn visit_pat(&mut self, pat: &'tcx Pat<'tcx>) { + self.handle_pat(pat); + intravisit::walk_pat(self, pat); + } } struct FieldVisitor<'a, 'tcx> { @@ -1851,11 +2217,6 @@ fn get_expr_context<'tcx>( } } -fn get_root<'tcx>(expr: &'tcx Expr<'tcx>) -> &'tcx Expr<'tcx> { - let ExprKind::Field(e, _) = expr.kind else { return expr }; - get_root(e) -} - fn tag_to_string(tag: Tag, ty: &str) -> String { if ty == "bool" { (tag.0 != 0).to_string() @@ -2265,3 +2626,11 @@ fn unwrap_cast_and_drop<'a, 'tcx>(e: &'a Expr<'tcx>) -> &'a Expr<'tcx> { e } } + +fn unwrap_deref<'a, 'tcx>(e: &'a Expr<'tcx>) -> &'a Expr<'tcx> { + if let ExprKind::Unary(UnOp::Deref, e) = e.kind { + unwrap_deref(e) + } else { + e + } +}