From 5faf982a5b905ce4a47fcfa151de105c39e9090e Mon Sep 17 00:00:00 2001 From: Zhongkai Fu Date: Mon, 4 Nov 2024 20:17:37 -0800 Subject: [PATCH] 1. Allow to rewrite settings in command line 2. Use KV cache for test by default 3. Support password protected indexed data set in zip format --- AdvUtils/Arg/ArgParser.cs | 79 +++++++---- Seq2SeqSharp/Applications/Options.cs | 7 +- Seq2SeqSharp/Corpus/MonoCorpus.cs | 132 ++++++++++-------- Seq2SeqSharp/Corpus/ParallelCorpus.cs | 127 ++++++++++------- Seq2SeqSharp/Corpus/Seq2SeqCorpus.cs | 5 +- Seq2SeqSharp/Corpus/SeqCorpus.cs | 4 +- Seq2SeqSharp/Seq2SeqSharp.csproj | 1 + Seq2SeqSharp/Utils/Misc.cs | 2 +- Seq2SeqSharp/Utils/ZipDecompressor.cs | 115 +++++++++++++++ Tools/GPTConsole/Program.cs | 8 +- Tools/Seq2SeqConsole/Program.cs | 9 +- .../Properties/launchSettings.json | 2 +- 12 files changed, 342 insertions(+), 149 deletions(-) create mode 100644 Seq2SeqSharp/Utils/ZipDecompressor.cs diff --git a/AdvUtils/Arg/ArgParser.cs b/AdvUtils/Arg/ArgParser.cs index bbcca764..6e74dd29 100644 --- a/AdvUtils/Arg/ArgParser.cs +++ b/AdvUtils/Arg/ArgParser.cs @@ -10,24 +10,48 @@ public class ArgParser object m_o; List m_arrayArgs; - public ArgParser(string[] args, object o) - { - m_o = o; + public static void UpdateFieldValue(object obj, string fieldName, string newValue) + { + // Get the Type of the object + Type objType = obj.GetType(); + + // Get the FieldInfo for the specified field name + FieldInfo fieldInfo = objType.GetField(fieldName, BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance); + + // If the field is found, set the new value + if (fieldInfo != null) + { + fieldInfo.SetValue(obj, newValue); + } + else + { + Console.WriteLine($"Field '{fieldName}' not found in the class."); + } + } + + public ArgParser(string[] args, object o) + { + m_o = o; m_arrayArgs = new List(); - Type typeArgAttr = typeof(Arg); - Type t = o.GetType(); - foreach (FieldInfo fi in t.GetFields(BindingFlags.NonPublic | BindingFlags.Public | BindingFlags.Instance)) - { - foreach (Arg arg in fi.GetCustomAttributes(typeArgAttr, true)) - { - m_arrayArgs.Add(new ArgField(o, fi, arg)); - } - } - - try - { - for (int i = 0; i < args.Length; i++) - { + Type typeArgAttr = typeof(Arg); + Type t = o.GetType(); + foreach (FieldInfo fi in t.GetFields(BindingFlags.NonPublic | BindingFlags.Public | BindingFlags.Instance)) + { + foreach (Arg arg in fi.GetCustomAttributes(typeArgAttr, true)) + { + m_arrayArgs.Add(new ArgField(o, fi, arg)); + } + } + + RewriteSettings(args, o); + } + + public void RewriteSettings(string[] args, object o) + { + try + { + for (int i = 0; i < args.Length; i++) + { if (args[i].StartsWith("-")) { string strArgName = args[i].Substring(1); @@ -41,19 +65,22 @@ public ArgParser(string[] args, object o) intarg.Set(strArgValue); + Console.WriteLine($"Rewrite field '{strArgName}' value."); + UpdateFieldValue(o, strArgName, strArgValue); + i++; } - } + } - foreach (ArgField a in m_arrayArgs) - a.Validate(); - } - catch (Exception err) - { + foreach (ArgField a in m_arrayArgs) + a.Validate(); + } + catch (Exception err) + { Console.Error.WriteLine(err.Message); - Usage(); - } - } + Usage(); + } + } ArgField? GetArgByName(string name) { diff --git a/Seq2SeqSharp/Applications/Options.cs b/Seq2SeqSharp/Applications/Options.cs index 31ee636e..70916891 100644 --- a/Seq2SeqSharp/Applications/Options.cs +++ b/Seq2SeqSharp/Applications/Options.cs @@ -238,6 +238,9 @@ public class Options [Range(0, 9999999)] public int StartBatchId = 0; + [Arg("Zip password for dataset. The defulat value is empty", nameof(DataPassword))] + public string DataPassword = ""; + [Arg("The max degress of parallelism in task. Default is 1", nameof(TaskParallelism))] [Range(1, 999)] public int TaskParallelism = 1; @@ -307,8 +310,8 @@ public class Options [Range(-1, 9999999)] public int RandomSeed = -1; - [Arg("Use KV Cache in test mode. The default value is false", nameof(UseKVCache))] - public bool UseKVCache = false; + [Arg("Use KV Cache in test mode. The default value is true", nameof(UseKVCache))] + public bool UseKVCache = true; [Arg("Initial loss Scaling when AMP is enabled. Default is 1 which is disabled.", nameof(InitLossScaling))] [Range(1, 65000)] diff --git a/Seq2SeqSharp/Corpus/MonoCorpus.cs b/Seq2SeqSharp/Corpus/MonoCorpus.cs index 681fdc1d..7750f7e0 100644 --- a/Seq2SeqSharp/Corpus/MonoCorpus.cs +++ b/Seq2SeqSharp/Corpus/MonoCorpus.cs @@ -49,6 +49,8 @@ public class IndexData private int m_batchNumInTotal = 0; private int m_startBatchId = 0; + private string m_dataPassword = String.Empty; + public List> CountTokenFreqs() { List> td = new List>(); @@ -201,15 +203,6 @@ private IndexData[] BuildIndex() indexDatas[i].len2offsets= len2offsets; indexDatas[i].filePath = binaryDataSetFilePath; }); - - - - - - - - - Logger.WriteLine(Logger.Level.debug, $"Shuffled '{corpusSize}' sentence pairs."); if (tooLongTgtSntCnt > 0) @@ -369,74 +362,98 @@ public IEnumerator GetEnumerator() int batchIdx = 0; int currentBatchPercent = 0; + MemoryMappedFile mmf = null; + MemoryMappedViewStream mms = null; + ZipDecompressor decompressor = null; + if (m_indexedDataSetFilePath.ToLower().EndsWith(".zip")) + { + Logger.WriteLine($"The data set is a zip archive."); + decompressor = new ZipDecompressor(m_indexedDataSetFilePath, m_dataPassword); + mms = decompressor.GetMemoryMappedViewStream(); + } + else + { + mmf = MemoryMappedFile.CreateFromFile(m_indexedDataSetFilePath); + mms = mmf.CreateViewStream(); + } - using (MemoryMappedFile mmf = MemoryMappedFile.CreateFromFile(m_indexedDataSetFilePath)) - using (MemoryMappedViewStream mms = mmf.CreateViewStream()) + using (BinaryReader br = new BinaryReader(mms)) { - using (BinaryReader br = new BinaryReader(mms)) + while (true) { - while (true) + int sizeInBatch = br.ReadInt32(); + if (sizeInBatch < 0) { - int sizeInBatch = br.ReadInt32(); - if (sizeInBatch < 0) - { - break; - } - - List outputs = new List(); - string[] tgtLines = br.ReadString().Split("\n"); - batchIdx++; - - if (batchIdx < m_startBatchId) - { - continue; - } + break; + } - if (batchIdx % 10000 == 0) - { - Logger.WriteLine($"Processing batch '{batchIdx}'"); - } + List outputs = new List(); + string[] tgtLines = br.ReadString().Split("\n"); + batchIdx++; - T batch; - int currentTokenCountsInBatch = 0; - for (int i = 0; i < sizeInBatch; i++) - { - var tgtLine = tgtLines[i]; + if (batchIdx < m_startBatchId) + { + continue; + } - if (m_batchNumInTotal > 0) - { - if ((100 * batchIdx / m_batchNumInTotal) > currentBatchPercent) - { - Logger.WriteLine($"Processing batch '{batchIdx}/{m_batchNumInTotal}'."); // The '{i}th' record in this batch is: Target = '{tgtLine}'"); - currentBatchPercent++; - } - } + if (batchIdx % 10000 == 0) + { + Logger.WriteLine($"Processing batch '{batchIdx}'"); + } - IPair sntPair = new SntPair(tgtLine, tgtLine); - currentTokenCountsInBatch += sntPair.GetTgtTokenCount(); - outputs.Add(sntPair); + T batch; + int currentTokenCountsInBatch = 0; + for (int i = 0; i < sizeInBatch; i++) + { + var tgtLine = tgtLines[i]; - if (currentTokenCountsInBatch >= m_maxTokenSizePerBatch) + if (m_batchNumInTotal > 0) + { + if ((100 * batchIdx / m_batchNumInTotal) > currentBatchPercent) { - batch = new T(); - batch.CreateBatch(outputs); - yield return batch; - - outputs = new List(); - currentTokenCountsInBatch = 0; + Logger.WriteLine($"Processing batch '{batchIdx}/{m_batchNumInTotal}'."); // The '{i}th' record in this batch is: Target = '{tgtLine}'"); + currentBatchPercent++; } } - if (outputs.Count > 0) + IPair sntPair = new SntPair(tgtLine, tgtLine); + currentTokenCountsInBatch += sntPair.GetTgtTokenCount(); + outputs.Add(sntPair); + + if (currentTokenCountsInBatch >= m_maxTokenSizePerBatch) { batch = new T(); batch.CreateBatch(outputs); yield return batch; + + outputs = new List(); + currentTokenCountsInBatch = 0; } } + + if (outputs.Count > 0) + { + batch = new T(); + batch.CreateBatch(outputs); + yield return batch; + } } } + if (mms != null) + { + mms.Dispose(); + } + if (mmf != null) + { + mmf.Dispose(); + } + + if (decompressor != null) + { + decompressor.Dispose(); + } + File.Delete(m_indexedDataSetFilePath); } @@ -450,9 +467,9 @@ public MonoCorpus() } - public MonoCorpus(string corpusFilePath, string tgtLangName, int maxTokenSizePerBatch, int maxTgtSentLength = 32, PaddingEnums paddingEnums = PaddingEnums.AllowPadding, TooLongSequence tooLongSequence = TooLongSequence.Ignore, string indexedFilePath = "", int startBatchId = 0) + public MonoCorpus(string corpusFilePath, string tgtLangName, int maxTokenSizePerBatch, int maxTgtSentLength = 32, PaddingEnums paddingEnums = PaddingEnums.AllowPadding, TooLongSequence tooLongSequence = TooLongSequence.Ignore, string indexedFilePath = "", int startBatchId = 0, string dataPassword = "") { - Logger.WriteLine($"Loading mono corpus from '{corpusFilePath}' Files search pattern '*.{tgtLangName}.snt' MaxTgtSentLength = '{maxTgtSentLength}', Token Padding Type = '{paddingEnums}', TooLongSequence = '{tooLongSequence}'"); + Logger.WriteLine($"Loading mono corpus from '{corpusFilePath}' Files search pattern '*.{tgtLangName}.snt' MaxTgtSentLength = '{maxTgtSentLength}', Token Padding Type = '{paddingEnums}', TooLongSequence = '{tooLongSequence}', Encrypted data set = '{!String.IsNullOrEmpty(dataPassword)}'"); m_maxTokenSizePerBatch = maxTokenSizePerBatch; m_maxTgtTokenSize = maxTgtSentLength; m_tooLongSequence = tooLongSequence; @@ -460,6 +477,7 @@ public MonoCorpus(string corpusFilePath, string tgtLangName, int maxTokenSizePer CorpusName = corpusFilePath; m_indexedDataSetFilePath = indexedFilePath; m_startBatchId = startBatchId; + m_dataPassword = dataPassword; m_tgtFileList = new List(); string[] files = Directory.GetFiles(corpusFilePath, $"*.{tgtLangName}.snt", SearchOption.TopDirectoryOnly); diff --git a/Seq2SeqSharp/Corpus/ParallelCorpus.cs b/Seq2SeqSharp/Corpus/ParallelCorpus.cs index 554a572d..b7852213 100644 --- a/Seq2SeqSharp/Corpus/ParallelCorpus.cs +++ b/Seq2SeqSharp/Corpus/ParallelCorpus.cs @@ -52,6 +52,7 @@ public interface ICorpus : IEnumerable private string m_sortedIndexedDataSetFilePath = ""; private int m_batchNumInTotal = 0; private int m_startBatchId = 0; + private string m_dataPassword = String.Empty; public (List>, List>) CountTokenFreqs() { @@ -442,78 +443,101 @@ public IEnumerator GetEnumerator() int batchIdx = 0; int currentBatchPercent = 0; + MemoryMappedFile mmf = null; + MemoryMappedViewStream mms = null; + ZipDecompressor decompressor = null; + if (m_sortedIndexedDataSetFilePath.ToLower().EndsWith(".zip")) + { + Logger.WriteLine($"The data set is a zip archive."); + decompressor = new ZipDecompressor(m_sortedIndexedDataSetFilePath, m_dataPassword); + mms = decompressor.GetMemoryMappedViewStream(); + } + else + { + mmf = MemoryMappedFile.CreateFromFile(m_sortedIndexedDataSetFilePath); + mms = mmf.CreateViewStream(); + } - using (MemoryMappedFile mmf = MemoryMappedFile.CreateFromFile(m_sortedIndexedDataSetFilePath)) - using (MemoryMappedViewStream mms = mmf.CreateViewStream()) + using (BinaryReader br = new BinaryReader(mms)) { - using (BinaryReader br = new BinaryReader(mms)) + while (true) { - while (true) + int sizeInBatch = br.ReadInt32(); + if (sizeInBatch < 0) { - int sizeInBatch = br.ReadInt32(); - if (sizeInBatch < 0) - { - break; - } + break; + } - List outputs = new List(); + List outputs = new List(); - string[] srcLines = br.ReadString().Split("\n"); - string[] tgtLines = br.ReadString().Split("\n"); - batchIdx++; + string[] srcLines = br.ReadString().Split("\n"); + string[] tgtLines = br.ReadString().Split("\n"); + batchIdx++; - if (batchIdx < m_startBatchId) - { - continue; - } + if (batchIdx < m_startBatchId) + { + continue; + } - if (batchIdx % 10000 == 0) - { - Logger.WriteLine(Logger.Level.debug, $"Processing batch '{batchIdx}'"); - } + if (batchIdx % 10000 == 0) + { + Logger.WriteLine(Logger.Level.debug, $"Processing batch '{batchIdx}'"); + } + T batch; + int currentTokenCountsInBatch = 0; + for (int i = 0; i < sizeInBatch; i++) + { + var srcLine = srcLines[i]; + var tgtLine = tgtLines[i]; - T batch; - int currentTokenCountsInBatch = 0; - for (int i = 0; i < sizeInBatch; i++) + if (m_batchNumInTotal > 0) { - var srcLine = srcLines[i]; - var tgtLine = tgtLines[i]; - - if (m_batchNumInTotal > 0) + if ((100 * batchIdx / m_batchNumInTotal) > currentBatchPercent) { - if ((100 * batchIdx / m_batchNumInTotal) > currentBatchPercent) - { - Logger.WriteLine($"Processing batch '{batchIdx}/{m_batchNumInTotal}'."); // The '{i}th' record in this batch is: Source = '{srcLine}' Target = '{tgtLine}'"); - currentBatchPercent++; - } - } - - IPair sntPair = new SntPair(srcLine, tgtLine); - currentTokenCountsInBatch += (sntPair.GetTgtTokenCount() + sntPair.GetSrcTokenCount()); - outputs.Add(sntPair); - - if (currentTokenCountsInBatch >= m_maxTokenSizePerBatch) - { - batch = new T(); - batch.CreateBatch(outputs); - yield return batch; - - outputs = new List(); - currentTokenCountsInBatch = 0; + Logger.WriteLine($"Processing batch '{batchIdx}/{m_batchNumInTotal}'."); // The '{i}th' record in this batch is: Source = '{srcLine}' Target = '{tgtLine}'"); + currentBatchPercent++; } } - if (outputs.Count > 0) + IPair sntPair = new SntPair(srcLine, tgtLine); + currentTokenCountsInBatch += (sntPair.GetTgtTokenCount() + sntPair.GetSrcTokenCount()); + outputs.Add(sntPair); + + if (currentTokenCountsInBatch >= m_maxTokenSizePerBatch) { batch = new T(); batch.CreateBatch(outputs); yield return batch; + + outputs = new List(); + currentTokenCountsInBatch = 0; } } + + if (outputs.Count > 0) + { + batch = new T(); + batch.CreateBatch(outputs); + yield return batch; + } } } + if (mms != null) + { + mms.Dispose(); + } + if (mmf != null) + { + mmf.Dispose(); + } + + if (decompressor != null) + { + decompressor.Dispose(); + } + File.Delete(m_sortedIndexedDataSetFilePath); } @@ -527,18 +551,17 @@ public ParallelCorpus() } - public ParallelCorpus(string corpusFilePath, string srcLangName, string tgtLangName, int maxTokenSizePerBatch, int maxSrcSentLength = 32, int maxTgtSentLength = 32, PaddingEnums paddingEnums = PaddingEnums.AllowPadding, TooLongSequence tooLongSequence = TooLongSequence.Ignore, string indexedFilePath = null, int startBatchId = 0) + public ParallelCorpus(string corpusFilePath, string srcLangName, string tgtLangName, int maxTokenSizePerBatch, int maxSrcSentLength = 32, int maxTgtSentLength = 32, PaddingEnums paddingEnums = PaddingEnums.AllowPadding, TooLongSequence tooLongSequence = TooLongSequence.Ignore, string indexedFilePath = null, int startBatchId = 0, string dataPassword = "") { - Logger.WriteLine($"Loading parallel corpus from '{corpusFilePath}' for source side '{srcLangName}' and target side '{tgtLangName}' MaxSrcSentLength = '{maxSrcSentLength}', MaxTgtSentLength = '{maxTgtSentLength}', Token Paading Type = '{paddingEnums}', TooLongSequence = '{tooLongSequence}'"); + Logger.WriteLine($"Loading parallel corpus from '{corpusFilePath}' for source side '{srcLangName}' and target side '{tgtLangName}' MaxSrcSentLength = '{maxSrcSentLength}', MaxTgtSentLength = '{maxTgtSentLength}', Token Paading Type = '{paddingEnums}', TooLongSequence = '{tooLongSequence}', Encrypted data set = '{!String.IsNullOrEmpty(dataPassword)}'"); m_maxTokenSizePerBatch = maxTokenSizePerBatch; m_maxSrcTokenSize = maxSrcSentLength; m_maxTgtTokenSize = maxTgtSentLength; - m_tooLongSequence = tooLongSequence; - m_paddingEnums = paddingEnums; CorpusName = corpusFilePath; m_sortedIndexedDataSetFilePath = indexedFilePath; + m_dataPassword = dataPassword; m_srcFileList = new List(); m_tgtFileList = new List(); diff --git a/Seq2SeqSharp/Corpus/Seq2SeqCorpus.cs b/Seq2SeqSharp/Corpus/Seq2SeqCorpus.cs index 14deaff4..f46f6272 100644 --- a/Seq2SeqSharp/Corpus/Seq2SeqCorpus.cs +++ b/Seq2SeqSharp/Corpus/Seq2SeqCorpus.cs @@ -12,14 +12,13 @@ using Seq2SeqSharp.Utils; using System; - namespace Seq2SeqSharp.Corpus { public class Seq2SeqCorpus : ParallelCorpus { - public Seq2SeqCorpus(string corpusFilePath, string srcLangName, string tgtLangName, int maxTokenSizePerBatch, int maxSrcSentLength = 32, int maxTgtSentLength = 32, PaddingEnums paddingEnums = PaddingEnums.AllowPadding, TooLongSequence tooLongSequence = TooLongSequence.Ignore, string indexedFilePath = null, int startBatchId = 0) - :base (corpusFilePath, srcLangName, tgtLangName, maxTokenSizePerBatch, maxSrcSentLength, maxTgtSentLength, paddingEnums: paddingEnums, tooLongSequence: tooLongSequence, indexedFilePath: indexedFilePath, startBatchId: startBatchId) + public Seq2SeqCorpus(string corpusFilePath, string srcLangName, string tgtLangName, int maxTokenSizePerBatch, int maxSrcSentLength = 32, int maxTgtSentLength = 32, PaddingEnums paddingEnums = PaddingEnums.AllowPadding, TooLongSequence tooLongSequence = TooLongSequence.Ignore, string indexedFilePath = null, int startBatchId = 0, string dataPassword = "") + :base (corpusFilePath, srcLangName, tgtLangName, maxTokenSizePerBatch, maxSrcSentLength, maxTgtSentLength, paddingEnums: paddingEnums, tooLongSequence: tooLongSequence, indexedFilePath: indexedFilePath, startBatchId: startBatchId, dataPassword: dataPassword) { } diff --git a/Seq2SeqSharp/Corpus/SeqCorpus.cs b/Seq2SeqSharp/Corpus/SeqCorpus.cs index e1eeb0dd..23db6cbe 100644 --- a/Seq2SeqSharp/Corpus/SeqCorpus.cs +++ b/Seq2SeqSharp/Corpus/SeqCorpus.cs @@ -18,8 +18,8 @@ namespace Seq2SeqSharp.Corpus public class SeqCorpus : MonoCorpus { - public SeqCorpus(string corpusFilePath, string tgtLangName, int maxTokenSizePerBatch, int maxTgtSentLength = 32, PaddingEnums paddingEnums = PaddingEnums.AllowPadding, TooLongSequence tooLongSequence = TooLongSequence.Ignore, string indexedFilePath = "", int startBatchId = 0) - : base(corpusFilePath, tgtLangName, maxTokenSizePerBatch, maxTgtSentLength, paddingEnums: paddingEnums, tooLongSequence: tooLongSequence, indexedFilePath: indexedFilePath, startBatchId: startBatchId) + public SeqCorpus(string corpusFilePath, string tgtLangName, int maxTokenSizePerBatch, int maxTgtSentLength = 32, PaddingEnums paddingEnums = PaddingEnums.AllowPadding, TooLongSequence tooLongSequence = TooLongSequence.Ignore, string indexedFilePath = "", int startBatchId = 0, string dataPassword = "") + : base(corpusFilePath, tgtLangName, maxTokenSizePerBatch, maxTgtSentLength, paddingEnums: paddingEnums, tooLongSequence: tooLongSequence, indexedFilePath: indexedFilePath, startBatchId: startBatchId, dataPassword: dataPassword) { } diff --git a/Seq2SeqSharp/Seq2SeqSharp.csproj b/Seq2SeqSharp/Seq2SeqSharp.csproj index 08bfb043..608beca8 100644 --- a/Seq2SeqSharp/Seq2SeqSharp.csproj +++ b/Seq2SeqSharp/Seq2SeqSharp.csproj @@ -47,6 +47,7 @@ + diff --git a/Seq2SeqSharp/Utils/Misc.cs b/Seq2SeqSharp/Utils/Misc.cs index a7898043..04e85279 100644 --- a/Seq2SeqSharp/Utils/Misc.cs +++ b/Seq2SeqSharp/Utils/Misc.cs @@ -79,7 +79,7 @@ public static int GetDeviceCount(bool GPU = true) } public static class Misc - { + { public static void AppendNewBatch(List> inputBatchs, string line, int maxTokenLength) { List tokens = line.Trim().Split(' ').ToList(); diff --git a/Seq2SeqSharp/Utils/ZipDecompressor.cs b/Seq2SeqSharp/Utils/ZipDecompressor.cs new file mode 100644 index 00000000..b6bc55e9 --- /dev/null +++ b/Seq2SeqSharp/Utils/ZipDecompressor.cs @@ -0,0 +1,115 @@ +using ICSharpCode.SharpZipLib.Zip; +using System; +using System.Collections.Generic; +using System.IO.MemoryMappedFiles; +using System.IO; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace Seq2SeqSharp.Utils +{ + public class ZipDecompressor : IDisposable + { + private readonly string _tempFilePath; + private MemoryMappedFile _memoryMappedFile; + + public ZipDecompressor(string zipFilePath, string password) + { + _tempFilePath = Path.GetTempFileName(); + + // Register event handlers for process exit and unhandled exception + AppDomain.CurrentDomain.ProcessExit += OnProcessExit; + AppDomain.CurrentDomain.UnhandledException += OnUnhandledException; + + // Decompress the ZIP file to the temp file + DecompressZipToTempFile(zipFilePath, password); + + // Create a memory-mapped file from the decompressed temp file + _memoryMappedFile = MemoryMappedFile.CreateFromFile(_tempFilePath, FileMode.Open, null); + } + + public MemoryMappedViewStream GetMemoryMappedViewStream() + { + return _memoryMappedFile?.CreateViewStream(); + } + + private void DecompressZipToTempFile(string zipFilePath, string password) + { + using (FileStream fs = File.OpenRead(zipFilePath)) + using (ZipInputStream zipStream = new ZipInputStream(fs)) + { + zipStream.Password = password; + + using (FileStream tempFileStream = new FileStream(_tempFilePath, FileMode.Create, FileAccess.Write, FileShare.Delete, 4096)) + { + byte[] buffer = new byte[8192]; + ZipEntry entry; + while ((entry = zipStream.GetNextEntry()) != null) + { + int size; + while ((size = zipStream.Read(buffer, 0, buffer.Length)) > 0) + { + tempFileStream.Write(buffer, 0, size); + } + } + } + } + } + + private void OnProcessExit(object sender, EventArgs e) + { + // Cleanup when the process exits + Dispose(); + } + + private void OnUnhandledException(object sender, UnhandledExceptionEventArgs e) + { + // Cleanup when an unhandled exception occurs + Dispose(); + } + + public void Dispose() + { + // Cleanup memory-mapped file and delete the temp file + _memoryMappedFile?.Dispose(); + _memoryMappedFile = null; + + if (File.Exists(_tempFilePath)) + { + try + { + File.Delete(_tempFilePath); + Console.WriteLine("Temporary file deleted successfully."); + } + catch (Exception ex) + { + Console.WriteLine("Failed to delete temporary file: " + ex.Message); + } + } + + // Unregister event handlers to avoid memory leaks + AppDomain.CurrentDomain.ProcessExit -= OnProcessExit; + AppDomain.CurrentDomain.UnhandledException -= OnUnhandledException; + } + + //static void Main(string[] args) + //{ + // string zipFilePath = "path/to/your/password-protected.zip"; + // string password = "your_password"; + + // using (var decompressor = new ZipDecompressor(zipFilePath, password)) + // { + // using (MemoryMappedViewStream memoryMappedViewStream = decompressor.GetMemoryMappedViewStream()) + // { + // // Use the memory-mapped view stream + // Console.WriteLine("Decompression completed, stream length: " + memoryMappedViewStream.Length); + + // // Perform operations with the stream here... + // } + // } + + // // The temp file will be automatically deleted when the ZipDecompressor is disposed. + //} + } +} diff --git a/Tools/GPTConsole/Program.cs b/Tools/GPTConsole/Program.cs index 7cb74225..61b7b4e2 100644 --- a/Tools/GPTConsole/Program.cs +++ b/Tools/GPTConsole/Program.cs @@ -48,11 +48,15 @@ private static void Main(string[] args) { Console.WriteLine($"Loading config file from '{opts.ConfigFilePath}'"); opts = JsonConvert.DeserializeObject(File.ReadAllText(opts.ConfigFilePath)); + argParser.RewriteSettings(args, opts); } Logger.Initialize(opts.LogDestination, opts.LogLevel, $"{nameof(GPTConsole)}_{opts.Task}_{Utils.GetTimeStamp(DateTime.Now)}.log"); - ShowOptions(args, opts); + if ((opts.LogLevel & Logger.Level.debug) == Logger.Level.debug) + { + ShowOptions(args, opts); + } DecodingOptions decodingOptions = opts.CreateDecodingOptions(); GPT ss = null; @@ -60,7 +64,7 @@ private static void Main(string[] args) { // Load train corpus var trainCorpus = new SeqCorpus(corpusFilePath: opts.TrainCorpusPath, tgtLangName: opts.TgtLang, maxTokenSizePerBatch: opts.MaxTokenSizePerBatch, - maxTgtSentLength: opts.MaxTgtSentLength, paddingEnums: opts.PaddingType, tooLongSequence: opts.TooLongSequence, indexedFilePath: opts.IndexedCorpusPath, startBatchId: opts.StartBatchId); + maxTgtSentLength: opts.MaxTgtSentLength, paddingEnums: opts.PaddingType, tooLongSequence: opts.TooLongSequence, indexedFilePath: opts.IndexedCorpusPath, startBatchId: opts.StartBatchId, dataPassword: opts.DataPassword); // Create learning rate ILearningRate learningRate = null; diff --git a/Tools/Seq2SeqConsole/Program.cs b/Tools/Seq2SeqConsole/Program.cs index 6694a363..809e0e26 100644 --- a/Tools/Seq2SeqConsole/Program.cs +++ b/Tools/Seq2SeqConsole/Program.cs @@ -54,6 +54,7 @@ private static void Main(string[] args) try { opts = JsonConvert.DeserializeObject(File.ReadAllText(opts.ConfigFilePath)); + argParser.RewriteSettings(args, opts); } catch(Exception ex) { @@ -72,7 +73,8 @@ private static void Main(string[] args) { // Load train corpus var trainCorpus = new Seq2SeqCorpus(corpusFilePath: opts.TrainCorpusPath, srcLangName: opts.SrcLang, tgtLangName: opts.TgtLang, maxTokenSizePerBatch: opts.MaxTokenSizePerBatch, - maxSrcSentLength: opts.MaxSrcSentLength, maxTgtSentLength: opts.MaxTgtSentLength, paddingEnums: opts.PaddingType, tooLongSequence: opts.TooLongSequence, indexedFilePath: opts.IndexedCorpusPath, startBatchId: opts.StartBatchId); + maxSrcSentLength: opts.MaxSrcSentLength, maxTgtSentLength: opts.MaxTgtSentLength, paddingEnums: opts.PaddingType, tooLongSequence: opts.TooLongSequence, indexedFilePath: opts.IndexedCorpusPath, + startBatchId: opts.StartBatchId, dataPassword: opts.DataPassword); // Load valid corpus var validCorpusList = new List(); @@ -81,7 +83,8 @@ private static void Main(string[] args) string[] validCorpusPathList = opts.ValidCorpusPaths.Split(';'); foreach (var validCorpusPath in validCorpusPathList) { - validCorpusList.Add(new Seq2SeqCorpus(validCorpusPath, opts.SrcLang, opts.TgtLang, opts.ValMaxTokenSizePerBatch, opts.MaxValidSrcSentLength, opts.MaxValidTgtSentLength, paddingEnums: opts.PaddingType, tooLongSequence: opts.TooLongSequence)); + validCorpusList.Add(new Seq2SeqCorpus(validCorpusPath, opts.SrcLang, opts.TgtLang, opts.ValMaxTokenSizePerBatch, opts.MaxValidSrcSentLength, opts.MaxValidTgtSentLength, + paddingEnums: opts.PaddingType, tooLongSequence: opts.TooLongSequence, dataPassword: opts.DataPassword)); } } @@ -160,7 +163,7 @@ private static void Main(string[] args) List metrics = CreateMetrics(); // Load valid corpus - Seq2SeqCorpus validCorpus = new Seq2SeqCorpus(opts.ValidCorpusPaths, opts.SrcLang, opts.TgtLang, opts.ValMaxTokenSizePerBatch, opts.MaxValidSrcSentLength, opts.MaxValidTgtSentLength, paddingEnums: opts.PaddingType, tooLongSequence: opts.TooLongSequence); + Seq2SeqCorpus validCorpus = new Seq2SeqCorpus(opts.ValidCorpusPaths, opts.SrcLang, opts.TgtLang, opts.ValMaxTokenSizePerBatch, opts.MaxValidSrcSentLength, opts.MaxValidTgtSentLength, paddingEnums: opts.PaddingType, tooLongSequence: opts.TooLongSequence, dataPassword: opts.DataPassword); ss = new Seq2Seq(opts); ss.EvaluationWatcher += Ss_EvaluationWatcher; diff --git a/Tools/Seq2SeqConsole/Properties/launchSettings.json b/Tools/Seq2SeqConsole/Properties/launchSettings.json index a6e5e40f..058d307e 100644 --- a/Tools/Seq2SeqConsole/Properties/launchSettings.json +++ b/Tools/Seq2SeqConsole/Properties/launchSettings.json @@ -2,7 +2,7 @@ "profiles": { "Seq2SeqConsole": { "commandName": "Project", - "commandLineArgs": "-ConfigFilePath train_opts.json", + "commandLineArgs": "-ConfigFilePath train_opts.json -DataPassword abc", "workingDirectory": "C:\\Works\\Workspace\\MT\\ENU_CHS", "environmentVariables": { "MKL_ENABLE_INSTRUCTIONS": "AVX2"