From e1c034da59410402cf5f528372099301ed1804cd Mon Sep 17 00:00:00 2001 From: Zhongkai Fu <fuzhongkai@gmail.com> Date: Thu, 16 Nov 2023 12:40:06 -0800 Subject: [PATCH] Add StartBatchId option for Seq2SeqConsole Optimize weights live cycle --- Seq2SeqSharp/Corpus/ParallelCorpus.cs | 16 +- Seq2SeqSharp/Corpus/Seq2SeqCorpus.cs | 4 +- Seq2SeqSharp/Tools/BaseSeq2SeqFramework.cs | 3 +- Seq2SeqSharp/Tools/ComputeGraphTensor.cs | 166 ++++++++++++++------- Seq2SeqSharp/Tools/WeightTensor.cs | 24 +-- Tools/Seq2SeqConsole/Program.cs | 2 +- 6 files changed, 143 insertions(+), 72 deletions(-) diff --git a/Seq2SeqSharp/Corpus/ParallelCorpus.cs b/Seq2SeqSharp/Corpus/ParallelCorpus.cs index 9e44e9c7..828dbc02 100644 --- a/Seq2SeqSharp/Corpus/ParallelCorpus.cs +++ b/Seq2SeqSharp/Corpus/ParallelCorpus.cs @@ -51,6 +51,7 @@ public interface ICorpus<out T> : IEnumerable<T> private string m_sortedIndexedDataSetFilePath = ""; private int m_batchNumInTotal = 0; + private int m_startBatchId = 0; public (List<Dictionary<string, long>>, List<Dictionary<string, long>>) CountTokenFreqs() { @@ -444,6 +445,17 @@ public IEnumerator<T> GetEnumerator() string[] tgtLines = br.ReadString().Split("\n"); batchIdx++; + if (batchIdx < m_startBatchId) + { + continue; + } + + if (batchIdx % 10000 == 0) + { + Logger.WriteLine($"Processing batch '{batchIdx}'"); + } + + T batch; int currentTokenCountsInBatch = 0; for (int i = 0; i < sizeInBatch; i++) @@ -498,7 +510,7 @@ public ParallelCorpus() } - public ParallelCorpus(string corpusFilePath, string srcLangName, string tgtLangName, int maxTokenSizePerBatch, int maxSrcSentLength = 32, int maxTgtSentLength = 32, ShuffleEnums shuffleEnums = ShuffleEnums.Random, TooLongSequence tooLongSequence = TooLongSequence.Ignore, string indexedFilePath = null) + public ParallelCorpus(string corpusFilePath, string srcLangName, string tgtLangName, int maxTokenSizePerBatch, int maxSrcSentLength = 32, int maxTgtSentLength = 32, ShuffleEnums shuffleEnums = ShuffleEnums.Random, TooLongSequence tooLongSequence = TooLongSequence.Ignore, string indexedFilePath = null, int startBatchId = 0) { Logger.WriteLine($"Loading parallel corpus from '{corpusFilePath}' for source side '{srcLangName}' and target side '{tgtLangName}' MaxSrcSentLength = '{maxSrcSentLength}', MaxTgtSentLength = '{maxTgtSentLength}', aggregateSrcLengthForShuffle = '{shuffleEnums}', TooLongSequence = '{tooLongSequence}'"); m_maxTokenSizePerBatch = maxTokenSizePerBatch; @@ -546,7 +558,7 @@ public ParallelCorpus(string corpusFilePath, string srcLangName, string tgtLangN m_srcFileList.Add(pair.Value); m_tgtFileList.Add(tgtKey2FileName[pair.Key]); } - + m_startBatchId = startBatchId; } } } diff --git a/Seq2SeqSharp/Corpus/Seq2SeqCorpus.cs b/Seq2SeqSharp/Corpus/Seq2SeqCorpus.cs index 8a0f3838..f6b01a81 100644 --- a/Seq2SeqSharp/Corpus/Seq2SeqCorpus.cs +++ b/Seq2SeqSharp/Corpus/Seq2SeqCorpus.cs @@ -18,8 +18,8 @@ namespace Seq2SeqSharp.Corpus public class Seq2SeqCorpus : ParallelCorpus<Seq2SeqCorpusBatch> { - public Seq2SeqCorpus(string corpusFilePath, string srcLangName, string tgtLangName, int maxTokenSizePerBatch, int maxSrcSentLength = 32, int maxTgtSentLength = 32, ShuffleEnums shuffleEnums = ShuffleEnums.Random, TooLongSequence tooLongSequence = TooLongSequence.Ignore, string indexedFilePath = null) - :base (corpusFilePath, srcLangName, tgtLangName, maxTokenSizePerBatch, maxSrcSentLength, maxTgtSentLength, shuffleEnums: shuffleEnums, tooLongSequence: tooLongSequence, indexedFilePath: indexedFilePath) + public Seq2SeqCorpus(string corpusFilePath, string srcLangName, string tgtLangName, int maxTokenSizePerBatch, int maxSrcSentLength = 32, int maxTgtSentLength = 32, ShuffleEnums shuffleEnums = ShuffleEnums.Random, TooLongSequence tooLongSequence = TooLongSequence.Ignore, string indexedFilePath = null, int startBatchId = 0) + :base (corpusFilePath, srcLangName, tgtLangName, maxTokenSizePerBatch, maxSrcSentLength, maxTgtSentLength, shuffleEnums: shuffleEnums, tooLongSequence: tooLongSequence, indexedFilePath: indexedFilePath, startBatchId: startBatchId) { } diff --git a/Seq2SeqSharp/Tools/BaseSeq2SeqFramework.cs b/Seq2SeqSharp/Tools/BaseSeq2SeqFramework.cs index 99cbdc97..ca5abd51 100644 --- a/Seq2SeqSharp/Tools/BaseSeq2SeqFramework.cs +++ b/Seq2SeqSharp/Tools/BaseSeq2SeqFramework.cs @@ -391,7 +391,8 @@ public void Train(int maxTrainingEpoch, ICorpus<IPairBatch> trainCorpus, ICorpus TrainOneEpoch(i, trainCorpus, validCorpusList, learningRate, optimizer, taskId2metrics, decodingOptions, RunForwardOnSingleDevice); // send progress reporting in the form of a percentage value (0-100%) - Logger.WriteLine(Logger.Level.info, "", (int)(100 * (i + 1) / maxTrainingEpoch)); + var finishedEpochPercent = (int)(100 * (i + 1) / maxTrainingEpoch); + Logger.WriteLine(Logger.Level.info, $"Finished Epoch Percent: {finishedEpochPercent}%", finishedEpochPercent); } SaveModel(createBackupPrevious: false, suffix: $".{m_weightsUpdateCount}"); diff --git a/Seq2SeqSharp/Tools/ComputeGraphTensor.cs b/Seq2SeqSharp/Tools/ComputeGraphTensor.cs index cc2df467..4240c917 100644 --- a/Seq2SeqSharp/Tools/ComputeGraphTensor.cs +++ b/Seq2SeqSharp/Tools/ComputeGraphTensor.cs @@ -119,17 +119,22 @@ public IWeightTensor Sigmoid(IWeightTensor w) if (m_needsBackprop) { + Tensor resTWeight = null; + if (m.NeedGradient) + { + resTWeight = res.TWeight.CopyRef(); + } + void backward() { if (m.NeedGradient) { - m.AddSigmoidGradient(res); + m.AddSigmoidGradient(resTWeight, res.TGradient); + resTWeight.Dispose(); } res.Dispose(); } m_backprop.Add(backward); - - res.UnbindFromComputeGraph(); } return res; @@ -214,23 +219,32 @@ public IWeightTensor AddTanh(IWeightTensor w1, IWeightTensor w2) Ops.AddTanh(res.TWeight, m1.TWeight, m2.TWeight); if (m_needsBackprop) { + Tensor resTWeight = null; + if (m1.NeedGradient || m2.NeedGradient) + { + resTWeight = res.TWeight.CopyRef(); + } + void backward() { if (m1.NeedGradient) { - m1.AddTanhGradient(res); + m1.AddTanhGradient(resTWeight, res.TGradient); } if (m2.NeedGradient) { - m2.AddTanhGradient(res); + m2.AddTanhGradient(resTWeight, res.TGradient); + } + + if (m1.NeedGradient || m2.NeedGradient) + { + resTWeight.Dispose(); } res.Dispose(); } m_backprop.Add(backward); - - res.UnbindFromComputeGraph(); } return res; @@ -248,21 +262,32 @@ public IWeightTensor AddTanh(IWeightTensor w1, IWeightTensor w2, IWeightTensor w Ops.AddTanh3(res.TWeight, m1.TWeight, m2.TWeight, m3.TWeight); if (m_needsBackprop) { + Tensor resTWeight = null; + if (m1.NeedGradient || m2.NeedGradient || m3.NeedGradient) + { + resTWeight = res.TWeight.CopyRef(); + } + void backward() { if (m1.NeedGradient) { - m1.AddTanhGradient(res); + m1.AddTanhGradient(resTWeight, res.TGradient); } if (m2.NeedGradient) { - m2.AddTanhGradient(res); + m2.AddTanhGradient(resTWeight, res.TGradient); } if (m3.NeedGradient) { - m3.AddTanhGradient(res); + m3.AddTanhGradient(resTWeight, res.TGradient); + } + + if (m1.NeedGradient || m2.NeedGradient || m3.NeedGradient) + { + resTWeight.Dispose(); } res.Dispose(); @@ -466,55 +491,63 @@ public IWeightTensor EltMulMulAdd(IWeightTensor w1, IWeightTensor w2, IWeightTen Ops.MulMulAdd(res.TWeight, m1.TWeight, m2.TWeight, m3.TWeight, m4.TWeight); if (m_needsBackprop) { + Tensor m1TWeight = null; + Tensor m2TWeight = null; + Tensor m3TWeight = null; + Tensor m4TWeight = null; + + + if (m2.NeedGradient) + { + m1TWeight = m1.TWeight.CopyRef(); + } + + if (m1.NeedGradient) + { + m2TWeight = m2.TWeight.CopyRef(); + } + + if (m4.NeedGradient) + { + m3TWeight = m3.TWeight.CopyRef(); + } + + if (m3.NeedGradient) + { + m4TWeight = m4.TWeight.CopyRef(); + } + void backward() { res.ReleaseWeight(); if (m1.NeedGradient) { - m1.AddMulGradient(m2.TWeight, res.TGradient); + m1.AddMulGradient(m2TWeight, res.TGradient); + m2TWeight.Dispose(); } if (m2.NeedGradient) { - m2.AddMulGradient(m1.TWeight, res.TGradient); + m2.AddMulGradient(m1TWeight, res.TGradient); + m1TWeight.Dispose(); } if (m3.NeedGradient) { - m3.AddMulGradient(m4.TWeight, res.TGradient); + m3.AddMulGradient(m4TWeight, res.TGradient); + m4TWeight.Dispose(); } if (m4.NeedGradient) { - m4.AddMulGradient(m3.TWeight, res.TGradient); + m4.AddMulGradient(m3TWeight, res.TGradient); + m3TWeight.Dispose(); } res.Dispose(); } m_backprop.Add(backward); - - // These tensors' weights will be used during back-propogation, so we unbind them from the computing graph - - if (m2.NeedGradient) - { - m1.UnbindFromComputeGraph(); - } - - if (m1.NeedGradient) - { - m2.UnbindFromComputeGraph(); - } - - if (m4.NeedGradient) - { - m3.UnbindFromComputeGraph(); - } - - if (m3.NeedGradient) - { - m4.UnbindFromComputeGraph(); - } } @@ -531,26 +564,38 @@ public IWeightTensor EltMul(IWeightTensor w1, IWeightTensor w2) Ops.Mul(res.TWeight, m1.TWeight, m2.TWeight); if (m_needsBackprop) { + Tensor m1TWeight = null; + Tensor m2TWeight = null; + + if (m2.NeedGradient) + { + m1TWeight = m1.TWeight.CopyRef(); + } + + if (m1.NeedGradient) + { + m2TWeight = m2.TWeight.CopyRef(); + } + void backward() { res.ReleaseWeight(); if (m1.NeedGradient) { - m1.AddMulGradient(m2.TWeight, res.TGradient); + m1.AddMulGradient(m2TWeight, res.TGradient); + m2TWeight.Dispose(); } if (m2.NeedGradient) { - m2.AddMulGradient(m1.TWeight, res.TGradient); + m2.AddMulGradient(m1TWeight, res.TGradient); + m1TWeight.Dispose(); } res.Dispose(); } m_backprop.Add(backward); - - m1.UnbindFromComputeGraph(); - m2.UnbindFromComputeGraph(); } return res; @@ -867,11 +912,18 @@ public IWeightTensor Tanh(IWeightTensor w) Ops.Tanh(res.TWeight, m.TWeight); if (m_needsBackprop) { + Tensor resTWeight = null; + if (m.NeedGradient) + { + resTWeight = res.TWeight.CopyRef(); + } + void backward() { if (m.NeedGradient) { - m.AddTanhGradient(res); + m.AddTanhGradient(resTWeight, res.TGradient); + resTWeight.Dispose(); } res.Dispose(); @@ -1517,6 +1569,12 @@ public IWeightTensor Softmax(IWeightTensor w, bool runGradients = true, bool inP Ops.Softmax(res.TWeight, t.TWeight); if (m_needsBackprop) { + Tensor resTWeight = null; + if (runGradients && t.NeedGradient) + { + resTWeight = res.TWeight.CopyRef(); + } + void backward() { if (runGradients && t.NeedGradient) @@ -1525,15 +1583,13 @@ void backward() { t.TGradient = res.TGradient.CopyRef(); } - t.AddSoftmaxGradient(res, inPlace); + t.AddSoftmaxGradient(resTWeight, res.TGradient, inPlace); + resTWeight.Dispose(); } res.Dispose(); } m_backprop.Add(backward); - - res.UnbindFromComputeGraph(); - } return res; @@ -2178,21 +2234,22 @@ public IWeightTensor LayerNorm(IWeightTensor src, IWeightTensor alpha, IWeightTe { var srcTWeight = srcT.TWeight.CopyRef(); var resTWeight = res.TWeight.CopyRef(); + var alphaTWeight = alphaT.TWeight.CopyRef(); + var betaTWeight = betaT.TWeight.CopyRef(); void backward() { if (srcT.NeedGradient || alphaT.NeedGradient || betaT.NeedGradient) { - Ops.LayerNormGrad(srcT.TGradient, alphaT.TGradient, betaT.TGradient, res.TGradient, resTWeight, srcTWeight, alphaT.TWeight, betaT.TWeight, eps); + Ops.LayerNormGrad(srcT.TGradient, alphaT.TGradient, betaT.TGradient, res.TGradient, resTWeight, srcTWeight, alphaTWeight, betaTWeight, eps); } srcTWeight.Dispose(); resTWeight.Dispose(); + alphaTWeight.Dispose(); + betaTWeight.Dispose(); res.Dispose(); } m_backprop.Add(backward); - - alphaT.UnbindFromComputeGraph(); - betaT.UnbindFromComputeGraph(); } return res; @@ -2213,21 +2270,22 @@ public IWeightTensor RMSNorm(IWeightTensor src, IWeightTensor alpha, IWeightTens { var srcTWeight = srcT.TWeight.CopyRef(); var resTWeight = res.TWeight.CopyRef(); + var alphaTWeight = alphaT.TWeight.CopyRef(); + var betaTWeight = betaT.TWeight.CopyRef(); void backward() { if (srcT.NeedGradient) { - Ops.RMSNormGrad(srcT.TGradient, alphaT.TGradient, betaT.TGradient, res.TGradient, resTWeight, srcTWeight, alphaT.TWeight, betaT.TWeight, eps); + Ops.RMSNormGrad(srcT.TGradient, alphaT.TGradient, betaT.TGradient, res.TGradient, resTWeight, srcTWeight, alphaTWeight, betaTWeight, eps); } srcTWeight.Dispose(); resTWeight.Dispose(); + alphaTWeight.Dispose(); + betaTWeight.Dispose(); res.Dispose(); } m_backprop.Add(backward); - - alphaT.UnbindFromComputeGraph(); - betaT.UnbindFromComputeGraph(); } return res; diff --git a/Seq2SeqSharp/Tools/WeightTensor.cs b/Seq2SeqSharp/Tools/WeightTensor.cs index 0b8393da..dd87a876 100644 --- a/Seq2SeqSharp/Tools/WeightTensor.cs +++ b/Seq2SeqSharp/Tools/WeightTensor.cs @@ -368,19 +368,19 @@ public void PrintWeights() // } } - public void AddSoftmaxGradient(WeightTensor src, bool inPlace = false) + public void AddSoftmaxGradient(Tensor srcTWeight, Tensor srcTGradient, bool inPlace = false) { if (m_TGradient == null) { m_allocator = TensorAllocator.Allocator(DeviceId); - m_TGradient = new Tensor(m_allocator, src.TGradient.ElementType, Sizes); - Ops.SoftmaxGrad(m_TGradient, src.TGradient, src.TWeight, false); + m_TGradient = new Tensor(m_allocator, srcTGradient.ElementType, Sizes); + Ops.SoftmaxGrad(m_TGradient, srcTGradient, srcTWeight, false); m_GradientSetName = "AddSoftmaxGradient"; } else { - Ops.SoftmaxGrad(m_TGradient, src.TGradient, src.TWeight, !inPlace); + Ops.SoftmaxGrad(m_TGradient, srcTGradient, srcTWeight, !inPlace); } } @@ -452,36 +452,36 @@ public void AddMulGradient(Tensor w, Tensor g, bool inPlace = false) } } - public void AddSigmoidGradient(WeightTensor src) + public void AddSigmoidGradient(Tensor srcTWeight, Tensor srcTGradient) { if (m_TGradient == null) { m_allocator = TensorAllocator.Allocator(DeviceId); - m_TGradient = new Tensor(m_allocator, src.TGradient.ElementType, Sizes); - Ops.SigmoidD(m_TGradient, src.TWeight, src.TGradient); + m_TGradient = new Tensor(m_allocator, srcTGradient.ElementType, Sizes); + Ops.SigmoidD(m_TGradient, srcTWeight, srcTGradient); m_GradientSetName = "AddSigmoidGradient"; } else { - Ops.AddSigmoidD(m_TGradient, m_TGradient, src.TWeight, src.TGradient); + Ops.AddSigmoidD(m_TGradient, m_TGradient, srcTWeight, srcTGradient); } } - public void AddTanhGradient(WeightTensor src) + public void AddTanhGradient(Tensor srcTWeight, Tensor srcTGradient) { if (m_TGradient == null) { m_allocator = TensorAllocator.Allocator(DeviceId); - m_TGradient = new Tensor(m_allocator, src.TGradient.ElementType, Sizes); + m_TGradient = new Tensor(m_allocator, srcTGradient.ElementType, Sizes); - Ops.TanhD(m_TGradient, src.TWeight, src.TGradient); + Ops.TanhD(m_TGradient, srcTWeight, srcTGradient); m_GradientSetName = "AddTanhGradient"; } else { - Ops.AddTanhD(m_TGradient, m_TGradient, src.TWeight, src.TGradient); + Ops.AddTanhD(m_TGradient, m_TGradient, srcTWeight, srcTGradient); } } diff --git a/Tools/Seq2SeqConsole/Program.cs b/Tools/Seq2SeqConsole/Program.cs index eb378c85..6d2e3558 100644 --- a/Tools/Seq2SeqConsole/Program.cs +++ b/Tools/Seq2SeqConsole/Program.cs @@ -64,7 +64,7 @@ private static void Main(string[] args) { // Load train corpus var trainCorpus = new Seq2SeqCorpus(corpusFilePath: opts.TrainCorpusPath, srcLangName: opts.SrcLang, tgtLangName: opts.TgtLang, maxTokenSizePerBatch: opts.MaxTokenSizePerBatch, - maxSrcSentLength: opts.MaxSrcSentLength, maxTgtSentLength: opts.MaxTgtSentLength, shuffleEnums: opts.ShuffleType, tooLongSequence: opts.TooLongSequence, indexedFilePath: opts.IndexedCorpusPath); + maxSrcSentLength: opts.MaxSrcSentLength, maxTgtSentLength: opts.MaxTgtSentLength, shuffleEnums: opts.ShuffleType, tooLongSequence: opts.TooLongSequence, indexedFilePath: opts.IndexedCorpusPath, startBatchId: opts.StartBatchId); // Load valid corpus var validCorpusList = new List<Seq2SeqCorpus>();