Skip to content

Commit

Permalink
Merge branch 'main' into experimental_minimizer_based_learning
Browse files Browse the repository at this point in the history
  • Loading branch information
dickensc committed Aug 25, 2023
2 parents 1ffba60 + 5bbd231 commit 1d54995
Show file tree
Hide file tree
Showing 34 changed files with 514 additions and 487 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ protected Reasoner createReasoner() {
}

protected TermStore createTermStore() {
return new ADMMTermStore(database);
return new ADMMTermStore(database.getAtomStore());
}

/**
Expand All @@ -124,7 +124,7 @@ protected TermStore createTermStore() {
*/
protected void completeInitialize() {
log.info("Grounding out model.");
long termCount = Grounding.groundAll(rules, termStore);
long termCount = Grounding.groundAll(rules, termStore, database);
log.info("Grounding complete.");
log.debug("Generated {} terms.", termCount);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,6 @@ protected Reasoner createReasoner() {

@Override
protected TermStore createTermStore() {
return new ADMMTermStore(database);
return new ADMMTermStore(database.getAtomStore());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,6 @@ protected Reasoner createReasoner() {

@Override
protected TermStore createTermStore() {
return new DualLCQPTermStore(database);
return new DualLCQPTermStore(database.getAtomStore());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,6 @@ protected Reasoner createReasoner() {

@Override
protected TermStore createTermStore() {
return new SGDTermStore(database);
return new SGDTermStore(database.getAtomStore());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,7 @@
import org.linqs.psl.config.Options;
import org.linqs.psl.database.AtomStore;
import org.linqs.psl.database.Database;
import org.linqs.psl.model.deep.DeepModelPredicate;
import org.linqs.psl.model.predicate.DeepPredicate;
import org.linqs.psl.model.predicate.Predicate;
import org.linqs.psl.model.rule.GroundRule;
import org.linqs.psl.model.rule.Rule;
import org.linqs.psl.model.rule.WeightedRule;
import org.linqs.psl.reasoner.InitialValue;
Expand All @@ -34,7 +31,6 @@
import org.linqs.psl.util.Logger;
import org.linqs.psl.util.MathUtils;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
Expand Down Expand Up @@ -164,7 +160,7 @@ protected void postInitGroundModel() {
throw new IllegalArgumentException("If validation is being run, then an evaluator must be specified for predicates.");
}

if (!((!runValidation) || (validationInferenceApplication.getDatabase().getAtomStore().size() > 0))) {
if (!((!runValidation) || (validationInferenceApplication.getTermStore().getAtomStore().size() > 0))) {
throw new IllegalStateException("If validation is being run, then validation data must be provided in the runtime.json file.");
}

Expand All @@ -178,10 +174,10 @@ protected void postInitGroundModel() {
trainMAPTermState = trainInferenceApplication.getTermStore().saveState();
validationMAPTermState = validationInferenceApplication.getTermStore().saveState();

float[] trainAtomValues = trainInferenceApplication.getDatabase().getAtomStore().getAtomValues();
float[] trainAtomValues = trainInferenceApplication.getTermStore().getAtomStore().getAtomValues();
trainMAPAtomValueState = Arrays.copyOf(trainAtomValues, trainAtomValues.length);

float[] validationAtomValues = validationInferenceApplication.getDatabase().getAtomStore().getAtomValues();
float[] validationAtomValues = validationInferenceApplication.getTermStore().getAtomStore().getAtomValues();
validationMAPAtomValueState = Arrays.copyOf(validationAtomValues, validationAtomValues.length);

rvAtomGradient = new float[trainAtomValues.length];
Expand Down Expand Up @@ -221,18 +217,12 @@ protected void doLearn() {
boolean breakGD = false;
float objective = 0.0f;
float oldObjective = Float.POSITIVE_INFINITY;

float[] bestWeights = new float[mutableRules.size()];


log.info("Gradient Descent Weight Learning Start.");
initForLearning();

for (int i = 0; i < bestWeights.length; i++) {
bestWeights[i] = mutableRules.get(i).getWeight();
}

long totalTime = 0;
int iteration = 0;
int epoch = 0;
while (!breakGD) {
long start = System.currentTimeMillis();

Expand Down Expand Up @@ -261,7 +251,7 @@ protected void doLearn() {
}
}

computeIterationStatistics();
computeIterationStatistics(epoch);

objective = computeTotalLoss();

Expand All @@ -271,26 +261,26 @@ protected void doLearn() {
clipWeightGradient();
}

gradientStep(iteration);
gradientStep(epoch);

long end = System.currentTimeMillis();

totalTime += end - start;
oldObjective = objective;

breakGD = breakOptimization(iteration, objective, oldObjective);
breakGD = breakOptimization(epoch, objective, oldObjective);
log.trace("Iteration {} -- Weight Learning Objective: {}, Gradient Magnitude: {}, Parameter Movement: {}, Iteration Time: {}",
iteration, objective, computeGradientNorm(), parameterMovement, (end - start));
epoch, objective, computeGradientNorm(), parameterMovement, (end - start));

iteration++;
epoch++;
}

log.info("Gradient Descent Weight Learning Finished.");

if (saveBestValidationWeights) {
// Reset rule weights to bestWeights.
for (int i = 0; i < mutableRules.size(); i++) {
mutableRules.get(i).setWeight(bestWeights[i]);
mutableRules.get(i).setWeight(bestValidationWeights[i]);
}
}

Expand Down Expand Up @@ -681,7 +671,7 @@ protected void computeMAPStateWithWarmStart(InferenceApplication inferenceApplic
TermState[] warmStartTermState, float[] warmStartAtomValueState) {
// Warm start inference with previous termState.
inferenceApplication.getTermStore().loadState(warmStartTermState);
AtomStore atomStore = inferenceApplication.getDatabase().getAtomStore();
AtomStore atomStore = inferenceApplication.getTermStore().getAtomStore();
float[] atomValues = atomStore.getAtomValues();
for (int i = 0; i < atomStore.size(); i++) {
if (atomStore.getAtom(i).isFixed()) {
Expand All @@ -697,7 +687,7 @@ protected void computeMAPStateWithWarmStart(InferenceApplication inferenceApplic

// Save the MPE state for future warm starts.
inferenceApplication.getTermStore().saveState(warmStartTermState);
float[] mpeAtomValues = inferenceApplication.getDatabase().getAtomStore().getAtomValues();
float[] mpeAtomValues = inferenceApplication.getTermStore().getAtomStore().getAtomValues();
System.arraycopy(mpeAtomValues, 0, warmStartAtomValueState, 0, mpeAtomValues.length);
}

Expand All @@ -708,7 +698,7 @@ protected void computeCurrentIncompatibility(float[] incompatibilityArray) {
// Zero out the incompatibility first.
Arrays.fill(incompatibilityArray, 0.0f);

float[] atomValues = trainInferenceApplication.getDatabase().getAtomStore().getAtomValues();
float[] atomValues = trainInferenceApplication.getTermStore().getAtomStore().getAtomValues();

// Sums up the incompatibilities.
for (Object rawTerm : trainInferenceApplication.getTermStore()) {
Expand All @@ -732,7 +722,7 @@ protected void computeCurrentIncompatibility(float[] incompatibilityArray) {
* Method called at the start of every gradient descent iteration to
* compute statistics needed for loss and gradient computations.
*/
protected abstract void computeIterationStatistics();
protected abstract void computeIterationStatistics(int epoch);

/**
* Method for computing the total regularized loss.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ public BinaryCrossEntropy(List<Rule> rules, Database trainTargetDatabase, Databa

@Override
protected float computeSupervisedLoss() {
AtomStore atomStore = trainInferenceApplication.getDatabase().getAtomStore();
AtomStore atomStore = trainInferenceApplication.getTermStore().getAtomStore();

float supervisedLoss = 0.0f;
for (Map.Entry<RandomVariableAtom, ObservedAtom> entry: trainingMap.getLabelMap().entrySet()) {
Expand All @@ -58,7 +58,7 @@ protected float computeSupervisedLoss() {

@Override
protected void addSupervisedProxRuleObservedAtomValueGradient() {
AtomStore atomStore = trainInferenceApplication.getDatabase().getAtomStore();
AtomStore atomStore = trainInferenceApplication.getTermStore().getAtomStore();

for (Map.Entry<RandomVariableAtom, ObservedAtom> entry: trainingMap.getLabelMap().entrySet()) {
RandomVariableAtom randomVariableAtom = entry.getKey();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,6 @@ public abstract class Minimizer extends GradientDescent {
protected float constraintTolerance;
protected float finalConstraintTolerance;

protected boolean initializedProxRuleConstants;
protected int outerIteration;

protected final float initialSquaredPenaltyCoefficient;
Expand Down Expand Up @@ -124,13 +123,12 @@ public Minimizer(List<Rule> rules, Database trainTargetDatabase, Database trainT
finalParameterMovementTolerance = Options.MINIMIZER_FINAL_PARAMETER_MOVEMENT_CONVERGENCE_TOLERANCE.getFloat();
constraintTolerance = (float)(1.0f / Math.pow(initialSquaredPenaltyCoefficient, 0.1f));
finalConstraintTolerance = Options.MINIMIZER_OBJECTIVE_DIFFERENCE_TOLERANCE.getFloat();
initializedProxRuleConstants = false;
outerIteration = 1;
}

@Override
protected void postInitGroundModel() {
AtomStore atomStore = trainInferenceApplication.getTermStore().getDatabase().getAtomStore();
AtomStore atomStore = trainInferenceApplication.getTermStore().getAtomStore();

// Create and add the augmented inference proximity terms.
int unFixedAtomCount = 0;
Expand Down Expand Up @@ -204,7 +202,7 @@ protected void postInitGroundModel() {
super.postInitGroundModel();

// Initialize latent and augmented inference warm start state objects.
float[] atomValues = trainInferenceApplication.getDatabase().getAtomStore().getAtomValues();
float[] atomValues = trainInferenceApplication.getTermStore().getAtomStore().getAtomValues();

latentInferenceTermState = trainInferenceApplication.getTermStore().saveState();
latentInferenceAtomValueState = Arrays.copyOf(atomValues, atomValues.length);
Expand Down Expand Up @@ -273,7 +271,7 @@ protected void gradientStep(int iteration) {
protected float internalParameterGradientStep(int iteration) {
float proxRuleObservedAtomsValueMovement = 0.0f;
// Take a step in the direction of the negative gradient of the proximity rule constants and project back onto box constraints.
float[] atomValues = trainInferenceApplication.getTermStore().getDatabase().getAtomStore().getAtomValues();
float[] atomValues = trainInferenceApplication.getTermStore().getAtomStore().getAtomValues();
for (int i = 0; i < proxRules.length; i++) {
float newProxRuleObservedAtomsValue = Math.min(Math.max(
proxRuleObservedAtoms[i].getValue() - proxRuleObservedAtomValueStepSize * proxRuleObservedAtomValueGradient[i], 0.0f), 1.0f);
Expand All @@ -291,13 +289,13 @@ protected void initializeProximityRuleConstants() {
// Initialize the proximity rule constants to the truth if it exists or the latent MAP state.
fixLabeledRandomVariables();

log.trace("Performing Latent Inference.");
log.trace("Running Latent Inference.");
computeMAPStateWithWarmStart(trainInferenceApplication, latentInferenceTermState, latentInferenceAtomValueState);
inTrainingMAPState = true;

unfixLabeledRandomVariables();

AtomStore atomStore = trainInferenceApplication.getDatabase().getAtomStore();
AtomStore atomStore = trainInferenceApplication.getTermStore().getAtomStore();
float[] atomValues = atomStore.getAtomValues();

System.arraycopy(latentInferenceAtomValueState, 0, augmentedInferenceAtomValueState, 0, latentInferenceAtomValueState.length);
Expand All @@ -320,15 +318,13 @@ protected void initializeProximityRuleConstants() {
atomValues[proxRuleObservedAtomIndexes[proxRuleIndex]] = observedAtom.getValue();
augmentedInferenceAtomValueState[proxRuleObservedAtomIndexes[proxRuleIndex]] = observedAtom.getValue();
}

initializedProxRuleConstants = true;
}

@Override
protected void computeIterationStatistics() {
protected void computeIterationStatistics(int epoch) {
computeFullInferenceStatistics();

if (!initializedProxRuleConstants) {
if (epoch == 0) {
initializeProximityRuleConstants();
}

Expand All @@ -341,7 +337,7 @@ protected void computeIterationStatistics() {
protected void computeTotalAtomGradient() {
float totalEnergyDifference = computeObjectiveDifference();

for (int i = 0; i < trainInferenceApplication.getDatabase().getAtomStore().size(); i++) {
for (int i = 0; i < trainInferenceApplication.getTermStore().getAtomStore().size(); i++) {
float rvGradientDifference = augmentedRVAtomGradient[i] - MAPRVAtomGradient[i];
float deepGradientDifference = augmentedDeepAtomGradient[i] - MAPDeepAtomGradient[i];

Expand Down Expand Up @@ -487,10 +483,10 @@ private float computeTotalEnergyDifference(float[] incompatibilityDifference){
}
}

GroundAtom[] atoms = trainInferenceApplication.getDatabase().getAtomStore().getAtoms();
GroundAtom[] atoms = trainInferenceApplication.getTermStore().getAtomStore().getAtoms();
float augmentedInferenceLCQPRegularization = 0.0f;
float fullInferenceLCQPRegularization = 0.0f;
for (int i = 0; i < trainInferenceApplication.getDatabase().getAtomStore().size(); i++) {
for (int i = 0; i < trainInferenceApplication.getTermStore().getAtomStore().size(); i++) {
if (atoms[i].isFixed()) {
continue;
}
Expand Down Expand Up @@ -543,7 +539,7 @@ protected void computeCurrentSquaredIncompatibility(float[] incompatibilityArray
// Zero out the incompatibility first.
Arrays.fill(incompatibilityArray, 0.0f);

float[] atomValues = trainInferenceApplication.getDatabase().getAtomStore().getAtomValues();
float[] atomValues = trainInferenceApplication.getTermStore().getAtomStore().getAtomValues();

// Sums up the incompatibilities.
for (Object rawTerm : trainInferenceApplication.getTermStore()) {
Expand All @@ -570,7 +566,7 @@ protected void computeCurrentSquaredIncompatibility(float[] incompatibilityArray
* with the same predicates and arguments having the same hash.
*/
protected void fixLabeledRandomVariables() {
AtomStore atomStore = trainInferenceApplication.getTermStore().getDatabase().getAtomStore();
AtomStore atomStore = trainInferenceApplication.getTermStore().getAtomStore();

for (Map.Entry<RandomVariableAtom, ObservedAtom> entry: trainingMap.getLabelMap().entrySet()) {
RandomVariableAtom randomVariableAtom = entry.getKey();
Expand All @@ -592,7 +588,7 @@ protected void fixLabeledRandomVariables() {
* with the same predicates and arguments having the same hash.
*/
protected void unfixLabeledRandomVariables() {
AtomStore atomStore = trainInferenceApplication.getDatabase().getAtomStore();
AtomStore atomStore = trainInferenceApplication.getTermStore().getAtomStore();

for (Map.Entry<RandomVariableAtom, ObservedAtom> entry: trainingMap.getLabelMap().entrySet()) {
RandomVariableAtom randomVariableAtom = entry.getKey();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ public SquaredError(List<Rule> rules, Database trainTargetDatabase, Database tra

@Override
protected float computeSupervisedLoss() {
AtomStore atomStore = trainInferenceApplication.getDatabase().getAtomStore();
AtomStore atomStore = trainInferenceApplication.getTermStore().getAtomStore();

float supervisedLoss = 0.0f;
for (Map.Entry<RandomVariableAtom, ObservedAtom> entry: trainingMap.getLabelMap().entrySet()) {
Expand All @@ -54,7 +54,7 @@ protected float computeSupervisedLoss() {

@Override
protected void addSupervisedProxRuleObservedAtomValueGradient() {
AtomStore atomStore = trainInferenceApplication.getDatabase().getAtomStore();
AtomStore atomStore = trainInferenceApplication.getTermStore().getAtomStore();

for (Map.Entry<RandomVariableAtom, ObservedAtom> entry: trainingMap.getLabelMap().entrySet()) {
RandomVariableAtom randomVariableAtom = entry.getKey();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ protected float computeLearningLoss() {
}

@Override
protected void computeIterationStatistics() {
protected void computeIterationStatistics(int epoch) {
computeLatentInferenceIncompatibility();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ protected void postInitGroundModel() {

// Initialize latent inference warm start state objects.
latentInferenceTermState = trainInferenceApplication.getTermStore().saveState();
float[] atomValues = trainInferenceApplication.getDatabase().getAtomStore().getAtomValues();
float[] atomValues = trainInferenceApplication.getTermStore().getAtomStore().getAtomValues();
latentInferenceAtomValueState = Arrays.copyOf(atomValues, atomValues.length);

rvLatentAtomGradient = new float[atomValues.length];
Expand Down Expand Up @@ -98,7 +98,7 @@ protected void computeLatentInferenceIncompatibility() {
* with the same predicates and arguments having the same hash.
*/
protected void fixLabeledRandomVariables() {
AtomStore atomStore = trainInferenceApplication.getTermStore().getDatabase().getAtomStore();
AtomStore atomStore = trainInferenceApplication.getTermStore().getAtomStore();

for (Map.Entry<RandomVariableAtom, ObservedAtom> entry: trainingMap.getLabelMap().entrySet()) {
RandomVariableAtom randomVariableAtom = entry.getKey();
Expand All @@ -120,7 +120,7 @@ protected void fixLabeledRandomVariables() {
* with the same predicates and arguments having the same hash.
*/
protected void unfixLabeledRandomVariables() {
AtomStore atomStore = trainInferenceApplication.getDatabase().getAtomStore();
AtomStore atomStore = trainInferenceApplication.getTermStore().getAtomStore();

for (Map.Entry<RandomVariableAtom, ObservedAtom> entry: trainingMap.getLabelMap().entrySet()) {
RandomVariableAtom randomVariableAtom = entry.getKey();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ public StructuredPerceptron(List<Rule> rules, Database trainTargetDatabase, Data
}

@Override
protected void computeIterationStatistics() {
protected void computeIterationStatistics(int epoch) {
computeLatentInferenceIncompatibility();
computeFullInferenceIncompatibility();
}
Expand Down
Loading

0 comments on commit 1d54995

Please sign in to comment.