Skip to content

Commit

Permalink
1. Allow to rewrite settings in command line
Browse files Browse the repository at this point in the history
2. Use KV cache for test by default
3. Support password protected indexed data set in zip format
  • Loading branch information
zhongkaifu committed Nov 5, 2024
1 parent 113dd85 commit 5faf982
Show file tree
Hide file tree
Showing 12 changed files with 342 additions and 149 deletions.
79 changes: 53 additions & 26 deletions AdvUtils/Arg/ArgParser.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,48 @@ public class ArgParser
object m_o;
List<ArgField> m_arrayArgs;

public ArgParser(string[] args, object o)
{
m_o = o;
public static void UpdateFieldValue(object obj, string fieldName, string newValue)
{
// Get the Type of the object
Type objType = obj.GetType();

// Get the FieldInfo for the specified field name
FieldInfo fieldInfo = objType.GetField(fieldName, BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance);

Check warning on line 19 in AdvUtils/Arg/ArgParser.cs

View workflow job for this annotation

GitHub Actions / build

Converting null literal or possible null value to non-nullable type.

// If the field is found, set the new value
if (fieldInfo != null)
{
fieldInfo.SetValue(obj, newValue);
}
else
{
Console.WriteLine($"Field '{fieldName}' not found in the class.");
}
}

public ArgParser(string[] args, object o)
{
m_o = o;
m_arrayArgs = new List<ArgField>();
Type typeArgAttr = typeof(Arg);
Type t = o.GetType();
foreach (FieldInfo fi in t.GetFields(BindingFlags.NonPublic | BindingFlags.Public | BindingFlags.Instance))
{
foreach (Arg arg in fi.GetCustomAttributes(typeArgAttr, true))
{
m_arrayArgs.Add(new ArgField(o, fi, arg));
}
}

try
{
for (int i = 0; i < args.Length; i++)
{
Type typeArgAttr = typeof(Arg);
Type t = o.GetType();
foreach (FieldInfo fi in t.GetFields(BindingFlags.NonPublic | BindingFlags.Public | BindingFlags.Instance))
{
foreach (Arg arg in fi.GetCustomAttributes(typeArgAttr, true))
{
m_arrayArgs.Add(new ArgField(o, fi, arg));
}
}

RewriteSettings(args, o);
}

public void RewriteSettings(string[] args, object o)
{
try
{
for (int i = 0; i < args.Length; i++)
{
if (args[i].StartsWith("-"))
{
string strArgName = args[i].Substring(1);
Expand All @@ -41,19 +65,22 @@ public ArgParser(string[] args, object o)

intarg.Set(strArgValue);

Console.WriteLine($"Rewrite field '{strArgName}' value.");
UpdateFieldValue(o, strArgName, strArgValue);

i++;
}
}
}

foreach (ArgField a in m_arrayArgs)
a.Validate();
}
catch (Exception err)
{
foreach (ArgField a in m_arrayArgs)
a.Validate();
}
catch (Exception err)
{
Console.Error.WriteLine(err.Message);
Usage();
}
}
Usage();
}
}

ArgField? GetArgByName(string name)
{
Expand Down
7 changes: 5 additions & 2 deletions Seq2SeqSharp/Applications/Options.cs
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,9 @@ public class Options
[Range(0, 9999999)]
public int StartBatchId = 0;

[Arg("Zip password for dataset. The defulat value is empty", nameof(DataPassword))]
public string DataPassword = "";

[Arg("The max degress of parallelism in task. Default is 1", nameof(TaskParallelism))]
[Range(1, 999)]
public int TaskParallelism = 1;
Expand Down Expand Up @@ -307,8 +310,8 @@ public class Options
[Range(-1, 9999999)]
public int RandomSeed = -1;

[Arg("Use KV Cache in test mode. The default value is false", nameof(UseKVCache))]
public bool UseKVCache = false;
[Arg("Use KV Cache in test mode. The default value is true", nameof(UseKVCache))]
public bool UseKVCache = true;

[Arg("Initial loss Scaling when AMP is enabled. Default is 1 which is disabled.", nameof(InitLossScaling))]
[Range(1, 65000)]
Expand Down
132 changes: 75 additions & 57 deletions Seq2SeqSharp/Corpus/MonoCorpus.cs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ public class IndexData
private int m_batchNumInTotal = 0;
private int m_startBatchId = 0;

private string m_dataPassword = String.Empty;

public List<Dictionary<string, long>> CountTokenFreqs()
{
List<Dictionary<string, long>> td = new List<Dictionary<string, long>>();
Expand Down Expand Up @@ -201,15 +203,6 @@ private IndexData[] BuildIndex()
indexDatas[i].len2offsets= len2offsets;
indexDatas[i].filePath = binaryDataSetFilePath;
});









Logger.WriteLine(Logger.Level.debug, $"Shuffled '{corpusSize}' sentence pairs.");

if (tooLongTgtSntCnt > 0)
Expand Down Expand Up @@ -369,74 +362,98 @@ public IEnumerator<T> GetEnumerator()

int batchIdx = 0;
int currentBatchPercent = 0;
MemoryMappedFile mmf = null;
MemoryMappedViewStream mms = null;
ZipDecompressor decompressor = null;
if (m_indexedDataSetFilePath.ToLower().EndsWith(".zip"))
{
Logger.WriteLine($"The data set is a zip archive.");
decompressor = new ZipDecompressor(m_indexedDataSetFilePath, m_dataPassword);
mms = decompressor.GetMemoryMappedViewStream();
}
else
{
mmf = MemoryMappedFile.CreateFromFile(m_indexedDataSetFilePath);
mms = mmf.CreateViewStream();
}

