Skip to content

Commit

Permalink
Add weighted_model_count binary (#161)
Browse files Browse the repository at this point in the history
Goal: unblock Lisa! Ref: #155.

This PR adds a new binary, `weighted_model_count`, that allows users to perform model counts over arbitrary boolean formulas (via BDD compilation). A brief sketch of the functionality:

- inputs:
  - a logical formula (specified with `-f`) in s-expression format; supports primitives (`And`, `Or`, `Not`) and helpers (`Ite`, `Xor`, `Iff`)
  - a `weights.json` (specified with `-w` that specifies `low` and `high` weights for each boolean variable. for now, they must be reals (i.e. compiled with `RealSemiring`); missing weights are inferred to be `0` (and are logged to the console)
  - a `config.json` (specified with `-c`) that:
      - optionally specifies a variable order; if none is provided, a linear one is used (inferring all variables from the formula)
      - optionally specifies a set of partial models to perform the WMC on; if none is provided, a WMC is done on the entire formula
- output (via JSON, path specified with `-o`):
  - the resulting size of the (unsmoothed) BDD
  - smoothed WMCs over each partial model
- optional flags: 
  - `-s`: silence all output
  - `-v`: more verbose output, includes timing statistics 

You can test out the binary using `cargo run` (see below). To compile it for production use, instead run:

```
$ cargo build --bin weighted_model_count --release --features="cli"
```

This should give you a binary `weighted_model_count` in `target/release`.

## usage: single WMC

To use, with:

`formula.sexp` (this formula represents `X XOR Y`):

```
(And (Or (Var X) (Var Y)) (Or (Not (Var X)) (Not (Var Y))))
```

`config.json`:

```json
{
  "order": ["Y", "X"]
}
```

`weights.json`:

```json
{
    "X": {
      "low": 0.3,
      "high": 0.7
    },
    "Y": {
      "low": 0.4,
      "high": 0.6
    }
}
```

Run:

```
$ cargo run --bin weighted_model_count --features="cli" -- -f formula.sexp -c config.json -w weights.json
0.45999999999999996
```

(floating-point rounding error!)

Run with `-v` to get more statistics.

## partial models

optionally, you can WMC over various partial models on one formula. 

`formula.sexp`:

```
(And (Or (Var X) (Var Z)) (Or (Not (Var X)) (Not (Var Y))))
```

`config.json`:

```json
{
  "partials": [
    { "X": true },
    { "X": false, "Y": true }
  ]
}
```

`weights.json`:

```json
{
    "Y": {
      "low": 0.4,
      "high": 0.6
    },
    "Z": {
      "low": 0.2,
      "high": 0.8
    }
}
```

running 

```
cargo run --bin weighted_model_count --features="cli" -- -f formula.sexp -c config.json -w weights.json -o output.json
```

gives:

```json
{
  "bdd_size": 3,
  "results": [
    {
      "partial_model": {
        "X": true
      },
      "wmc": 0.4
    },
    {
      "partial_model": {
        "X": false,
        "Y": true
      },
      "wmc": 0.8
    }
  ]
}
```

Observe that adding an empty partial model (with `{}`) recovers a WMC over the entire input formula.
  • Loading branch information
mattxwang authored Aug 15, 2023
1 parent a7d72c8 commit 53e5bd6
Show file tree
Hide file tree
Showing 3 changed files with 368 additions and 1 deletion.
5 changes: 5 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,11 @@ name = "bottomup_formula_to_bdd"
path = "bin/bottomup_formula_to_bdd.rs"
required-features = ["cli"]

[[bin]]
name = "weighted_model_count"
path = "bin/weighted_model_count.rs"
required-features = ["cli"]

[[example]]
name = "one_shot_benchmark"
path = "examples/one_shot_benchmark.rs"
Expand Down
362 changes: 362 additions & 0 deletions bin/weighted_model_count.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,362 @@
use std::{
collections::HashMap,
fmt::Debug,
fs::{self, File},
io::Write,
time::Instant,
};

use clap::Parser;
use rsdd::{
builder::{bdd::RobddBuilder, cache::LruIteTable, BottomUpBuilder},
constants::primes,
repr::{BddPtr, DDNNFPtr, LogicalExpr, PartialModel, VarLabel, VarOrder, WmcParams},
serialize::LogicalSExpr,
util::semirings::{FiniteField, RealSemiring, Semiring},
};
use serde::{Deserialize, Serialize};

#[derive(Serialize, Deserialize)]
struct VariableWeight<T> {
low: T,
high: T,
}

#[derive(Serialize, Deserialize)]
struct Config {
order: Option<Vec<String>>,
partials: Option<Vec<HashMap<String, bool>>>,
}

#[derive(Serialize, Debug)]
struct PartialWmcResult<T: Semiring + Serialize + Debug> {
partial_model: HashMap<String, bool>,
wmc: T,
}

#[derive(Serialize)]
struct PartialWmcOutput<T: Semiring + Serialize + Debug> {
bdd_size: usize,
results: Vec<PartialWmcResult<T>>,
}

impl Config {
fn to_var_order(&self, mapping: &HashMap<&String, usize>) -> Option<VarOrder> {
self.order.as_ref().map(|o| {
VarOrder::new(
&o.iter()
.map(|var| {
VarLabel::new(*mapping.get(var).unwrap_or_else(|| {
panic!("Found unknown variable {} in order configuration", var)
}) as u64)
})
.collect::<Vec<_>>(),
)
})
}
}

#[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)]
struct Args {
/// input logical expression in s expression form
#[clap(short, long, value_parser)]
file: String,

/// output file to write results to
#[clap(short, long, value_parser)]
output: Option<String>,

/// (optional) config file for variable ordering
#[clap(short, long, value_parser)]
config: Option<String>,

/// path to weights JSON file
#[clap(short, long, value_parser)]
weights: Option<String>,

/// show verbose output (including timing information, cache profiling, etc.)
#[clap(short, long, value_parser)]
verbose: bool,

/// silence all output; takes precedence over verbose
#[clap(short, long, value_parser)]
silent: bool,
}

fn generate_partial_assignments(
partials: &[HashMap<String, bool>],
inverse_mapping: &HashMap<usize, &String>,
num_vars: usize,
) -> Vec<PartialModel> {
partials
.iter()
.map(|assignments| {
PartialModel::from_assignments(
&(0..num_vars)
.map(|index| {
if let Some(str) = inverse_mapping.get(&index) {
if let Some(polarity) = assignments.get(*str) {
return Some(*polarity);
}
}
None
})
.collect::<Vec<_>>(),
)
})
.collect()
}

fn serialize_partial_model(
model: &PartialModel,
inverse_mapping: &HashMap<usize, &String>,
) -> HashMap<String, bool> {
let mut h = HashMap::new();

model.true_assignments.iter().for_each(|v| {
h.insert(
(*inverse_mapping.get(&(v.value() as usize)).unwrap()).clone(),
true,
);
});

model.false_assignments.iter().for_each(|v| {
h.insert(
(*inverse_mapping.get(&(v.value() as usize)).unwrap()).clone(),
false,
);
});

h
}

fn single_wmc(
expr: LogicalExpr,
num_vars: usize,
order: VarOrder,
params: WmcParams<RealSemiring>,
verbose: bool,
silent: bool,
) {
let builder = RobddBuilder::<LruIteTable<BddPtr>>::new(order.clone());

let unweighted_params: WmcParams<FiniteField<{ primes::U64_LARGEST }>> =
WmcParams::new(HashMap::from_iter(
(0..num_vars as u64)
.map(|v| (VarLabel::new(v), (FiniteField::one(), FiniteField::one()))),
));

let start = Instant::now();

let bdd = builder.compile_logical_expr(&expr);

let bdd = builder.smooth(bdd, num_vars);

let res = bdd.wmc(&order, &params);

let elapsed = start.elapsed();

if !silent {
println!(
"unweighted model count: {}\nweighted model count: {}",
builder
.smooth(bdd, num_vars)
.wmc(&order, &unweighted_params),
res
);
}

if verbose && !silent {
eprintln!("=== STATS ===");

let stats = builder.stats();
eprintln!("compilation time: {:.4}s", elapsed.as_secs_f64());
eprintln!("recursive calls: {}", stats.num_recursive_calls);
}
}

#[allow(clippy::too_many_arguments)]
fn partial_wmcs(
expr: LogicalExpr,
num_vars: usize,
order: &VarOrder,
params: &WmcParams<RealSemiring>,
partials: &[PartialModel],
inverse_mapping: &HashMap<usize, &String>,
verbose: bool,
silent: bool,
) -> PartialWmcOutput<RealSemiring> {
let builder = RobddBuilder::<LruIteTable<BddPtr>>::new(order.clone());
let mut results = Vec::new();

let start = Instant::now();

let bdd = builder.compile_logical_expr(&expr);

let init_compilation = start.elapsed();

for model in partials {
let conditioned = builder.condition_model(bdd, model);
let wmc = builder.smooth(conditioned, num_vars).wmc(order, params);

let res = PartialWmcResult {
partial_model: serialize_partial_model(model, inverse_mapping),
wmc,
};

if !silent {
println!("{:?}", res);
}

results.push(res);
}

let elapsed = start.elapsed();

if verbose && !silent {
eprintln!("=== STATS ===");

let stats = builder.stats();
eprintln!(
"initial compilation time: {:.4}s",
init_compilation.as_secs_f64()
);
eprintln!("total compilation time: {:.4}s", elapsed.as_secs_f64());
eprintln!(
"amortized partial model time: {:.4}s",
elapsed.as_secs_f64() / partials.len() as f64
);
eprintln!("recursive calls: {}", stats.num_recursive_calls);
}

PartialWmcOutput {
bdd_size: bdd.count_nodes(),
results,
}
}

fn main() {
let args = Args::parse();

let file = fs::read_to_string(&args.file)
.unwrap_or_else(|e| panic!("Error reading file {}; error: {}", args.file, e));

let config = if let Some(path_to_config) = args.config {
let config = fs::read_to_string(&path_to_config)
.unwrap_or_else(|e| panic!("Error reading file {}; error: {}", path_to_config, e));
serde_json::from_str::<Config>(&config).unwrap_or_else(|e| {
panic!(
"Error parsing {} as JSON config option; error: {}",
config, e
)
})
} else {
Config {
order: None,
partials: None,
}
};

let weights = if let Some(path_to_weights) = args.weights {
let config = fs::read_to_string(&path_to_weights)
.unwrap_or_else(|e| panic!("Error reading file {}; error: {}", path_to_weights, e));
serde_json::from_str::<HashMap<String, VariableWeight<f64>>>(&config)
.unwrap_or_else(|e| panic!("Error parsing {} as JSON weights; error: {}", config, e))
} else {
panic!("no weights file provided");
};

let sexpr = serde_sexpr::from_str::<LogicalSExpr>(&file).unwrap_or_else(|e| {
panic!(
"Error parsing {} as logical s-expression; error: {}",
file, e
)
});
let expr = LogicalExpr::from_sexpr(&sexpr);
let mut num_vars = sexpr.unique_variables().len();

let mut mapping = sexpr.variable_mapping();

let mut var_to_val = HashMap::from_iter(weights.iter().map(|(k, v)| {
let label = mapping.get(k);

match label {
None => {
let n = (
VarLabel::new(num_vars as u64),
(RealSemiring(v.low), RealSemiring(v.high)),
);
mapping.insert(k, num_vars);
num_vars += 1;
n
}
Some(index) => (
VarLabel::new(*index as u64),
(RealSemiring(v.low), RealSemiring(v.high)),
),
}
}));

let inverse_mapping: HashMap<usize, &String> =
HashMap::from_iter(mapping.iter().map(|(k, v)| (*v, *k)));

for index in 0..num_vars as u64 {
let label = VarLabel::new(index);
if var_to_val.get(&label).is_none() {
if !args.silent {
println!(
"Encountered variable {:?} with no associated weights. Assigning default: ({}, {})",
inverse_mapping.get(&(index as usize)),
RealSemiring::zero(),
RealSemiring::zero()
);
}
var_to_val.insert(label, (RealSemiring::zero(), RealSemiring::zero()));
}
}

let params: WmcParams<RealSemiring> = WmcParams::new(var_to_val);

let order = config.to_var_order(&mapping).unwrap_or_else(|| {
if !args.silent {
println!("No ordering in config; defaulting to linear order.")
}
VarOrder::linear_order(num_vars)
});

if let Some(partials) = config.partials {
let partials = generate_partial_assignments(&partials, &inverse_mapping, num_vars);
let output = partial_wmcs(
expr,
num_vars,
&order,
&params,
&partials,
&inverse_mapping,
args.verbose,
args.silent,
);

if let Some(path) = args.output {
let mut file = File::create(path).unwrap();
let r = file.write_all(serde_json::to_string_pretty(&output).unwrap().as_bytes());
assert!(r.is_ok(), "Error writing file");
}
} else {
single_wmc(
expr,
num_vars,
order.clone(),
params,
args.verbose,
args.silent,
);
}

if args.verbose && !args.silent {
eprintln!("=== METADATA ===");
eprintln!("variable mapping: {:?}", sexpr.variable_mapping());
eprintln!("variable ordering: {}", order);
}
}
2 changes: 1 addition & 1 deletion src/util/semirings/realsemiring.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use super::semiring_traits::*;
use std::{fmt::Display, ops};

#[derive(Debug, Clone, Copy, PartialEq, PartialOrd)]
#[derive(Debug, Clone, Copy, PartialEq, PartialOrd, Serialize, Deserialize)]
pub struct RealSemiring(pub f64);

impl Display for RealSemiring {
Expand Down

0 comments on commit 53e5bd6

Please sign in to comment.