Skip to content

Commit

Permalink
Test deep weighted logical rule.
Browse files Browse the repository at this point in the history
  • Loading branch information
dickensc committed May 21, 2024
1 parent 1851af2 commit 3c229bf
Show file tree
Hide file tree
Showing 10 changed files with 151 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -264,8 +264,10 @@ protected void relaxHardConstraints() {
for (Rule rule : rules) {
if (rule instanceof WeightedRule) {
Weight weight = ((WeightedRule)rule).getWeight();
if (weight.getValue() > largestWeight) {
largestWeight = weight.getValue();

if ((!weight.isDeep()) && (1.0f > largestWeight)) {
// 1.0f is the largest possible value of a deep weight.
largestWeight = 1.0f;
}
} else {
hasUnweightedRule = true;
Expand Down
4 changes: 3 additions & 1 deletion psl-core/src/main/java/org/linqs/psl/model/rule/Weight.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,14 @@
import org.linqs.psl.model.atom.GroundAtom;
import org.linqs.psl.util.HashCode;

import java.io.Serializable;

/**
* A weight for a rule.
* A weight is a constant value and can be associated with a GroundAtom.
* The value of the weight is the constant value multiplied by the value of the GroundAtom.
*/
public class Weight {
public class Weight implements Serializable {
private float constantValue;
private Atom atom;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ public class WeightedArithmeticRule extends AbstractArithmeticRule implements We
protected boolean squared;

public WeightedArithmeticRule(ArithmeticRuleExpression expression, Weight weight, boolean squared) {
this(expression, weight, squared, expression.toString());
this(expression, weight, squared, weight.toString() + ": " + expression.toString());
}

public WeightedArithmeticRule(ArithmeticRuleExpression expression, Weight weight, boolean squared, String name) {
Expand Down Expand Up @@ -74,7 +74,7 @@ protected AbstractGroundArithmeticRule makeGroundRule(
);

WeightedArithmeticRule groundedDeepWeightedRule = new WeightedArithmeticRule(
newExpression, groundedWeight, squared, groundedWeight.getAtom().toString() + ": " + name
newExpression, groundedWeight, squared
);

groundedDeepWeightedRule.setParentHashCode(hashCode());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ protected AbstractLogicalRule(Formula formula, String name, int hashcode) {
this.hashcode = hashcode;

this.formula = formula;
groundingResourcesKey = AbstractLogicalRule.class.getName() + ";" + formula + ";GroundingResources";
groundingResourcesKey = AbstractLogicalRule.class.getName() + ";" + name + ";GroundingResources";

// Do the formula analysis so we know what atoms to query for grounding.
// We will query for all positive atoms in the negated DNF.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ public void parseNegatedDNF(FormulaAnalysis.DNFClause negatedDNF, Weight weight)
if ((weight != null) && (weight.isDeep())) {
assert (weight.getAtom() instanceof QueryAtom);

weightQueryAtom = (QueryAtom)weight.getAtom();
weightQueryAtom = (QueryAtom) weight.getAtom();
weightArgumentsBuffer = new Constant[weightQueryAtom.getArity()];
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ protected GroundRule ground(Constant[] constants, Map<Variable, Integer> variabl
@Override
public WeightedRule relax(Weight weight, boolean squared) {
unregister();
return new WeightedLogicalRule(formula, weight, squared, name);
return new WeightedLogicalRule(formula, weight, squared);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ public class WeightedLogicalRule extends AbstractLogicalRule implements Weighted
protected boolean squared;

public WeightedLogicalRule(Formula formula, Weight weight, boolean squared) {
this(formula, weight, squared, formula.toString());
this(formula, weight, squared, weight.toString() + ": " + formula.toString());
}

public WeightedLogicalRule(Formula formula, Weight weight, boolean squared, String name) {
Expand Down Expand Up @@ -67,7 +67,7 @@ protected WeightedGroundLogicalRule groundFormulaInstance(List<GroundAtom> posLi
if (groundedWeight == null) {
return new WeightedGroundLogicalRule(this, posLiterals, negLiterals);
} else {
WeightedLogicalRule groundedDeepWeightedRule = new WeightedLogicalRule(formula, groundedWeight, squared, groundedWeight.getAtom().toString() + ": " + name);
WeightedLogicalRule groundedDeepWeightedRule = new WeightedLogicalRule(formula, groundedWeight, squared);
groundedDeepWeightedRule.setParentHashCode(hashCode());
addChildHashCode(groundedDeepWeightedRule.hashCode());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,10 @@ private void mergeAtomComponents(GroundAtom atom1, GroundAtom atom2) {
return;
}

if (atom1RootIndex == -1 || atom2RootIndex == -1) {
throw new IllegalArgumentException("Atoms must be in the atom store before they can be merged.");
}

GroundAtom atom1Root = atomStore.getAtom(atom1RootIndex);
GroundAtom atom2Root = atomStore.getAtom(atom2RootIndex);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,17 @@ public void testSimpleModels() {

inference.close();
inferDB.close();

// Exogenous model with observed atom weight.
info = TestModel.getExogenousModelWithObservedAtomWeight();
inferDB = info.dataStore.getDatabase(info.targetPartition, new HashSet<StandardPredicate>(), info.observationPartition);
inference = getInference(info.model.getRules(), inferDB);

// Test the inference application is able to find the MAP state.
assertEquals(0.0, inference.inference(), 0.1f);

inference.close();
inferDB.close();
}

/**
Expand All @@ -459,7 +470,7 @@ public void reasonerEvaluateTest() {

@Test
public void testAtomWithConstant() {
// Nice(A) & Nice(B) & Friends('Alice', B) && (A != B) -> Friends(A, B)
// 1.0: Nice(A) & Nice(B) & Friends('Alice', B) && (A != B) -> Friends(A, B)
info.model.addRule(new WeightedLogicalRule(
new Implication(
new Conjunction(
Expand Down
121 changes: 121 additions & 0 deletions psl-core/src/test/java/org/linqs/psl/test/TestModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,127 @@ public static ModelInformation getExogenousModel() {
return getModel(DatabaseTestUtil.getH2Driver(), predicates, rules, observations, targets, truths);
}

/**
* A model with only a single exogenous rule with a observed atom weight.
* Weight(A, B): Nice(A) & Nice(B) & (A != B) -> Friends(A, B) ^2
* Such that the weight is 1.0 for all instance of the rule.
*/
public static ModelInformation getExogenousModelWithObservedAtomWeight() {
// Define Predicates
Map<String, ConstantType[]> predicatesInfo = new HashMap<String, ConstantType[]>();
predicatesInfo.put("Nice", new ConstantType[]{ConstantType.UniqueStringID});
predicatesInfo.put("Friends", new ConstantType[]{ConstantType.UniqueStringID, ConstantType.UniqueStringID});
predicatesInfo.put("Weight", new ConstantType[]{ConstantType.UniqueStringID, ConstantType.UniqueStringID});

Map<String, StandardPredicate> predicates = new HashMap<String, StandardPredicate>();
for (Map.Entry<String, ConstantType[]> predicateEntry : predicatesInfo.entrySet()) {
StandardPredicate predicate = StandardPredicate.get(predicateEntry.getKey(), predicateEntry.getValue());
predicates.put(predicateEntry.getKey(), predicate);
}

// Define Rules
List<Rule> rules = new ArrayList<Rule>();
rules.add(new WeightedLogicalRule(
new Implication(
new Conjunction(
new QueryAtom(predicates.get("Nice"), new Variable("A")),
new QueryAtom(predicates.get("Nice"), new Variable("B")),
new QueryAtom(GroundingOnlyPredicate.NotEqual, new Variable("A"), new Variable("B"))
),
new QueryAtom(predicates.get("Friends"), new Variable("A"), new Variable("B"))
),
new Weight(1.0f, new QueryAtom(predicates.get("Weight"), new Variable("A"), new Variable("B"))),
true));

// Data
Map<StandardPredicate, List<PredicateData>> observations = new HashMap<StandardPredicate, List<PredicateData>>();
Map<StandardPredicate, List<PredicateData>> targets = new HashMap<StandardPredicate, List<PredicateData>>();
Map<StandardPredicate, List<PredicateData>> truths = new HashMap<StandardPredicate, List<PredicateData>>();

// Nice
observations.put(predicates.get("Nice"), new ArrayList<PredicateData>(Arrays.asList(
new PredicateData(1.0, new Object[]{"Alice"}),
new PredicateData(1.0, new Object[]{"Bob"}),
new PredicateData(1.0, new Object[]{"Charlie"}),
new PredicateData(1.0, new Object[]{"Derek"}),
new PredicateData(1.0, new Object[]{"Eugene"})
)));

// Weight
observations.put(predicates.get("Weight"), new ArrayList<PredicateData>(Arrays.asList(
new PredicateData(1.0, new Object[]{"Alice", "Bob"}),
new PredicateData(1.0, new Object[]{"Bob", "Alice"}),
new PredicateData(1.0, new Object[]{"Alice", "Charlie"}),
new PredicateData(1.0, new Object[]{"Charlie", "Alice"}),
new PredicateData(1.0, new Object[]{"Alice", "Derek"}),
new PredicateData(1.0, new Object[]{"Derek", "Alice"}),
new PredicateData(1.0, new Object[]{"Alice", "Eugene"}),
new PredicateData(1.0, new Object[]{"Eugene", "Alice"}),
new PredicateData(1.0, new Object[]{"Bob", "Charlie"}),
new PredicateData(1.0, new Object[]{"Charlie", "Bob"}),
new PredicateData(1.0, new Object[]{"Bob", "Derek"}),
new PredicateData(1.0, new Object[]{"Derek", "Bob"}),
new PredicateData(1.0, new Object[]{"Bob", "Eugene"}),
new PredicateData(1.0, new Object[]{"Eugene", "Bob"}),
new PredicateData(1.0, new Object[]{"Charlie", "Derek"}),
new PredicateData(1.0, new Object[]{"Derek", "Charlie"}),
new PredicateData(1.0, new Object[]{"Charlie", "Eugene"}),
new PredicateData(1.0, new Object[]{"Eugene", "Charlie"}),
new PredicateData(1.0, new Object[]{"Derek", "Eugene"}),
new PredicateData(1.0, new Object[]{"Eugene", "Derek"})
)));

// Friends
targets.put(predicates.get("Friends"), new ArrayList<PredicateData>(Arrays.asList(
new PredicateData(new Object[]{"Alice", "Bob"}),
new PredicateData(new Object[]{"Bob", "Alice"}),
new PredicateData(new Object[]{"Alice", "Charlie"}),
new PredicateData(new Object[]{"Charlie", "Alice"}),
new PredicateData(new Object[]{"Alice", "Derek"}),
new PredicateData(new Object[]{"Derek", "Alice"}),
new PredicateData(new Object[]{"Alice", "Eugene"}),
new PredicateData(new Object[]{"Eugene", "Alice"}),
new PredicateData(new Object[]{"Bob", "Charlie"}),
new PredicateData(new Object[]{"Charlie", "Bob"}),
new PredicateData(new Object[]{"Bob", "Derek"}),
new PredicateData(new Object[]{"Derek", "Bob"}),
new PredicateData(new Object[]{"Bob", "Eugene"}),
new PredicateData(new Object[]{"Eugene", "Bob"}),
new PredicateData(new Object[]{"Charlie", "Derek"}),
new PredicateData(new Object[]{"Derek", "Charlie"}),
new PredicateData(new Object[]{"Charlie", "Eugene"}),
new PredicateData(new Object[]{"Eugene", "Charlie"}),
new PredicateData(new Object[]{"Derek", "Eugene"}),
new PredicateData(new Object[]{"Eugene", "Derek"})
)));

truths.put(predicates.get("Friends"), new ArrayList<PredicateData>(Arrays.asList(
new PredicateData(1, new Object[]{"Alice", "Bob"}),
new PredicateData(1, new Object[]{"Bob", "Alice"}),
new PredicateData(1, new Object[]{"Alice", "Charlie"}),
new PredicateData(1, new Object[]{"Charlie", "Alice"}),
new PredicateData(1, new Object[]{"Alice", "Derek"}),
new PredicateData(1, new Object[]{"Derek", "Alice"}),
new PredicateData(1, new Object[]{"Alice", "Eugene"}),
new PredicateData(1, new Object[]{"Eugene", "Alice"}),
new PredicateData(1, new Object[]{"Bob", "Charlie"}),
new PredicateData(1, new Object[]{"Charlie", "Bob"}),
new PredicateData(1, new Object[]{"Bob", "Derek"}),
new PredicateData(1, new Object[]{"Derek", "Bob"}),
new PredicateData(0, new Object[]{"Bob", "Eugene"}),
new PredicateData(0, new Object[]{"Eugene", "Bob"}),
new PredicateData(1, new Object[]{"Charlie", "Derek"}),
new PredicateData(1, new Object[]{"Derek", "Charlie"}),
new PredicateData(0, new Object[]{"Charlie", "Eugene"}),
new PredicateData(0, new Object[]{"Eugene", "Charlie"}),
new PredicateData(0, new Object[]{"Derek", "Eugene"}),
new PredicateData(0, new Object[]{"Eugene", "Derek"})
)));

return getModel(DatabaseTestUtil.getH2Driver(), predicates, rules, observations, targets, truths);
}


/**
* A model with only a single symmetry rule.
* 10: Person(A) & Person(B) & Friends(A, B) & (A != B) -> Friends(B, A) ^2
Expand Down

0 comments on commit 3c229bf

Please sign in to comment.