Skip to content

Commit

Permalink
Update gradient normalization in optimizer
Browse files Browse the repository at this point in the history
  • Loading branch information
zhongkaifu committed Dec 1, 2023
1 parent 5bd2a37 commit ab5d708
Show file tree
Hide file tree
Showing 8 changed files with 57 additions and 44 deletions.
30 changes: 17 additions & 13 deletions Seq2SeqSharp/Corpus/MonoCorpus.cs
Original file line number Diff line number Diff line change
Expand Up @@ -226,19 +226,23 @@ public List<Dictionary<string, long>> CountTokenFreqs()

public long GetNextLength(Dictionary<long, long> len2counts, long totalRecordsNum)
{
long rndItems = rnd.NextInt64(totalRecordsNum);
long totalItems = 0;
foreach (var pair in len2counts)
{
long length = pair.Value;
if (totalItems <= rndItems && totalItems + length >= rndItems)
{
return pair.Key;
}
totalItems += length;
}

return -1;
long[] keys = len2counts.Keys.ToArray();
int rndIdx = rnd.Next(keys.Length);
return keys[rndIdx];

//long rndItems = rnd.NextInt64(totalRecordsNum);
//long totalItems = 0;
//foreach (var pair in len2counts)
//{
// long length = pair.Value;
// if (totalItems <= rndItems && totalItems + length >= rndItems)
// {
// return pair.Key;
// }
// totalItems += length;
//}

//return -1;
}

public void PrepareDataSet()
Expand Down
30 changes: 17 additions & 13 deletions Seq2SeqSharp/Corpus/ParallelCorpus.cs
Original file line number Diff line number Diff line change
Expand Up @@ -315,19 +315,23 @@ public interface ICorpus<out T> : IEnumerable<T>

public long GetNextLength(Dictionary<long, long> len2counts, long totalRecordsNum)
{
long rndItems = rnd.NextInt64(totalRecordsNum);
long totalItems = 0;
foreach (var pair in len2counts)
{
long length = pair.Value;
if (totalItems <= rndItems && totalItems + length >= rndItems)
{
return pair.Key;
}
totalItems += length;
}

return -1;
long[] keys = len2counts.Keys.ToArray();
int rndIdx = rnd.Next(keys.Length);
return keys[rndIdx];

//long rndItems = rnd.NextInt64(totalRecordsNum);
//long totalItems = 0;
//foreach (var pair in len2counts)
//{
// long length = pair.Value;
// if (totalItems <= rndItems && totalItems + length >= rndItems)
// {
// return pair.Key;
// }
// totalItems += length;
//}

//return -1;
}

public void PrepareDataSet()
Expand Down
2 changes: 1 addition & 1 deletion Seq2SeqSharp/Corpus/VisionTextCorpusBatch.cs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ public class VisionTextCorpusBatch : IVisionSntPairBatch

public int BatchSize => SrcBatchPaths.Count;

public int SrcTokenCount { get; set; } = 768;
public int SrcTokenCount { get; set; } = 256;
public int TgtTokenCount { get; set; }

public IPairBatch CloneSrcTokens()
Expand Down
10 changes: 5 additions & 5 deletions Seq2SeqSharp/Optimizer/AdamOptimizer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ public AdamOptimizer(float clipval, float beta1 = 0.9f, float beta2 = 0.98f, boo
m_checkTensorCorrupted = checkTensorCorrupted;
}

