Skip to content

Commit

Permalink
Passing statistical engine service tests - first pass
Browse files Browse the repository at this point in the history
  • Loading branch information
Enkidu93 committed Jan 3, 2025
1 parent 23c3e22 commit 6da6f88
Show file tree
Hide file tree
Showing 5 changed files with 512 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ public class StatisticalPostprocessBuildJob(
ILogger<StatisticalPostprocessBuildJob> logger,
ISharedFileService sharedFileService,
IDistributedReaderWriterLockFactory lockFactory,
ISmtModelFactory smtModelFactory,
IWordAlignmentModelFactory wordAlignmentModelFactory,
IOptionsMonitor<BuildJobOptions> buildOptions,
IOptionsMonitor<WordAlignmentEngineOptions> engineOptions
)
Expand All @@ -22,7 +22,7 @@ IOptionsMonitor<WordAlignmentEngineOptions> engineOptions
buildOptions
)
{
private readonly ISmtModelFactory _smtModelFactory = smtModelFactory;
private readonly IWordAlignmentModelFactory _wordAlignmentModelFactory = wordAlignmentModelFactory;
private readonly IOptionsMonitor<WordAlignmentEngineOptions> _engineOptions = engineOptions;
private readonly IDistributedReaderWriterLockFactory _lockFactory = lockFactory;

Expand All @@ -38,7 +38,7 @@ CancellationToken cancellationToken

await using (
Stream wordAlignmentStream = await SharedFileService.OpenReadAsync(
$"builds/{buildId}/word_alignment_outputs.json",
$"builds/{buildId}/word_alignments.outputs.json",
cancellationToken
)
)
Expand Down Expand Up @@ -74,7 +74,7 @@ protected override async Task<int> SaveModelAsync(string engineId, string buildI
Stream engineStream = await SharedFileService.OpenReadAsync($"builds/{buildId}/model.tar.gz", ct)
)
{
await _smtModelFactory.UpdateEngineFromAsync(
await _wordAlignmentModelFactory.UpdateEngineFromAsync(
Path.Combine(_engineOptions.CurrentValue.EnginesDir, engineId),
engineStream,
ct
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ public void Remove(string engineId)

public async Task CommitAsync(
IDistributedReaderWriterLockFactory lockFactory,
IRepository<TranslationEngine> engines,
IRepository<WordAlignmentEngine> engines,
TimeSpan inactiveTimeout,
CancellationToken cancellationToken = default
)
Expand All @@ -40,10 +40,10 @@ public async Task CommitAsync(
await @lock.WriterLockAsync(
async ct =>
{
TranslationEngine? engine = await engines.GetAsync(state.EngineId, ct);
if (engine is not null && !(engine.CollectTrainSegmentPairs ?? false))
WordAlignmentEngine? engine = await engines.GetAsync(state.EngineId, ct);
if (engine is not null)
// there is no way to cancel this call
state.Commit(engine.BuildRevision, inactiveTimeout);
state.Commit(engine!.BuildRevision, inactiveTimeout);
},
_options.CurrentValue.EngineCommitTimeout,
cancellationToken: cancellationToken
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ CancellationToken cancellationToken
new(await SharedFileService.OpenWriteAsync($"builds/{buildId}/train.trg.txt", cancellationToken));

await using Stream inferenceStream = await SharedFileService.OpenWriteAsync(
$"builds/{buildId}/word_alignment_inputs.json",
$"builds/{buildId}/word_alignments.inputs.json",
cancellationToken
);
await using Utf8JsonWriter inferenceWriter = new(inferenceStream, InferenceWriterOptions);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ public async Task CancelBuildAsync_Building(BuildJobRunnerType trainJobRunnerTyp
await env.WaitForTrainingToStartAsync();
TranslationEngine engine = env.Engines.Get(EngineId1);
Assert.That(engine.CurrentBuild, Is.Not.Null);
Assert.That(engine.CurrentBuild.JobState, Is.EqualTo(BuildJobState.Active));
Assert.That(engine.CurrentBuild!.JobState, Is.EqualTo(BuildJobState.Active));
await env.Service.CancelBuildAsync(EngineId1);
await env.WaitForBuildToFinishAsync();
_ = env.SmtBatchTrainer.DidNotReceive().SaveAsync();
Expand All @@ -122,12 +122,12 @@ public async Task StartBuildAsync_RestartUnfinishedBuild()
await env.WaitForTrainingToStartAsync();
TranslationEngine engine = env.Engines.Get(EngineId1);
Assert.That(engine.CurrentBuild, Is.Not.Null);
Assert.That(engine.CurrentBuild.JobState, Is.EqualTo(BuildJobState.Active));
Assert.That(engine.CurrentBuild!.JobState, Is.EqualTo(BuildJobState.Active));
env.StopServer();
await env.WaitForBuildToRestartAsync();
engine = env.Engines.Get(EngineId1);
Assert.That(engine.CurrentBuild, Is.Not.Null);
Assert.That(engine.CurrentBuild.JobState, Is.EqualTo(BuildJobState.Pending));
Assert.That(engine.CurrentBuild!.JobState, Is.EqualTo(BuildJobState.Pending));
_ = env.PlatformService.Received().BuildRestartingAsync(BuildId1);
env.SmtBatchTrainer.ClearSubstitute(ClearOptions.CallActions);
env.StartServer();
Expand All @@ -147,7 +147,7 @@ public async Task DeleteAsync_WhileBuilding(BuildJobRunnerType trainJobRunnerTyp
await env.WaitForTrainingToStartAsync();
TranslationEngine engine = env.Engines.Get(EngineId1);
Assert.That(engine.CurrentBuild, Is.Not.Null);
Assert.That(engine.CurrentBuild.JobState, Is.EqualTo(BuildJobState.Active));
Assert.That(engine.CurrentBuild!.JobState, Is.EqualTo(BuildJobState.Active));
await env.Service.DeleteAsync(EngineId1);
await env.WaitForBuildToFinishAsync();
await env.WaitForAllHangfireJobsToFinishAsync();
Expand All @@ -167,7 +167,7 @@ public async Task TrainSegmentPairAsync(BuildJobRunnerType trainJobRunnerType)
await env.WaitForBuildToStartAsync();
TranslationEngine engine = env.Engines.Get(EngineId1);
Assert.That(engine.CurrentBuild, Is.Not.Null);
Assert.That(engine.CurrentBuild.JobState, Is.EqualTo(BuildJobState.Active));
Assert.That(engine.CurrentBuild!.JobState, Is.EqualTo(BuildJobState.Active));
await env.Service.TrainSegmentPairAsync(EngineId1, "esto es una prueba.", "this is a test.", true);
env.StopTraining();
await env.WaitForBuildToFinishAsync();
Expand Down
Loading

0 comments on commit 6da6f88

Please sign in to comment.