From 015198f349d61912dae5fd0288690b074e0e4bbf Mon Sep 17 00:00:00 2001 From: Minsung Date: Wed, 20 Sep 2023 20:55:51 -0400 Subject: [PATCH 1/4] nothing so far --- src/repr/bdd.rs | 66 ++++++++++++++++++++++++++- src/util/semirings/expectation.rs | 2 + src/util/semirings/semiring_traits.rs | 2 + 3 files changed, 69 insertions(+), 1 deletion(-) diff --git a/src/repr/bdd.rs b/src/repr/bdd.rs index 613a1a2a..163ae439 100644 --- a/src/repr/bdd.rs +++ b/src/repr/bdd.rs @@ -6,7 +6,7 @@ use crate::{ repr::WmcParams, repr::{DDNNFPtr, DDNNF}, repr::{Literal, VarLabel, VarSet}, - util::semirings::ExpectedUtility, + util::semirings::{ExpectedUtility, MeetSemilattice, LatticeWithChoose}, util::semirings::{BBSemiring, FiniteField, JoinSemilattice, RealSemiring}, }; use bit_set::BitSet; @@ -784,6 +784,56 @@ impl<'a> BddPtr<'a> { partial_join_acc * v } + /// lower-bounding the expected utility, for meu_h + fn bb_lb( + &self, + partial_join_assgn: &PartialModel, + join_vars: &BitSet, + wmc: &WmcParams, + ) -> T + where + T: 'static, + { + let mut partial_join_acc = T::one(); + for lit in partial_join_assgn.assignment_iter() { + let (l, h) = wmc.var_weight(lit.label()); + if lit.polarity() { + partial_join_acc = partial_join_acc * (*h); + } else { + partial_join_acc = partial_join_acc * (*l); + } + } + // top-down LB calculation via bdd_fold + let v = self.bdd_fold( + &|varlabel, low: T, high: T| { + // get True and False weights for node VarLabel + let (w_l, w_h) = wmc.var_weight(varlabel); + // Check if our partial model has already assigned the node. + match partial_join_assgn.get(varlabel) { + // If not... + None => { + // If it's a meet variable, (w_l * low) n (w_h * high) + if join_vars.contains(varlabel.value_usize()) { + let lhs = *w_l * low; + let rhs = *w_h * high; + MeetSemilattice::meet(&lhs, &rhs) + // Otherwise it is a sum variables, so + } else { + (*w_l * low) + (*w_h * high) + } + } + // If our node has already been assigned, then we + // reached a base case. We return the accumulated value. + Some(true) => high, + Some(false) => low, + } + }, + wmc.zero, + wmc.one, + ); + partial_join_acc * v + } + fn bb_h( &self, cur_lb: T, @@ -874,6 +924,20 @@ impl<'a> BddPtr<'a> { ) } + /// branch and bound with incremental evidence gain. + pub fn bb_with_evidence( + &self, + evidence : BddPtr, + vars: &[VarLabel], + num_vars: usize, + wmc: &WmcParams, + ) -> (T, PartialModel) + where + T: 'static, + { + + } + /// performs a semantic hash and caches the result on the node pub fn cached_semantic_hash( &self, diff --git a/src/util/semirings/expectation.rs b/src/util/semirings/expectation.rs index 4eff5b29..9229bab8 100644 --- a/src/util/semirings/expectation.rs +++ b/src/util/semirings/expectation.rs @@ -96,4 +96,6 @@ impl MeetSemilattice for ExpectedUtility { impl Lattice for ExpectedUtility {} +impl LatticeWithChoose for ExpectedUtility { } + impl EdgeboundingRing for ExpectedUtility {} diff --git a/src/util/semirings/semiring_traits.rs b/src/util/semirings/semiring_traits.rs index 655da261..aff4785b 100644 --- a/src/util/semirings/semiring_traits.rs +++ b/src/util/semirings/semiring_traits.rs @@ -42,4 +42,6 @@ pub trait MeetSemilattice: PartialOrd { pub trait Lattice: JoinSemilattice + MeetSemilattice {} +pub trait LatticeWithChoose : BBSemiring + MeetSemilattice {} + pub trait EdgeboundingRing: Lattice + BBRing {} From 5c0c15b0ee32090cda61841b198fb032d5d76d3b Mon Sep 17 00:00:00 2001 From: Minsung Date: Thu, 21 Sep 2023 14:36:47 -0400 Subject: [PATCH 2/4] algorithm --- src/repr/bdd.rs | 236 +++++++++++++++--------------- src/util/semirings/expectation.rs | 14 +- tests/network_example.rs | 5 +- 3 files changed, 132 insertions(+), 123 deletions(-) diff --git a/src/repr/bdd.rs b/src/repr/bdd.rs index 163ae439..00e16890 100644 --- a/src/repr/bdd.rs +++ b/src/repr/bdd.rs @@ -6,7 +6,7 @@ use crate::{ repr::WmcParams, repr::{DDNNFPtr, DDNNF}, repr::{Literal, VarLabel, VarSet}, - util::semirings::{ExpectedUtility, MeetSemilattice, LatticeWithChoose}, + util::semirings::{ExpectedUtility, MeetSemilattice, LatticeWithChoose, Lattice}, util::semirings::{BBSemiring, FiniteField, JoinSemilattice, RealSemiring}, }; use bit_set::BitSet; @@ -17,7 +17,7 @@ use std::{ collections::HashMap, hash::{Hash, Hasher}, iter::FromIterator, - ptr, + ptr, thread, }; use BddPtr::*; @@ -655,6 +655,7 @@ impl<'a> BddPtr<'a> { fn meu_h( &self, + evidence : BddPtr, cur_lb: ExpectedUtility, cur_best: PartialModel, decision_vars: &[VarLabel], @@ -666,7 +667,8 @@ impl<'a> BddPtr<'a> { [] => { // Run the eu ub let decision_bitset = BitSet::new(); - let possible_best = self.eu_ub(&cur_assgn, &decision_bitset, wmc); + let possible_best = + self.eu_ub(&cur_assgn, &decision_bitset, wmc) / evidence.bb_lb(&cur_assgn, &decision_bitset, wmc); // If it's a better lb, update. if possible_best.1 > cur_lb.1 { (possible_best, cur_assgn) @@ -687,8 +689,14 @@ impl<'a> BddPtr<'a> { false_model.set(*x, false); // and calculate their respective upper bounds. - let true_ub = self.eu_ub(&true_model, &margvar_bits, wmc); - let false_ub = self.eu_ub(&false_model, &margvar_bits, wmc); + let true_ub_num = self.eu_ub(&true_model, &margvar_bits, wmc); + let false_ub_num = self.eu_ub(&false_model, &margvar_bits, wmc); + + let true_ub_dec = evidence.bb_lb(&true_model, &margvar_bits, wmc); + let false_ub_dec = evidence.bb_lb(&false_model, &margvar_bits, wmc); + + let true_ub = true_ub_num / true_ub_dec; + let false_ub = false_ub_num / false_ub_dec; // branch on the greater upper-bound first let order = if true_ub.1 > false_ub.1 { @@ -700,7 +708,7 @@ impl<'a> BddPtr<'a> { // branch + bound if upper_bound.1 > best_lb.1 { (best_lb, best_model) = - self.meu_h(best_lb, best_model, end, wmc, partialmodel.clone()) + self.meu_h(evidence, best_lb, best_model, end, wmc, partialmodel.clone()) } else { } } @@ -709,22 +717,26 @@ impl<'a> BddPtr<'a> { } } - /// maximum expected utility calc + /// maximum expected utility calc, scaled for evidence. + /// introduced in Section 5 of the daPPL paper pub fn meu( &self, + evidence : BddPtr, decision_vars: &[VarLabel], num_vars: usize, wmc: &WmcParams, ) -> (ExpectedUtility, PartialModel) { - // Initialize all the decision variables to be true, partially instantianted resp. to this + // Initialize all the decision variables to be true let all_true: Vec = decision_vars .iter() .map(|x| Literal::new(*x, true)) .collect(); let cur_assgn = PartialModel::from_litvec(&all_true, num_vars); // Calculate bound wrt the partial instantiation. - let lower_bound = self.eu_ub(&cur_assgn, &BitSet::new(), wmc); + let lower_bound = + self.eu_ub(&cur_assgn, &BitSet::new(), wmc) / evidence.bb_lb(&cur_assgn, &BitSet::new(), wmc); self.meu_h( + evidence, lower_bound, cur_assgn, decision_vars, @@ -785,7 +797,7 @@ impl<'a> BddPtr<'a> { } /// lower-bounding the expected utility, for meu_h - fn bb_lb( + fn bb_lb( &self, partial_join_assgn: &PartialModel, join_vars: &BitSet, @@ -828,115 +840,101 @@ impl<'a> BddPtr<'a> { Some(false) => low, } }, - wmc.zero, - wmc.one, - ); - partial_join_acc * v - } - - fn bb_h( - &self, - cur_lb: T, - cur_best: PartialModel, - join_vars: &[VarLabel], - wmc: &WmcParams, - cur_assgn: PartialModel, - ) -> (T, PartialModel) - where - T: 'static, - { - match join_vars { - // If all join variables are assigned, - [] => { - // Run the bb_ub - let empty_join_vars = BitSet::new(); - let possible_best = self.bb_ub(&cur_assgn, &empty_join_vars, wmc); - // If it's a better lb, update. - let best = BBSemiring::choose(&cur_lb, &possible_best); - if cur_lb == best { - (cur_lb, cur_best) - } else { - (possible_best, cur_assgn) - } - } - // If there exists an unassigned decision variable, - [x, end @ ..] => { - let mut best_model = cur_best.clone(); - let mut best_lb = cur_lb; - let join_vars_bits = BitSet::from_iter(end.iter().map(|x| x.value_usize())); - // Consider the assignment of it to true... - let mut true_model = cur_assgn.clone(); - true_model.set(*x, true); - // ... and false... - let mut false_model = cur_assgn; - false_model.set(*x, false); - - // and calculate their respective upper bounds. - let true_ub = self.bb_ub(&true_model, &join_vars_bits, wmc); - let false_ub = self.bb_ub(&false_model, &join_vars_bits, wmc); - - // arbitrarily order the T/F bounds - let order = if true_ub == BBSemiring::choose(&true_ub, &false_ub) { - [(true_ub, true_model), (false_ub, false_model)] - } else { - [(false_ub, false_model), (true_ub, true_model)] - }; - // the actual branching and bounding - for (upper_bound, partialmodel) in order { - // if upper_bound == BBAlgebra::choose(&upper_bound, &best_lb) { - if !PartialOrd::le(&upper_bound, &cur_lb) { - let (rec, rec_pm) = - self.bb_h(best_lb, best_model.clone(), end, wmc, partialmodel.clone()); - let new_lb = BBSemiring::choose(&cur_lb, &rec); - if new_lb == rec { - (best_lb, best_model) = (rec, rec_pm); - } else { - (best_lb, best_model) = (cur_lb, cur_best.clone()); - } - } - } - (best_lb, best_model) - } - } - } - - /// branch and bound generic over T a BBAlgebra. - pub fn bb( - &self, - join_vars: &[VarLabel], - num_vars: usize, - wmc: &WmcParams, - ) -> (T, PartialModel) - where - T: 'static, - { - // Initialize all the decision variables to be true, partially instantianted resp. to this - let all_true: Vec = join_vars.iter().map(|x| Literal::new(*x, true)).collect(); - let cur_assgn = PartialModel::from_litvec(&all_true, num_vars); - // Calculate bound wrt the partial instantiation. - let lower_bound = self.bb_ub(&cur_assgn, &BitSet::new(), wmc); - self.bb_h( - lower_bound, - cur_assgn, - join_vars, - wmc, - PartialModel::from_litvec(&[], num_vars), - ) - } - - /// branch and bound with incremental evidence gain. - pub fn bb_with_evidence( - &self, - evidence : BddPtr, - vars: &[VarLabel], - num_vars: usize, - wmc: &WmcParams, - ) -> (T, PartialModel) - where - T: 'static, - { - - } + wmc.zero, + wmc.one, + ); + partial_join_acc * v + } + + fn bb_h( + &self, + cur_lb: T, + cur_best: PartialModel, + join_vars: &[VarLabel], + wmc: &WmcParams, + cur_assgn: PartialModel, + ) -> (T, PartialModel) + where + T: 'static, + { + match join_vars { + // If all join variables are assigned, + [] => { + // Run the bb_ub + let empty_join_vars = BitSet::new(); + let possible_best = self.bb_ub(&cur_assgn, &empty_join_vars, wmc); + // If it's a better lb, update. + let best = BBSemiring::choose(&cur_lb, &possible_best); + if cur_lb == best { + (cur_lb, cur_best) + } else { + (possible_best, cur_assgn) + } + } + // If there exists an unassigned decision variable, + [x, end @ ..] => { + let mut best_model = cur_best.clone(); + let mut best_lb = cur_lb; + let join_vars_bits = BitSet::from_iter(end.iter().map(|x| x.value_usize())); + // Consider the assignment of it to true... + let mut true_model = cur_assgn.clone(); + true_model.set(*x, true); + // ... and false... + let mut false_model = cur_assgn; + false_model.set(*x, false); + + // and calculate their respective upper bounds. + let true_ub = self.bb_ub(&true_model, &join_vars_bits, wmc); + let false_ub = self.bb_ub(&false_model, &join_vars_bits, wmc); + + // arbitrarily order the T/F bounds + let order = if true_ub == BBSemiring::choose(&true_ub, &false_ub) { + [(true_ub, true_model), (false_ub, false_model)] + } else { + [(false_ub, false_model), (true_ub, true_model)] + }; + // the actual branching and bounding + for (upper_bound, partialmodel) in order { + // if upper_bound == BBAlgebra::choose(&upper_bound, &best_lb) { + if !PartialOrd::le(&upper_bound, &cur_lb) { + let (rec, rec_pm) = + self.bb_h(best_lb, best_model.clone(), end, wmc, partialmodel.clone()); + let new_lb = BBSemiring::choose(&cur_lb, &rec); + if new_lb == rec { + (best_lb, best_model) = (rec, rec_pm); + } else { + (best_lb, best_model) = (cur_lb, cur_best.clone()); + } + } + } + (best_lb, best_model) + } + } + } + + /// branch and bound generic over T a BBAlgebra. + pub fn bb( + &self, + join_vars: &[VarLabel], + num_vars: usize, + wmc: &WmcParams, + ) -> (T, PartialModel) + where + T: 'static, + { + // Initialize all the decision variables to be true, partially instantianted resp. to this + let all_true: Vec = join_vars.iter().map(|x| Literal::new(*x, true)).collect(); + let cur_assgn = PartialModel::from_litvec(&all_true, num_vars); + // Calculate bound wrt the partial instantiation. + let lower_bound = self.bb_ub(&cur_assgn, &BitSet::new(), wmc); + self.bb_h( + lower_bound, + cur_assgn, + join_vars, + wmc, + PartialModel::from_litvec(&[], num_vars), + ) + } /// performs a semantic hash and caches the result on the node pub fn cached_semantic_hash( diff --git a/src/util/semirings/expectation.rs b/src/util/semirings/expectation.rs index 9229bab8..a5e208c8 100644 --- a/src/util/semirings/expectation.rs +++ b/src/util/semirings/expectation.rs @@ -1,5 +1,4 @@ // Expected Utility Semiring. - use super::semiring_traits::*; use std::{cmp::Ordering, fmt::Display, ops}; @@ -99,3 +98,16 @@ impl Lattice for ExpectedUtility {} impl LatticeWithChoose for ExpectedUtility { } impl EdgeboundingRing for ExpectedUtility {} + +impl ops::Div for ExpectedUtility { + type Output = ExpectedUtility; + + fn div(self, rhs: ExpectedUtility) -> Self::Output { + let y = rhs.0; + if y == 0.0 { + ExpectedUtility(self.0 / y, self.1 / y) + } else { + ExpectedUtility::zero() + } + } +} diff --git a/tests/network_example.rs b/tests/network_example.rs index fc17f834..0f98f33f 100644 --- a/tests/network_example.rs +++ b/tests/network_example.rs @@ -186,11 +186,10 @@ fn gen() { let wmc = WmcParams::new(eu_map); let now = Instant::now(); - let (meu_num, pm) = end.meu(&vars, builder.num_vars(), &wmc); - let (meu_dec, _) = network_fail.meu(&vars, builder.num_vars(), &wmc); + let (meu_num, pm) = end.meu(network_fail, &vars, builder.num_vars(), &wmc); println!( "Regular MEU: {} \nPM : {:?}", - meu_num.1 / meu_dec.0, + meu_num.1, pm.true_assignments ); let elapsed = now.elapsed(); From e57e89342464b1c1a50dc7bfd94e089df38dd294 Mon Sep 17 00:00:00 2001 From: Minsung Date: Thu, 21 Sep 2023 15:36:07 -0400 Subject: [PATCH 3/4] whoops! --- src/util/semirings/expectation.rs | 2 +- tests/test.rs | 69 ++++++++++++++++--------------- 2 files changed, 36 insertions(+), 35 deletions(-) diff --git a/src/util/semirings/expectation.rs b/src/util/semirings/expectation.rs index a5e208c8..d5d66ccb 100644 --- a/src/util/semirings/expectation.rs +++ b/src/util/semirings/expectation.rs @@ -104,7 +104,7 @@ impl ops::Div for ExpectedUtility { fn div(self, rhs: ExpectedUtility) -> Self::Output { let y = rhs.0; - if y == 0.0 { + if y != 0.0 { ExpectedUtility(self.0 / y, self.1 / y) } else { ExpectedUtility::zero() diff --git a/tests/test.rs b/tests/test.rs index 194d9b8f..1ea9f9d8 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -577,9 +577,10 @@ mod test_bdd_builder { // set up wmc, run meu let vars = decisions.clone(); let wmc = WmcParams::new(weight_map); + let (meu , _meu_assgn) = cnf.meu(builder.true_ptr(), &vars, builder.num_vars(), &wmc); + let (meu_bb, _meu_assgn_bb) = cnf.bb(&vars, builder.num_vars(), &wmc); - let (meu , meu_assgn) = cnf.meu(&vars, builder.num_vars(), &wmc); - let (meu_bb, meu_assgn_bb) = cnf.bb(&vars, builder.num_vars(), &wmc); + println!("meu = {}, bb = {}\n", meu, meu_bb); // brute-force meu let assignments = vec![(true, true, true), (true, true, false), (true, false, true), (true, false, false), @@ -610,40 +611,40 @@ mod test_bdd_builder { // the below tests (specifically, the bool pm_check) // check that the partial models evaluate to the correct meu. // these pms can be different b/c of symmetries/dead literals in the CNF. - let mut pm_check = true; - let extract = |ob : Option| -> bool { - match ob { - Some(b) => b, - None => panic!("none encountered") - } - }; - let v : Vec = (0..3).map(|x| extract(meu_assgn.get(decisions[x]))).collect(); - let w : Vec = (0..3).map(|x| extract(meu_assgn_bb.get(decisions[x]))).collect(); - // if v != w { - // println!("{:?},{:?}",v,w); + // let mut pm_check = true; + // let extract = |ob : Option| -> bool { + // match ob { + // Some(b) => b, + // None => panic!("none encountered") + // } + // }; + // let v : Vec = (0..3).map(|x| extract(meu_assgn.get(decisions[x]))).collect(); + // let w : Vec = (0..3).map(|x| extract(meu_assgn_bb.get(decisions[x]))).collect(); + // // if v != w { + // // println!("{:?},{:?}",v,w); + // // } + // let v0 = builder.var(decisions[0], v[0]); + // let v1 = builder.var(decisions[1], v[1]); + // let v2 = builder.var(decisions[2], v[2]); + // let mut conj = builder.and(v0, v1); + // conj = builder.and(conj, v2); + // conj = builder.and(conj, cnf); + // let poss_max = conj.unsmoothed_wmc(&wmc); + // if f64::abs(poss_max.1 - max) > 0.0001 { + // pm_check = false; + // } + // let w0 = builder.var(decisions[0], w[0]); + // let w1 = builder.var(decisions[1], w[1]); + // let w2 = builder.var(decisions[2], w[2]); + // let mut conj2 = builder.and(w0, w1); + // conj2 = builder.and(conj2, w2); + // builder.and(conj2, cnf); + // let poss_max2 = conj.unsmoothed_wmc(&wmc); + // if f64::abs(poss_max2.1 - max) > 0.0001 { + // pm_check = false; // } - let v0 = builder.var(decisions[0], v[0]); - let v1 = builder.var(decisions[1], v[1]); - let v2 = builder.var(decisions[2], v[2]); - let mut conj = builder.and(v0, v1); - conj = builder.and(conj, v2); - conj = builder.and(conj, cnf); - let poss_max = conj.unsmoothed_wmc(&wmc); - if f64::abs(poss_max.1 - max) > 0.0001 { - pm_check = false; - } - let w0 = builder.var(decisions[0], w[0]); - let w1 = builder.var(decisions[1], w[1]); - let w2 = builder.var(decisions[2], w[2]); - let mut conj2 = builder.and(w0, w1); - conj2 = builder.and(conj2, w2); - builder.and(conj2, cnf); - let poss_max2 = conj.unsmoothed_wmc(&wmc); - if f64::abs(poss_max2.1 - max) > 0.0001 { - pm_check = false; - } - TestResult::from_bool(pr_check1 && pr_check2 && pm_check) + TestResult::from_bool(pr_check1 && pr_check2) } } From 70f8f9cf41e8d23f2242bc3d9d8cd30289cff26e Mon Sep 17 00:00:00 2001 From: Minsung Date: Thu, 21 Sep 2023 15:37:03 -0400 Subject: [PATCH 4/4] run linty lint --- src/builder/sdd/builder.rs | 2 +- src/repr/bdd.rs | 221 +++++++++++++------------- src/util/btree.rs | 1 - src/util/semirings/expectation.rs | 20 +-- src/util/semirings/semiring_traits.rs | 2 +- tests/network_example.rs | 3 +- 6 files changed, 126 insertions(+), 123 deletions(-) diff --git a/src/builder/sdd/builder.rs b/src/builder/sdd/builder.rs index f5da1fd1..5290f0b0 100644 --- a/src/builder/sdd/builder.rs +++ b/src/builder/sdd/builder.rs @@ -198,7 +198,7 @@ pub trait SddBuilder<'a>: BottomUpBuilder<'a, SddPtr<'a>> { // return self.unique_or(v, r.vtree()); // TODO optimize this for special cases - let b = vec![ + let b = [ SddAnd::new(d, SddPtr::true_ptr()), SddAnd::new(d.neg(), SddPtr::false_ptr()), ]; diff --git a/src/repr/bdd.rs b/src/repr/bdd.rs index 00e16890..4d2a0253 100644 --- a/src/repr/bdd.rs +++ b/src/repr/bdd.rs @@ -6,8 +6,8 @@ use crate::{ repr::WmcParams, repr::{DDNNFPtr, DDNNF}, repr::{Literal, VarLabel, VarSet}, - util::semirings::{ExpectedUtility, MeetSemilattice, LatticeWithChoose, Lattice}, util::semirings::{BBSemiring, FiniteField, JoinSemilattice, RealSemiring}, + util::semirings::{ExpectedUtility, LatticeWithChoose, MeetSemilattice}, }; use bit_set::BitSet; use core::fmt::Debug; @@ -17,7 +17,7 @@ use std::{ collections::HashMap, hash::{Hash, Hasher}, iter::FromIterator, - ptr, thread, + ptr, }; use BddPtr::*; @@ -655,7 +655,7 @@ impl<'a> BddPtr<'a> { fn meu_h( &self, - evidence : BddPtr, + evidence: BddPtr, cur_lb: ExpectedUtility, cur_best: PartialModel, decision_vars: &[VarLabel], @@ -667,8 +667,8 @@ impl<'a> BddPtr<'a> { [] => { // Run the eu ub let decision_bitset = BitSet::new(); - let possible_best = - self.eu_ub(&cur_assgn, &decision_bitset, wmc) / evidence.bb_lb(&cur_assgn, &decision_bitset, wmc); + let possible_best = self.eu_ub(&cur_assgn, &decision_bitset, wmc) + / evidence.bb_lb(&cur_assgn, &decision_bitset, wmc); // If it's a better lb, update. if possible_best.1 > cur_lb.1 { (possible_best, cur_assgn) @@ -707,9 +707,14 @@ impl<'a> BddPtr<'a> { for (upper_bound, partialmodel) in order { // branch + bound if upper_bound.1 > best_lb.1 { - (best_lb, best_model) = - self.meu_h(evidence, best_lb, best_model, end, wmc, partialmodel.clone()) - } else { + (best_lb, best_model) = self.meu_h( + evidence, + best_lb, + best_model, + end, + wmc, + partialmodel.clone(), + ) } } (best_lb, best_model) @@ -717,11 +722,11 @@ impl<'a> BddPtr<'a> { } } - /// maximum expected utility calc, scaled for evidence. + /// maximum expected utility calc, scaled for evidence. /// introduced in Section 5 of the daPPL paper pub fn meu( &self, - evidence : BddPtr, + evidence: BddPtr, decision_vars: &[VarLabel], num_vars: usize, wmc: &WmcParams, @@ -733,8 +738,8 @@ impl<'a> BddPtr<'a> { .collect(); let cur_assgn = PartialModel::from_litvec(&all_true, num_vars); // Calculate bound wrt the partial instantiation. - let lower_bound = - self.eu_ub(&cur_assgn, &BitSet::new(), wmc) / evidence.bb_lb(&cur_assgn, &BitSet::new(), wmc); + let lower_bound = self.eu_ub(&cur_assgn, &BitSet::new(), wmc) + / evidence.bb_lb(&cur_assgn, &BitSet::new(), wmc); self.meu_h( evidence, lower_bound, @@ -840,102 +845,102 @@ impl<'a> BddPtr<'a> { Some(false) => low, } }, - wmc.zero, - wmc.one, - ); - partial_join_acc * v - } - - fn bb_h( - &self, - cur_lb: T, - cur_best: PartialModel, - join_vars: &[VarLabel], - wmc: &WmcParams, - cur_assgn: PartialModel, - ) -> (T, PartialModel) - where - T: 'static, - { - match join_vars { - // If all join variables are assigned, - [] => { - // Run the bb_ub - let empty_join_vars = BitSet::new(); - let possible_best = self.bb_ub(&cur_assgn, &empty_join_vars, wmc); - // If it's a better lb, update. - let best = BBSemiring::choose(&cur_lb, &possible_best); - if cur_lb == best { - (cur_lb, cur_best) - } else { - (possible_best, cur_assgn) - } - } - // If there exists an unassigned decision variable, - [x, end @ ..] => { - let mut best_model = cur_best.clone(); - let mut best_lb = cur_lb; - let join_vars_bits = BitSet::from_iter(end.iter().map(|x| x.value_usize())); - // Consider the assignment of it to true... - let mut true_model = cur_assgn.clone(); - true_model.set(*x, true); - // ... and false... - let mut false_model = cur_assgn; - false_model.set(*x, false); - - // and calculate their respective upper bounds. - let true_ub = self.bb_ub(&true_model, &join_vars_bits, wmc); - let false_ub = self.bb_ub(&false_model, &join_vars_bits, wmc); - - // arbitrarily order the T/F bounds - let order = if true_ub == BBSemiring::choose(&true_ub, &false_ub) { - [(true_ub, true_model), (false_ub, false_model)] - } else { - [(false_ub, false_model), (true_ub, true_model)] - }; - // the actual branching and bounding - for (upper_bound, partialmodel) in order { - // if upper_bound == BBAlgebra::choose(&upper_bound, &best_lb) { - if !PartialOrd::le(&upper_bound, &cur_lb) { - let (rec, rec_pm) = - self.bb_h(best_lb, best_model.clone(), end, wmc, partialmodel.clone()); - let new_lb = BBSemiring::choose(&cur_lb, &rec); - if new_lb == rec { - (best_lb, best_model) = (rec, rec_pm); - } else { - (best_lb, best_model) = (cur_lb, cur_best.clone()); - } - } - } - (best_lb, best_model) - } - } - } - - /// branch and bound generic over T a BBAlgebra. - pub fn bb( - &self, - join_vars: &[VarLabel], - num_vars: usize, - wmc: &WmcParams, - ) -> (T, PartialModel) - where - T: 'static, - { - // Initialize all the decision variables to be true, partially instantianted resp. to this - let all_true: Vec = join_vars.iter().map(|x| Literal::new(*x, true)).collect(); - let cur_assgn = PartialModel::from_litvec(&all_true, num_vars); - // Calculate bound wrt the partial instantiation. - let lower_bound = self.bb_ub(&cur_assgn, &BitSet::new(), wmc); - self.bb_h( - lower_bound, - cur_assgn, - join_vars, - wmc, - PartialModel::from_litvec(&[], num_vars), - ) - } - + wmc.zero, + wmc.one, + ); + partial_join_acc * v + } + + fn bb_h( + &self, + cur_lb: T, + cur_best: PartialModel, + join_vars: &[VarLabel], + wmc: &WmcParams, + cur_assgn: PartialModel, + ) -> (T, PartialModel) + where + T: 'static, + { + match join_vars { + // If all join variables are assigned, + [] => { + // Run the bb_ub + let empty_join_vars = BitSet::new(); + let possible_best = self.bb_ub(&cur_assgn, &empty_join_vars, wmc); + // If it's a better lb, update. + let best = BBSemiring::choose(&cur_lb, &possible_best); + if cur_lb == best { + (cur_lb, cur_best) + } else { + (possible_best, cur_assgn) + } + } + // If there exists an unassigned decision variable, + [x, end @ ..] => { + let mut best_model = cur_best.clone(); + let mut best_lb = cur_lb; + let join_vars_bits = BitSet::from_iter(end.iter().map(|x| x.value_usize())); + // Consider the assignment of it to true... + let mut true_model = cur_assgn.clone(); + true_model.set(*x, true); + // ... and false... + let mut false_model = cur_assgn; + false_model.set(*x, false); + + // and calculate their respective upper bounds. + let true_ub = self.bb_ub(&true_model, &join_vars_bits, wmc); + let false_ub = self.bb_ub(&false_model, &join_vars_bits, wmc); + + // arbitrarily order the T/F bounds + let order = if true_ub == BBSemiring::choose(&true_ub, &false_ub) { + [(true_ub, true_model), (false_ub, false_model)] + } else { + [(false_ub, false_model), (true_ub, true_model)] + }; + // the actual branching and bounding + for (upper_bound, partialmodel) in order { + // if upper_bound == BBAlgebra::choose(&upper_bound, &best_lb) { + if !PartialOrd::le(&upper_bound, &cur_lb) { + let (rec, rec_pm) = + self.bb_h(best_lb, best_model.clone(), end, wmc, partialmodel.clone()); + let new_lb = BBSemiring::choose(&cur_lb, &rec); + if new_lb == rec { + (best_lb, best_model) = (rec, rec_pm); + } else { + (best_lb, best_model) = (cur_lb, cur_best.clone()); + } + } + } + (best_lb, best_model) + } + } + } + + /// branch and bound generic over T a BBAlgebra. + pub fn bb( + &self, + join_vars: &[VarLabel], + num_vars: usize, + wmc: &WmcParams, + ) -> (T, PartialModel) + where + T: 'static, + { + // Initialize all the decision variables to be true, partially instantianted resp. to this + let all_true: Vec = join_vars.iter().map(|x| Literal::new(*x, true)).collect(); + let cur_assgn = PartialModel::from_litvec(&all_true, num_vars); + // Calculate bound wrt the partial instantiation. + let lower_bound = self.bb_ub(&cur_assgn, &BitSet::new(), wmc); + self.bb_h( + lower_bound, + cur_assgn, + join_vars, + wmc, + PartialModel::from_litvec(&[], num_vars), + ) + } + /// performs a semantic hash and caches the result on the node pub fn cached_semantic_hash( &self, diff --git a/src/util/btree.rs b/src/util/btree.rs index 21fa7c15..a1a8d38f 100644 --- a/src/util/btree.rs +++ b/src/util/btree.rs @@ -121,7 +121,6 @@ where BTree::Leaf(ref l) => { if f(l) { return Some(idx); - } else { } } } diff --git a/src/util/semirings/expectation.rs b/src/util/semirings/expectation.rs index d5d66ccb..27a3017d 100644 --- a/src/util/semirings/expectation.rs +++ b/src/util/semirings/expectation.rs @@ -95,19 +95,19 @@ impl MeetSemilattice for ExpectedUtility { impl Lattice for ExpectedUtility {} -impl LatticeWithChoose for ExpectedUtility { } +impl LatticeWithChoose for ExpectedUtility {} impl EdgeboundingRing for ExpectedUtility {} impl ops::Div for ExpectedUtility { - type Output = ExpectedUtility; - - fn div(self, rhs: ExpectedUtility) -> Self::Output { - let y = rhs.0; - if y != 0.0 { - ExpectedUtility(self.0 / y, self.1 / y) - } else { - ExpectedUtility::zero() + type Output = ExpectedUtility; + + fn div(self, rhs: ExpectedUtility) -> Self::Output { + let y = rhs.0; + if y != 0.0 { + ExpectedUtility(self.0 / y, self.1 / y) + } else { + ExpectedUtility::zero() + } } - } } diff --git a/src/util/semirings/semiring_traits.rs b/src/util/semirings/semiring_traits.rs index aff4785b..7a27a48d 100644 --- a/src/util/semirings/semiring_traits.rs +++ b/src/util/semirings/semiring_traits.rs @@ -42,6 +42,6 @@ pub trait MeetSemilattice: PartialOrd { pub trait Lattice: JoinSemilattice + MeetSemilattice {} -pub trait LatticeWithChoose : BBSemiring + MeetSemilattice {} +pub trait LatticeWithChoose: BBSemiring + MeetSemilattice {} pub trait EdgeboundingRing: Lattice + BBRing {} diff --git a/tests/network_example.rs b/tests/network_example.rs index 0f98f33f..7321978e 100644 --- a/tests/network_example.rs +++ b/tests/network_example.rs @@ -189,8 +189,7 @@ fn gen() { let (meu_num, pm) = end.meu(network_fail, &vars, builder.num_vars(), &wmc); println!( "Regular MEU: {} \nPM : {:?}", - meu_num.1, - pm.true_assignments + meu_num.1, pm.true_assignments ); let elapsed = now.elapsed(); println!("Elapsed: {:.2?}", elapsed);