Skip to content

Commit

Permalink
Add hangfire implementation and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Enkidu93 committed Jan 6, 2025
1 parent 6da6f88 commit bbc3248
Show file tree
Hide file tree
Showing 3 changed files with 181 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ await _client.UpdateBuildStatusAsync(

public async Task InsertInferencesAsync(
string engineId,
Stream pretranslationsStream,
Stream inferenceStream,
CancellationToken cancellationToken = default
)
{
Expand All @@ -117,7 +117,7 @@ await _outboxService.EnqueueMessageAsync(
ServalTranslationPlatformOutboxConstants.InsertInferences,
engineId,
engineId,
pretranslationsStream,
inferenceStream,
cancellationToken: cancellationToken
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ public class StatisticalTrainBuildJob(
IRepository<WordAlignmentEngine> engines,
IDataAccessContext dataAccessContext,
IBuildJobService<WordAlignmentEngine> buildJobService,
ILogger<StatisticalTrainBuildJob> logger
ILogger<StatisticalTrainBuildJob> logger,
ISharedFileService sharedFileService,
IWordAlignmentModelFactory wordAlignmentModelFactory
)
: HangfireBuildJob<WordAlignmentEngine>(
platformServices.First(ps => ps.EngineGroup == EngineGroup.WordAlignment),
Expand All @@ -15,14 +17,179 @@ ILogger<StatisticalTrainBuildJob> logger
logger
)
{
protected override Task DoWorkAsync(
private static readonly JsonWriterOptions WordAlignmentWriterOptions = new() { Indented = true };
private static readonly JsonSerializerOptions JsonSerializerOptions =
new() { PropertyNamingPolicy = JsonNamingPolicy.CamelCase };
private const int BatchSize = 128;

private readonly ISharedFileService _sharedFileService = sharedFileService;
private readonly IWordAlignmentModelFactory _wordAlignmentFactory = wordAlignmentModelFactory;

protected override async Task DoWorkAsync(
string engineId,
string buildId,
object? data,
string? buildOptions,
CancellationToken cancellationToken
)
{
throw new NotImplementedException();
using TempDirectory tempDir = new(buildId);
string corpusDir = Path.Combine(tempDir.Path, "corpus");
await DownloadDataAsync(buildId, corpusDir, cancellationToken);

// assemble corpus
ITextCorpus sourceCorpus = new TextFileTextCorpus(Path.Combine(corpusDir, "train.src.txt"));
ITextCorpus targetCorpus = new TextFileTextCorpus(Path.Combine(corpusDir, "train.trg.txt"));
IParallelTextCorpus parallelCorpus = sourceCorpus.AlignRows(targetCorpus);

// train word alignment model
string engineDir = Path.Combine(tempDir.Path, "engine");
int trainCount = await TrainAsync(buildId, engineDir, parallelCorpus, cancellationToken);

cancellationToken.ThrowIfCancellationRequested();

await GenerateWordAlignmentsAsync(buildId, engineDir, cancellationToken);

bool canceling = !await BuildJobService.StartBuildJobAsync(
BuildJobRunnerType.Hangfire,
EngineType.Statistical,
engineId,
buildId,
BuildStage.Postprocess,
buildOptions: buildOptions,
data: (trainCount, 0.0),
cancellationToken: cancellationToken
);
if (canceling)
throw new OperationCanceledException();
}

protected override async Task CleanupAsync(
string engineId,
string buildId,
object? data,
JobCompletionStatus completionStatus
)
{
if (completionStatus is JobCompletionStatus.Canceled)
{
try
{
await _sharedFileService.DeleteAsync($"builds/{buildId}/");
}
catch (Exception e)
{
Logger.LogWarning(e, "Unable to to delete job data for build {BuildId}.", buildId);
}
}
}

private async Task DownloadDataAsync(string buildId, string corpusDir, CancellationToken cancellationToken)
{
Directory.CreateDirectory(corpusDir);
await using Stream srcText = await _sharedFileService.OpenReadAsync(
$"builds/{buildId}/train.src.txt",
cancellationToken
);
await using FileStream srcFileStream = File.Create(Path.Combine(corpusDir, "train.src.txt"));
await srcText.CopyToAsync(srcFileStream, cancellationToken);

await using Stream tgtText = await _sharedFileService.OpenReadAsync(
$"builds/{buildId}/train.trg.txt",
cancellationToken
);
await using FileStream tgtFileStream = File.Create(Path.Combine(corpusDir, "train.trg.txt"));
await tgtText.CopyToAsync(tgtFileStream, cancellationToken);
}

private async Task<int> TrainAsync(
string buildId,
string engineDir,
IParallelTextCorpus parallelCorpus,
CancellationToken cancellationToken
)
{
_wordAlignmentFactory.InitNew(engineDir);
LatinWordTokenizer tokenizer = new();
using ITrainer wordAlignmentTrainer = _wordAlignmentFactory.CreateTrainer(engineDir, tokenizer, parallelCorpus);
cancellationToken.ThrowIfCancellationRequested();

var progress = new BuildProgress(PlatformService, buildId);
await wordAlignmentTrainer.TrainAsync(progress, cancellationToken);

int trainCorpusSize = wordAlignmentTrainer.Stats.TrainCorpusSize;

cancellationToken.ThrowIfCancellationRequested();

await wordAlignmentTrainer.SaveAsync(cancellationToken);

await using Stream engineStream = await _sharedFileService.OpenWriteAsync(
$"builds/{buildId}/model.tar.gz",
cancellationToken
);
await _wordAlignmentFactory.SaveEngineToAsync(engineDir, engineStream, cancellationToken);
return trainCorpusSize;
}

private async Task GenerateWordAlignmentsAsync(
string buildId,
string engineDir,
CancellationToken cancellationToken
)
{
await using Stream sourceStream = await _sharedFileService.OpenReadAsync(
$"builds/{buildId}/word_alignments.inputs.json",
cancellationToken
);

IAsyncEnumerable<Models.WordAlignment> wordAlignments = JsonSerializer
.DeserializeAsyncEnumerable<Models.WordAlignment>(sourceStream, JsonSerializerOptions, cancellationToken)
.OfType<Models.WordAlignment>();

await using Stream targetStream = await _sharedFileService.OpenWriteAsync(
$"builds/{buildId}/word_alignments.outputs.json",
cancellationToken
);
await using Utf8JsonWriter targetWriter = new(targetStream, WordAlignmentWriterOptions);

LatinWordTokenizer tokenizer = new();
LatinWordDetokenizer detokenizer = new();
using IWordAlignmentModel wordAlignmentModel = _wordAlignmentFactory.Create(engineDir);
await foreach (IReadOnlyList<Models.WordAlignment> batch in BatchAsync(wordAlignments))
{
(IReadOnlyList<string> Source, IReadOnlyList<string> Target)[] segments = batch
.Select(p => (p.SourceTokens, p.TargetTokens))
.ToArray();
IReadOnlyList<WordAlignmentMatrix> results = wordAlignmentModel.AlignBatch(segments);
foreach ((Models.WordAlignment wordAlignment, WordAlignmentMatrix result) in batch.Zip(results))
{
JsonSerializer.Serialize(
targetWriter,
wordAlignment with
{
Alignment = result.ToAlignedWordPairs().ToList()
},
JsonSerializerOptions
);
}
}
}

public static async IAsyncEnumerable<IReadOnlyList<Models.WordAlignment>> BatchAsync(
IAsyncEnumerable<Models.WordAlignment> wordAlignments
)
{
List<Models.WordAlignment> batch = [];
await foreach (Models.WordAlignment item in wordAlignments)
{
batch.Add(item);
if (batch.Count == BatchSize)
{
yield return batch;
batch = [];
}
}
if (batch.Count > 0)
yield return batch;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ public async Task CreateAsync()
env.WordAlignmentModelFactory.Received().InitNew(engineDir);
}

// [TestCase(BuildJobRunnerType.Hangfire)] //TODO Implement hangfire?
[TestCase(BuildJobRunnerType.Hangfire)]
[TestCase(BuildJobRunnerType.ClearML)]
public async Task StartBuildAsync(BuildJobRunnerType trainJobRunnerType)
{
Expand Down Expand Up @@ -78,7 +78,7 @@ await env.Service.StartBuildAsync(
env.WordAlignmentModel.Received().Dispose();
}

// [TestCase(BuildJobRunnerType.Hangfire)] //TODO implement Hangfire?
[TestCase(BuildJobRunnerType.Hangfire)]
[TestCase(BuildJobRunnerType.ClearML)]
public async Task CancelBuildAsync_Building(BuildJobRunnerType trainJobRunnerType)
{
Expand All @@ -97,15 +97,15 @@ public async Task CancelBuildAsync_Building(BuildJobRunnerType trainJobRunnerTyp
Assert.That(engine.CurrentBuild, Is.Null);
}

// [TestCase(BuildJobRunnerType.Hangfire)] //TODO implement Hangfire?
[TestCase(BuildJobRunnerType.Hangfire)]
[TestCase(BuildJobRunnerType.ClearML)]
public void CancelBuildAsync_NotBuilding(BuildJobRunnerType trainJobRunnerType)
{
using var env = new TestEnvironment(trainJobRunnerType);
Assert.ThrowsAsync<InvalidOperationException>(() => env.Service.CancelBuildAsync(EngineId1));
}

// [TestCase(BuildJobRunnerType.Hangfire)] //TODO implement Hangfire?
[TestCase(BuildJobRunnerType.Hangfire)]
[TestCase(BuildJobRunnerType.ClearML)]
public async Task DeleteAsync_WhileBuilding(BuildJobRunnerType trainJobRunnerType)
{
Expand Down Expand Up @@ -170,6 +170,7 @@ public TestEnvironment(BuildJobRunnerType trainJobRunnerType = BuildJobRunnerTyp
PlatformService.EngineGroup.Returns(EngineGroup.WordAlignment);
WordAlignmentModel = Substitute.For<IWordAlignmentModel>();
WordAlignmentBatchTrainer = Substitute.For<ITrainer>();
WordAlignmentBatchTrainer.Stats.Returns(new TrainStats { TrainCorpusSize = 0 });
WordAlignmentModelFactory = CreateWordAlignmentModelFactory();
_lockFactory = new DistributedReaderWriterLockFactory(
new OptionsWrapper<ServiceOptions>(new ServiceOptions { ServiceId = "host" }),
Expand All @@ -188,7 +189,7 @@ public TestEnvironment(BuildJobRunnerType trainJobRunnerType = BuildJobRunnerTyp
[
new ClearMLBuildQueue()
{
EngineType = EngineType.Statistical.ToString().ToString(),
EngineType = EngineType.Statistical.ToString(),
ModelType = "thot",
DockerImage = "default",
Queue = "default"
Expand Down Expand Up @@ -488,7 +489,9 @@ public override object ActivateJob(Type jobType)
_env.Engines,
new MemoryDataAccessContext(),
_env.BuildJobService,
Substitute.For<ILogger<StatisticalTrainBuildJob>>()
Substitute.For<ILogger<StatisticalTrainBuildJob>>(),
_env.SharedFileService,
_env.WordAlignmentModelFactory
);
}
return base.ActivateJob(jobType);
Expand Down

0 comments on commit bbc3248

Please sign in to comment.