Skip to content

Commit

Permalink
fix up code to check for deadness later
Browse files Browse the repository at this point in the history
  • Loading branch information
oflatt committed Feb 3, 2025
1 parent df21ada commit c2031a6
Show file tree
Hide file tree
Showing 12 changed files with 94 additions and 79 deletions.
3 changes: 0 additions & 3 deletions dag_in_context/src/add_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -327,9 +327,6 @@ impl Expr {
body.add_ctx_with_cache(current_ctx, cache),
)),
Expr::Symbolic(s, ty) => RcExpr::new(Expr::Symbolic(s.clone(), ty.clone())),
Expr::DeadCode(subexpr) => RcExpr::new(Expr::DeadCode(
subexpr.add_ctx_with_cache(current_ctx, cache),
)),
};
cache
.with_ctx
Expand Down
3 changes: 0 additions & 3 deletions dag_in_context/src/from_egglog.rs
Original file line number Diff line number Diff line change
Expand Up @@ -343,9 +343,6 @@ impl<'a> FromEgglog<'a> {
self.expr_from_egglog(expr.clone()),
))
}
("DeadCode", [subexpr]) => {
Rc::new(Expr::DeadCode(self.expr_from_egglog(self.termdag.get(*subexpr).clone())))
}
_ => panic!("Invalid expr: {:?}", expr),
});

