Skip to content

Commit

Permalink
ensure same bs on all TPUs
Browse files Browse the repository at this point in the history
  • Loading branch information
markus583 committed Apr 23, 2024
1 parent 37824a4 commit baf398a
Showing 1 changed file with 18 additions and 13 deletions.
31 changes: 18 additions & 13 deletions wtpsplit/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,26 @@
import math
import os
import random
import shutil
import sys
import time
from collections import Counter, defaultdict
from dataclasses import dataclass
from functools import partial
from glob import glob
from typing import List, Optional
import shutil

import datasets
import numpy as np
import torch
import torch_xla.core.xla_model as xm
import transformers
from datasets import load_dataset
from datasets.download import DownloadConfig
from tokenizers import AddedToken
from torchinfo import summary
from tqdm.auto import tqdm
from transformers import AutoTokenizer, HfArgumentParser, TrainingArguments, set_seed
import pdb

import wandb
from wtpsplit.models import (
Expand All @@ -32,10 +32,10 @@
SubwordXLMConfig,
SubwordXLMForTokenClassification,
)
from wtpsplit.train.evaluate import evaluate_sentence, evaluate_sentence_pairwise, evaluate_sentence_kmers
from wtpsplit.train.evaluate import evaluate_sentence, evaluate_sentence_kmers, evaluate_sentence_pairwise
from wtpsplit.train.trainer import Trainer
from wtpsplit.utils import Constants, LabelArgs, corrupt_training, get_label_dict, get_subword_label_dict
from wtpsplit.train.utils import Model, cleanup_cache_files
from wtpsplit.utils import Constants, LabelArgs, corrupt_training, get_label_dict, get_subword_label_dict

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -200,6 +200,13 @@ def main():
else:
(args, training_args, label_args) = parser.parse_args_into_dataclasses()
wandb_name = None
if xm.xrt_world_size() == 4:
# ensure same batch size on TPUv3 and TPUv4
training_args.per_device_train_batch_size *= 2
logger.warning(f"Per device train batch size: {training_args.per_device_train_batch_size}")
logger.warning(
f"Total train batch size: {training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps* xm.xrt_world_size()}"
)

setup_logging(training_args)
set_seed(training_args.seed)
Expand Down Expand Up @@ -242,7 +249,6 @@ def main():
if args.lookahead:
assert args.lookahead % args.num_hidden_layers == 0


else:
tokenizer = None
config = LACanineConfig.from_pretrained(
Expand Down Expand Up @@ -271,8 +277,7 @@ def main():
backbone = LACanineForTokenClassification.from_pretrained(
args.model_name_or_path, ignore_mismatched_sizes=True, config=config
)



model = Model(
backbone,
loss_margin=args.loss_margin,
Expand Down Expand Up @@ -649,12 +654,12 @@ def compute_metrics(trainer):
# avg_metrics[f"pairwise_average_{dataset_name}_f1"].append(info["f1"])
# avg_metrics[f"pairwise_average_{dataset_name}_f1_best"].append(info["f1_best"])
# avg_metrics[f"pairwise_average_{dataset_name}_threshold_best"].append(info["threshold_best"])
# if lang_code in ["zh", "ja", "my", "km"]:
# avg_metrics[f"pairwise_average_nonwhitespace_{dataset_name}_pr_auc"].append(score)
# avg_metrics[f"pairwise_average_nonwhitespace_{dataset_name}_acc"].append(avg_acc)
# else:
# avg_metrics[f"pairwise_average_whitespace_{dataset_name}_pr_auc"].append(score)
# avg_metrics[f"pairwise_average_whitespace_{dataset_name}_acc"].append(avg_acc)
# if lang_code in ["zh", "ja", "my", "km"]:
# avg_metrics[f"pairwise_average_nonwhitespace_{dataset_name}_pr_auc"].append(score)
# avg_metrics[f"pairwise_average_nonwhitespace_{dataset_name}_acc"].append(avg_acc)
# else:
# avg_metrics[f"pairwise_average_whitespace_{dataset_name}_pr_auc"].append(score)
# avg_metrics[f"pairwise_average_whitespace_{dataset_name}_acc"].append(avg_acc)

for name, values in avg_metrics.items():
if len(values) > 1:
Expand Down

0 comments on commit baf398a

Please sign in to comment.