public void UpdateWeights(List<IWeightTensor> model, int batchSize, float step_size, float regc, int iter)
public void UpdateWeights(List<IWeightTensor> model, int tokenSize, float step_size, float regc, int iter)
{
Dictionary<int, List<IWeightTensor>> id2Models = new Dictionary<int, List<IWeightTensor>>();
Dictionary<string, IWeightTensor> name2tensor = new Dictionary<string, IWeightTensor>();
Expand Down Expand Up @@ -94,13 +94,13 @@ public void UpdateWeights(List<IWeightTensor> model, int batchSize, float step_s
foreach (IWeightTensor item in kv.Value)
{
WeightTensor m = item as WeightTensor;
UpdateWeightsTensor(m, batchSize, step_size * m.LearningRateFactor, regc, iter);
UpdateWeightsTensor(m, m.NeedGradient ? tokenSize : 1, step_size * m.LearningRateFactor, regc, iter);
}
});
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private void UpdateWeightsTensor(WeightTensor m, int batchSize, float step_size, float regc, int iter)
private void UpdateWeightsTensor(WeightTensor m, int tokenSize, float step_size, float regc, int iter)
{
try
{
Expand All @@ -113,7 +113,7 @@ private void UpdateWeightsTensor(WeightTensor m, int batchSize, float step_size,
Ops.Copy(t2, m_cacheName2M[m.Name]);


Ops.Adam(m.TWeight, m.TGradient, t1, t2, batchSize, step_size, m_clipval, regc, m_beta2, m_beta1, iter, m_smoothEps);
Ops.Adam(m.TWeight, m.TGradient, t1, t2, tokenSize, step_size, m_clipval, regc, m_beta2, m_beta1, iter, m_smoothEps);

Ops.Copy(m_cacheName2V[m.Name], t1);
t1.Dispose();
Expand All @@ -123,7 +123,7 @@ private void UpdateWeightsTensor(WeightTensor m, int batchSize, float step_size,
}
else
{
Ops.Adam(m.TWeight, m.TGradient, m_cacheName2V[m.Name], m_cacheName2M[m.Name], batchSize, step_size, m_clipval, regc, m_beta2, m_beta1, iter, m_smoothEps);
Ops.Adam(m.TWeight, m.TGradient, m_cacheName2V[m.Name], m_cacheName2M[m.Name], tokenSize, step_size, m_clipval, regc, m_beta2, m_beta1, iter, m_smoothEps);
}

}
Expand Down
2 changes: 1 addition & 1 deletion Seq2SeqSharp/Optimizer/IOptimizer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,6 @@ namespace Seq2SeqSharp.Optimizer
{
public interface IOptimizer
{
void UpdateWeights(List<IWeightTensor> model, int batchSize, float step_size, float regc, int iter);
void UpdateWeights(List<IWeightTensor> model, int tokenSize, float step_size, float regc, int iter);
}
}
8 changes: 4 additions & 4 deletions Seq2SeqSharp/Optimizer/RMSPropOptimizer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ public RMSPropOptimizer(float clipval, float decayRate = 0.999f)
m_decayRate = decayRate;
}

public void UpdateWeights(List<IWeightTensor> model, int batchSize, float step_size, float regc, int iter)
public void UpdateWeights(List<IWeightTensor> model, int tokenSize, float step_size, float regc, int iter)
{
Dictionary<int, List<IWeightTensor>> id2Models = new Dictionary<int, List<IWeightTensor>>();
Dictionary<string, IWeightTensor> name2tensor = new Dictionary<string, IWeightTensor>();
Expand Down Expand Up @@ -65,17 +65,17 @@ public void UpdateWeights(List<IWeightTensor> model, int batchSize, float step_s
foreach (IWeightTensor item in kv.Value)
{
WeightTensor m = item as WeightTensor;
UpdateWeightsTensor(m, batchSize, step_size, regc, iter);
UpdateWeightsTensor(m, m.NeedGradient ? tokenSize : 1, step_size, regc, iter);
}
});
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private void UpdateWeightsTensor(WeightTensor m, int batchSize, float step_size, float regc, int iter)
private void UpdateWeightsTensor(WeightTensor m, int tokenSize, float step_size, float regc, int iter)
{
try
{
Ops.RMSProp(m.TWeight, m.TGradient, m_cacheName2V[m.Name], batchSize, step_size, m_clipval, regc, m_decayRate, m_smoothEps);
Ops.RMSProp(m.TWeight, m.TGradient, m_cacheName2V[m.Name], tokenSize, step_size, m_clipval, regc, m_decayRate, m_smoothEps);
}
catch (Exception err)
{
Expand Down
10 changes: 5 additions & 5 deletions Seq2SeqSharp/Tools/BaseSeq2SeqFramework.cs
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ protected T LoadModelRoutine<ProtoBuf_T>(Func<T, bool> initializeParametersFunc,
Logger.WriteLine(Logger.Level.debug, $"Creating shared embeddings for both source side and target side. Shape = '({modelMetaData.SrcVocab.Count} ,{modelMetaData.EncoderEmbeddingDim})'");

srcEmbeddings = new MultiProcessorNetworkWrapper<IWeightTensor>(new WeightTensor(new long[2] { modelMetaData.SrcVocab.Count, modelMetaData.EncoderEmbeddingDim },
raDeviceIds.GetNextItem(), initType: RandomInitType.Uniform, fanOut: true, name: "SharedEmbeddings", isTrainable: isSrcEmbeddingTrainable, learningRateFactor: encoderStartLearningRateFactor, dtype: elementType), DeviceIds);
raDeviceIds.GetNextItem(), initType: RandomInitType.Uniform, fanOut: true, name: "SharedEmbeddings", isTrainable: isSrcEmbeddingTrainable, learningRateFactor: encoderStartLearningRateFactor, dtype: elementType, needGradientNorm: false), DeviceIds);

tgtEmbeddings = null;
}
Expand All @@ -359,12 +359,12 @@ protected T LoadModelRoutine<ProtoBuf_T>(Func<T, bool> initializeParametersFunc,
Logger.WriteLine(Logger.Level.debug, $"Creating embeddings for source side. Shape = '({modelMetaData.SrcVocab.Count} ,{modelMetaData.EncoderEmbeddingDim})'");

srcEmbeddings = new MultiProcessorNetworkWrapper<IWeightTensor>(new WeightTensor(new long[2] { modelMetaData.SrcVocab.Count, modelMetaData.EncoderEmbeddingDim },
raDeviceIds.GetNextItem(), initType: RandomInitType.Uniform, fanOut: true, name: "SrcEmbeddings", isTrainable: isSrcEmbeddingTrainable, learningRateFactor: encoderStartLearningRateFactor, dtype: elementType), DeviceIds);
raDeviceIds.GetNextItem(), initType: RandomInitType.Uniform, fanOut: true, name: "SrcEmbeddings", isTrainable: isSrcEmbeddingTrainable, learningRateFactor: encoderStartLearningRateFactor, dtype: elementType, needGradientNorm: false), DeviceIds);

