Skip to content

Commit

Permalink
transformation: bitfield struct init
Browse files Browse the repository at this point in the history
  • Loading branch information
Medowhill committed May 21, 2024
1 parent 4c40935 commit a4920ea
Showing 1 changed file with 149 additions and 30 deletions.
179 changes: 149 additions & 30 deletions src/tag_analysis.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@ use rustc_hir::{
def::Res,
definitions::DefPathDataName,
intravisit::{self, Visitor as HVisitor},
Expr, ExprKind, ItemKind, Node, QPath, VariantData,
Expr, ExprKind, ItemKind, Node, PatKind, QPath, StmtKind, VariantData,
};
use rustc_index::{bit_set::BitSet, IndexVec};
use rustc_middle::{
hir::nested_filter,
mir::{
visit::{MutatingUseContext, PlaceContext, Visitor as MVisitor},
AggregateKind, ConstantKind, HasLocalDecls, Local, LocalDecl, Location, Place, PlaceElem,
ProjectionElem, Rvalue, Terminator, TerminatorKind,
AggregateKind, ConstantKind, HasLocalDecls, Local, LocalDecl, Location, Operand, Place,
PlaceElem, ProjectionElem, Rvalue, Terminator, TerminatorKind,
},
ty::{List, Ty, TyCtxt, TyKind, TypeAndMut, TypeckResults},
};
Expand Down Expand Up @@ -154,26 +154,39 @@ pub fn analyze(tcx: TyCtxt<'_>, conf: &Config) {
for item_id in hir.items() {
let item = hir.item(item_id);
let local_def_id = item_id.owner_id.def_id;
let body = match item.kind {
ItemKind::Fn(_, _, _) => tcx.optimized_mir(local_def_id),
ItemKind::Static(_, _, _) => tcx.mir_for_ctfe(local_def_id),
let (body_id, body) = match item.kind {
ItemKind::Fn(_, _, body_id) => (body_id, tcx.optimized_mir(local_def_id)),
ItemKind::Static(_, _, body_id) => (body_id, tcx.mir_for_ctfe(local_def_id)),
_ => continue,
};
let hbody = hir.body(body_id);
let mut visitor = MBodyVisitor::new(tcx, &body.local_decls, &structs, &unions);
visitor.visit_body(body);
let mut hvisitor = BitFieldInitVisitor {
tcx,
inits: HashMap::new(),
};
hvisitor.visit_body(hbody);
if !hvisitor.inits.is_empty() {
for (local, location) in &visitor.inits {
let span = body
.stmt_at(*location)
.either(|stmt| stmt.source_info.span, |term| term.source_info.span);
let aggregate_span = some_or!(hvisitor.inits.get(&span), continue);
aggregates.insert(*aggregate_span, (*local, *location));
}
}
if visitor.accesses.is_empty()
&& visitor.struct_accesses.is_empty()
&& visitor.aggregates.is_empty()
{
continue;
}
for vs in visitor.aggregates.values() {
for (local, location) in vs {
let span = body
.stmt_at(*location)
.either(|stmt| stmt.source_info.span, |term| term.source_info.span);
aggregates.insert(span, (*local, *location));
}
for (local, location) in visitor.aggregates.values().flatten() {
let span = body
.stmt_at(*location)
.either(|stmt| stmt.source_info.span, |term| term.source_info.span);
aggregates.entry(span).or_insert((*local, *location));
}
for a in &visitor.accesses {
fields.entry(a.ty).or_default().insert(a.field);
Expand Down Expand Up @@ -284,6 +297,26 @@ pub fn analyze(tcx: TyCtxt<'_>, conf: &Config) {
if body.basic_blocks[block].statements.len() > statement_index {
continue;
}
if let TerminatorKind::Call {
func: Operand::Constant(box constant),
..
} = body.basic_blocks[block].terminator().kind
{
let ConstantKind::Val(_, ty) = constant.literal else { unreachable!() };
let TyKind::FnDef(def_id, _) = ty.kind() else { unreachable!() };
if def_id.as_local().is_some()
&& tcx.impl_of_method(*def_id).is_some()
&& tcx
.def_path(*def_id)
.data
.last()
.unwrap()
.to_string()
.starts_with("set_")
{
continue;
}
}
let uv = fs
.get(&f)
.map(|obj| {
Expand Down Expand Up @@ -370,15 +403,29 @@ pub fn analyze(tcx: TyCtxt<'_>, conf: &Config) {
}

let mut rewrite_structs: HashMap<_, Vec<_>> = HashMap::new();
for (u, tag_field) in tag_fields {
for (u, tag_field) in &tag_fields {
let (i, s) = union_to_struct[&u];
rewrite_structs
.entry(s)
.or_default()
.push((u, tag_field, i));
.push((*u, *tag_field, i));
}

let struct_fields: HashMap<_, IndexVec<FieldIdx, _>> = rewrite_structs
let union_field_names: HashMap<_, HashMap<_, _>> = tag_fields
.keys()
.map(|u| {
let ItemKind::Union(VariantData::Struct(fs, _), _) = hir.expect_item(*u).kind else {
unreachable!()
};
let fs = fs
.iter()
.enumerate()
.map(|(i, f)| (f.ident.name.to_ident_string(), FieldIdx::from_usize(i)))
.collect();
(*u, fs)
})
.collect();
let struct_field_names: HashMap<_, IndexVec<FieldIdx, _>> = rewrite_structs
.keys()
.map(|s| {
let ItemKind::Struct(VariantData::Struct(fs, _), _) = hir.expect_item(*s).kind else {
Expand Down Expand Up @@ -676,12 +723,13 @@ impl {} {{
func: local_def_id,
structs: &structs,
struct_tag_fields: &struct_tag_fields,
struct_field_names: &struct_field_names,
aggregates: &aggregates,
field_values: &field_values,
unions: &unions,
union_variant_tags: &union_variant_tags,
union_to_struct: &union_to_struct,
struct_fields: &struct_fields,
union_field_names: &union_field_names,
suggestions: &mut suggestions,
};
visitor.visit_body(hir_body);
Expand Down Expand Up @@ -784,6 +832,7 @@ struct MBodyVisitor<'tcx, 'a> {
accesses: Vec<FieldAccess<'tcx>>,
struct_accesses: HashSet<Local>,
aggregates: HashMap<LocalDefId, Vec<(Local, Location)>>,
inits: Vec<(Local, Location)>,
}

impl<'tcx, 'a> MBodyVisitor<'tcx, 'a> {
Expand All @@ -801,6 +850,7 @@ impl<'tcx, 'a> MBodyVisitor<'tcx, 'a> {
accesses: vec![],
struct_accesses: HashSet::new(),
aggregates: HashMap::new(),
inits: vec![],
}
}
}
Expand Down Expand Up @@ -840,15 +890,29 @@ impl<'tcx> MVisitor<'tcx> for MBodyVisitor<'tcx, '_> {
}

fn visit_assign(&mut self, place: &Place<'tcx>, rvalue: &Rvalue<'tcx>, location: Location) {
if let Rvalue::Aggregate(box AggregateKind::Adt(def_id, _, _, _, _), _) = rvalue {
if let Some(def_id) = def_id.as_local() {
if self.structs.contains_key(&def_id) || self.unions.contains(&def_id) {
self.aggregates
.entry(def_id)
.or_default()
.push((place.local, location));
match rvalue {
Rvalue::Aggregate(box AggregateKind::Adt(def_id, _, _, _, _), _) => {
if let Some(def_id) = def_id.as_local() {
if self.structs.contains_key(&def_id) || self.unions.contains(&def_id) {
self.aggregates
.entry(def_id)
.or_default()
.push((place.local, location));
}
}
}
Rvalue::Use(Operand::Copy(_) | Operand::Move(_)) => {
let ty = Place::ty(place, self.local_decls, self.tcx).ty;
if let TyKind::Adt(adt_def, _) = ty.kind() {
let def_id = adt_def.did();
if let Some(def_id) = def_id.as_local() {
if place.projection.is_empty() && self.structs.contains_key(&def_id) {
self.inits.push((place.local, location));
}
}
}
}
_ => {}
}
self.super_assign(place, rvalue, location);
}
Expand Down Expand Up @@ -916,12 +980,13 @@ struct HBodyVisitor<'a, 'tcx> {
func: LocalDefId,
structs: &'a HashMap<LocalDefId, Vec<(FieldIdx, LocalDefId)>>,
struct_tag_fields: &'a HashMap<LocalDefId, FieldIdx>,
struct_field_names: &'a HashMap<LocalDefId, IndexVec<FieldIdx, String>>,
aggregates: &'a HashMap<Span, (Local, Location)>,
field_values: &'a HashMap<FieldAt, BTreeSet<u128>>,
unions: &'a Vec<LocalDefId>,
union_variant_tags: &'a HashMap<(LocalDefId, FieldIdx), Vec<u128>>,
union_to_struct: &'a HashMap<LocalDefId, (FieldIdx, LocalDefId)>,
struct_fields: &'a HashMap<LocalDefId, IndexVec<FieldIdx, String>>,
union_field_names: &'a HashMap<LocalDefId, HashMap<String, FieldIdx>>,
suggestions: &'a mut HashMap<PathBuf, Vec<Suggestion>>,
}

Expand Down Expand Up @@ -953,12 +1018,18 @@ impl<'tcx> HBodyVisitor<'_, 'tcx> {
local,
field: tag_field_idx,
};
let values = &self.field_values[&field_at];
assert_eq!(values.len(), 1, "{:?}", values);
let tag = values.iter().next().unwrap();
let tag = self.field_values.get(&field_at).and_then(|values| {
assert!(values.len() <= 1, "{:?}", values);
values.iter().next()
});

for (i, u) in unions {
let expr = fs[i.as_usize()].expr;
let field_name = &self.struct_field_names[&def_id][*i];
let expr = fs
.iter()
.find(|f| &f.ident.name.to_ident_string() == field_name)
.unwrap()
.expr;
let snippet = compile_util::span_to_snippet(expr.span, source_map);
let ExprKind::Struct(path, ufs, _) = expr.kind else { unreachable!() };
let union_name = source_map.span_to_snippet(path.span()).unwrap();
Expand All @@ -969,6 +1040,13 @@ impl<'tcx> HBodyVisitor<'_, 'tcx> {
let name = ufs[0].ident.name.to_ident_string();
let span = ufs[0].expr.span;
let init = source_map.span_to_snippet(span).unwrap();
let tag = if let Some(tag) = tag {
*tag
} else {
let i = self.union_field_names[u][&name];
let tags = &self.union_variant_tags[&(*u, i)];
tags[0]
};
let v = format!("{}::{}{}({})", union_name, name, tag, init);
let suggestion = compile_util::make_suggestion(snippet, v);
suggestions.push(suggestion);
Expand Down Expand Up @@ -1049,7 +1127,7 @@ impl<'tcx> HBodyVisitor<'_, 'tcx> {
} else {
let (_, s_did) = self.union_to_struct[&did];
let field_idx = self.struct_tag_fields[&s_did];
let field_name = &self.struct_fields[&s_did][field_idx];
let field_name = &self.struct_field_names[&s_did][field_idx];
let ExprKind::Field(e2, _) = e.kind else { unreachable!() };
let tag = format!(
"{}.{}",
Expand Down Expand Up @@ -1083,6 +1161,47 @@ impl<'tcx> HVisitor<'tcx> for HBodyVisitor<'_, 'tcx> {
}
}

struct BitFieldInitVisitor<'tcx> {
tcx: TyCtxt<'tcx>,
inits: HashMap<Span, Span>,
}

impl<'tcx> BitFieldInitVisitor<'tcx> {
fn handle_expr(&mut self, expr: &'tcx Expr<'tcx>) {
let ExprKind::Block(block, _) = expr.kind else { return };
if block.stmts.len() <= 1 {
return;
}
let StmtKind::Local(local) = block.stmts[0].kind else { return };
let PatKind::Binding(_, hir_id, ident, _) = local.pat.kind else { return };
if ident.name.to_ident_string() != "init" {
return;
}
let init = some_or!(local.init, return);
let ExprKind::Struct(_, _, _) = init.kind else { return };
let e = some_or!(block.expr, return);
let ExprKind::Path(QPath::Resolved(_, path)) = e.kind else { return };
let Res::Local(id) = path.res else { return };
if hir_id != id {
return;
}
self.inits.insert(e.span, init.span);
}
}

impl<'tcx> HVisitor<'tcx> for BitFieldInitVisitor<'tcx> {
type NestedFilter = nested_filter::OnlyBodies;

fn nested_visit_map(&mut self) -> Self::Map {
self.tcx.hir()
}

fn visit_expr(&mut self, expr: &'tcx Expr<'tcx>) {
self.handle_expr(expr);
intravisit::walk_expr(self, expr);
}
}

#[derive(Debug, Clone, Copy)]
enum ExprContext {
Value,
Expand Down

0 comments on commit a4920ea

Please sign in to comment.