using (MemoryMappedFile mmf = MemoryMappedFile.CreateFromFile(m_indexedDataSetFilePath))
using (MemoryMappedViewStream mms = mmf.CreateViewStream())
using (BinaryReader br = new BinaryReader(mms))
{
using (BinaryReader br = new BinaryReader(mms))
while (true)
{
while (true)
int sizeInBatch = br.ReadInt32();
if (sizeInBatch < 0)
{
int sizeInBatch = br.ReadInt32();
if (sizeInBatch < 0)
{
break;
}

List<IPair> outputs = new List<IPair>();
string[] tgtLines = br.ReadString().Split("\n");
batchIdx++;

if (batchIdx < m_startBatchId)
{
continue;
}
break;
}

if (batchIdx % 10000 == 0)
{
Logger.WriteLine($"Processing batch '{batchIdx}'");
}
List<IPair> outputs = new List<IPair>();
string[] tgtLines = br.ReadString().Split("\n");
batchIdx++;

T batch;
int currentTokenCountsInBatch = 0;
for (int i = 0; i < sizeInBatch; i++)
{
var tgtLine = tgtLines[i];
if (batchIdx < m_startBatchId)
{
continue;
}

if (m_batchNumInTotal > 0)
{
if ((100 * batchIdx / m_batchNumInTotal) > currentBatchPercent)
{
Logger.WriteLine($"Processing batch '{batchIdx}/{m_batchNumInTotal}'."); // The '{i}th' record in this batch is: Target = '{tgtLine}'");
currentBatchPercent++;
}
}
if (batchIdx % 10000 == 0)
{
Logger.WriteLine($"Processing batch '{batchIdx}'");
}

IPair sntPair = new SntPair(tgtLine, tgtLine);
currentTokenCountsInBatch += sntPair.GetTgtTokenCount();
outputs.Add(sntPair);
T batch;
int currentTokenCountsInBatch = 0;
for (int i = 0; i < sizeInBatch; i++)
{
var tgtLine = tgtLines[i];

if (currentTokenCountsInBatch >= m_maxTokenSizePerBatch)
if (m_batchNumInTotal > 0)
{
if ((100 * batchIdx / m_batchNumInTotal) > currentBatchPercent)
{
batch = new T();
batch.CreateBatch(outputs);
yield return batch;

outputs = new List<IPair>();
currentTokenCountsInBatch = 0;
Logger.WriteLine($"Processing batch '{batchIdx}/{m_batchNumInTotal}'."); // The '{i}th' record in this batch is: Target = '{tgtLine}'");
currentBatchPercent++;
}
}

if (outputs.Count > 0)
IPair sntPair = new SntPair(tgtLine, tgtLine);
currentTokenCountsInBatch += sntPair.GetTgtTokenCount();
outputs.Add(sntPair);

if (currentTokenCountsInBatch >= m_maxTokenSizePerBatch)
{
batch = new T();
batch.CreateBatch(outputs);
yield return batch;

outputs = new List<IPair>();
currentTokenCountsInBatch = 0;
}
}

if (outputs.Count > 0)
{
batch = new T();
batch.CreateBatch(outputs);
yield return batch;
}
}
}

if (mms != null)
{
mms.Dispose();
}
if (mmf != null)
{
mmf.Dispose();
}

if (decompressor != null)
{
decompressor.Dispose();
}

File.Delete(m_indexedDataSetFilePath);
}

Expand All @@ -450,16 +467,17 @@ public MonoCorpus()

}

public MonoCorpus(string corpusFilePath, string tgtLangName, int maxTokenSizePerBatch, int maxTgtSentLength = 32, PaddingEnums paddingEnums = PaddingEnums.AllowPadding, TooLongSequence tooLongSequence = TooLongSequence.Ignore, string indexedFilePath = "", int startBatchId = 0)
public MonoCorpus(string corpusFilePath, string tgtLangName, int maxTokenSizePerBatch, int maxTgtSentLength = 32, PaddingEnums paddingEnums = PaddingEnums.AllowPadding, TooLongSequence tooLongSequence = TooLongSequence.Ignore, string indexedFilePath = "", int startBatchId = 0, string dataPassword = "")
{
Logger.WriteLine($"Loading mono corpus from '{corpusFilePath}' Files search pattern '*.{tgtLangName}.snt' MaxTgtSentLength = '{maxTgtSentLength}', Token Padding Type = '{paddingEnums}', TooLongSequence = '{tooLongSequence}'");
Logger.WriteLine($"Loading mono corpus from '{corpusFilePath}' Files search pattern '*.{tgtLangName}.snt' MaxTgtSentLength = '{maxTgtSentLength}', Token Padding Type = '{paddingEnums}', TooLongSequence = '{tooLongSequence}', Encrypted data set = '{!String.IsNullOrEmpty(dataPassword)}'");
m_maxTokenSizePerBatch = maxTokenSizePerBatch;
m_maxTgtTokenSize = maxTgtSentLength;
m_tooLongSequence = tooLongSequence;
m_paddingEnums = paddingEnums;
CorpusName = corpusFilePath;
m_indexedDataSetFilePath = indexedFilePath;
m_startBatchId = startBatchId;
m_dataPassword = dataPassword;

m_tgtFileList = new List<string>();
string[] files = Directory.GetFiles(corpusFilePath, $"*.{tgtLangName}.snt", SearchOption.TopDirectoryOnly);
Expand Down
Loading

0 comments on commit 5faf982

Please sign in to comment.