Expand Down
50 changes: 25 additions & 25 deletions dag_in_context/src/greedy_dag_extractor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -462,12 +462,15 @@ impl<'a> Extractor<'a> {
/// violating the invariant that we only extract one term per eclass.
/// This function would be called with `Neg(b)` as `term`, and would return `Neg(a)` as the new term.
/// This restores the invariant that we only extract one term per eclass.
///
/// When is_free is true, return 0 for the cost and don't add new nodes to the cost set.
fn add_term_to_cost_set(
&mut self,
info: &EgraphInfo,
current_costs: &mut HashTrieMap<ClassId, (Term, Cost)>,
term: Term,
other_costs: &HashTrieMap<ClassId, (Term, Cost)>,
is_free: bool,
) -> (Term, Cost) {
match &term {
Term::Lit(_) => {
Expand All @@ -480,21 +483,21 @@ impl<'a> Extractor<'a> {
return (term, NotNan::new(0.).unwrap());
}

if head.to_string() == "DeadCode" {
// no need to add to cost set
return (term, NotNan::new(0.).unwrap());
}

let nodeid = &self.term_node(&term);
let eclass = info.egraph.nid_to_cid(nodeid);
if let Some((existing_term, _existing_cost)) = current_costs.get(eclass) {
(existing_term.clone(), NotNan::new(0.).unwrap())
} else {
let unshared_cost = match other_costs.get(eclass) {
Some((_, cost)) => *cost,
// no cost stored, so it's free
None => NotNan::new(0.).unwrap(),
};
let mut cost = unshared_cost;
if is_free {
cost = NotNan::new(0.).unwrap();
}

let new_term = {
let mut new_children = vec![];
for child in children {
Expand All @@ -504,15 +507,19 @@ impl<'a> Extractor<'a> {
current_costs,
child.clone(),
other_costs,
is_free,
);
new_children.push(new_child);
cost += child_cost;
}
self.termdag.app(*head, new_children)
};
self.add_correspondence(new_term.clone(), nodeid.clone());
*current_costs =
current_costs.insert(eclass.clone(), (new_term.clone(), unshared_cost));

if cost > NotNan::new(0.).unwrap() {
*current_costs =
current_costs.insert(eclass.clone(), (new_term.clone(), unshared_cost));
}

(new_term, cost)
}
Expand Down Expand Up @@ -736,30 +743,22 @@ impl<'a> Extractor<'a> {
}
}
}
eprintln!("used children: {:?}", used_children);

if !add_to_shared {
// now that we have which children are used, try to break up the inputs
if let Some(broken_up_terms) = self.try_break_up_term(&child_set.term) {
let mut new_input_children = vec![];
for (idx, input_tuple_term) in broken_up_terms.iter().enumerate() {
if used_children.contains(&idx) {
let (child_term, net_cost) = self.add_term_to_cost_set(
info,
&mut costs,
input_tuple_term.clone(),
&child_set.costs,
);
shared_total += net_cost;
new_input_children.push(child_term);
} else {
let deadcode_term = self
.termdag
.app("DeadCode".into(), vec![input_tuple_term.clone()]);
new_input_children.push(deadcode_term.clone());
let old_term_node =
self.correspondence.get(input_tuple_term).unwrap();
self.add_correspondence(deadcode_term, old_term_node.clone());
}
let (child_term, net_cost) = self.add_term_to_cost_set(
info,
&mut costs,
input_tuple_term.clone(),
&child_set.costs,
!used_children.contains(&idx),
);
shared_total += net_cost;
new_input_children.push(child_term);
}
let (new_term, children_used) =
self.build_concat(child_set.term.clone(), &new_input_children);
Expand All @@ -779,6 +778,7 @@ impl<'a> Extractor<'a> {
&mut costs,
child_set.term.clone(),
&child_set.costs,
false,
);
shared_total += net_cost;
children_terms.push(child_term);
Expand Down
4 changes: 0 additions & 4 deletions dag_in_context/src/interpreter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -460,10 +460,6 @@ impl<'a> VirtualMachine<'a> {
self.interpret_call(func_name, &e_val)
}
Expr::Symbolic(_, _ty) => panic!("found symbolic"),
// dead code marks it as dead, but we can still evaluate it
// soundly
// dead code should be removed by the dead code pass after extraction
Expr::DeadCode(subexpr) => self.interpret_expr(subexpr, arg),
};
self.eval_cache.insert(Rc::as_ptr(expr), res.clone());
res
Expand Down
1 change: 0 additions & 1 deletion dag_in_context/src/linearity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,6 @@ impl<'a> Extractor<'a> {
}
Expr::Const(_, _, _) => panic!("Const has no effect"),
Expr::Symbolic(_, _ty) => panic!("found symbolic"),
Expr::DeadCode(_subexpr) => panic!("found dead code"),
}
}

Expand Down
2 changes: 0 additions & 2 deletions dag_in_context/src/pretty_print.rs
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,6 @@ impl Expr {
Expr::Arg(..) => "arg".into(),
Expr::Function(name, ..) => "fun_".to_owned() + name,
Expr::Symbolic(var, _ty) => "symbolic_".to_owned() + var,
Expr::DeadCode(_subexpr) => todo!("dead code pretty print"),
}
}

Expand Down Expand Up @@ -494,7 +493,6 @@ impl Expr {
)
}
Expr::Symbolic(str, _ty) => format!("{str}.clone()"),
Expr::DeadCode(_subexpr) => todo!("dead code ast"),
}
}
}
Expand Down
85 changes: 69 additions & 16 deletions dag_in_context/src/remove_dead_code_nodes.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use std::{collections::HashMap, rc::Rc};
use std::{
collections::{HashMap, HashSet},
rc::Rc,
};

use crate::{
ast::parallel_vec_ty,
Expand Down Expand Up @@ -65,25 +68,65 @@ fn try_split_inputs(expr: RcExpr) -> Option<Vec<RcExpr>> {
}
}

fn indices_used(expr: RcExpr) -> HashSet<usize> {
let mut res = HashSet::new();
match expr.as_ref() {
Expr::Get(expr, index) => match expr.as_ref() {
Expr::Arg(_ty, _ctx) => {
res.insert(*index);
}
_ => {
res.extend(indices_used(expr.clone()));
}
},
Expr::Arg(ty, _ctx) => {
// all of them are used, add one per length of tuple
match ty {
Type::TupleT(vec) => {
for i in 0..vec.len() {
res.insert(i);
}
}
_ => {
res.insert(0);
}
}
}
_ => {
for expr in expr.children_same_scope() {
res.extend(indices_used(expr));
}
}
}
res
}