Logger.WriteLine(Logger.Level.debug, $"Creating embeddings for target side. Shape = '({modelMetaData.TgtVocab.Count} ,{modelMetaData.DecoderEmbeddingDim})'");

tgtEmbeddings = new MultiProcessorNetworkWrapper<IWeightTensor>(new WeightTensor(new long[2] { modelMetaData.TgtVocab.Count, modelMetaData.DecoderEmbeddingDim },
raDeviceIds.GetNextItem(), initType: RandomInitType.Uniform, fanOut: true, name: "TgtEmbeddings", isTrainable: isTgtEmbeddingTrainable, learningRateFactor: decoderStartLearningRateFactor, dtype: elementType), DeviceIds);
raDeviceIds.GetNextItem(), initType: RandomInitType.Uniform, fanOut: true, name: "TgtEmbeddings", isTrainable: isTgtEmbeddingTrainable, learningRateFactor: decoderStartLearningRateFactor, dtype: elementType, needGradientNorm: false), DeviceIds);
}

return (srcEmbeddings, tgtEmbeddings);
Expand All @@ -375,7 +375,7 @@ internal MultiProcessorNetworkWrapper<IWeightTensor> CreateTgtEmbeddings(IModel
Logger.WriteLine(Logger.Level.debug, $"Creating embeddings for target side. Shape = '({modelMetaData.TgtVocab.Count} ,{modelMetaData.DecoderEmbeddingDim})'");

var tgtEmbeddings = new MultiProcessorNetworkWrapper<IWeightTensor>(new WeightTensor(new long[2] { modelMetaData.TgtVocab.Count, modelMetaData.DecoderEmbeddingDim },
raDeviceIds.GetNextItem(), initType: RandomInitType.Uniform, fanOut: true, name: "TgtEmbeddings", isTrainable: isTgtEmbeddingTrainable, learningRateFactor: decoderStartLearningRateFactor, dtype: elementType), DeviceIds);
raDeviceIds.GetNextItem(), initType: RandomInitType.Uniform, fanOut: true, name: "TgtEmbeddings", isTrainable: isTgtEmbeddingTrainable, learningRateFactor: decoderStartLearningRateFactor, dtype: elementType, needGradientNorm: false), DeviceIds);

