Skip to content

Commit

Permalink
Fix the connection from FlatZinc to the linear propagator
Browse files Browse the repository at this point in the history
  • Loading branch information
Dekker1 committed May 3, 2024
1 parent 729b660 commit dc776fd
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 28 deletions.
37 changes: 32 additions & 5 deletions crates/huub/src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ mod reformulate;
use std::ops::AddAssign;

use flatzinc_serde::RangeList;
use itertools::Itertools;
use pindakaas::{
solver::{PropagatorAccess, Solver as SolverTrait},
ClauseDatabase, Cnf, Lit as RawLit, Valuation as SatValuation, Var as RawVar,
Expand All @@ -18,13 +19,13 @@ use self::{
};
use crate::{
model::{int::IntVarDef, reformulate::ReifContext},
propagator::{all_different::AllDifferentValue, linear::LinearLE},
propagator::{all_different::AllDifferentValue, int_lin_le::LinearLE},
solver::{
engine::int_var::IntVar as SlvIntVar,
view::{BoolViewInner, SolverView},
SatSolver,
},
Solver,
IntVal, Solver,
};

#[derive(Debug, Default)]
Expand Down Expand Up @@ -96,7 +97,8 @@ impl ClauseDatabase for Model {
pub enum Constraint {
Clause(Vec<BoolExpr>),
AllDifferent(Vec<IntExpr>),
LinearLE(Vec<i64>, Vec<IntExpr>, i64),
IntLinLessEq(Vec<IntVal>, Vec<IntExpr>, IntVal),
IntLinEq(Vec<IntVal>, Vec<IntExpr>, IntVal),
}

impl Constraint {
Expand Down Expand Up @@ -130,12 +132,37 @@ impl Constraint {
.collect();
slv.add_propagator(AllDifferentValue::new(vars));
}
Constraint::LinearLE(coeffs, vars, c) => {
Constraint::IntLinLessEq(coeffs, vars, c) => {
let vars: Vec<_> = vars
.iter()
.zip_eq(coeffs.iter())
.map(|(v, &c)| {
v.to_arg(
if c >= 0 {
ReifContext::Pos
} else {
ReifContext::Neg
},
slv,
map,
)
})
.collect();
slv.add_propagator(LinearLE::new(coeffs, vars, *c));
}
Constraint::IntLinEq(coeffs, vars, c) => {
let vars: Vec<_> = vars
.iter()
.map(|v| v.to_arg(ReifContext::Mixed, slv, map))
.collect();
slv.add_propagator(LinearLE::new(coeffs, vars, c));
// coeffs * vars <= c
slv.add_propagator(LinearLE::new(coeffs, vars.clone(), *c));
// coeffs * vars >= c <=> -coeffs * vars <= -c
slv.add_propagator(LinearLE::new(
&coeffs.iter().map(|c| -c).collect_vec(),
vars,
-c,
))
}
}
}
Expand Down
11 changes: 8 additions & 3 deletions crates/huub/src/model/flatzinc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ impl Model {
});
}
}
"int_linear_le" => {
"int_lin_le" | "int_lin_eq" => {
let is_eq = c.id.deref() == "int_lin_eq";
if let [coeffs, vars, rhs] = c.args.as_slice() {
let coeffs = arg_array(fzn, coeffs)?;
let vars = arg_array(fzn, vars)?;
Expand All @@ -84,10 +85,14 @@ impl Model {
.iter()
.map(|l| lit_int(fzn, &mut prb, &mut map, l))
.collect();
prb += Constraint::LinearLE(coeffs?, vars?, *rhs);
prb += if is_eq {
Constraint::IntLinEq
} else {
Constraint::IntLinLessEq
}(coeffs?, vars?, *rhs);
} else {
return Err(FlatZincError::InvalidNumArgs {
name: "int_linear_le",
name: if is_eq { "int_lin_eq" } else { "int_linear_le" },
found: c.args.len(),
expected: 3,
});
Expand Down
6 changes: 3 additions & 3 deletions crates/huub/src/model/int.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@ use crate::{
view::{IntView, IntViewInner, SolverView},
SatSolver,
},
Solver, Variable,
IntVal, Solver, Variable,
};

#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum IntExpr {
Var(IntVar),
Val(i64),
Val(IntVal),
}

impl IntExpr {
Expand Down Expand Up @@ -47,5 +47,5 @@ pub struct IntVar(pub(crate) u32);

#[derive(Debug)]
pub(crate) struct IntVarDef {
pub(crate) domain: RangeList<i64>,
pub(crate) domain: RangeList<IntVal>,
}
2 changes: 1 addition & 1 deletion crates/huub/src/propagator.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
pub(crate) mod all_different;
pub(crate) mod conflict;
pub(crate) mod int_event;
pub(crate) mod linear;
pub(crate) mod int_lin_le;
pub(crate) mod reason;

use std::fmt::Debug;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,17 @@ use crate::{
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub(crate) struct LinearLE {
vars: Vec<IntView>, // Variables in the linear inequality
rhs: i64, // Lower bound of the linear inequality
rhs: IntVal, // Lower bound of the linear inequality
action_list: Vec<u32>, // List of variables that have been modified since the last propagation
}

impl LinearLE {
pub(crate) fn new<V: Into<IntView>, VI: IntoIterator<Item = V>>(
coeffs: &[IntVal],
vars: VI,
rhs: &IntVal,
mut max_sum: IntVal,
) -> Self {
let vars: Vec<IntView> = vars.into_iter().map(Into::into).collect();
let mut max_sum = *rhs;
let scaled_vars: Vec<IntView> =
vars.iter()
.enumerate()
Expand Down Expand Up @@ -127,7 +126,7 @@ mod tests {
Cnf,
};

use crate::{propagator::linear::LinearLE, solver::engine::int_var::IntVar, Solver, Value};
use crate::{propagator::int_lin_le::LinearLE, solver::engine::int_var::IntVar, Solver, Value};

#[test]
fn test_linear_le_sat() {
Expand All @@ -136,7 +135,7 @@ mod tests {
let b = IntVar::new_in(&mut slv, RangeList::from_iter([1..=2]), true);
let c = IntVar::new_in(&mut slv, RangeList::from_iter([1..=2]), true);

slv.add_propagator(LinearLE::new(&[2, 1, 1], vec![a, b, c], &10));
slv.add_propagator(LinearLE::new(&[2, 1, 1], vec![a, b, c], 10));
let result = slv.solve(|val| {
let Value::Int(a_val) = val(a.into()).unwrap() else {
panic!()
Expand All @@ -159,7 +158,7 @@ mod tests {
let b = IntVar::new_in(&mut slv, RangeList::from_iter([1..=4]), true);
let c = IntVar::new_in(&mut slv, RangeList::from_iter([1..=4]), true);

slv.add_propagator(LinearLE::new(&[2, 1, 1], vec![a, b, c], &3));
slv.add_propagator(LinearLE::new(&[2, 1, 1], vec![a, b, c], 3));
assert_eq!(slv.solve(|_| {}), SolveResult::Unsat)
}

Expand All @@ -170,7 +169,7 @@ mod tests {
let b = IntVar::new_in(&mut slv, RangeList::from_iter([1..=4]), true);
let c = IntVar::new_in(&mut slv, RangeList::from_iter([1..=4]), true);

slv.add_propagator(LinearLE::new(&[-2, -1, -1], vec![a, b, c], &-3));
slv.add_propagator(LinearLE::new(&[-2, -1, -1], vec![a, b, c], -3));
let result = slv.solve(|val| {
let Value::Int(a_val) = val(a.into()).unwrap() else {
panic!()
Expand All @@ -193,7 +192,7 @@ mod tests {
let b = IntVar::new_in(&mut slv, RangeList::from_iter([1..=2]), true);
let c = IntVar::new_in(&mut slv, RangeList::from_iter([1..=2]), true);

slv.add_propagator(LinearLE::new(&[-2, -1, -1], vec![a, b, c], &-10));
slv.add_propagator(LinearLE::new(&[-2, -1, -1], vec![a, b, c], -10));
assert_eq!(slv.solve(|_| {}), SolveResult::Unsat)
}
}
16 changes: 8 additions & 8 deletions crates/huub/src/solver/engine/int_var.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ impl IntVar {
Sat: SatSolver + SolverTrait<ValueFn = Sol>,
>(
slv: &mut Solver<Sat>,
domain: RangeList<i64>,
domain: RangeList<IntVal>,
direct_encoding: bool,
) -> IntView {
let orig_domain_len = (&domain)
Expand Down Expand Up @@ -122,11 +122,11 @@ impl IntVar {
let direct_offset = self.orig_domain_len - 1;

let meaning = if offset < direct_offset {
LitMeaning::GreaterEq(*self.orig_domain.lower_bound().unwrap() + 1 + offset as i64)
LitMeaning::GreaterEq(*self.orig_domain.lower_bound().unwrap() + 1 + offset as IntVal)
} else {
debug_assert!(self.has_direct);
let offset = offset - direct_offset;
LitMeaning::Eq(*self.orig_domain.lower_bound().unwrap() + 1 + offset as i64)
LitMeaning::Eq(*self.orig_domain.lower_bound().unwrap() + 1 + offset as IntVal)
};
if lit.is_negated() {
!meaning
Expand Down Expand Up @@ -211,7 +211,7 @@ impl IntVar {
BoolView(BoolViewInner::Lit(if negate { !lit } else { lit }))
}

pub(crate) fn get_value<V: SatValuation + ?Sized>(&self, model: &V) -> i64 {
pub(crate) fn get_value<V: SatValuation + ?Sized>(&self, model: &V) -> IntVal {
let mut val_iter = self.orig_domain.clone().into_iter().flatten();
for l in self.order_vars() {
match model.value(l.into()) {
Expand All @@ -228,10 +228,10 @@ impl IntVar {

#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub enum LitMeaning {
Eq(i64),
NotEq(i64),
GreaterEq(i64),
Less(i64),
Eq(IntVal),
NotEq(IntVal),
GreaterEq(IntVal),
Less(IntVal),
}

impl Not for LitMeaning {
Expand Down

0 comments on commit dc776fd

Please sign in to comment.