Skip to content

Commit

Permalink
bug fix for gradient norm
Browse files Browse the repository at this point in the history
  • Loading branch information
zhongkaifu committed Dec 1, 2023
1 parent ce3f1d1 commit 0c87e77
Show file tree
Hide file tree
Showing 8 changed files with 15 additions and 15 deletions.
2 changes: 1 addition & 1 deletion Seq2SeqSharp/Applications/Image2Seq.cs
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ private bool CreateTrainableParameters(IModel model)
isTrainable: true, learningRateFactor: m_options.DecoderStartLearningRateFactor, elementType: elementType), DeviceIds);

m_posEmbedding = new MultiProcessorNetworkWrapper<IWeightTensor>(new WeightTensor(new long[2] { 1024, model.HiddenDim }, raDeviceIds.GetNextItem(), initType: RandomInitType.Uniform, name: "PositionalEmbedding",
learningRateFactor: m_options.EncoderStartLearningRateFactor, isTrainable: true, dtype: elementType), raDeviceIds.ToArray());
learningRateFactor: m_options.EncoderStartLearningRateFactor, isTrainable: true, dtype: elementType, needGradientNorm: false), raDeviceIds.ToArray());

// m_cls = new MultiProcessorNetworkWrapper<IWeightTensor>(new WeightTensor(new long[2] { 1, model.HiddenDim }, raDeviceIds.GetNextItem(), initType: RandomInitType.Uniform, name: "CLS", learningRateFactor: m_options.EncoderStartLearningRateFactor,
// isTrainable: true, dtype: elementType), raDeviceIds.ToArray());
Expand Down
2 changes: 1 addition & 1 deletion Seq2SeqSharp/Applications/SeqClassification.cs
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ private bool CreateTrainableParameters(IModel model)
Logger.WriteLine(Logger.Level.debug, $"Creating embeddings. Shape = '({model.SrcVocab.Count} ,{model.EncoderEmbeddingDim})'");

m_srcEmbedding = new MultiProcessorNetworkWrapper<IWeightTensor>(new WeightTensor(new long[2] { model.SrcVocab.Count, model.EncoderEmbeddingDim }, raDeviceIds.GetNextItem(), initType: RandomInitType.Uniform, fanOut: true, name: "SrcEmbeddings",
isTrainable: m_options.IsEmbeddingTrainable), DeviceIds);
isTrainable: m_options.IsEmbeddingTrainable, needGradientNorm: false), DeviceIds);

return true;
}
Expand Down
2 changes: 1 addition & 1 deletion Seq2SeqSharp/Applications/SeqLabel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ private bool CreateTrainableParameters(IModel model)
m_ffLayer = new MultiProcessorNetworkWrapper<FeedForwardLayer>(new FeedForwardLayer("FeedForward", model.HiddenDim, model.TgtVocab.Count, dropoutRatio: 0.0f, deviceId: raDeviceIds.GetNextItem(), isTrainable: true), DeviceIds);

m_srcEmbedding = new MultiProcessorNetworkWrapper<IWeightTensor>(new WeightTensor(new long[2] { model.SrcVocab.Count, model.EncoderEmbeddingDim }, raDeviceIds.GetNextItem(), initType: RandomInitType.Uniform, name: "SrcEmbeddings",
isTrainable: true), DeviceIds);
isTrainable: true, needGradientNorm: false), DeviceIds);
(m_posEmbedding, m_segmentEmbedding) = Misc.CreateAuxEmbeddings(raDeviceIds, model.HiddenDim, m_options.MaxSentLength, model, createAPE: (model.PEType == PositionEmbeddingEnums.APE));

return true;
Expand Down
10 changes: 5 additions & 5 deletions Seq2SeqSharp/Optimizer/AdamOptimizer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ public AdamOptimizer(float clipval, float beta1 = 0.9f, float beta2 = 0.98f, boo
m_checkTensorCorrupted = checkTensorCorrupted;
}

