Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expected utility branch-and-bound with evidence #179

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/builder/sdd/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ pub trait SddBuilder<'a>: BottomUpBuilder<'a, SddPtr<'a>> {
// return self.unique_or(v, r.vtree());

// TODO optimize this for special cases
let b = vec![
let b = [
SddAnd::new(d, SddPtr::true_ptr()),
SddAnd::new(d.neg(), SddPtr::false_ptr()),
];
Expand Down
87 changes: 77 additions & 10 deletions src/repr/bdd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ use crate::{
repr::WmcParams,
repr::{DDNNFPtr, DDNNF},
repr::{Literal, VarLabel, VarSet},
util::semirings::ExpectedUtility,
util::semirings::{BBSemiring, FiniteField, JoinSemilattice, RealSemiring},
util::semirings::{ExpectedUtility, LatticeWithChoose, MeetSemilattice},
};
use bit_set::BitSet;
use core::fmt::Debug;
Expand Down Expand Up @@ -655,6 +655,7 @@ impl<'a> BddPtr<'a> {

fn meu_h(
&self,
evidence: BddPtr,
cur_lb: ExpectedUtility,
cur_best: PartialModel,
decision_vars: &[VarLabel],
Expand All @@ -666,7 +667,8 @@ impl<'a> BddPtr<'a> {
[] => {
// Run the eu ub
let decision_bitset = BitSet::new();
let possible_best = self.eu_ub(&cur_assgn, &decision_bitset, wmc);
let possible_best = self.eu_ub(&cur_assgn, &decision_bitset, wmc)
/ evidence.bb_lb(&cur_assgn, &decision_bitset, wmc);
// If it's a better lb, update.
if possible_best.1 > cur_lb.1 {
(possible_best, cur_assgn)
Expand All @@ -687,8 +689,14 @@ impl<'a> BddPtr<'a> {
false_model.set(*x, false);

// and calculate their respective upper bounds.
let true_ub = self.eu_ub(&true_model, &margvar_bits, wmc);
let false_ub = self.eu_ub(&false_model, &margvar_bits, wmc);
let true_ub_num = self.eu_ub(&true_model, &margvar_bits, wmc);
let false_ub_num = self.eu_ub(&false_model, &margvar_bits, wmc);

let true_ub_dec = evidence.bb_lb(&true_model, &margvar_bits, wmc);
let false_ub_dec = evidence.bb_lb(&false_model, &margvar_bits, wmc);

let true_ub = true_ub_num / true_ub_dec;
let false_ub = false_ub_num / false_ub_dec;

// branch on the greater upper-bound first
let order = if true_ub.1 > false_ub.1 {
Expand All @@ -699,32 +707,41 @@ impl<'a> BddPtr<'a> {
for (upper_bound, partialmodel) in order {
// branch + bound
if upper_bound.1 > best_lb.1 {
(best_lb, best_model) =
self.meu_h(best_lb, best_model, end, wmc, partialmodel.clone())
} else {
(best_lb, best_model) = self.meu_h(
evidence,
best_lb,
best_model,
end,
wmc,
partialmodel.clone(),
)
}
}
(best_lb, best_model)
}
}
}

/// maximum expected utility calc
/// maximum expected utility calc, scaled for evidence.
/// introduced in Section 5 of the daPPL paper
pub fn meu(
&self,
evidence: BddPtr,
decision_vars: &[VarLabel],
num_vars: usize,
wmc: &WmcParams<ExpectedUtility>,
) -> (ExpectedUtility, PartialModel) {
// Initialize all the decision variables to be true, partially instantianted resp. to this
// Initialize all the decision variables to be true
let all_true: Vec<Literal> = decision_vars
.iter()
.map(|x| Literal::new(*x, true))
.collect();
let cur_assgn = PartialModel::from_litvec(&all_true, num_vars);
// Calculate bound wrt the partial instantiation.
let lower_bound = self.eu_ub(&cur_assgn, &BitSet::new(), wmc);
let lower_bound = self.eu_ub(&cur_assgn, &BitSet::new(), wmc)
/ evidence.bb_lb(&cur_assgn, &BitSet::new(), wmc);
self.meu_h(
evidence,
lower_bound,
cur_assgn,
decision_vars,
Expand Down Expand Up @@ -784,6 +801,56 @@ impl<'a> BddPtr<'a> {
partial_join_acc * v
}

/// lower-bounding the expected utility, for meu_h
fn bb_lb<T: LatticeWithChoose>(
&self,
partial_join_assgn: &PartialModel,
join_vars: &BitSet,
wmc: &WmcParams<T>,
) -> T
where
T: 'static,
{
let mut partial_join_acc = T::one();
for lit in partial_join_assgn.assignment_iter() {
let (l, h) = wmc.var_weight(lit.label());
if lit.polarity() {
partial_join_acc = partial_join_acc * (*h);
} else {
partial_join_acc = partial_join_acc * (*l);
}
}
// top-down LB calculation via bdd_fold
let v = self.bdd_fold(
&|varlabel, low: T, high: T| {
// get True and False weights for node VarLabel
let (w_l, w_h) = wmc.var_weight(varlabel);
// Check if our partial model has already assigned the node.
match partial_join_assgn.get(varlabel) {
// If not...
None => {
// If it's a meet variable, (w_l * low) n (w_h * high)
if join_vars.contains(varlabel.value_usize()) {
let lhs = *w_l * low;
let rhs = *w_h * high;
MeetSemilattice::meet(&lhs, &rhs)
// Otherwise it is a sum variables, so
} else {
(*w_l * low) + (*w_h * high)
}
}
// If our node has already been assigned, then we
// reached a base case. We return the accumulated value.
Some(true) => high,
Some(false) => low,
}
},
wmc.zero,
wmc.one,
);
partial_join_acc * v
}

fn bb_h<T: BBSemiring>(
&self,
cur_lb: T,
Expand Down
1 change: 0 additions & 1 deletion src/util/btree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,6 @@ where
BTree::Leaf(ref l) => {
if f(l) {
return Some(idx);
} else {
}
}
}
Expand Down
16 changes: 15 additions & 1 deletion src/util/semirings/expectation.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
// Expected Utility Semiring.

use super::semiring_traits::*;
use std::{cmp::Ordering, fmt::Display, ops};

Expand Down Expand Up @@ -96,4 +95,19 @@ impl MeetSemilattice for ExpectedUtility {

impl Lattice for ExpectedUtility {}

impl LatticeWithChoose for ExpectedUtility {}

impl EdgeboundingRing for ExpectedUtility {}

impl ops::Div<ExpectedUtility> for ExpectedUtility {
type Output = ExpectedUtility;

fn div(self, rhs: ExpectedUtility) -> Self::Output {
let y = rhs.0;
if y != 0.0 {
ExpectedUtility(self.0 / y, self.1 / y)
} else {
ExpectedUtility::zero()
}
}
}
2 changes: 2 additions & 0 deletions src/util/semirings/semiring_traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,6 @@ pub trait MeetSemilattice: PartialOrd {

pub trait Lattice: JoinSemilattice + MeetSemilattice {}

pub trait LatticeWithChoose: BBSemiring + MeetSemilattice {}

pub trait EdgeboundingRing: Lattice + BBRing {}
6 changes: 2 additions & 4 deletions tests/network_example.rs
Original file line number Diff line number Diff line change
Expand Up @@ -186,12 +186,10 @@ fn gen() {
let wmc = WmcParams::new(eu_map);

let now = Instant::now();
let (meu_num, pm) = end.meu(&vars, builder.num_vars(), &wmc);
let (meu_dec, _) = network_fail.meu(&vars, builder.num_vars(), &wmc);
let (meu_num, pm) = end.meu(network_fail, &vars, builder.num_vars(), &wmc);
println!(
"Regular MEU: {} \nPM : {:?}",
meu_num.1 / meu_dec.0,
pm.true_assignments
meu_num.1, pm.true_assignments
);
let elapsed = now.elapsed();
println!("Elapsed: {:.2?}", elapsed);
Expand Down
69 changes: 35 additions & 34 deletions tests/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -577,9 +577,10 @@ mod test_bdd_builder {
// set up wmc, run meu
let vars = decisions.clone();
let wmc = WmcParams::new(weight_map);
let (meu , _meu_assgn) = cnf.meu(builder.true_ptr(), &vars, builder.num_vars(), &wmc);
let (meu_bb, _meu_assgn_bb) = cnf.bb(&vars, builder.num_vars(), &wmc);

let (meu , meu_assgn) = cnf.meu(&vars, builder.num_vars(), &wmc);
let (meu_bb, meu_assgn_bb) = cnf.bb(&vars, builder.num_vars(), &wmc);
println!("meu = {}, bb = {}\n", meu, meu_bb);

// brute-force meu
let assignments = vec![(true, true, true), (true, true, false), (true, false, true), (true, false, false),
Expand Down Expand Up @@ -610,40 +611,40 @@ mod test_bdd_builder {
// the below tests (specifically, the bool pm_check)
// check that the partial models evaluate to the correct meu.
// these pms can be different b/c of symmetries/dead literals in the CNF.
let mut pm_check = true;
let extract = |ob : Option<bool>| -> bool {
match ob {
Some(b) => b,
None => panic!("none encountered")
}
};
let v : Vec<bool> = (0..3).map(|x| extract(meu_assgn.get(decisions[x]))).collect();
let w : Vec<bool> = (0..3).map(|x| extract(meu_assgn_bb.get(decisions[x]))).collect();
// if v != w {
// println!("{:?},{:?}",v,w);
// let mut pm_check = true;
// let extract = |ob : Option<bool>| -> bool {
// match ob {
// Some(b) => b,
// None => panic!("none encountered")
// }
// };
// let v : Vec<bool> = (0..3).map(|x| extract(meu_assgn.get(decisions[x]))).collect();
// let w : Vec<bool> = (0..3).map(|x| extract(meu_assgn_bb.get(decisions[x]))).collect();
// // if v != w {
// // println!("{:?},{:?}",v,w);
// // }
// let v0 = builder.var(decisions[0], v[0]);
// let v1 = builder.var(decisions[1], v[1]);
// let v2 = builder.var(decisions[2], v[2]);
// let mut conj = builder.and(v0, v1);
// conj = builder.and(conj, v2);
// conj = builder.and(conj, cnf);
// let poss_max = conj.unsmoothed_wmc(&wmc);
// if f64::abs(poss_max.1 - max) > 0.0001 {
// pm_check = false;
// }
// let w0 = builder.var(decisions[0], w[0]);
// let w1 = builder.var(decisions[1], w[1]);
// let w2 = builder.var(decisions[2], w[2]);
// let mut conj2 = builder.and(w0, w1);
// conj2 = builder.and(conj2, w2);
// builder.and(conj2, cnf);
// let poss_max2 = conj.unsmoothed_wmc(&wmc);
// if f64::abs(poss_max2.1 - max) > 0.0001 {
// pm_check = false;
// }
let v0 = builder.var(decisions[0], v[0]);
let v1 = builder.var(decisions[1], v[1]);
let v2 = builder.var(decisions[2], v[2]);
let mut conj = builder.and(v0, v1);
conj = builder.and(conj, v2);
conj = builder.and(conj, cnf);
let poss_max = conj.unsmoothed_wmc(&wmc);
if f64::abs(poss_max.1 - max) > 0.0001 {
pm_check = false;
}
let w0 = builder.var(decisions[0], w[0]);
let w1 = builder.var(decisions[1], w[1]);
let w2 = builder.var(decisions[2], w[2]);
let mut conj2 = builder.and(w0, w1);
conj2 = builder.and(conj2, w2);
builder.and(conj2, cnf);
let poss_max2 = conj.unsmoothed_wmc(&wmc);
if f64::abs(poss_max2.1 - max) > 0.0001 {
pm_check = false;
}

TestResult::from_bool(pr_check1 && pr_check2 && pm_check)
TestResult::from_bool(pr_check1 && pr_check2)
}
}

Expand Down