diff --git a/dag_in_context/src/add_context.rs b/dag_in_context/src/add_context.rs index f5051e16e..d7e258d04 100644 --- a/dag_in_context/src/add_context.rs +++ b/dag_in_context/src/add_context.rs @@ -327,9 +327,6 @@ impl Expr { body.add_ctx_with_cache(current_ctx, cache), )), Expr::Symbolic(s, ty) => RcExpr::new(Expr::Symbolic(s.clone(), ty.clone())), - Expr::DeadCode(subexpr) => RcExpr::new(Expr::DeadCode( - subexpr.add_ctx_with_cache(current_ctx, cache), - )), }; cache .with_ctx diff --git a/dag_in_context/src/from_egglog.rs b/dag_in_context/src/from_egglog.rs index afda043cd..cd9625796 100644 --- a/dag_in_context/src/from_egglog.rs +++ b/dag_in_context/src/from_egglog.rs @@ -343,9 +343,6 @@ impl<'a> FromEgglog<'a> { self.expr_from_egglog(expr.clone()), )) } - ("DeadCode", [subexpr]) => { - Rc::new(Expr::DeadCode(self.expr_from_egglog(self.termdag.get(*subexpr).clone()))) - } _ => panic!("Invalid expr: {:?}", expr), }); diff --git a/dag_in_context/src/greedy_dag_extractor.rs b/dag_in_context/src/greedy_dag_extractor.rs index 5ad9cec9e..f0f4de915 100644 --- a/dag_in_context/src/greedy_dag_extractor.rs +++ b/dag_in_context/src/greedy_dag_extractor.rs @@ -462,12 +462,15 @@ impl<'a> Extractor<'a> { /// violating the invariant that we only extract one term per eclass. /// This function would be called with `Neg(b)` as `term`, and would return `Neg(a)` as the new term. /// This restores the invariant that we only extract one term per eclass. + /// + /// When is_free is true, return 0 for the cost and don't add new nodes to the cost set. fn add_term_to_cost_set( &mut self, info: &EgraphInfo, current_costs: &mut HashTrieMap, term: Term, other_costs: &HashTrieMap, + is_free: bool, ) -> (Term, Cost) { match &term { Term::Lit(_) => { @@ -480,11 +483,6 @@ impl<'a> Extractor<'a> { return (term, NotNan::new(0.).unwrap()); } - if head.to_string() == "DeadCode" { - // no need to add to cost set - return (term, NotNan::new(0.).unwrap()); - } - let nodeid = &self.term_node(&term); let eclass = info.egraph.nid_to_cid(nodeid); if let Some((existing_term, _existing_cost)) = current_costs.get(eclass) { @@ -492,9 +490,14 @@ impl<'a> Extractor<'a> { } else { let unshared_cost = match other_costs.get(eclass) { Some((_, cost)) => *cost, + // no cost stored, so it's free None => NotNan::new(0.).unwrap(), }; let mut cost = unshared_cost; + if is_free { + cost = NotNan::new(0.).unwrap(); + } + let new_term = { let mut new_children = vec![]; for child in children { @@ -504,6 +507,7 @@ impl<'a> Extractor<'a> { current_costs, child.clone(), other_costs, + is_free, ); new_children.push(new_child); cost += child_cost; @@ -511,8 +515,11 @@ impl<'a> Extractor<'a> { self.termdag.app(*head, new_children) }; self.add_correspondence(new_term.clone(), nodeid.clone()); - *current_costs = - current_costs.insert(eclass.clone(), (new_term.clone(), unshared_cost)); + + if cost > NotNan::new(0.).unwrap() { + *current_costs = + current_costs.insert(eclass.clone(), (new_term.clone(), unshared_cost)); + } (new_term, cost) } @@ -736,30 +743,22 @@ impl<'a> Extractor<'a> { } } } + eprintln!("used children: {:?}", used_children); if !add_to_shared { // now that we have which children are used, try to break up the inputs if let Some(broken_up_terms) = self.try_break_up_term(&child_set.term) { let mut new_input_children = vec![]; for (idx, input_tuple_term) in broken_up_terms.iter().enumerate() { - if used_children.contains(&idx) { - let (child_term, net_cost) = self.add_term_to_cost_set( - info, - &mut costs, - input_tuple_term.clone(), - &child_set.costs, - ); - shared_total += net_cost; - new_input_children.push(child_term); - } else { - let deadcode_term = self - .termdag - .app("DeadCode".into(), vec![input_tuple_term.clone()]); - new_input_children.push(deadcode_term.clone()); - let old_term_node = - self.correspondence.get(input_tuple_term).unwrap(); - self.add_correspondence(deadcode_term, old_term_node.clone()); - } + let (child_term, net_cost) = self.add_term_to_cost_set( + info, + &mut costs, + input_tuple_term.clone(), + &child_set.costs, + !used_children.contains(&idx), + ); + shared_total += net_cost; + new_input_children.push(child_term); } let (new_term, children_used) = self.build_concat(child_set.term.clone(), &new_input_children); @@ -779,6 +778,7 @@ impl<'a> Extractor<'a> { &mut costs, child_set.term.clone(), &child_set.costs, + false, ); shared_total += net_cost; children_terms.push(child_term); diff --git a/dag_in_context/src/interpreter.rs b/dag_in_context/src/interpreter.rs index 8fc491a5e..a47b758b9 100644 --- a/dag_in_context/src/interpreter.rs +++ b/dag_in_context/src/interpreter.rs @@ -460,10 +460,6 @@ impl<'a> VirtualMachine<'a> { self.interpret_call(func_name, &e_val) } Expr::Symbolic(_, _ty) => panic!("found symbolic"), - // dead code marks it as dead, but we can still evaluate it - // soundly - // dead code should be removed by the dead code pass after extraction - Expr::DeadCode(subexpr) => self.interpret_expr(subexpr, arg), }; self.eval_cache.insert(Rc::as_ptr(expr), res.clone()); res diff --git a/dag_in_context/src/linearity.rs b/dag_in_context/src/linearity.rs index f3e352775..b8f3f41b8 100644 --- a/dag_in_context/src/linearity.rs +++ b/dag_in_context/src/linearity.rs @@ -188,7 +188,6 @@ impl<'a> Extractor<'a> { } Expr::Const(_, _, _) => panic!("Const has no effect"), Expr::Symbolic(_, _ty) => panic!("found symbolic"), - Expr::DeadCode(_subexpr) => panic!("found dead code"), } } diff --git a/dag_in_context/src/pretty_print.rs b/dag_in_context/src/pretty_print.rs index 5beec1b96..b3df70bf1 100644 --- a/dag_in_context/src/pretty_print.rs +++ b/dag_in_context/src/pretty_print.rs @@ -408,7 +408,6 @@ impl Expr { Expr::Arg(..) => "arg".into(), Expr::Function(name, ..) => "fun_".to_owned() + name, Expr::Symbolic(var, _ty) => "symbolic_".to_owned() + var, - Expr::DeadCode(_subexpr) => todo!("dead code pretty print"), } } @@ -494,7 +493,6 @@ impl Expr { ) } Expr::Symbolic(str, _ty) => format!("{str}.clone()"), - Expr::DeadCode(_subexpr) => todo!("dead code ast"), } } } diff --git a/dag_in_context/src/remove_dead_code_nodes.rs b/dag_in_context/src/remove_dead_code_nodes.rs index 797feda0a..2720052e1 100644 --- a/dag_in_context/src/remove_dead_code_nodes.rs +++ b/dag_in_context/src/remove_dead_code_nodes.rs @@ -1,4 +1,7 @@ -use std::{collections::HashMap, rc::Rc}; +use std::{ + collections::{HashMap, HashSet}, + rc::Rc, +}; use crate::{ ast::parallel_vec_ty, @@ -65,25 +68,65 @@ fn try_split_inputs(expr: RcExpr) -> Option> { } } +fn indices_used(expr: RcExpr) -> HashSet { + let mut res = HashSet::new(); + match expr.as_ref() { + Expr::Get(expr, index) => match expr.as_ref() { + Expr::Arg(_ty, _ctx) => { + res.insert(*index); + } + _ => { + res.extend(indices_used(expr.clone())); + } + }, + Expr::Arg(ty, _ctx) => { + // all of them are used, add one per length of tuple + match ty { + Type::TupleT(vec) => { + for i in 0..vec.len() { + res.insert(i); + } + } + _ => { + res.insert(0); + } + } + } + _ => { + for expr in expr.children_same_scope() { + res.extend(indices_used(expr)); + } + } + } + res +} + /// given a vector of inputs, add the non-dead ones to a new vector /// and return the indicies of the dead ones fn partition_inputs_and_remove_dead_code( inputs: Vec, + regions: Vec, memo: &mut HashMap<(*const Expr, Vec), RcExpr>, current_dead: &Vec, ) -> (Vec, Vec) { + let indices_used = regions + .iter() + .map(|region| indices_used(region.clone())) + .fold(HashSet::new(), |mut acc, used| { + acc.extend(used); + acc + }); + let mut new_inputs = vec![]; let mut new_dead_indicies = vec![]; - for (index, input) in inputs.iter().enumerate() { - match input.as_ref() { - Expr::DeadCode(_subexpr) => { - new_dead_indicies.push(index); - } - _ => { - new_inputs.push(remove_dead_code_expr(input.clone(), memo, current_dead)); - } + for (i, input) in inputs.iter().enumerate() { + if indices_used.contains(&i) { + new_inputs.push(remove_dead_code_expr(input.clone(), memo, current_dead)); + } else { + new_dead_indicies.push(i); } } + (new_inputs, new_dead_indicies) } @@ -106,6 +149,11 @@ fn remove_dead_code_expr( // check if the expr is an argument match expr.as_ref() { Expr::Arg(ty, ctx) => { + // if the index is dead, panic + if dead_indicies.contains(index) { + panic!("Found dead code in argument"); + } + let new_ty = remove_dead_code_ty(ty.clone(), dead_indicies); let mut new_index = *index; for dead_index in dead_indicies { @@ -141,8 +189,12 @@ fn remove_dead_code_expr( } Expr::If(pred, inputs, then, else_case) => { if let Some(split_inputs) = try_split_inputs(inputs.clone()) { - let (new_inputs, new_dead_indicies) = - partition_inputs_and_remove_dead_code(split_inputs, memo, dead_indicies); + let (new_inputs, new_dead_indicies) = partition_inputs_and_remove_dead_code( + split_inputs, + vec![then.clone(), else_case.clone()], + memo, + dead_indicies, + ); let new_pred = remove_dead_code_expr(pred.clone(), memo, dead_indicies); RcExpr::new(Expr::If( new_pred.clone(), @@ -161,8 +213,12 @@ fn remove_dead_code_expr( } Expr::Switch(pred, inputs, branches) => { if let Some(split_inputs) = try_split_inputs(inputs.clone()) { - let (new_inputs, new_dead_indicies) = - partition_inputs_and_remove_dead_code(split_inputs, memo, dead_indicies); + let (new_inputs, new_dead_indicies) = partition_inputs_and_remove_dead_code( + split_inputs, + branches.clone(), + memo, + dead_indicies, + ); let mut new_branches = vec![]; for branch in branches.iter() { new_branches.push(remove_dead_code_expr( @@ -188,9 +244,6 @@ fn remove_dead_code_expr( )) } } - Expr::DeadCode(_subexpr) => { - panic!("Reached dead code without being in inputs of control flow node"); - } Expr::Function(_, _, _, _expr) => panic!("Found function inside of function"), _ => { expr.map_expr_children(|expr| remove_dead_code_expr(expr.clone(), memo, dead_indicies)) diff --git a/dag_in_context/src/schema.rs b/dag_in_context/src/schema.rs index 935852868..c9bfd8e72 100644 --- a/dag_in_context/src/schema.rs +++ b/dag_in_context/src/schema.rs @@ -124,9 +124,6 @@ pub enum Expr { DoWhile(RcExpr, RcExpr), Arg(Type, Assumption), Function(String, Type, Type, RcExpr), - // marks a subexpression as dead code that can be dropped - // used by extraction, then dead code is removed by a dead code pass - DeadCode(RcExpr), // optionally, the type of this symbol for typechecking Symbolic(String, Option), } diff --git a/dag_in_context/src/schema_helpers.rs b/dag_in_context/src/schema_helpers.rs index 30daed0af..656643117 100644 --- a/dag_in_context/src/schema_helpers.rs +++ b/dag_in_context/src/schema_helpers.rs @@ -132,7 +132,6 @@ impl Expr { Expr::Alloc(..) => Constructor::Alloc, Expr::Top(..) => Constructor::Top, Expr::Symbolic(_, _ty) => panic!("found symbolic"), - Expr::DeadCode(_subexpr) => panic!("found dead code"), } } pub fn func_name(&self) -> Option { @@ -223,7 +222,6 @@ impl Expr { Expr::Empty(_, _) => vec![], Expr::Arg(_, _) => vec![], Expr::Symbolic(_, _ty) => vec![], - Expr::DeadCode(subexpr) => vec![subexpr.clone()], } } @@ -253,7 +251,6 @@ impl Expr { Expr::DoWhile(inputs, _body) => vec![inputs.clone()], Expr::Arg(_, _) => vec![], Expr::Symbolic(_, _ty) => vec![], - Expr::DeadCode(subexpr) => vec![subexpr.clone()], } } @@ -275,7 +272,6 @@ impl Expr { Expr::Arg(ty, _) => ty.clone(), Expr::Function(_, ty, _, _) => ty.clone(), Expr::Symbolic(_, _ty) => panic!("found symbolic"), - Expr::DeadCode(subexpr) => subexpr.get_arg_type(), } } @@ -375,7 +371,6 @@ impl Expr { Expr::Arg(_, ctx) => ctx, Expr::Function(_, _, _, x) => x.get_ctx(), Expr::Symbolic(_, _ty) => panic!("found symbolic"), - Expr::DeadCode(subexpr) => subexpr.get_ctx(), } } @@ -511,14 +506,6 @@ impl Expr { } Expr::Empty(_, _) => Rc::new(Expr::Empty(arg_ty.clone(), arg_ctx.clone())), Expr::Symbolic(_, _ty) => panic!("found symbolic"), - Expr::DeadCode(subexpr) => Rc::new(Expr::DeadCode(Self::subst_with_cache( - arg, - arg_ty, - arg_ctx, - subexpr, - subst_cache, - context_cache, - ))), }; // Add the substituted to cache diff --git a/dag_in_context/src/to_egglog.rs b/dag_in_context/src/to_egglog.rs index 5adf001d0..a4dfc4d60 100644 --- a/dag_in_context/src/to_egglog.rs +++ b/dag_in_context/src/to_egglog.rs @@ -262,7 +262,6 @@ impl Expr { term_dag.app("Function".into(), vec![name_lit, ty_in, ty_out, body]) } Expr::Symbolic(name, _ty) => term_dag.var(name.into()), - Expr::DeadCode(_subexpr) => panic!("Dead code should not be converted to egglog"), }; term_dag diff --git a/dag_in_context/src/typechecker.rs b/dag_in_context/src/typechecker.rs index e2c4f29cb..f95bfc92f 100644 --- a/dag_in_context/src/typechecker.rs +++ b/dag_in_context/src/typechecker.rs @@ -527,10 +527,6 @@ impl<'a> TypeChecker<'a> { } (found_ty.clone(), expr.clone()) } - Expr::DeadCode(subexpr) => { - let (ty, new_subexpr) = self.add_arg_types_to_expr(subexpr.clone(), arg_tys); - (ty, RcExpr::new(Expr::DeadCode(new_subexpr))) - } Expr::Function(_, _, _, _) => panic!("Expected expression, got function"), Expr::Symbolic(_, ty) => ( ty.clone() @@ -561,7 +557,6 @@ impl<'a> TypeChecker<'a> { pub(crate) fn get_arg_type(expr: &RcExpr) -> Type { match expr.as_ref() { Expr::Arg(ty, _) => ty.clone(), - Expr::DeadCode(subexpr) => Self::get_arg_type(subexpr), Expr::Const(_, ty, _) => ty.clone(), Expr::Top(_, rc, _, _) => Self::get_arg_type(rc), Expr::Bop(_, left, _) => Self::get_arg_type(left), diff --git a/src/rvsdg/from_dag.rs b/src/rvsdg/from_dag.rs index 8ee2b50dd..afb182133 100644 --- a/src/rvsdg/from_dag.rs +++ b/src/rvsdg/from_dag.rs @@ -463,9 +463,6 @@ impl<'a> TreeToRvsdg<'a> { res } Expr::Symbolic(_, _ty) => panic!("symbolic not supported"), - Expr::DeadCode(_) => { - panic!("dead code should have been removed at the end of extraction!") - } }; self.translation_cache .insert(Rc::as_ptr(&expr), res.clone());