From e1c034da59410402cf5f528372099301ed1804cd Mon Sep 17 00:00:00 2001
From: Zhongkai Fu <fuzhongkai@gmail.com>
Date: Thu, 16 Nov 2023 12:40:06 -0800
Subject: [PATCH] Add StartBatchId option for Seq2SeqConsole Optimize weights
 live cycle

---
 Seq2SeqSharp/Corpus/ParallelCorpus.cs      |  16 +-
 Seq2SeqSharp/Corpus/Seq2SeqCorpus.cs       |   4 +-
 Seq2SeqSharp/Tools/BaseSeq2SeqFramework.cs |   3 +-
 Seq2SeqSharp/Tools/ComputeGraphTensor.cs   | 166 ++++++++++++++-------
 Seq2SeqSharp/Tools/WeightTensor.cs         |  24 +--
 Tools/Seq2SeqConsole/Program.cs            |   2 +-
 6 files changed, 143 insertions(+), 72 deletions(-)

diff --git a/Seq2SeqSharp/Corpus/ParallelCorpus.cs b/Seq2SeqSharp/Corpus/ParallelCorpus.cs
index 9e44e9c7..828dbc02 100644
--- a/Seq2SeqSharp/Corpus/ParallelCorpus.cs
+++ b/Seq2SeqSharp/Corpus/ParallelCorpus.cs
@@ -51,6 +51,7 @@ public interface ICorpus<out T> : IEnumerable<T>
 
         private string m_sortedIndexedDataSetFilePath = "";
         private int m_batchNumInTotal = 0;
+        private int m_startBatchId = 0;
 
         public (List<Dictionary<string, long>>, List<Dictionary<string, long>>) CountTokenFreqs()
         {
@@ -444,6 +445,17 @@ public IEnumerator<T> GetEnumerator()
                         string[] tgtLines = br.ReadString().Split("\n");
                         batchIdx++;
 
+                        if (batchIdx < m_startBatchId)
+                        {
+                            continue;
+                        }
+
+                        if (batchIdx % 10000 == 0)
+                        {
+                            Logger.WriteLine($"Processing batch '{batchIdx}'");
+                        }
+
+
                         T batch;
                         int currentTokenCountsInBatch = 0;
                         for (int i = 0; i < sizeInBatch; i++)
@@ -498,7 +510,7 @@ public ParallelCorpus()
 
         }
 
-        public ParallelCorpus(string corpusFilePath, string srcLangName, string tgtLangName, int maxTokenSizePerBatch, int maxSrcSentLength = 32, int maxTgtSentLength = 32, ShuffleEnums shuffleEnums = ShuffleEnums.Random, TooLongSequence tooLongSequence = TooLongSequence.Ignore, string indexedFilePath = null)
+        public ParallelCorpus(string corpusFilePath, string srcLangName, string tgtLangName, int maxTokenSizePerBatch, int maxSrcSentLength = 32, int maxTgtSentLength = 32, ShuffleEnums shuffleEnums = ShuffleEnums.Random, TooLongSequence tooLongSequence = TooLongSequence.Ignore, string indexedFilePath = null, int startBatchId = 0)
         {
             Logger.WriteLine($"Loading parallel corpus from '{corpusFilePath}' for source side '{srcLangName}' and target side '{tgtLangName}' MaxSrcSentLength = '{maxSrcSentLength}',  MaxTgtSentLength = '{maxTgtSentLength}', aggregateSrcLengthForShuffle = '{shuffleEnums}', TooLongSequence = '{tooLongSequence}'");
             m_maxTokenSizePerBatch = maxTokenSizePerBatch;
@@ -546,7 +558,7 @@ public ParallelCorpus(string corpusFilePath, string srcLangName, string tgtLangN
                 m_srcFileList.Add(pair.Value);
                 m_tgtFileList.Add(tgtKey2FileName[pair.Key]);
             }
-
+            m_startBatchId = startBatchId;
         }
     }
 }
