Skip to content

Commit

Permalink
Add StartBatchId option for Seq2SeqConsole
Browse files Browse the repository at this point in the history
Optimize weights live cycle
  • Loading branch information
zhongkaifu committed Nov 16, 2023
1 parent da969bb commit e1c034d
Show file tree
Hide file tree
Showing 6 changed files with 143 additions and 72 deletions.
16 changes: 14 additions & 2 deletions Seq2SeqSharp/Corpus/ParallelCorpus.cs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ public interface ICorpus<out T> : IEnumerable<T>

private string m_sortedIndexedDataSetFilePath = "";
private int m_batchNumInTotal = 0;
private int m_startBatchId = 0;

public (List<Dictionary<string, long>>, List<Dictionary<string, long>>) CountTokenFreqs()
{
Expand Down Expand Up @@ -444,6 +445,17 @@ public IEnumerator<T> GetEnumerator()
string[] tgtLines = br.ReadString().Split("\n");
batchIdx++;

if (batchIdx < m_startBatchId)
{
continue;
}

if (batchIdx % 10000 == 0)
{
Logger.WriteLine($"Processing batch '{batchIdx}'");
}


T batch;
int currentTokenCountsInBatch = 0;
for (int i = 0; i < sizeInBatch; i++)
Expand Down Expand Up @@ -498,7 +510,7 @@ public ParallelCorpus()

}

public ParallelCorpus(string corpusFilePath, string srcLangName, string tgtLangName, int maxTokenSizePerBatch, int maxSrcSentLength = 32, int maxTgtSentLength = 32, ShuffleEnums shuffleEnums = ShuffleEnums.Random, TooLongSequence tooLongSequence = TooLongSequence.Ignore, string indexedFilePath = null)
public ParallelCorpus(string corpusFilePath, string srcLangName, string tgtLangName, int maxTokenSizePerBatch, int maxSrcSentLength = 32, int maxTgtSentLength = 32, ShuffleEnums shuffleEnums = ShuffleEnums.Random, TooLongSequence tooLongSequence = TooLongSequence.Ignore, string indexedFilePath = null, int startBatchId = 0)
{
Logger.WriteLine($"Loading parallel corpus from '{corpusFilePath}' for source side '{srcLangName}' and target side '{tgtLangName}' MaxSrcSentLength = '{maxSrcSentLength}', MaxTgtSentLength = '{maxTgtSentLength}', aggregateSrcLengthForShuffle = '{shuffleEnums}', TooLongSequence = '{tooLongSequence}'");
m_maxTokenSizePerBatch = maxTokenSizePerBatch;
Expand Down Expand Up @@ -546,7 +558,7 @@ public ParallelCorpus(string corpusFilePath, string srcLangName, string tgtLangN
m_srcFileList.Add(pair.Value);
m_tgtFileList.Add(tgtKey2FileName[pair.Key]);
}

m_startBatchId = startBatchId;
}
}
}
4 changes: 2 additions & 2 deletions Seq2SeqSharp/Corpus/Seq2SeqCorpus.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ namespace Seq2SeqSharp.Corpus
public class Seq2SeqCorpus : ParallelCorpus<Seq2SeqCorpusBatch>
{

public Seq2SeqCorpus(string corpusFilePath, string srcLangName, string tgtLangName, int maxTokenSizePerBatch, int maxSrcSentLength = 32, int maxTgtSentLength = 32, ShuffleEnums shuffleEnums = ShuffleEnums.Random, TooLongSequence tooLongSequence = TooLongSequence.Ignore, string indexedFilePath = null)
:base (corpusFilePath, srcLangName, tgtLangName, maxTokenSizePerBatch, maxSrcSentLength, maxTgtSentLength, shuffleEnums: shuffleEnums, tooLongSequence: tooLongSequence, indexedFilePath: indexedFilePath)
public Seq2SeqCorpus(string corpusFilePath, string srcLangName, string tgtLangName, int maxTokenSizePerBatch, int maxSrcSentLength = 32, int maxTgtSentLength = 32, ShuffleEnums shuffleEnums = ShuffleEnums.Random, TooLongSequence tooLongSequence = TooLongSequence.Ignore, string indexedFilePath = null, int startBatchId = 0)
:base (corpusFilePath, srcLangName, tgtLangName, maxTokenSizePerBatch, maxSrcSentLength, maxTgtSentLength, shuffleEnums: shuffleEnums, tooLongSequence: tooLongSequence, indexedFilePath: indexedFilePath, startBatchId: startBatchId)
{

}
Expand Down
3 changes: 2 additions & 1 deletion Seq2SeqSharp/Tools/BaseSeq2SeqFramework.cs
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,8 @@ public void Train(int maxTrainingEpoch, ICorpus<IPairBatch> trainCorpus, ICorpus
TrainOneEpoch(i, trainCorpus, validCorpusList, learningRate, optimizer, taskId2metrics, decodingOptions, RunForwardOnSingleDevice);

// send progress reporting in the form of a percentage value (0-100%)
Logger.WriteLine(Logger.Level.info, "", (int)(100 * (i + 1) / maxTrainingEpoch));
var finishedEpochPercent = (int)(100 * (i + 1) / maxTrainingEpoch);
Logger.WriteLine(Logger.Level.info, $"Finished Epoch Percent: {finishedEpochPercent}%", finishedEpochPercent);
}

SaveModel(createBackupPrevious: false, suffix: $".{m_weightsUpdateCount}");
Expand Down
Loading

0 comments on commit e1c034d

Please sign in to comment.