return tgtEmbeddings;
}
Expand Down Expand Up @@ -494,7 +494,7 @@ internal void TrainOneEpoch(int ep, ICorpus<IPairBatch> trainCorpus, ICorpus<IPa
List<IWeightTensor> models = GetParametersFromDefaultDevice();

m_weightsUpdateCount++;
solver.UpdateWeights(models, sWordCnt + tWordCnt, lr, m_regc, m_weightsUpdateCount);
solver.UpdateWeights(models, Math.Max(sWordCnt, tWordCnt), lr, m_regc, m_weightsUpdateCount);

costInTotal += cost;
updatesInOneEpoch++;
Expand Down
9 changes: 7 additions & 2 deletions Seq2SeqSharp/Tools/WeightTensor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ public int Columns

private readonly DType m_elementType = DType.Float32;

public readonly bool NeedGradientNorm = true;


public long ElementCount
{
get
Expand Down Expand Up @@ -158,7 +161,7 @@ public Tensor TGradient

public DType ElementType => m_elementType;

public WeightTensor(long[] sizes, int deviceId, string name = "", bool isTrainable = false, RandomInitType initType = RandomInitType.None, bool fanIn = false, bool fanOut = false, float learningRateFactor = 1.0f, IComputeGraph graphToBind = null, bool needGradient = true, DType dtype = DType.Float32)
public WeightTensor(long[] sizes, int deviceId, string name = "", bool isTrainable = false, RandomInitType initType = RandomInitType.None, bool fanIn = false, bool fanOut = false, float learningRateFactor = 1.0f, IComputeGraph graphToBind = null, bool needGradient = true, DType dtype = DType.Float32, bool needGradientNorm = true)
{
Name = name;
DeviceId = deviceId;
Expand All @@ -171,6 +174,7 @@ public WeightTensor(long[] sizes, int deviceId, string name = "", bool isTrainab
m_fanOut = fanOut;
m_normType = initType;
m_elementType= dtype;
NeedGradientNorm = needGradientNorm;

if (graphToBind != null)
{
Expand Down Expand Up @@ -205,7 +209,7 @@ public WeightTensor(long[] sizes, int deviceId, string name = "", bool isTrainab
}


public WeightTensor(long[] sizes, float c, int deviceId, string name = "", bool isTrainable = false, float learningRateFactor = 1.0f, bool needGradient = true, DType dtype = DType.Float32)
public WeightTensor(long[] sizes, float c, int deviceId, string name = "", bool isTrainable = false, float learningRateFactor = 1.0f, bool needGradient = true, DType dtype = DType.Float32, bool needGradientNorm = true)
{
Name = name;
DeviceId = deviceId;
Expand All @@ -214,6 +218,7 @@ public WeightTensor(long[] sizes, float c, int deviceId, string name = "", bool
LearningRateFactor = learningRateFactor;
Sizes = sizes;
m_elementType = dtype;
NeedGradientNorm= needGradientNorm;
m_allocator = TensorAllocator.Allocator(DeviceId);

var tensor = new Tensor(m_allocator, m_elementType, Sizes);
Expand Down

0 comments on commit ab5d708

Please sign in to comment.