diff --git a/src/ai/analysis.rs b/src/ai/analysis.rs index 44639fc..cef3f44 100644 --- a/src/ai/analysis.rs +++ b/src/ai/analysis.rs @@ -48,9 +48,14 @@ impl Default for AnalysisConfig { } pub type OutputParams = BTreeMap>; + +// Write Before RETurn s pub type Wbrets = BTreeMap>; -pub type AnalysisResult = (OutputParams, BTreeMap); +// Removable checks for Write s +pub type Rcfws = BTreeMap>; + +pub type AnalysisResult = (OutputParams, BTreeMap); pub fn analyze_path(path: &Path, conf: &AnalysisConfig) -> AnalysisResult { analyze_input(compile_util::path_to_input(path), conf) @@ -65,18 +70,18 @@ pub fn analyze_input(input: Input, conf: &AnalysisConfig) -> AnalysisResult { compile_util::run_compiler(config, |tcx| { analyze(tcx, conf) .into_iter() - .filter_map(|(def_id, (_, params, writes))| { + .filter_map(|(def_id, (_, params, wbret, rcfw))| { if params.is_empty() { None } else { - Some((tcx.def_path_str(def_id), (params, writes))) + Some((tcx.def_path_str(def_id), (params, wbret, rcfw))) } }) .collect::>() }) .unwrap() .into_iter() - .map(|(k, (v1, v2))| ((k.clone(), v1), (k, v2))) + .map(|(k, (v1, v2, v3))| ((k.clone(), v1), (k.clone(), (v2, v3)))) .unzip() } @@ -90,14 +95,7 @@ enum Write { pub fn analyze( tcx: TyCtxt<'_>, conf: &AnalysisConfig, -) -> BTreeMap< - DefId, - ( - FunctionSummary, - Vec, - Wbrets, - ), -> { +) -> BTreeMap, Wbrets, Rcfws)> { let hir = tcx.hir(); let mut call_graph = BTreeMap::new(); @@ -164,6 +162,9 @@ pub fn analyze( let mut call_args_map = BTreeMap::new(); let mut analysis_times: BTreeMap<_, u128> = BTreeMap::new(); let mut wbrets: BTreeMap>> = BTreeMap::new(); + let mut wbbbrets: BTreeMap>> = BTreeMap::new(); + + let mut rcfws = BTreeMap::new(); for id in &po { let def_ids = &elems[id]; let recursive = if def_ids.len() == 1 { @@ -216,11 +217,18 @@ pub fn analyze( let ret_location = return_location(body); let mut wbret = BTreeMap::new(); + let mut wbbbret = BTreeMap::new(); if let Some(ret_location) = ret_location { - if let Some(ret_loc_assign0) = exists_assign0(body, ret_location.block) { + if let Some((ret_loc_assign0, index)) = exists_assign0(body, ret_location.block) + { + let loc = Location { + block: ret_location.block, + statement_index: index, + }; + let writes: BTreeSet<_> = states - .get(&ret_location) + .get(&loc) .cloned() .unwrap_or_default() .values() @@ -230,17 +238,18 @@ pub fn analyze( wbret.insert( unsafe { std::mem::transmute(ret_loc_assign0.data()) }, - writes, + writes.clone(), ); + wbbbret.insert(ret_location.block, writes); } else { let preds = body.basic_blocks.predecessors().get(ret_location.block); if let Some(v) = preds { for i in v { - if let Some(sp) = exists_assign0(body, *i) { + if let Some((sp, index)) = exists_assign0(body, *i) { let loc = Location { block: *i, - statement_index: body.basic_blocks[*i].statements.len(), + statement_index: index, }; let writes: BTreeSet<_> = states @@ -252,7 +261,11 @@ pub fn analyze( .map(|p| p.base() - 1) .collect(); - wbret.insert(unsafe { std::mem::transmute(sp.data()) }, writes); + wbret.insert( + unsafe { std::mem::transmute(sp.data()) }, + writes.clone(), + ); + wbbbret.insert(*i, writes); } } } @@ -280,7 +293,30 @@ pub fn analyze( } else { wbret }; + + let wbbbret = if let Some(old) = wbbbrets.get(def_id) { + let keys: BTreeSet<_> = wbbbret.keys().chain(old.keys()).cloned().collect(); + + keys.into_iter() + .map(|bb| { + ( + bb, + match (wbbbret.get(&bb), old.get(&bb)) { + (Some(v1), Some(v2)) => { + v1.intersection(v2).cloned().collect::>() + } + (Some(v), None) | (None, Some(v)) => (*v).clone(), + _ => unreachable!(), + }, + ) + }) + .collect::>() + } else { + wbbbret + }; + wbrets.insert(*def_id, wbret); + wbbbrets.insert(*def_id, wbbbret); let mut return_states = ret_location .and_then(|ret| states.get(&ret)) @@ -323,6 +359,69 @@ pub fn analyze( for p in &mut output_params { analyzer.find_complete_write(p, &result, &writes_map, &call_args, *def_id); } + + let body = tcx.optimized_mir(*def_id); + let wbbbret = &wbbbrets[def_id]; + let mut rcfw: Rcfws = BTreeMap::new(); + for p in output_params.iter() { + let OutputParam { + index, + must: _, + return_values: _, + complete_writes, + } = p; + for complete_write in complete_writes.iter() { + let CompleteWrite { + block, + statement_index, + write_arg: _, + } = complete_write; + + let mut stack = vec![BasicBlock::from_usize(*block)]; + + let success = loop { + if let Some(block) = stack.pop() { + match wbbbret.get(&block) { + Some(ws) => { + if !ws.contains(index) { + break false; + } + } + None => (), + } + + let bbd = &body.basic_blocks[block]; + let term = bbd.terminator(); + + match term.kind { + TerminatorKind::Return => (), + _ => { + for bb in term.successors() { + stack.push(bb); + } + } + } + } else { + break true; + } + }; + + if success { + let location = Location { + block: BasicBlock::from_usize(*block), + statement_index: *statement_index, + }; + let span = unsafe { + std::mem::transmute(body.source_info(location).span.data()) + }; + + let entry = rcfw.entry(*index); + entry.or_default().insert(span); + } + } + } + + rcfws.insert(*def_id, rcfw); output_params_map.insert(*def_id, output_params); } break; @@ -351,6 +450,7 @@ pub fn analyze( summary, output_params, wbrets.get(&def_id).cloned().unwrap_or_default(), + rcfws.get(&def_id).cloned().unwrap_or_default(), ), ) }) @@ -1152,11 +1252,11 @@ fn return_location(body: &Body<'_>) -> Option { None } -fn exists_assign0(body: &Body<'_>, bb: BasicBlock) -> Option { - for stmt in body.basic_blocks[bb].statements.iter() { +fn exists_assign0(body: &Body<'_>, bb: BasicBlock) -> Option<(Span, usize)> { + for (i, stmt) in body.basic_blocks[bb].statements.iter().enumerate() { if let StatementKind::Assign(rb) = &stmt.kind { if (**rb).0.local.as_u32() == 0u32 { - return Some(stmt.source_info.span); + return Some((stmt.source_info.span, i)); } } } diff --git a/src/transform.rs b/src/transform.rs index fc57681..fc1b36f 100644 --- a/src/transform.rs +++ b/src/transform.rs @@ -22,40 +22,23 @@ use crate::{ai::analysis::*, compile_util}; pub fn transform_path( path: &Path, params: &OutputParams, - writes: &BTreeMap, + extra_info: &BTreeMap, ) { let input = compile_util::path_to_input(path); let config = compile_util::make_config(input); let suggestions = - compile_util::run_compiler(config, |tcx| transform(tcx, params, writes)).unwrap(); + compile_util::run_compiler(config, |tcx| transform(tcx, params, extra_info)).unwrap(); compile_util::apply_suggestions(&suggestions); } fn transform( tcx: TyCtxt<'_>, param_map: &OutputParams, - writes: &BTreeMap, + extra_info: &BTreeMap, ) -> BTreeMap> { let hir = tcx.hir(); let source_map = tcx.sess.source_map(); - let writes = writes - .iter() - .map(|(k, v)| { - ( - k, - v.iter() - .map(|(k, v)| { - ( - unsafe { std::mem::transmute::(*k) }.span(), - v, - ) - }) - .collect::>(), - ) - }) - .collect::>(); - let mut def_id_ty_map = BTreeMap::new(); for id in hir.items() { let item = hir.item(id); @@ -72,6 +55,7 @@ fn transform( let mut funcs = BTreeMap::new(); let mut wbrets = BTreeMap::new(); + let mut rcfws = BTreeMap::new(); for id in hir.items() { let item = hir.item(id); let ItemKind::Fn(sig, _, body_id) = item.kind else { @@ -159,24 +143,46 @@ fn transform( (*index, param) }) .collect(); - if let Some(write) = writes.get(&name) { + + if let Some((wbret, _)) = extra_info.get(&name) { wbrets.insert( def_id, - write + wbret .clone() .into_iter() .map(|(sp, params)| { ( - sp, + unsafe { std::mem::transmute::(sp) }.span(), params .iter() .filter_map(|i| index_map.get(i).map(|p| p.name.clone())) .collect::>(), ) }) - .collect::>(), + .collect::>(), + ); + } + + if let Some((_, rcfw)) = extra_info.get(&name) { + rcfws.insert( + def_id, + rcfw.clone() + .into_iter() + .map(|(index, spans)| { + ( + index_map.get(&index).cloned().unwrap().name, + spans + .iter() + .map(|sp| { + unsafe { std::mem::transmute::(*sp) }.span() + }) + .collect(), + ) + }) + .collect::>>(), ); } + let hir_id_map: BTreeMap<_, _> = index_map .values() .cloned() @@ -298,7 +304,11 @@ fn transform( } let assign_map = curr.map(|c| c.assign_map(span)).unwrap_or_default(); - let mut mtch = func.call_match(&args, &assign_map); + + let mut mtch = func.first_return.and_then(|(_, first)| { + let set_flag = generate_set_flag(&span, &first, &rcfws[&def_id], &assign_map); + func.call_match(&args, set_flag) + }); if let Some(call) = get_if_cmp_call(hir_id, span, tcx) { if let Some(then) = func.cmp(call.op, call.target) { @@ -310,11 +320,7 @@ fn transform( let fail = "Err(_) => "; let (_, i) = func.first_return.as_ref().unwrap(); let arg = &args[*i]; - let set_flag = if let Some(arg) = assign_map.get(i) { - format!("{}___s = true;", arg) - } else { - "".to_string() - }; + let set_flag = generate_set_flag(&span, i, &rcfws[&def_id], &assign_map); let assign = if arg.code.contains("&mut ") { format!(" *({}) = v___; {}", arg.code, set_flag) } else { @@ -382,7 +388,7 @@ fn transform( let rv = format!("{}rv___{}", pre_s, post_s); let (_, wbret) = wbrets[&def_id] .iter() - .find(|(sp, _)| span.contains(**sp)) + .find(|(sp, _)| span.contains(*sp)) .unwrap(); let rv = func.return_value(Some(rv), wbret); fix( @@ -398,7 +404,18 @@ fn transform( } fix(span.shrink_to_lo(), binding); - let mut assign = func.call_assign(&args, &assign_map); + let set_flags = func + .remaining_return + .iter() + .map(|i| { + ( + *i, + generate_set_flag(&span, i, &rcfws[&def_id], &assign_map), + ) + }) + .collect(); + + let mut assign = func.call_assign(&args, &set_flags); if let Some(m) = &mtch { assign += m; assign += ")"; @@ -464,10 +481,18 @@ fn transform( fix(span, local_vars); for param in func.params() { + let rcfw = &rcfws[&def_id].get(¶m.name); + for span in ¶m.writes { if call_spans.contains(span) { continue; } + if let Some(rcfw) = rcfw { + if rcfw.iter().any(|sp| span.contains(*sp)) { + continue; + } + } + let pos = span.hi() + BytePos(1); let span = span.with_hi(pos).with_lo(pos); let assign = format!("{0}___s = true;", param.name); @@ -522,7 +547,7 @@ fn transform( let orig = value.map(|value| source_map.span_to_snippet(value).unwrap()); let (_, wbret) = wbrets[&def_id] .iter() - .find(|(sp, _)| span.contains(**sp)) + .find(|(sp, _)| span.contains(*sp)) .unwrap(); let ret_v = func.return_value(orig, wbret); fix(span, format!("return {}", ret_v)); @@ -616,34 +641,29 @@ impl Func { map } - fn call_assign(&self, args: &[Arg], assign_map: &BTreeMap) -> String { + fn call_assign(&self, args: &[Arg], set_flags: &BTreeMap) -> String { let mut assigns = vec![]; for i in &self.remaining_return { let arg = &args[*i]; let param = &self.index_map[i]; - let set_flag = if let Some(arg) = assign_map.get(i) { - format!("{}___s = true;", arg) - } else { - "".to_string() - }; let assign = if param.must { if arg.code.contains("&mut ") { - format!("*({}) = rv___{}; {}", arg.code, i, set_flag) + format!("*({}) = rv___{}; {}", arg.code, i, set_flags[i]) } else { format!( "if !({0}).is_null() {{ *({0}) = rv___{1}; {2} }}", - arg.code, i, set_flag + arg.code, i, set_flags[i] ) } } else if arg.code.contains("&mut ") { format!( "if let Some(v___) = rv___{} {{ *({}) = v___; {} }}", - i, arg.code, set_flag + i, arg.code, set_flags[i] ) } else { format!( "if !({0}).is_null() {{ if let Some(v___) = rv___{1} {{ *({0}) = v___; {2} }} }}", - arg.code, i, set_flag + arg.code, i, set_flags[i] ) }; assigns.push(assign); @@ -652,14 +672,9 @@ impl Func { mk_string(assigns.iter(), "; ", " ", end) } - fn call_match(&self, args: &[Arg], assign_map: &BTreeMap) -> Option { + fn call_match(&self, args: &[Arg], set_flag: String) -> Option { let (succ_value, first) = &self.first_return?; let arg = &args[*first]; - let set_flag = if let Some(arg) = assign_map.get(first) { - format!("{}___s = true;", arg) - } else { - "".to_string() - }; let assign = if arg.code.contains("&mut ") { format!("*({}) = v___; {}", arg.code, set_flag) } else { @@ -1070,3 +1085,21 @@ fn mk_string, I: Iterator>( s.push_str(end); s } + +fn generate_set_flag( + span: &Span, + i: &usize, + rcfws: &BTreeMap>, + assign_map: &BTreeMap, +) -> String { + if let Some(arg) = assign_map.get(i) { + let rcfw = &rcfws.get(arg); + if let Some(rcfw) = rcfw { + if rcfw.iter().any(|sp| span.contains(*sp)) { + return "".to_string(); + } + } + return format!("{}___s = true;", arg); + } + "".to_string() +}