public void UpdateWeights(List<IWeightTensor> model, int tokenSize, float step_size, float regc, int iter)
public void UpdateWeights(List<IWeightTensor> model, int batchSize, int tokenSize, float step_size, float regc, int iter)
{
Dictionary<int, List<IWeightTensor>> id2Models = new Dictionary<int, List<IWeightTensor>>();
Dictionary<string, IWeightTensor> name2tensor = new Dictionary<string, IWeightTensor>();
Expand Down Expand Up @@ -94,13 +94,13 @@ public void UpdateWeights(List<IWeightTensor> model, int tokenSize, float step_s
foreach (IWeightTensor item in kv.Value)
{
WeightTensor m = item as WeightTensor;
UpdateWeightsTensor(m, m.NeedGradient ? tokenSize : 1, step_size * m.LearningRateFactor, regc, iter);
UpdateWeightsTensor(m, m.NeedGradientNorm ? tokenSize : batchSize, step_size * m.LearningRateFactor, regc, iter);
}
});
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private void UpdateWeightsTensor(WeightTensor m, int tokenSize, float step_size, float regc, int iter)
private void UpdateWeightsTensor(WeightTensor m, int normFactor, float step_size, float regc, int iter)
{
try
{
Expand All @@ -113,7 +113,7 @@ private void UpdateWeightsTensor(WeightTensor m, int tokenSize, float step_size,
Ops.Copy(t2, m_cacheName2M[m.Name]);


Ops.Adam(m.TWeight, m.TGradient, t1, t2, tokenSize, step_size, m_clipval, regc, m_beta2, m_beta1, iter, m_smoothEps);
Ops.Adam(m.TWeight, m.TGradient, t1, t2, normFactor, step_size, m_clipval, regc, m_beta2, m_beta1, iter, m_smoothEps);

Ops.Copy(m_cacheName2V[m.Name], t1);
t1.Dispose();
Expand All @@ -123,7 +123,7 @@ private void UpdateWeightsTensor(WeightTensor m, int tokenSize, float step_size,
}
else
{
Ops.Adam(m.TWeight, m.TGradient, m_cacheName2V[m.Name], m_cacheName2M[m.Name], tokenSize, step_size, m_clipval, regc, m_beta2, m_beta1, iter, m_smoothEps);
Ops.Adam(m.TWeight, m.TGradient, m_cacheName2V[m.Name], m_cacheName2M[m.Name], normFactor, step_size, m_clipval, regc, m_beta2, m_beta1, iter, m_smoothEps);
}

}
Expand Down
2 changes: 1 addition & 1 deletion Seq2SeqSharp/Optimizer/IOptimizer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,6 @@ namespace Seq2SeqSharp.Optimizer
{
public interface IOptimizer
{
void UpdateWeights(List<IWeightTensor> model, int tokenSize, float step_size, float regc, int iter);
void UpdateWeights(List<IWeightTensor> model, int batchSize, int tokenSize, float step_size, float regc, int iter);
}
}
8 changes: 4 additions & 4 deletions Seq2SeqSharp/Optimizer/RMSPropOptimizer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ public RMSPropOptimizer(float clipval, float decayRate = 0.999f)
m_decayRate = decayRate;
}

public void UpdateWeights(List<IWeightTensor> model, int tokenSize, float step_size, float regc, int iter)
public void UpdateWeights(List<IWeightTensor> model, int batchSize, int tokenSize, float step_size, float regc, int iter)
{
Dictionary<int, List<IWeightTensor>> id2Models = new Dictionary<int, List<IWeightTensor>>();
Dictionary<string, IWeightTensor> name2tensor = new Dictionary<string, IWeightTensor>();
Expand Down Expand Up @@ -65,17 +65,17 @@ public void UpdateWeights(List<IWeightTensor> model, int tokenSize, float step_s
foreach (IWeightTensor item in kv.Value)
{
WeightTensor m = item as WeightTensor;
UpdateWeightsTensor(m, m.NeedGradient ? tokenSize : 1, step_size, regc, iter);
UpdateWeightsTensor(m, m.NeedGradientNorm ? tokenSize : batchSize, step_size, regc, iter);
}
});
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private void UpdateWeightsTensor(WeightTensor m, int tokenSize, float step_size, float regc, int iter)
private void UpdateWeightsTensor(WeightTensor m, int normFactor, float step_size, float regc, int iter)
{
try
{
Ops.RMSProp(m.TWeight, m.TGradient, m_cacheName2V[m.Name], tokenSize, step_size, m_clipval, regc, m_decayRate, m_smoothEps);
Ops.RMSProp(m.TWeight, m.TGradient, m_cacheName2V[m.Name], normFactor, step_size, m_clipval, regc, m_decayRate, m_smoothEps);
}
catch (Exception err)
{
Expand Down
2 changes: 1 addition & 1 deletion Seq2SeqSharp/Tools/BaseSeq2SeqFramework.cs
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,7 @@ internal void TrainOneEpoch(int ep, ICorpus<IPairBatch> trainCorpus, ICorpus<IPa
List<IWeightTensor> models = GetParametersFromDefaultDevice();

m_weightsUpdateCount++;
solver.UpdateWeights(models, Math.Max(sWordCnt, tWordCnt), lr, m_regc, m_weightsUpdateCount);
solver.UpdateWeights(models, processedLine, Math.Max(sWordCnt, tWordCnt), lr, m_regc, m_weightsUpdateCount);

costInTotal += cost;
updatesInOneEpoch++;
Expand Down
2 changes: 1 addition & 1 deletion Seq2SeqSharp/Utils/Misc.cs
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ public static (MultiProcessorNetworkWrapper<IWeightTensor>, MultiProcessorNetwor
if (modelMetaData.EnableSegmentEmbeddings)
{
segmentEmbeddings = new MultiProcessorNetworkWrapper<IWeightTensor>(new WeightTensor(new long[2] { modelMetaData.MaxSegmentNum, modelMetaData.EncoderEmbeddingDim }, raDeviceIds.GetNextItem(), initType: RandomInitType.Uniform, name: "SegmentEmbedding",
isTrainable: isTrainable, dtype: elementType), raDeviceIds.ToArray());
isTrainable: isTrainable, dtype: elementType, needGradientNorm: false), raDeviceIds.ToArray());
}
}

Expand Down

0 comments on commit 0c87e77

Please sign in to comment.