Skip to content

Commit

Permalink
Merge pull request #93 from twodawg/SeqMedicalExample
Browse files Browse the repository at this point in the history
Seq medical example
  • Loading branch information
zhongkaifu authored Nov 23, 2024
2 parents 86ef3ec + 5d61c4f commit 7604e8c
Show file tree
Hide file tree
Showing 4 changed files with 177 additions and 0 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -122,3 +122,5 @@ Temporary Items
/Tools/SeqDictMatchConsole/bin/Release/net8.0
/Tests/Seq2SeqSharp.Tests/obj/Debug/net8.0
/Tools/SeqDictMatchConsole/bin/Debug/net8.0
/Tools/SeqMedical/obj
/Tools/SeqMedical/bin/Debug/net9.0
9 changes: 9 additions & 0 deletions Seq2SeqSharp.sln
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "PythonPackage", "PythonPack
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "ImgSeqConsole", "Tools\ImgSeqConsole\ImgSeqConsole.csproj", "{D5B59E92-8BFF-4B30-844B-E95E67D5A68B}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "SeqMedical", "Tools\SeqMedical\SeqMedical.csproj", "{21C7B227-5C67-4AB4-828F-A2E142415076}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
Expand Down Expand Up @@ -186,6 +188,12 @@ Global
{D5B59E92-8BFF-4B30-844B-E95E67D5A68B}.Release|Any CPU.Build.0 = Release|Any CPU
{D5B59E92-8BFF-4B30-844B-E95E67D5A68B}.ReleaseCpuOnly|Any CPU.ActiveCfg = Release|Any CPU
{D5B59E92-8BFF-4B30-844B-E95E67D5A68B}.ReleaseCpuOnly|Any CPU.Build.0 = Release|Any CPU
{21C7B227-5C67-4AB4-828F-A2E142415076}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{21C7B227-5C67-4AB4-828F-A2E142415076}.Debug|Any CPU.Build.0 = Debug|Any CPU
{21C7B227-5C67-4AB4-828F-A2E142415076}.Release|Any CPU.ActiveCfg = Release|Any CPU
{21C7B227-5C67-4AB4-828F-A2E142415076}.Release|Any CPU.Build.0 = Release|Any CPU
{21C7B227-5C67-4AB4-828F-A2E142415076}.ReleaseCpuOnly|Any CPU.ActiveCfg = Release|Any CPU
{21C7B227-5C67-4AB4-828F-A2E142415076}.ReleaseCpuOnly|Any CPU.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
Expand All @@ -204,6 +212,7 @@ Global
{DFEE8ACE-4935-40D1-8B9B-1E9F7FFC6FAE} = {C15C991E-2657-4CF3-A976-84334A25DBD2}
{4DBA9DDA-569C-4F31-9C98-84837D2F8148} = {C2DFE174-7167-41D4-A8D2-EC8DC54AA71E}
{D5B59E92-8BFF-4B30-844B-E95E67D5A68B} = {C2DFE174-7167-41D4-A8D2-EC8DC54AA71E}
{21C7B227-5C67-4AB4-828F-A2E142415076} = {C2DFE174-7167-41D4-A8D2-EC8DC54AA71E}
EndGlobalSection
GlobalSection(ExtensibilityGlobals) = postSolution
SolutionGuid = {CAE1535E-6AF4-4CD0-8E90-EBACD99D865A}
Expand Down
152 changes: 152 additions & 0 deletions Tools/SeqMedical/Program.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
using AdvUtils;
using Seq2SeqSharp;
using Seq2SeqSharp.Applications;
using Seq2SeqSharp.Corpus;
using Seq2SeqSharp.Enums;
using Seq2SeqSharp.LearningRate;
using Seq2SeqSharp.Metrics;
using Seq2SeqSharp.Optimizer;
using Seq2SeqSharp.Utils;
using System.Diagnostics;

List<string> inputSentences = ["Chest pain with numbness",
"Pressure or tightness in the chest",
"Excessive sweating",
"fever and chills",
"headache and body pain",
"cough and nausia",
"blood in urine",
"fever and chills",
"heart murmur",
"bumps on hands",
"nausia"];
// Cannot contain spaces
List<string> labels = ["heart_attack",
"heart_attack",
"heart_attack",
"flu",
"flu",
"flu",
"endocarditis",
"endocarditis",
"endocarditis",
"endocarditis",
"endocarditis"];

