Skip to content

Commit

Permalink
Shuffle batch ordering and epoch end operation for deep models.
Browse files Browse the repository at this point in the history
  • Loading branch information
dickensc committed Aug 29, 2023
1 parent 7400761 commit 70d2923
Show file tree
Hide file tree
Showing 7 changed files with 59 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,14 @@ protected void doLearn() {
i, batchObjective, computeGradientNorm(), (batchEnd - batchStart));
}
epoch++;

setFullTrainModel();

// Predict with the deep predicates again to ensure predictions are aligned with the full training model.
for (DeepPredicate deepPredicate : deepPredicates) {
deepPredicate.predictDeepModel(true);
}

measureEpochParameterMovement();
epochEnd(epoch);

Expand Down Expand Up @@ -414,12 +421,17 @@ protected void epochStart(int epoch) {
for (int i = 0; i < mutableRules.size(); i++) {
epochStartWeights[i] = mutableRules.get(i).getWeight();
}

batchGenerator.shuffle();
}

protected void epochEnd(int epoch) {
// By default, do nothing. Child classes can override this method to perform actions at the end of an epoch.
// This method is called after the epoch parameter movement is measured and the model is reset to the full training model
// but before measuring the stopping condition.
// Child classes should override this method to add additional functionality.
for (DeepPredicate deepPredicate : deepPredicates) {
deepPredicate.epochEnd();
}
}

protected void measureEpochParameterMovement() {
Expand Down Expand Up @@ -468,12 +480,9 @@ protected void setFullTrainModel() {
trainMAPTermState = trainFullMAPTermState;
trainMAPAtomValueState = trainFullMAPAtomValueState;

// Set the deep predicate atom store and predict with the deep predicates again
// to ensure predictions are aligned with the full training model.
for (int i = 0; i < deepPredicates.size(); i++) {
DeepPredicate deepPredicate = deepPredicates.get(i);
deepPredicate.setDeepModel(deepModelPredicates.get(i));
deepPredicate.predictDeepModel(false);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.linqs.psl.model.predicate.DeepPredicate;
import org.linqs.psl.reasoner.term.ReasonerTerm;
import org.linqs.psl.reasoner.term.SimpleTermStore;
import org.linqs.psl.util.RandUtils;

import java.util.ArrayList;
import java.util.List;
Expand Down Expand Up @@ -85,6 +86,10 @@ public void generateBatches() {
}
}

public void shuffle() {
RandUtils.pairedShuffle(batchTermStores, batchDeepModelPredicates);
}

public abstract void generateBatchTermStores();

public void clear() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.linqs.psl.model.atom.ObservedAtom;
import org.linqs.psl.model.atom.RandomVariableAtom;
import org.linqs.psl.model.atom.UnmanagedObservedAtom;
import org.linqs.psl.model.predicate.DeepPredicate;
import org.linqs.psl.model.predicate.Predicate;
import org.linqs.psl.model.predicate.StandardPredicate;
import org.linqs.psl.model.rule.Rule;
Expand Down Expand Up @@ -323,7 +324,13 @@ protected void epochStart(int epoch) {

initializeProximityRuleConstants();
}

setFullTrainModel();

// Predict with the deep predicates again to ensure predictions are aligned with the full training model.
for (DeepPredicate deepPredicate : deepPredicates) {
deepPredicate.predictDeepModel(true);
}
}

proxRuleObservedAtomsValueEpochMovement = 0.0f;
Expand Down Expand Up @@ -537,8 +544,14 @@ private float computeTotalObjectiveDifference() {

totalObjectiveDifference += computeObjectiveDifference();
}

setFullTrainModel();

// Predict with the deep predicates again to ensure predictions are aligned with the full training model.
for (DeepPredicate deepPredicate : deepPredicates) {
deepPredicate.predictDeepModel(true);
}

return totalObjectiveDifference;
}

Expand Down
13 changes: 13 additions & 0 deletions psl-core/src/main/java/org/linqs/psl/model/deep/DeepModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,19 @@ public void fitDeepModel() {
log.debug("Fit deep model results for {} : {}", this, resultString);
}

public void epochEnd() {
log.debug("Epoch end deep model {}.", this);

JSONObject message = new JSONObject();
message.put("task", "epoch_end");
message.put("options", pythonOptions);

JSONObject response = sendSocketMessage(message);

String resultString = getResultString(response);
log.debug("Epoch end deep model results for {} : {}", this, resultString);
}

public float predictDeepModel(Boolean learning) {
log.debug("Predict deep model {}.", this);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ public void fitDeepPredicate(float[] symbolicGradients) {
deepModel.fitDeepModel();
}

public void epochEnd() {
deepModel.epochEnd();
}

public DeepModelPredicate getDeepModel() {
return deepModel;
}
Expand Down
6 changes: 6 additions & 0 deletions psl-python/pslpython/deeppsl/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ def internal_init_model(self, application, options = {}):
def internal_fit(self, data, gradients, options = {}):
raise NotImplementedError("internal_fit")

def internal_epoch_end(self, options = {}):
raise NotImplementedError("internal_epoch")

def internal_predict(self, data, options = {}):
raise NotImplementedError("internal_predict")

Expand Down Expand Up @@ -103,6 +106,9 @@ def fit_predicate(self, options = {}):

return self.internal_fit(data, gradients, options = options)

def epoch_end(self, options = {}):
return self.internal_epoch_end(options = options)

def predict_predicate(self, options = {}):
self._predict_predicate(False, options = options)

Expand Down
5 changes: 5 additions & 0 deletions psl-python/pslpython/deeppsl/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ def handle_internal(self, connection, data):
result = self._init(request)
elif request['task'] == 'fit':
result = self._fit(request)
elif request['task'] == 'epoch_end':
result = self._epoch_end(request)
elif request['task'] == 'predict':
result = self._predict(request)
elif request['task'] == 'predict_learn':
Expand Down Expand Up @@ -114,6 +116,9 @@ def _fit(self, request):
else:
raise ValueError("Unknown deep model type in fit: '%s'." % (deep_model,))

def _epoch_end(self, request):
return self._model.epoch_end(options=request.get('options', {}))

def _predict(self, request):
deep_model = request['deep_model']
options = request.get('options', {})
Expand Down

0 comments on commit 70d2923

Please sign in to comment.