Skip to content

Commit

Permalink
Add job serialisation and save options. Save job to model dir when ru…
Browse files Browse the repository at this point in the history
…nning
  • Loading branch information
harrykeightley committed Oct 18, 2023
1 parent 7fe6d35 commit 3ab1288
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 9 deletions.
41 changes: 37 additions & 4 deletions elpis/models/job.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import json
from copy import copy
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, List, Optional
Expand All @@ -11,9 +13,6 @@ def list_field(default=None, metadata=None):
return field(default_factory=lambda: default, metadata=metadata)


DEFAULT_METRICS = ("wer", "cer")


@dataclass
class ModelArguments:
"""
Expand Down Expand Up @@ -110,6 +109,10 @@ class ModelArguments:
},
)

def to_dict(self) -> Dict[str, Any]:
result = dict(self.__dict__)
return result


@dataclass
class DataArguments:
Expand Down Expand Up @@ -219,7 +222,7 @@ class DataArguments:
metadata={"help": "Whether the target text should be lower cased."},
)
eval_metrics: List[str] = list_field( # type: ignore
default=DEFAULT_METRICS,
default=["wer", "cer"],
metadata={
"help": "A list of metrics the model should be evaluated on. E.g. `('wer', 'cer')`"
},
Expand Down Expand Up @@ -299,6 +302,10 @@ class DataArguments:
},
)

def to_dict(self) -> Dict[str, Any]:
result = dict(self.__dict__)
return result


@dataclass
class Job:
Expand Down Expand Up @@ -334,6 +341,13 @@ def from_json(cls, file: Path) -> Job:
model_args=model_args, data_args=data_args, training_args=training_args
)

def save(self, path: Path, overwrite=True) -> None:
if not overwrite and path.is_file():
return

with open(path, "w") as out_file:
json.dump(self.to_dict(), out_file)

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> Job:
(
Expand All @@ -344,3 +358,22 @@ def from_dict(cls, data: Dict[str, Any]) -> Job:
return cls(
model_args=model_args, data_args=data_args, training_args=training_args
)

def to_dict(self) -> Dict[str, Any]:
return (
self.training_args.to_dict()
| self.data_args.to_dict()
| self.model_args.to_dict()
)

def __eq__(self, __value: object) -> bool:
if not isinstance(__value, Job):
return False

job = __value

return (
self.training_args.to_dict() == job.training_args.to_dict()
and self.model_args == job.model_args
and self.data_args == job.data_args
)
11 changes: 6 additions & 5 deletions elpis/trainer/trainer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import warnings
from contextlib import nullcontext
from functools import partial, reduce
from functools import reduce
from pathlib import Path
from typing import Any, Iterable, Optional

Expand Down Expand Up @@ -48,6 +48,7 @@ def run_job(
cache_dir = job.model_args.cache_dir
Path(output_dir).mkdir(exist_ok=True, parents=True)

job.save(Path(output_dir) / "job.json")
set_seed(job.training_args.seed)

logger.info("Preparing Datasets...")
Expand Down Expand Up @@ -82,7 +83,7 @@ def run_job(
config.save_pretrained(output_dir) # type: ignore

try:
processor = AutoProcessor.from_pretrained(job.training_args.output_dir)
processor = AutoProcessor.from_pretrained(output_dir)
except (OSError, KeyError):
warnings.warn(
"Loading a processor from a feature extractor config that does not"
Expand All @@ -91,7 +92,7 @@ def run_job(
" `'processor_class': 'Wav2Vec2Processor'`",
FutureWarning,
)
processor = Wav2Vec2Processor.from_pretrained(job.training_args.output_dir)
processor = Wav2Vec2Processor.from_pretrained(output_dir)

data_collator = DataCollatorCTCWithPadding(processor=processor) # type: ignore

Expand Down Expand Up @@ -299,7 +300,7 @@ def last_checkpoint(job: Job) -> Optional[str]:
return checkpoint


def train(job: Job, trainer: Trainer, dataset: DatasetDict):
def train(job: Job, trainer: Trainer, dataset: DatasetDict | IterableDatasetDict):
if not job.training_args.do_train:
logger.info("Skipping training: `job.training_args.do_train` is false.")
return
Expand Down Expand Up @@ -327,7 +328,7 @@ def train(job: Job, trainer: Trainer, dataset: DatasetDict):
trainer.save_state()


def evaluate(job: Job, trainer: Trainer, dataset: DatasetDict):
def evaluate(job: Job, trainer: Trainer, dataset: DatasetDict | IterableDatasetDict):
if not job.training_args.do_eval:
logger.info("Skipping eval: `job.training_args.do_eval` is false.")
return
Expand Down
35 changes: 35 additions & 0 deletions tests/models/test_job.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from pathlib import Path

import pytest
from transformers import TrainingArguments

from elpis.models.job import DataArguments, Job, ModelArguments


@pytest.fixture
def job(tmp_path: Path):
model_dir = tmp_path / "model"

return Job(
model_args=ModelArguments(
"facebook/wav2vec2-base",
),
data_args=DataArguments(
dataset_name_or_path="mozilla-foundation/common_voice_11_0",
dataset_config_name="gn",
),
training_args=TrainingArguments(output_dir=str(model_dir)),
)


def test_save_job(tmp_path: Path, job: Job):
file = tmp_path / "job.json"
job.save(file)

assert file.is_file()
assert Job.from_json(file) == job


def test_job_serialization(job: Job):
data = job.to_dict()
assert Job.from_dict(data) == job

0 comments on commit 3ab1288

Please sign in to comment.