diff --git a/Seq2SeqSharp/Corpus/Seq2SeqCorpus.cs b/Seq2SeqSharp/Corpus/Seq2SeqCorpus.cs
index 8a0f3838..f6b01a81 100644
--- a/Seq2SeqSharp/Corpus/Seq2SeqCorpus.cs
+++ b/Seq2SeqSharp/Corpus/Seq2SeqCorpus.cs
@@ -18,8 +18,8 @@ namespace Seq2SeqSharp.Corpus
     public class Seq2SeqCorpus : ParallelCorpus<Seq2SeqCorpusBatch>
     {
 
-        public Seq2SeqCorpus(string corpusFilePath, string srcLangName, string tgtLangName, int maxTokenSizePerBatch, int maxSrcSentLength = 32, int maxTgtSentLength = 32, ShuffleEnums shuffleEnums = ShuffleEnums.Random, TooLongSequence tooLongSequence = TooLongSequence.Ignore, string indexedFilePath = null)
-            :base (corpusFilePath, srcLangName, tgtLangName, maxTokenSizePerBatch, maxSrcSentLength, maxTgtSentLength, shuffleEnums: shuffleEnums, tooLongSequence: tooLongSequence, indexedFilePath: indexedFilePath)
+        public Seq2SeqCorpus(string corpusFilePath, string srcLangName, string tgtLangName, int maxTokenSizePerBatch, int maxSrcSentLength = 32, int maxTgtSentLength = 32, ShuffleEnums shuffleEnums = ShuffleEnums.Random, TooLongSequence tooLongSequence = TooLongSequence.Ignore, string indexedFilePath = null, int startBatchId = 0)
+            :base (corpusFilePath, srcLangName, tgtLangName, maxTokenSizePerBatch, maxSrcSentLength, maxTgtSentLength, shuffleEnums: shuffleEnums, tooLongSequence: tooLongSequence, indexedFilePath: indexedFilePath, startBatchId: startBatchId)
         {
 
         }
diff --git a/Seq2SeqSharp/Tools/BaseSeq2SeqFramework.cs b/Seq2SeqSharp/Tools/BaseSeq2SeqFramework.cs
index 99cbdc97..ca5abd51 100644
--- a/Seq2SeqSharp/Tools/BaseSeq2SeqFramework.cs
+++ b/Seq2SeqSharp/Tools/BaseSeq2SeqFramework.cs
@@ -391,7 +391,8 @@ public void Train(int maxTrainingEpoch, ICorpus<IPairBatch> trainCorpus, ICorpus
                 TrainOneEpoch(i, trainCorpus, validCorpusList, learningRate, optimizer, taskId2metrics, decodingOptions, RunForwardOnSingleDevice);
 
                 // send progress reporting in the form of a percentage value (0-100%)
-                Logger.WriteLine(Logger.Level.info, "", (int)(100 * (i + 1) / maxTrainingEpoch));
+                var finishedEpochPercent = (int)(100 * (i + 1) / maxTrainingEpoch);
+                Logger.WriteLine(Logger.Level.info, $"Finished Epoch Percent: {finishedEpochPercent}%", finishedEpochPercent);
             }
 
             SaveModel(createBackupPrevious: false, suffix: $".{m_weightsUpdateCount}");
diff --git a/Seq2SeqSharp/Tools/ComputeGraphTensor.cs b/Seq2SeqSharp/Tools/ComputeGraphTensor.cs
index cc2df467..4240c917 100644
--- a/Seq2SeqSharp/Tools/ComputeGraphTensor.cs
+++ b/Seq2SeqSharp/Tools/ComputeGraphTensor.cs
@@ -119,17 +119,22 @@ public IWeightTensor Sigmoid(IWeightTensor w)
 
             if (m_needsBackprop)
             {
+                Tensor resTWeight = null;
+                if (m.NeedGradient)
+                {
+                    resTWeight = res.TWeight.CopyRef();
+                }
+
                 void backward()
                 {
                     if (m.NeedGradient)
                     {
-                        m.AddSigmoidGradient(res);
+                        m.AddSigmoidGradient(resTWeight, res.TGradient);
+                        resTWeight.Dispose();
                     }
                     res.Dispose();
                 }
                 m_backprop.Add(backward);
-
-                res.UnbindFromComputeGraph();
             }
 
             return res;
@@ -214,23 +219,32 @@ public IWeightTensor AddTanh(IWeightTensor w1, IWeightTensor w2)
             Ops.AddTanh(res.TWeight, m1.TWeight, m2.TWeight);
             if (m_needsBackprop)
             {
+                Tensor resTWeight = null;
+                if (m1.NeedGradient || m2.NeedGradient)
+                {
+                    resTWeight = res.TWeight.CopyRef();
+                }
+
                 void backward()
                 {
                     if (m1.NeedGradient)
                     {
-                        m1.AddTanhGradient(res);
+                        m1.AddTanhGradient(resTWeight, res.TGradient);
                     }
 
                     if (m2.NeedGradient)
                     {
-                        m2.AddTanhGradient(res);
+                        m2.AddTanhGradient(resTWeight, res.TGradient);
+                    }
+
+                    if (m1.NeedGradient || m2.NeedGradient)
+                    {
+                        resTWeight.Dispose();
                     }
 
                     res.Dispose();
                 }
                 m_backprop.Add(backward);
-
-                res.UnbindFromComputeGraph();
             }
 
             return res;
@@ -248,21 +262,32 @@ public IWeightTensor AddTanh(IWeightTensor w1, IWeightTensor w2, IWeightTensor w
             Ops.AddTanh3(res.TWeight, m1.TWeight, m2.TWeight, m3.TWeight);
             if (m_needsBackprop)
             {
+                Tensor resTWeight = null;
+                if (m1.NeedGradient || m2.NeedGradient || m3.NeedGradient)
+                {
+                    resTWeight = res.TWeight.CopyRef();
+                }
+
                 void backward()
                 {
                     if (m1.NeedGradient)
                     {
-                        m1.AddTanhGradient(res);
+                        m1.AddTanhGradient(resTWeight, res.TGradient);
                     }
 
                     if (m2.NeedGradient)
                     {
-                        m2.AddTanhGradient(res);
+                        m2.AddTanhGradient(resTWeight, res.TGradient);
                     }
 
                     if (m3.NeedGradient)
                     {
-                        m3.AddTanhGradient(res);
+                        m3.AddTanhGradient(resTWeight, res.TGradient);
+                    }
+
+                    if (m1.NeedGradient || m2.NeedGradient || m3.NeedGradient)
+                    {
+                        resTWeight.Dispose();
                     }
 
                     res.Dispose();
@@ -466,55 +491,63 @@ public IWeightTensor EltMulMulAdd(IWeightTensor w1, IWeightTensor w2, IWeightTen
             Ops.MulMulAdd(res.TWeight, m1.TWeight, m2.TWeight, m3.TWeight, m4.TWeight);
             if (m_needsBackprop)
             {
+                Tensor m1TWeight = null;
+                Tensor m2TWeight = null;
+                Tensor m3TWeight = null;
+                Tensor m4TWeight = null;
+
+
+                if (m2.NeedGradient)
+                {
+                    m1TWeight = m1.TWeight.CopyRef();
+                }
+
+                if (m1.NeedGradient)
+                {
+                    m2TWeight = m2.TWeight.CopyRef();
+                }
+
+                if (m4.NeedGradient)
+                {
+                    m3TWeight = m3.TWeight.CopyRef();
+                }
+
+                if (m3.NeedGradient)
+                {
+                    m4TWeight = m4.TWeight.CopyRef();
+                }
+
                 void backward()
                 {
                     res.ReleaseWeight();
 
                     if (m1.NeedGradient)
                     {
-                        m1.AddMulGradient(m2.TWeight, res.TGradient);
+                        m1.AddMulGradient(m2TWeight, res.TGradient);
+                        m2TWeight.Dispose();
                     }
 
                     if (m2.NeedGradient)
                     {
-                        m2.AddMulGradient(m1.TWeight, res.TGradient);
+                        m2.AddMulGradient(m1TWeight, res.TGradient);
+                        m1TWeight.Dispose();
                     }
 
                     if (m3.NeedGradient)
                     {
-                        m3.AddMulGradient(m4.TWeight, res.TGradient);
+                        m3.AddMulGradient(m4TWeight, res.TGradient);
+                        m4TWeight.Dispose();
                     }
 
                     if (m4.NeedGradient)
                     {
-                        m4.AddMulGradient(m3.TWeight, res.TGradient);
+                        m4.AddMulGradient(m3TWeight, res.TGradient);
+                        m3TWeight.Dispose();
                     }
 
                     res.Dispose();
                 }
                 m_backprop.Add(backward);
-
-                // These tensors' weights will be used during back-propogation, so we unbind them from the computing graph
-
-                if (m2.NeedGradient)
-                {
-                    m1.UnbindFromComputeGraph();
-                }
-
-                if (m1.NeedGradient)
-                {
-                    m2.UnbindFromComputeGraph();
-                }
-
-                if (m4.NeedGradient)
-                {
-                    m3.UnbindFromComputeGraph();
-                }
-
-                if (m3.NeedGradient)
-                {
-                    m4.UnbindFromComputeGraph();
-                }
             }
 
 
@@ -531,26 +564,38 @@ public IWeightTensor EltMul(IWeightTensor w1, IWeightTensor w2)
             Ops.Mul(res.TWeight, m1.TWeight, m2.TWeight);
             if (m_needsBackprop)
             {
+                Tensor m1TWeight = null;
+                Tensor m2TWeight = null;
+
+                if (m2.NeedGradient)
+                {
+                    m1TWeight = m1.TWeight.CopyRef();
+                }
+
+                if (m1.NeedGradient)
+                {
+                    m2TWeight = m2.TWeight.CopyRef();
+                }
+
                 void backward()
                 {
                     res.ReleaseWeight();
 
                     if (m1.NeedGradient)
                     {
-                        m1.AddMulGradient(m2.TWeight, res.TGradient);
+                        m1.AddMulGradient(m2TWeight, res.TGradient);
+                        m2TWeight.Dispose();
                     }
 
                     if (m2.NeedGradient)
                     {
-                        m2.AddMulGradient(m1.TWeight, res.TGradient);
+                        m2.AddMulGradient(m1TWeight, res.TGradient);
+                        m1TWeight.Dispose();
                     }
 
                     res.Dispose();
                 }
                 m_backprop.Add(backward);
-
-                m1.UnbindFromComputeGraph();
-                m2.UnbindFromComputeGraph();
             }
 
             return res;
@@ -867,11 +912,18 @@ public IWeightTensor Tanh(IWeightTensor w)
             Ops.Tanh(res.TWeight, m.TWeight);
             if (m_needsBackprop)
             {
+                Tensor resTWeight = null;
+                if (m.NeedGradient)
+                {
+                    resTWeight = res.TWeight.CopyRef();
+                }
+
                 void backward()
                 {
                     if (m.NeedGradient)
                     {
-                        m.AddTanhGradient(res);
+                        m.AddTanhGradient(resTWeight, res.TGradient);
+                        resTWeight.Dispose();
                     }
 
                     res.Dispose();
@@ -1517,6 +1569,12 @@ public IWeightTensor Softmax(IWeightTensor w, bool runGradients = true, bool inP
             Ops.Softmax(res.TWeight, t.TWeight);
             if (m_needsBackprop)
             {
+                Tensor resTWeight = null;
+                if (runGradients && t.NeedGradient)
+                {
+                    resTWeight = res.TWeight.CopyRef();
+                }
+
                 void backward()
                 {
                     if (runGradients && t.NeedGradient)
@@ -1525,15 +1583,13 @@ void backward()
                         {
                             t.TGradient = res.TGradient.CopyRef();
                         }
-                        t.AddSoftmaxGradient(res, inPlace);
+                        t.AddSoftmaxGradient(resTWeight, res.TGradient, inPlace);
+                        resTWeight.Dispose();
                     }
 
                     res.Dispose();
                 }
                 m_backprop.Add(backward);
-
-                res.UnbindFromComputeGraph();
-
             }
 
             return res;
@@ -2178,21 +2234,22 @@ public IWeightTensor LayerNorm(IWeightTensor src, IWeightTensor alpha, IWeightTe
             {
                 var srcTWeight = srcT.TWeight.CopyRef();
                 var resTWeight = res.TWeight.CopyRef();
+                var alphaTWeight = alphaT.TWeight.CopyRef();
+                var betaTWeight = betaT.TWeight.CopyRef();
                 void backward()
                 {
                     if (srcT.NeedGradient || alphaT.NeedGradient || betaT.NeedGradient)
                     {
-                        Ops.LayerNormGrad(srcT.TGradient, alphaT.TGradient, betaT.TGradient, res.TGradient, resTWeight, srcTWeight, alphaT.TWeight, betaT.TWeight, eps);
+                        Ops.LayerNormGrad(srcT.TGradient, alphaT.TGradient, betaT.TGradient, res.TGradient, resTWeight, srcTWeight, alphaTWeight, betaTWeight, eps);
                     }
                     srcTWeight.Dispose();
                     resTWeight.Dispose();
+                    alphaTWeight.Dispose();
+                    betaTWeight.Dispose();
 
                     res.Dispose();
                 }
                 m_backprop.Add(backward);
-
-                alphaT.UnbindFromComputeGraph();
-                betaT.UnbindFromComputeGraph();
             }
 
             return res;
@@ -2213,21 +2270,22 @@ public IWeightTensor RMSNorm(IWeightTensor src, IWeightTensor alpha, IWeightTens
             {
                 var srcTWeight = srcT.TWeight.CopyRef();
                 var resTWeight = res.TWeight.CopyRef();
+                var alphaTWeight = alphaT.TWeight.CopyRef();
+                var betaTWeight = betaT.TWeight.CopyRef();
                 void backward()
                 {
                     if (srcT.NeedGradient)
                     {
-                        Ops.RMSNormGrad(srcT.TGradient, alphaT.TGradient, betaT.TGradient, res.TGradient, resTWeight, srcTWeight, alphaT.TWeight, betaT.TWeight, eps);
+                        Ops.RMSNormGrad(srcT.TGradient, alphaT.TGradient, betaT.TGradient, res.TGradient, resTWeight, srcTWeight, alphaTWeight, betaTWeight, eps);
                     }
                     srcTWeight.Dispose();
                     resTWeight.Dispose();
+                    alphaTWeight.Dispose();
+                    betaTWeight.Dispose();
 
                     res.Dispose();
                 }
                 m_backprop.Add(backward);
-
-                alphaT.UnbindFromComputeGraph();
-                betaT.UnbindFromComputeGraph();
             }
 
             return res;
diff --git a/Seq2SeqSharp/Tools/WeightTensor.cs b/Seq2SeqSharp/Tools/WeightTensor.cs
index 0b8393da..dd87a876 100644
--- a/Seq2SeqSharp/Tools/WeightTensor.cs
+++ b/Seq2SeqSharp/Tools/WeightTensor.cs
@@ -368,19 +368,19 @@ public void PrintWeights()
          //   }
         }
 
-        public void AddSoftmaxGradient(WeightTensor src, bool inPlace = false)
+        public void AddSoftmaxGradient(Tensor srcTWeight, Tensor srcTGradient, bool inPlace = false)
         {
             if (m_TGradient == null)
             {
                 m_allocator = TensorAllocator.Allocator(DeviceId);
-                m_TGradient = new Tensor(m_allocator, src.TGradient.ElementType, Sizes);
-                Ops.SoftmaxGrad(m_TGradient, src.TGradient, src.TWeight, false);
+                m_TGradient = new Tensor(m_allocator, srcTGradient.ElementType, Sizes);
+                Ops.SoftmaxGrad(m_TGradient, srcTGradient, srcTWeight, false);
 
                 m_GradientSetName = "AddSoftmaxGradient";
             }
             else
             {
-                Ops.SoftmaxGrad(m_TGradient, src.TGradient, src.TWeight, !inPlace);
+                Ops.SoftmaxGrad(m_TGradient, srcTGradient, srcTWeight, !inPlace);
             }
         }
 
@@ -452,36 +452,36 @@ public void AddMulGradient(Tensor w, Tensor g, bool inPlace = false)
             }
         }
 
-        public void AddSigmoidGradient(WeightTensor src)
+        public void AddSigmoidGradient(Tensor srcTWeight, Tensor srcTGradient)
         {
             if (m_TGradient == null)
             {
                 m_allocator = TensorAllocator.Allocator(DeviceId);
-                m_TGradient = new Tensor(m_allocator, src.TGradient.ElementType, Sizes);
-                Ops.SigmoidD(m_TGradient, src.TWeight, src.TGradient);
+                m_TGradient = new Tensor(m_allocator, srcTGradient.ElementType, Sizes);
+                Ops.SigmoidD(m_TGradient, srcTWeight, srcTGradient);
 
                 m_GradientSetName = "AddSigmoidGradient";
             }
             else
             {
-                Ops.AddSigmoidD(m_TGradient, m_TGradient, src.TWeight, src.TGradient);
+                Ops.AddSigmoidD(m_TGradient, m_TGradient, srcTWeight, srcTGradient);
             }
         }
 
-        public void AddTanhGradient(WeightTensor src)
+        public void AddTanhGradient(Tensor srcTWeight, Tensor srcTGradient)
         {
             if (m_TGradient == null)
             {
                 m_allocator = TensorAllocator.Allocator(DeviceId);
-                m_TGradient = new Tensor(m_allocator, src.TGradient.ElementType, Sizes);
+                m_TGradient = new Tensor(m_allocator, srcTGradient.ElementType, Sizes);
 
-                Ops.TanhD(m_TGradient, src.TWeight, src.TGradient);
+                Ops.TanhD(m_TGradient, srcTWeight, srcTGradient);
 
                 m_GradientSetName = "AddTanhGradient";
             }
             else
             {
-                Ops.AddTanhD(m_TGradient, m_TGradient, src.TWeight, src.TGradient);
+                Ops.AddTanhD(m_TGradient, m_TGradient, srcTWeight, srcTGradient);
             }
         }
 
diff --git a/Tools/Seq2SeqConsole/Program.cs b/Tools/Seq2SeqConsole/Program.cs
index eb378c85..6d2e3558 100644
--- a/Tools/Seq2SeqConsole/Program.cs
+++ b/Tools/Seq2SeqConsole/Program.cs
@@ -64,7 +64,7 @@ private static void Main(string[] args)
                 {
                     // Load train corpus
                     var trainCorpus = new Seq2SeqCorpus(corpusFilePath: opts.TrainCorpusPath, srcLangName: opts.SrcLang, tgtLangName: opts.TgtLang, maxTokenSizePerBatch: opts.MaxTokenSizePerBatch,
-                        maxSrcSentLength: opts.MaxSrcSentLength, maxTgtSentLength: opts.MaxTgtSentLength, shuffleEnums: opts.ShuffleType, tooLongSequence: opts.TooLongSequence, indexedFilePath: opts.IndexedCorpusPath);
+                        maxSrcSentLength: opts.MaxSrcSentLength, maxTgtSentLength: opts.MaxTgtSentLength, shuffleEnums: opts.ShuffleType, tooLongSequence: opts.TooLongSequence, indexedFilePath: opts.IndexedCorpusPath, startBatchId: opts.StartBatchId);
 
                     // Load valid corpus
                     var validCorpusList = new List<Seq2SeqCorpus>();