var training = false; // Toggle depending if a retraining is needed (did the statements change?)
do
{
string prompt = null;
if (!training)
{
Console.WriteLine("What are your symptoms?");
prompt = Console.ReadLine();
}
var rootDir = "D:\\Temp\\";
// Define model parameters
var opts = new SeqClassificationOptions
{
ModelFilePath = training ? $"{rootDir}model.bin" : $"{rootDir}model.bin.100",
SrcLang = "Src",
TgtLang = "Labels",
ProcessorType = ProcessorTypeEnums.CPU,
TrainCorpusPath = $"{rootDir}train",
ValidCorpusPaths = $"{rootDir}valid",
LogDestination = Logger.Destination.Console,
TaskParallelism = 4,

MaxEpochNum = 100,
//MaxSentLength = 1024,
//StartLearningRate = 0.01f,
EncoderLayerDepth = 6, // Increases the model size

Task = training ? ModeEnums.Train : ModeEnums.Test,
InputTestFile = $"{rootDir}input.txt",
OutputFile = $"{rootDir}output.txt",
LogLevel = training ? Logger.Level.info : Logger.Level.err,
};
Logger.Initialize(opts.LogDestination, opts.LogLevel, $"{opts.Task}_{Utils.GetTimeStamp(DateTime.Now)}.log");

DecodingOptions decodingOptions = opts.CreateDecodingOptions();
SeqClassification ss = null;

if (opts.Task == ModeEnums.Train)
{
// Save each labels with a tab and then each inputSentences to D:\Temp\train.enu.snt
Directory.CreateDirectory(opts.TrainCorpusPath);
File.WriteAllLines($"{opts.TrainCorpusPath}\\train.src.snt", inputSentences.Select((sentence, index) => $"{sentence}"));
File.WriteAllLines($"{opts.TrainCorpusPath}\\train.labels.snt", inputSentences.Select((sentence, index) => $"{labels[index]}"));

Directory.CreateDirectory(opts.ValidCorpusPaths);
File.WriteAllLines($"{opts.ValidCorpusPaths}\\valid.src.snt", inputSentences.Select((sentence, index) => $"{sentence}"));
File.WriteAllLines($"{opts.ValidCorpusPaths}\\valid.labels.snt", inputSentences.Select((sentence, index) => $"{labels[index]}"));

// Prepare data for Seq2SeqSharp
// Load train corpus
var trainCorpus = new SeqClassificationMultiTasksCorpus(corpusFilePath: opts.TrainCorpusPath,
srcLangName: opts.SrcLang, tgtLangName: opts.TgtLang, maxTokenSizePerBatch: opts.MaxTokenSizePerBatch,
maxSentLength: opts.MaxSentLength, paddingEnums: opts.PaddingType, tooLongSequence: opts.TooLongSequence);

// Valid corpus
var validCorpusList = new List<SeqClassificationMultiTasksCorpus>();
validCorpusList.Add(new SeqClassificationMultiTasksCorpus(opts.ValidCorpusPaths,
srcLangName: opts.SrcLang, tgtLangName: opts.TgtLang, opts.ValMaxTokenSizePerBatch,
opts.MaxSentLength, paddingEnums: opts.PaddingType, tooLongSequence: opts.TooLongSequence));

// Create optimizer
IOptimizer optimizer = Misc.CreateOptimizer(opts);

var (srcVocab, tgtVocab) = trainCorpus.BuildVocabs(opts.SrcVocabSize, opts.TgtVocabSize);

// Create metrics
Dictionary<int, List<IMetric>> taskId2metrics = new Dictionary<int, List<IMetric>>();
taskId2metrics.Add(0, new List<IMetric>());
taskId2metrics[0].Add(new MultiLabelsFscoreMetric("", tgtVocab.GetAllTokens(keepBuildInTokens: false)));

// Create learning rate
ILearningRate learningRate = new DecayLearningRate(opts.StartLearningRate, opts.WarmUpSteps, opts.WeightsUpdateCount,
opts.LearningRateStepDownFactor, opts.UpdateNumToStepDownLearningRate);

// Train the model
ss = new SeqClassification(opts, srcVocab, tgtVocab);

// Add event handler for monitoring
ss.StatusUpdateWatcher += Misc.Ss_StatusUpdateWatcher;
ss.EvaluationWatcher += Ss_EvaluationWatcher;

ss.Train(opts.MaxEpochNum, trainCorpus, validCorpusList.ToArray(),
learningRate, taskId2metrics: taskId2metrics, optimizer, decodingOptions);

Console.WriteLine("Training complete!");
}
else if (opts.Task == ModeEnums.Test)
{
File.WriteAllText(opts.InputTestFile, prompt);

if (File.Exists(opts.OutputFile))
{
Logger.WriteLine(Logger.Level.info, ConsoleColor.Yellow, $"Output file '{opts.OutputFile}' exist. Delete it.");
File.Delete(opts.OutputFile);
}

//Test trained model
ss = new SeqClassification(opts);
Stopwatch stopwatch = Stopwatch.StartNew();

ss.Test<SeqClassificationMultiTasksCorpusBatch>(opts.InputTestFile, opts.OutputFile, opts.BatchSize, decodingOptions, opts.SrcSentencePieceModelPath, opts.TgtSentencePieceModelPath);

stopwatch.Stop();

Logger.WriteLine($"Test mode execution time elapsed: '{stopwatch.Elapsed}'");

Console.WriteLine(File.ReadAllText(opts.OutputFile));
}
void Ss_EvaluationWatcher(object sender, EventArgs e)
{
EvaluationEventArg ep = e as EvaluationEventArg;

Logger.WriteLine(Logger.Level.info, ep.Color, ep.Message);
}
}
while (!training);

14 changes: 14 additions & 0 deletions Tools/SeqMedical/SeqMedical.csproj
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
<Project Sdk="Microsoft.NET.Sdk">

<PropertyGroup>
<OutputType>Exe</OutputType>
<TargetFramework>net8.0</TargetFramework>
<ImplicitUsings>enable</ImplicitUsings>
<Nullable>enable</Nullable>
</PropertyGroup>

<ItemGroup>
<ProjectReference Include="..\..\Seq2SeqSharp\Seq2SeqSharp.csproj" />
</ItemGroup>

</Project>

0 comments on commit 7604e8c

Please sign in to comment.