diff --git a/kazu/training/config.py b/kazu/training/config.py index dd9d6e11..8d0209fe 100644 --- a/kazu/training/config.py +++ b/kazu/training/config.py @@ -45,6 +45,8 @@ class TrainingConfig: architecture: str = "bert" #: fraction of epoch to complete before evaluations begin epoch_completion_fraction_before_evals: float = 0.75 + #: The random seed to use + seed: int = 42 @dataclass diff --git a/kazu/training/modelling_utils.py b/kazu/training/modelling_utils.py index 5e7347b6..de44310f 100644 --- a/kazu/training/modelling_utils.py +++ b/kazu/training/modelling_utils.py @@ -1,7 +1,6 @@ import copy import json import logging -import random from collections.abc import Iterable from pathlib import Path from typing import Any, Optional, Union @@ -102,26 +101,24 @@ def get_gold_ents_for_side_by_side_view(self, docs: list[Document]) -> list[list return result def update( - self, test_docs: list[Document], global_step: Union[int, str], has_gs: bool = True + self, docs: list[Document], global_step: Union[int, str], has_gs: bool = True ) -> None: ls_manager = LabelStudioManager( headers=self.ls_manager.headers, project_name=f"{self.ls_manager.project_name}_test_{global_step}", ) - ls_manager.delete_project_if_exists() ls_manager.create_linking_project() - docs_subset = random.sample(test_docs, min([len(test_docs), 100])) - if not docs_subset: + if not docs: logger.info("no results to represent yet") return if has_gs: - side_by_side = self.get_gold_ents_for_side_by_side_view(docs_subset) + side_by_side = self.get_gold_ents_for_side_by_side_view(docs) ls_manager.update_view(self.view, side_by_side) ls_manager.update_tasks(side_by_side) else: - ls_manager.update_view(self.view, docs_subset) - ls_manager.update_tasks(docs_subset) + ls_manager.update_view(self.view, docs) + ls_manager.update_tasks(docs) def create_wrapper(cfg: DictConfig, label_list: list[str]) -> Optional[LSManagerViewWrapper]: diff --git a/kazu/training/train_multilabel_ner.py b/kazu/training/train_multilabel_ner.py index 715fb8cf..46c2fc61 100644 --- a/kazu/training/train_multilabel_ner.py +++ b/kazu/training/train_multilabel_ner.py @@ -4,6 +4,7 @@ import logging import math import pickle +import random import shutil import tempfile from collections import defaultdict @@ -337,6 +338,7 @@ def __init__( self.label_list = label_list self.pretrained_model_name_or_path = pretrained_model_name_or_path self.keys_to_use = _select_keys_to_use(self.training_config.architecture) + random.seed(training_config.seed) def _write_to_tensorboard( self, global_step: int, main_tag: str, tag_scalar_dict: dict[str, NumericMetric] @@ -360,7 +362,8 @@ def evaluate_model( model_test_docs = self._process_docs(model) if self.ls_wrapper: - self.ls_wrapper.update(model_test_docs, global_step) + sample_test_docs = random.sample(model_test_docs, min([len(model_test_docs), 100])) + self.ls_wrapper.update(sample_test_docs, global_step) all_results, tensorboad_loggables = calculate_metrics( epoch_loss, model_test_docs, self.label_list