Skip to content

Commit

Permalink
Updated training results
Browse files Browse the repository at this point in the history
  • Loading branch information
liebharc committed Sep 10, 2024
1 parent ba12ebe commit 91ff1e0
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 3 deletions.
19 changes: 18 additions & 1 deletion Training.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,24 @@ an improvement or fix was done to the conversion.

The training process itself takes depending on the hardware you use a 2-3 days.

## Run 101

Commit: ba12ebef4606948816a06f4a011248d07a6f06da
Date: 10 Sep 2024
SER: 9%
Validation result: 6.4

Training runs now also pick the last iteration and not the one with the lowest validation loss.

## Run 100

Commit: e317d1ba4452798036d2b24a20f37061b8441bae
Date: 10 Sep 2024
SER: 14%
Validation result: 7.6

Increased model depth from 4 to 8.

## Run 70

Commit: 11c1eeaf5760d617f09678e276866d31253a5ace
Expand Down Expand Up @@ -67,7 +85,6 @@ SER: 53%

First training run within the `homr` repo.


## Previous runs

These runs where performed inside the [Polyphonic-TrOMR](https://raw.githubusercontent.com/liebharc/Polyphonic-TrOMR/master/Training.md) repo.
Expand Down
12 changes: 10 additions & 2 deletions training/transformer/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import shutil
import sys

import safetensors
import torch
import torch._dynamo
from transformers import Trainer, TrainingArguments # type: ignore
Expand Down Expand Up @@ -150,8 +151,15 @@ def train_transformer(fp32: bool = False, pretrained: bool = False, resume: str
if pretrained:
eprint("Loading pretrained model")
model = TrOMR(config)
tr_omr_pretrained = config.filepaths.checkpoint
model.load_state_dict(torch.load(tr_omr_pretrained), strict=False)
checkpoint_file_path = config.filepaths.checkpoint
if ".safetensors" in checkpoint_file_path:
tensors = {}
with safetensors.safe_open(checkpoint_file_path, framework="pt", device=0) as f: # type: ignore
for k in f.keys():
tensors[k] = f.get_tensor(k)
model.load_state_dict(tensors, strict=False)
else:
model.load_state_dict(torch.load(checkpoint_file_path), strict=False)
else:
model = TrOMR(config)

Expand Down

0 comments on commit 91ff1e0

Please sign in to comment.