From ab5d7088d68d68d6131c694eb8a5948f8c25ede6 Mon Sep 17 00:00:00 2001 From: Zhongkai Fu Date: Thu, 30 Nov 2023 22:32:24 -0800 Subject: [PATCH] Update gradient normalization in optimizer --- Seq2SeqSharp/Corpus/MonoCorpus.cs | 30 +++++++++++--------- Seq2SeqSharp/Corpus/ParallelCorpus.cs | 30 +++++++++++--------- Seq2SeqSharp/Corpus/VisionTextCorpusBatch.cs | 2 +- Seq2SeqSharp/Optimizer/AdamOptimizer.cs | 10 +++---- Seq2SeqSharp/Optimizer/IOptimizer.cs | 2 +- Seq2SeqSharp/Optimizer/RMSPropOptimizer.cs | 8 +++--- Seq2SeqSharp/Tools/BaseSeq2SeqFramework.cs | 10 +++---- Seq2SeqSharp/Tools/WeightTensor.cs | 9 ++++-- 8 files changed, 57 insertions(+), 44 deletions(-) diff --git a/Seq2SeqSharp/Corpus/MonoCorpus.cs b/Seq2SeqSharp/Corpus/MonoCorpus.cs index 23bbe7e6..648759ab 100644 --- a/Seq2SeqSharp/Corpus/MonoCorpus.cs +++ b/Seq2SeqSharp/Corpus/MonoCorpus.cs @@ -226,19 +226,23 @@ public List> CountTokenFreqs() public long GetNextLength(Dictionary len2counts, long totalRecordsNum) { - long rndItems = rnd.NextInt64(totalRecordsNum); - long totalItems = 0; - foreach (var pair in len2counts) - { - long length = pair.Value; - if (totalItems <= rndItems && totalItems + length >= rndItems) - { - return pair.Key; - } - totalItems += length; - } - - return -1; + long[] keys = len2counts.Keys.ToArray(); + int rndIdx = rnd.Next(keys.Length); + return keys[rndIdx]; + + //long rndItems = rnd.NextInt64(totalRecordsNum); + //long totalItems = 0; + //foreach (var pair in len2counts) + //{ + // long length = pair.Value; + // if (totalItems <= rndItems && totalItems + length >= rndItems) + // { + // return pair.Key; + // } + // totalItems += length; + //} + + //return -1; } public void PrepareDataSet() diff --git a/Seq2SeqSharp/Corpus/ParallelCorpus.cs b/Seq2SeqSharp/Corpus/ParallelCorpus.cs index 554a572d..70dabe03 100644 --- a/Seq2SeqSharp/Corpus/ParallelCorpus.cs +++ b/Seq2SeqSharp/Corpus/ParallelCorpus.cs @@ -315,19 +315,23 @@ public interface ICorpus : IEnumerable public long GetNextLength(Dictionary len2counts, long totalRecordsNum) { - long rndItems = rnd.NextInt64(totalRecordsNum); - long totalItems = 0; - foreach (var pair in len2counts) - { - long length = pair.Value; - if (totalItems <= rndItems && totalItems + length >= rndItems) - { - return pair.Key; - } - totalItems += length; - } - - return -1; + long[] keys = len2counts.Keys.ToArray(); + int rndIdx = rnd.Next(keys.Length); + return keys[rndIdx]; + + //long rndItems = rnd.NextInt64(totalRecordsNum); + //long totalItems = 0; + //foreach (var pair in len2counts) + //{ + // long length = pair.Value; + // if (totalItems <= rndItems && totalItems + length >= rndItems) + // { + // return pair.Key; + // } + // totalItems += length; + //} + + //return -1; } public void PrepareDataSet() diff --git a/Seq2SeqSharp/Corpus/VisionTextCorpusBatch.cs b/Seq2SeqSharp/Corpus/VisionTextCorpusBatch.cs index affe0cab..96a7075b 100644 --- a/Seq2SeqSharp/Corpus/VisionTextCorpusBatch.cs +++ b/Seq2SeqSharp/Corpus/VisionTextCorpusBatch.cs @@ -48,7 +48,7 @@ public class VisionTextCorpusBatch : IVisionSntPairBatch public int BatchSize => SrcBatchPaths.Count; - public int SrcTokenCount { get; set; } = 768; + public int SrcTokenCount { get; set; } = 256; public int TgtTokenCount { get; set; } public IPairBatch CloneSrcTokens() diff --git a/Seq2SeqSharp/Optimizer/AdamOptimizer.cs b/Seq2SeqSharp/Optimizer/AdamOptimizer.cs index 72d7e950..7f74e6fa 100644 --- a/Seq2SeqSharp/Optimizer/AdamOptimizer.cs +++ b/Seq2SeqSharp/Optimizer/AdamOptimizer.cs @@ -48,7 +48,7 @@ public AdamOptimizer(float clipval, float beta1 = 0.9f, float beta2 = 0.98f, boo m_checkTensorCorrupted = checkTensorCorrupted; } - public void UpdateWeights(List model, int batchSize, float step_size, float regc, int iter) + public void UpdateWeights(List model, int tokenSize, float step_size, float regc, int iter) { Dictionary> id2Models = new Dictionary>(); Dictionary name2tensor = new Dictionary(); @@ -94,13 +94,13 @@ public void UpdateWeights(List model, int batchSize, float step_s foreach (IWeightTensor item in kv.Value) { WeightTensor m = item as WeightTensor; - UpdateWeightsTensor(m, batchSize, step_size * m.LearningRateFactor, regc, iter); + UpdateWeightsTensor(m, m.NeedGradient ? tokenSize : 1, step_size * m.LearningRateFactor, regc, iter); } }); } [MethodImpl(MethodImplOptions.AggressiveInlining)] - private void UpdateWeightsTensor(WeightTensor m, int batchSize, float step_size, float regc, int iter) + private void UpdateWeightsTensor(WeightTensor m, int tokenSize, float step_size, float regc, int iter) { try { @@ -113,7 +113,7 @@ private void UpdateWeightsTensor(WeightTensor m, int batchSize, float step_size, Ops.Copy(t2, m_cacheName2M[m.Name]); - Ops.Adam(m.TWeight, m.TGradient, t1, t2, batchSize, step_size, m_clipval, regc, m_beta2, m_beta1, iter, m_smoothEps); + Ops.Adam(m.TWeight, m.TGradient, t1, t2, tokenSize, step_size, m_clipval, regc, m_beta2, m_beta1, iter, m_smoothEps); Ops.Copy(m_cacheName2V[m.Name], t1); t1.Dispose(); @@ -123,7 +123,7 @@ private void UpdateWeightsTensor(WeightTensor m, int batchSize, float step_size, } else { - Ops.Adam(m.TWeight, m.TGradient, m_cacheName2V[m.Name], m_cacheName2M[m.Name], batchSize, step_size, m_clipval, regc, m_beta2, m_beta1, iter, m_smoothEps); + Ops.Adam(m.TWeight, m.TGradient, m_cacheName2V[m.Name], m_cacheName2M[m.Name], tokenSize, step_size, m_clipval, regc, m_beta2, m_beta1, iter, m_smoothEps); } } diff --git a/Seq2SeqSharp/Optimizer/IOptimizer.cs b/Seq2SeqSharp/Optimizer/IOptimizer.cs index 49d93928..06d55f1b 100644 --- a/Seq2SeqSharp/Optimizer/IOptimizer.cs +++ b/Seq2SeqSharp/Optimizer/IOptimizer.cs @@ -7,6 +7,6 @@ namespace Seq2SeqSharp.Optimizer { public interface IOptimizer { - void UpdateWeights(List model, int batchSize, float step_size, float regc, int iter); + void UpdateWeights(List model, int tokenSize, float step_size, float regc, int iter); } } diff --git a/Seq2SeqSharp/Optimizer/RMSPropOptimizer.cs b/Seq2SeqSharp/Optimizer/RMSPropOptimizer.cs index f686353c..93a30675 100644 --- a/Seq2SeqSharp/Optimizer/RMSPropOptimizer.cs +++ b/Seq2SeqSharp/Optimizer/RMSPropOptimizer.cs @@ -27,7 +27,7 @@ public RMSPropOptimizer(float clipval, float decayRate = 0.999f) m_decayRate = decayRate; } - public void UpdateWeights(List model, int batchSize, float step_size, float regc, int iter) + public void UpdateWeights(List model, int tokenSize, float step_size, float regc, int iter) { Dictionary> id2Models = new Dictionary>(); Dictionary name2tensor = new Dictionary(); @@ -65,17 +65,17 @@ public void UpdateWeights(List model, int batchSize, float step_s foreach (IWeightTensor item in kv.Value) { WeightTensor m = item as WeightTensor; - UpdateWeightsTensor(m, batchSize, step_size, regc, iter); + UpdateWeightsTensor(m, m.NeedGradient ? tokenSize : 1, step_size, regc, iter); } }); } [MethodImpl(MethodImplOptions.AggressiveInlining)] - private void UpdateWeightsTensor(WeightTensor m, int batchSize, float step_size, float regc, int iter) + private void UpdateWeightsTensor(WeightTensor m, int tokenSize, float step_size, float regc, int iter) { try { - Ops.RMSProp(m.TWeight, m.TGradient, m_cacheName2V[m.Name], batchSize, step_size, m_clipval, regc, m_decayRate, m_smoothEps); + Ops.RMSProp(m.TWeight, m.TGradient, m_cacheName2V[m.Name], tokenSize, step_size, m_clipval, regc, m_decayRate, m_smoothEps); } catch (Exception err) { diff --git a/Seq2SeqSharp/Tools/BaseSeq2SeqFramework.cs b/Seq2SeqSharp/Tools/BaseSeq2SeqFramework.cs index b6067d39..9f14db6b 100644 --- a/Seq2SeqSharp/Tools/BaseSeq2SeqFramework.cs +++ b/Seq2SeqSharp/Tools/BaseSeq2SeqFramework.cs @@ -350,7 +350,7 @@ protected T LoadModelRoutine(Func initializeParametersFunc, Logger.WriteLine(Logger.Level.debug, $"Creating shared embeddings for both source side and target side. Shape = '({modelMetaData.SrcVocab.Count} ,{modelMetaData.EncoderEmbeddingDim})'"); srcEmbeddings = new MultiProcessorNetworkWrapper(new WeightTensor(new long[2] { modelMetaData.SrcVocab.Count, modelMetaData.EncoderEmbeddingDim }, - raDeviceIds.GetNextItem(), initType: RandomInitType.Uniform, fanOut: true, name: "SharedEmbeddings", isTrainable: isSrcEmbeddingTrainable, learningRateFactor: encoderStartLearningRateFactor, dtype: elementType), DeviceIds); + raDeviceIds.GetNextItem(), initType: RandomInitType.Uniform, fanOut: true, name: "SharedEmbeddings", isTrainable: isSrcEmbeddingTrainable, learningRateFactor: encoderStartLearningRateFactor, dtype: elementType, needGradientNorm: false), DeviceIds); tgtEmbeddings = null; } @@ -359,12 +359,12 @@ protected T LoadModelRoutine(Func initializeParametersFunc, Logger.WriteLine(Logger.Level.debug, $"Creating embeddings for source side. Shape = '({modelMetaData.SrcVocab.Count} ,{modelMetaData.EncoderEmbeddingDim})'"); srcEmbeddings = new MultiProcessorNetworkWrapper(new WeightTensor(new long[2] { modelMetaData.SrcVocab.Count, modelMetaData.EncoderEmbeddingDim }, - raDeviceIds.GetNextItem(), initType: RandomInitType.Uniform, fanOut: true, name: "SrcEmbeddings", isTrainable: isSrcEmbeddingTrainable, learningRateFactor: encoderStartLearningRateFactor, dtype: elementType), DeviceIds); + raDeviceIds.GetNextItem(), initType: RandomInitType.Uniform, fanOut: true, name: "SrcEmbeddings", isTrainable: isSrcEmbeddingTrainable, learningRateFactor: encoderStartLearningRateFactor, dtype: elementType, needGradientNorm: false), DeviceIds); Logger.WriteLine(Logger.Level.debug, $"Creating embeddings for target side. Shape = '({modelMetaData.TgtVocab.Count} ,{modelMetaData.DecoderEmbeddingDim})'"); tgtEmbeddings = new MultiProcessorNetworkWrapper(new WeightTensor(new long[2] { modelMetaData.TgtVocab.Count, modelMetaData.DecoderEmbeddingDim }, - raDeviceIds.GetNextItem(), initType: RandomInitType.Uniform, fanOut: true, name: "TgtEmbeddings", isTrainable: isTgtEmbeddingTrainable, learningRateFactor: decoderStartLearningRateFactor, dtype: elementType), DeviceIds); + raDeviceIds.GetNextItem(), initType: RandomInitType.Uniform, fanOut: true, name: "TgtEmbeddings", isTrainable: isTgtEmbeddingTrainable, learningRateFactor: decoderStartLearningRateFactor, dtype: elementType, needGradientNorm: false), DeviceIds); } return (srcEmbeddings, tgtEmbeddings); @@ -375,7 +375,7 @@ internal MultiProcessorNetworkWrapper CreateTgtEmbeddings(IModel Logger.WriteLine(Logger.Level.debug, $"Creating embeddings for target side. Shape = '({modelMetaData.TgtVocab.Count} ,{modelMetaData.DecoderEmbeddingDim})'"); var tgtEmbeddings = new MultiProcessorNetworkWrapper(new WeightTensor(new long[2] { modelMetaData.TgtVocab.Count, modelMetaData.DecoderEmbeddingDim }, - raDeviceIds.GetNextItem(), initType: RandomInitType.Uniform, fanOut: true, name: "TgtEmbeddings", isTrainable: isTgtEmbeddingTrainable, learningRateFactor: decoderStartLearningRateFactor, dtype: elementType), DeviceIds); + raDeviceIds.GetNextItem(), initType: RandomInitType.Uniform, fanOut: true, name: "TgtEmbeddings", isTrainable: isTgtEmbeddingTrainable, learningRateFactor: decoderStartLearningRateFactor, dtype: elementType, needGradientNorm: false), DeviceIds); return tgtEmbeddings; } @@ -494,7 +494,7 @@ internal void TrainOneEpoch(int ep, ICorpus trainCorpus, ICorpus models = GetParametersFromDefaultDevice(); m_weightsUpdateCount++; - solver.UpdateWeights(models, sWordCnt + tWordCnt, lr, m_regc, m_weightsUpdateCount); + solver.UpdateWeights(models, Math.Max(sWordCnt, tWordCnt), lr, m_regc, m_weightsUpdateCount); costInTotal += cost; updatesInOneEpoch++; diff --git a/Seq2SeqSharp/Tools/WeightTensor.cs b/Seq2SeqSharp/Tools/WeightTensor.cs index 2afb1272..32cc6e45 100644 --- a/Seq2SeqSharp/Tools/WeightTensor.cs +++ b/Seq2SeqSharp/Tools/WeightTensor.cs @@ -68,6 +68,9 @@ public int Columns private readonly DType m_elementType = DType.Float32; + public readonly bool NeedGradientNorm = true; + + public long ElementCount { get @@ -158,7 +161,7 @@ public Tensor TGradient public DType ElementType => m_elementType; - public WeightTensor(long[] sizes, int deviceId, string name = "", bool isTrainable = false, RandomInitType initType = RandomInitType.None, bool fanIn = false, bool fanOut = false, float learningRateFactor = 1.0f, IComputeGraph graphToBind = null, bool needGradient = true, DType dtype = DType.Float32) + public WeightTensor(long[] sizes, int deviceId, string name = "", bool isTrainable = false, RandomInitType initType = RandomInitType.None, bool fanIn = false, bool fanOut = false, float learningRateFactor = 1.0f, IComputeGraph graphToBind = null, bool needGradient = true, DType dtype = DType.Float32, bool needGradientNorm = true) { Name = name; DeviceId = deviceId; @@ -171,6 +174,7 @@ public WeightTensor(long[] sizes, int deviceId, string name = "", bool isTrainab m_fanOut = fanOut; m_normType = initType; m_elementType= dtype; + NeedGradientNorm = needGradientNorm; if (graphToBind != null) { @@ -205,7 +209,7 @@ public WeightTensor(long[] sizes, int deviceId, string name = "", bool isTrainab } - public WeightTensor(long[] sizes, float c, int deviceId, string name = "", bool isTrainable = false, float learningRateFactor = 1.0f, bool needGradient = true, DType dtype = DType.Float32) + public WeightTensor(long[] sizes, float c, int deviceId, string name = "", bool isTrainable = false, float learningRateFactor = 1.0f, bool needGradient = true, DType dtype = DType.Float32, bool needGradientNorm = true) { Name = name; DeviceId = deviceId; @@ -214,6 +218,7 @@ public WeightTensor(long[] sizes, float c, int deviceId, string name = "", bool LearningRateFactor = learningRateFactor; Sizes = sizes; m_elementType = dtype; + NeedGradientNorm= needGradientNorm; m_allocator = TensorAllocator.Allocator(DeviceId); var tensor = new Tensor(m_allocator, m_elementType, Sizes);