Skip to content

Commit

Permalink
Update policy gradient learning and learning logs.
Browse files Browse the repository at this point in the history
  • Loading branch information
dickensc committed Dec 11, 2023
1 parent c2a6e69 commit 26d79c7
Show file tree
Hide file tree
Showing 10 changed files with 154 additions and 93 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,15 @@ public abstract class GradientDescent extends WeightLearningApplication {
* on the chosen loss with unit simplex constrained weights.
* NONE: Perform standard gradient descent with only lower bound (>=0) constraints on the weights.
*/
public static enum GDExtension {
public static enum SymbolicWeightUpdate {
MIRROR_DESCENT,
PROJECTED_GRADIENT,
NONE
GRADIENT_DESCENT
}

protected GDExtension gdExtension;
protected boolean symbolicWeightLearning;
protected SymbolicWeightUpdate symbolicWeightUpdate;

protected Map<WeightedRule, Integer> ruleIndexMap;

protected float[] weightGradient;
Expand Down Expand Up @@ -124,7 +126,8 @@ public GradientDescent(List<Rule> rules, Database trainTargetDatabase, Database
Database validationTargetDatabase, Database validationTruthDatabase, boolean runValidation) {
super(rules, trainTargetDatabase, trainTruthDatabase, validationTargetDatabase, validationTruthDatabase, runValidation);

gdExtension = GDExtension.valueOf(Options.WLA_GRADIENT_DESCENT_EXTENSION.getString().toUpperCase());
symbolicWeightLearning = Options.WLA_GRADIENT_DESCENT_SYMBOLIC_LEARNING.getBoolean();
symbolicWeightUpdate = SymbolicWeightUpdate.valueOf(Options.WLA_GRADIENT_DESCENT_EXTENSION.getString().toUpperCase());

ruleIndexMap = new HashMap<WeightedRule, Integer>(mutableRules.size());
for (int i = 0; i < mutableRules.size(); i++) {
Expand Down Expand Up @@ -286,7 +289,7 @@ protected void initializeGradients() {
}

protected void initForLearning() {
switch (gdExtension) {
switch (symbolicWeightUpdate) {
case MIRROR_DESCENT:
case PROJECTED_GRADIENT:
// Initialize weights to be centered on the unit simplex.
Expand Down Expand Up @@ -348,10 +351,12 @@ protected void doLearn() {

DeepPredicate.trainModeAllDeepPredicates();

int numBatches = 0;
float averageBatchObjective = 0.0f;
batchGenerator.permuteBatchOrdering();
int batchId = batchGenerator.epochStart();
while (!batchGenerator.isEpochComplete()) {
long batchStart = System.currentTimeMillis();
numBatches++;

setBatch(batchId);
DeepPredicate.predictAllDeepPredicates();
Expand All @@ -364,23 +369,23 @@ protected void doLearn() {
clipWeightGradient();
}

float batchObjective = computeTotalLoss();
averageBatchObjective += computeTotalLoss();

gradientStep(epoch);

if (epoch % trainingStopComputePeriod == 0) {
epochDeepAtomValueMovement += DeepPredicate.predictAllDeepPredicates();
}

long batchEnd = System.currentTimeMillis();

log.trace("Batch: {} -- Weight Learning Objective: {}, Gradient Magnitude: {}, Iteration Time: {}",
batchId, batchObjective, computeGradientNorm(), (batchEnd - batchStart));

batchId = batchGenerator.nextBatch();
}
batchGenerator.epochEnd();

if (numBatches > 0) {
// Average the objective across batches.
averageBatchObjective /= numBatches;
}

setFullModel();

long end = System.currentTimeMillis();
Expand All @@ -396,7 +401,7 @@ protected void doLearn() {
setFullModel();

epoch++;
log.trace("Epoch: {} -- Iteration Time: {}", epoch, (end - start));
log.trace("Epoch: {}, Weight Learning Objective: {}, Iteration Time: {}", epoch, averageBatchObjective, (end - start));
}
log.info("Gradient Descent Weight Learning Finished.");

Expand Down Expand Up @@ -655,9 +660,13 @@ protected void internalParameterGradientStep(int epoch) {
* Return the total change in the weights.
*/
protected void weightGradientStep(int epoch) {
if (!symbolicWeightLearning) {
return;
}

float stepSize = computeStepSize(epoch);

switch (gdExtension) {
switch (symbolicWeightUpdate) {
case MIRROR_DESCENT:
float exponentiatedGradientSum = 0.0f;
for (int j = 0; j < mutableRules.size(); j++) {
Expand Down Expand Up @@ -713,7 +722,7 @@ protected float computeStepSize(int epoch) {
protected float computeGradientNorm() {
float norm = 0.0f;

switch (gdExtension) {
switch (symbolicWeightUpdate) {
case MIRROR_DESCENT:
norm = computeMirrorDescentNorm();
break;
Expand Down Expand Up @@ -931,8 +940,6 @@ protected float computeTotalLoss() {
float learningLoss = computeLearningLoss();
float regularization = computeRegularization();

log.trace("Learning Loss: {}, Regularization: {}", learningLoss, regularization);

return learningLoss + regularization;
}

Expand Down Expand Up @@ -963,6 +970,10 @@ protected float computeRegularization() {
protected void computeTotalWeightGradient() {
Arrays.fill(weightGradient, 0.0f);

if (!symbolicWeightLearning) {
return;
}

addLearningLossWeightGradient();
addRegularizationWeightGradient();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -611,10 +611,6 @@ protected float computeLearningLoss() {
float objectiveDifference = augmentedInferenceEnergy - mapEnergy;
float constraintViolation = Math.max(0.0f, objectiveDifference - constraintRelaxationConstant);
float supervisedLoss = computeSupervisedLoss();
float totalProxValue = computeTotalProxValue(new float[proxRuleObservedAtoms.length]);

log.trace("Prox Loss: {}, Objective difference: {}, Constraint Violation: {}, Supervised Loss: {}, Energy Loss: {}.",
totalProxValue, objectiveDifference, constraintViolation, supervisedLoss, latentInferenceEnergy);

return (squaredPenaltyCoefficient / 2.0f) * (float)Math.pow(constraintViolation, 2.0f)
+ linearPenaltyCoefficient * (constraintViolation)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -542,10 +542,6 @@ protected float computeLearningLoss() {
float objectiveDifference = augmentedInferenceEnergy - mapEnergy;
float constraintViolation = Math.max(0.0f, objectiveDifference - constraintRelaxationConstant);
float supervisedLoss = computeSupervisedLoss();
float totalProxValue = computeTotalProxValue(new float[proxRuleObservedAtoms.length]);

log.trace("Prox Loss: {}, Objective difference: {}, Constraint Violation: {}, Supervised Loss: {}, Energy Loss: {}.",
totalProxValue, objectiveDifference, constraintViolation, supervisedLoss, latentInferenceEnergy);

return (squaredPenaltyCoefficient / 2.0f) * (float)Math.pow(constraintViolation, 2.0f)
+ linearPenaltyCoefficient * (constraintViolation)
Expand Down
Loading

0 comments on commit 26d79c7

Please sign in to comment.