/// given a vector of inputs, add the non-dead ones to a new vector
/// and return the indicies of the dead ones
fn partition_inputs_and_remove_dead_code(
inputs: Vec<RcExpr>,
regions: Vec<RcExpr>,
memo: &mut HashMap<(*const Expr, Vec<usize>), RcExpr>,
current_dead: &Vec<usize>,
) -> (Vec<RcExpr>, Vec<usize>) {
let indices_used = regions
.iter()
.map(|region| indices_used(region.clone()))
.fold(HashSet::new(), |mut acc, used| {
acc.extend(used);
acc
});

let mut new_inputs = vec![];
let mut new_dead_indicies = vec![];
for (index, input) in inputs.iter().enumerate() {
match input.as_ref() {
Expr::DeadCode(_subexpr) => {
new_dead_indicies.push(index);
}
_ => {
new_inputs.push(remove_dead_code_expr(input.clone(), memo, current_dead));
}
for (i, input) in inputs.iter().enumerate() {
if indices_used.contains(&i) {
new_inputs.push(remove_dead_code_expr(input.clone(), memo, current_dead));
} else {
new_dead_indicies.push(i);
}
}

(new_inputs, new_dead_indicies)
}

Expand All @@ -106,6 +149,11 @@ fn remove_dead_code_expr(
// check if the expr is an argument
match expr.as_ref() {
Expr::Arg(ty, ctx) => {
// if the index is dead, panic
if dead_indicies.contains(index) {
panic!("Found dead code in argument");
}

let new_ty = remove_dead_code_ty(ty.clone(), dead_indicies);
let mut new_index = *index;
for dead_index in dead_indicies {
Expand Down Expand Up @@ -141,8 +189,12 @@ fn remove_dead_code_expr(
}
Expr::If(pred, inputs, then, else_case) => {
if let Some(split_inputs) = try_split_inputs(inputs.clone()) {
let (new_inputs, new_dead_indicies) =
partition_inputs_and_remove_dead_code(split_inputs, memo, dead_indicies);
let (new_inputs, new_dead_indicies) = partition_inputs_and_remove_dead_code(
split_inputs,
vec![then.clone(), else_case.clone()],
memo,
dead_indicies,
);
let new_pred = remove_dead_code_expr(pred.clone(), memo, dead_indicies);
RcExpr::new(Expr::If(
new_pred.clone(),
Expand All @@ -161,8 +213,12 @@ fn remove_dead_code_expr(
}
Expr::Switch(pred, inputs, branches) => {
if let Some(split_inputs) = try_split_inputs(inputs.clone()) {
let (new_inputs, new_dead_indicies) =
partition_inputs_and_remove_dead_code(split_inputs, memo, dead_indicies);
let (new_inputs, new_dead_indicies) = partition_inputs_and_remove_dead_code(
split_inputs,
branches.clone(),
memo,
dead_indicies,
);
let mut new_branches = vec![];
for branch in branches.iter() {
new_branches.push(remove_dead_code_expr(
Expand All @@ -188,9 +244,6 @@ fn remove_dead_code_expr(
))
}
}
Expr::DeadCode(_subexpr) => {
panic!("Reached dead code without being in inputs of control flow node");
}
Expr::Function(_, _, _, _expr) => panic!("Found function inside of function"),
_ => {
expr.map_expr_children(|expr| remove_dead_code_expr(expr.clone(), memo, dead_indicies))
Expand Down
3 changes: 0 additions & 3 deletions dag_in_context/src/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,6 @@ pub enum Expr {
DoWhile(RcExpr, RcExpr),
Arg(Type, Assumption),
Function(String, Type, Type, RcExpr),
// marks a subexpression as dead code that can be dropped
// used by extraction, then dead code is removed by a dead code pass
DeadCode(RcExpr),
// optionally, the type of this symbol for typechecking
Symbolic(String, Option<Type>),
}
Expand Down
13 changes: 0 additions & 13 deletions dag_in_context/src/schema_helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,6 @@ impl Expr {
Expr::Alloc(..) => Constructor::Alloc,
Expr::Top(..) => Constructor::Top,
Expr::Symbolic(_, _ty) => panic!("found symbolic"),
Expr::DeadCode(_subexpr) => panic!("found dead code"),
}
}
pub fn func_name(&self) -> Option<String> {
Expand Down Expand Up @@ -223,7 +222,6 @@ impl Expr {
Expr::Empty(_, _) => vec![],
Expr::Arg(_, _) => vec![],
Expr::Symbolic(_, _ty) => vec![],
Expr::DeadCode(subexpr) => vec![subexpr.clone()],
}
}

Expand Down Expand Up @@ -253,7 +251,6 @@ impl Expr {
Expr::DoWhile(inputs, _body) => vec![inputs.clone()],
Expr::Arg(_, _) => vec![],
Expr::Symbolic(_, _ty) => vec![],
Expr::DeadCode(subexpr) => vec![subexpr.clone()],
}
}

Expand All @@ -275,7 +272,6 @@ impl Expr {
Expr::Arg(ty, _) => ty.clone(),
Expr::Function(_, ty, _, _) => ty.clone(),
Expr::Symbolic(_, _ty) => panic!("found symbolic"),
Expr::DeadCode(subexpr) => subexpr.get_arg_type(),
}
}

Expand Down Expand Up @@ -375,7 +371,6 @@ impl Expr {
Expr::Arg(_, ctx) => ctx,
Expr::Function(_, _, _, x) => x.get_ctx(),
Expr::Symbolic(_, _ty) => panic!("found symbolic"),
Expr::DeadCode(subexpr) => subexpr.get_ctx(),
}
}

Expand Down Expand Up @@ -511,14 +506,6 @@ impl Expr {
}
Expr::Empty(_, _) => Rc::new(Expr::Empty(arg_ty.clone(), arg_ctx.clone())),
Expr::Symbolic(_, _ty) => panic!("found symbolic"),
Expr::DeadCode(subexpr) => Rc::new(Expr::DeadCode(Self::subst_with_cache(
arg,
arg_ty,
arg_ctx,
subexpr,
subst_cache,
context_cache,
))),
};

// Add the substituted to cache
Expand Down
1 change: 0 additions & 1 deletion dag_in_context/src/to_egglog.rs
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,6 @@ impl Expr {
term_dag.app("Function".into(), vec![name_lit, ty_in, ty_out, body])
}
Expr::Symbolic(name, _ty) => term_dag.var(name.into()),
Expr::DeadCode(_subexpr) => panic!("Dead code should not be converted to egglog"),
};

term_dag
Expand Down
5 changes: 0 additions & 5 deletions dag_in_context/src/typechecker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -527,10 +527,6 @@ impl<'a> TypeChecker<'a> {
}
(found_ty.clone(), expr.clone())
}
Expr::DeadCode(subexpr) => {
let (ty, new_subexpr) = self.add_arg_types_to_expr(subexpr.clone(), arg_tys);
(ty, RcExpr::new(Expr::DeadCode(new_subexpr)))
}
Expr::Function(_, _, _, _) => panic!("Expected expression, got function"),
Expr::Symbolic(_, ty) => (
ty.clone()
Expand Down Expand Up @@ -561,7 +557,6 @@ impl<'a> TypeChecker<'a> {
pub(crate) fn get_arg_type(expr: &RcExpr) -> Type {
match expr.as_ref() {
Expr::Arg(ty, _) => ty.clone(),
Expr::DeadCode(subexpr) => Self::get_arg_type(subexpr),
Expr::Const(_, ty, _) => ty.clone(),
Expr::Top(_, rc, _, _) => Self::get_arg_type(rc),
Expr::Bop(_, left, _) => Self::get_arg_type(left),
Expand Down
Loading

0 comments on commit c2031a6

Please sign in to comment.