Skip to content

Commit

Permalink
Update PolicyGradient.java
Browse files Browse the repository at this point in the history
  • Loading branch information
dickensc committed Dec 11, 2023
1 parent 26d79c7 commit aae3f56
Showing 1 changed file with 20 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,12 @@ protected void setBatch(int batch) {
protected void computeIterationStatistics() {
Arrays.fill(scores, 0.0f);

computeMAPInferenceStatistics();
computeSupervisedLoss();
computeLatentInferenceStatistics();

float preSampleScore = computeScore();

for (DeepModelPredicate deepModelPredicate : trainDeepModelPredicates) {
Map<String, ArrayList<RandomVariableAtom>> atomIdentiferToCategories = deepModelPredicate.getAtomIdentiferToCategories();
for (Map.Entry<String, ArrayList<RandomVariableAtom>> entry : atomIdentiferToCategories.entrySet()) {
Expand All @@ -153,10 +159,16 @@ protected void computeIterationStatistics() {
computeSupervisedLoss();
computeLatentInferenceStatistics();

float score = computeScore();
float sampleScore = computeScore();
for (RandomVariableAtom category : categories) {
int atomIndex = trainInferenceApplication.getTermStore().getAtomStore().getAtomIndex(category);
scores[atomIndex] = score;
if (policySampledDeepAtomValues[atomIndex] == 1.0f) {
scores[atomIndex] = preSampleScore - sampleScore;
// log.trace("Deep Atom: {} Score: {}",
// trainInferenceApplication.getTermStore().getAtomStore().getAtom(atomIndex), scores[atomIndex]);
} else {
scores[atomIndex] = 0.0f;
}
}

resetDeepAtomValues(categories);
Expand Down Expand Up @@ -300,17 +312,21 @@ private void computeCategoricalAtomGradient(int atomIndex) {

switch (policyUpdate) {
case REINFORCE:
deepAtomGradient[atomIndex] += scores[atomIndex] / sampleProbabilities[atomIndex];
deepAtomGradient[atomIndex] -= scores[atomIndex] / sampleProbabilities[atomIndex];
break;
case REINFORCE_BASELINE:
deepAtomGradient[atomIndex] += (scores[atomIndex] - scoreMovingAverage) / sampleProbabilities[atomIndex];
deepAtomGradient[atomIndex] -= (scores[atomIndex] - scoreMovingAverage) / sampleProbabilities[atomIndex];
break;
default:
throw new IllegalArgumentException("Unknown policy update: " + policyUpdate);
}

// log.trace("Deep Atom: {} Score: {}, Deep Atom Gradient: {}",
// trainInferenceApplication.getTermStore().getAtomStore().getAtom(atomIndex), scores[atomIndex], deepAtomGradient[atomIndex]);
}

private float computeScore() {
// log.trace("Latent Inference Energy: " + latentInferenceEnergy + " Supervised Loss: " + supervisedLoss);
return energyLossCoefficient * latentInferenceEnergy + supervisedLoss;
}

Expand Down

0 comments on commit aae3f56

Please sign in to comment.