diff --git a/wenet/bin/train.py b/wenet/bin/train.py index a8a18fdcf..60d6374f4 100644 --- a/wenet/bin/train.py +++ b/wenet/bin/train.py @@ -24,6 +24,7 @@ import torch.distributed as dist from torch.distributed.elastic.multiprocessing.errors import record +from wenet.utils.common import lrs_to_str from wenet.utils.executor import Executor from wenet.utils.config import override_config @@ -117,8 +118,7 @@ def main(): # Get executor tag = configs["init_infos"].get("tag", "init") - executor = Executor(global_step=configs["init_infos"].get('step', -1) + - int("step_" in tag)) + executor = Executor(global_step=configs["init_infos"].get('step', -1)) # Init scaler, used for pytorch amp mixed precision training scaler = init_scaler(args) @@ -134,9 +134,9 @@ def main(): for epoch in range(start_epoch, end_epoch): configs['epoch'] = epoch - lr = optimizer.param_groups[0]['lr'] - logging.info('Epoch {} TRAIN info lr {} rank {}'.format( - epoch, lr, rank)) + lrs = [group['lr'] for group in optimizer.param_groups] + logging.info('Epoch {} Step {} TRAIN info lr {} rank {}'.format( + epoch, executor.step, lrs_to_str(lrs), rank)) dist.barrier( ) # NOTE(xcsong): Ensure all ranks start Train at the same time. @@ -150,19 +150,16 @@ def main(): dist.barrier( ) # NOTE(xcsong): Ensure all ranks start CV at the same time. loss_dict = executor.cv(model, cv_data_loader, configs) - - lr = optimizer.param_groups[0]['lr'] - logging.info('Epoch {} CV info lr {} cv_loss {} rank {} acc {}'.format( - epoch, lr, loss_dict["loss"], rank, loss_dict["acc"])) info_dict = { 'epoch': epoch, - 'lr': lr, + 'lrs': [group['lr'] for group in optimizer.param_groups], 'step': executor.step, 'save_time': datetime.datetime.now().strftime('%d/%m/%Y %H:%M:%S'), 'tag': "epoch_{}".format(epoch), 'loss_dict': loss_dict, **configs } + # epoch cv: tensorboard && log log_per_epoch(writer, info_dict=info_dict) save_model(model, info_dict=info_dict) diff --git a/wenet/utils/checkpoint.py b/wenet/utils/checkpoint.py index d60582716..30f06d46b 100644 --- a/wenet/utils/checkpoint.py +++ b/wenet/utils/checkpoint.py @@ -41,6 +41,7 @@ def load_checkpoint(model: torch.nn.Module, path: str) -> dict: def save_state_dict_and_infos(state_dict, path: str, infos=None): + logging.info('Checkpoint: save to checkpoint %s' % path) torch.save(state_dict, path) info_path = re.sub('.pt$', '.yaml', path) if infos is None: @@ -56,7 +57,6 @@ def save_checkpoint(model: torch.nn.Module, path: str, infos=None): Args: infos (dict or None): any info you want to save. ''' - logging.info('Checkpoint: save to checkpoint %s' % path) if isinstance(model, torch.nn.DataParallel): state_dict = model.module.state_dict() elif isinstance(model, torch.nn.parallel.DistributedDataParallel): diff --git a/wenet/utils/common.py b/wenet/utils/common.py index 282d73179..72afbfd05 100644 --- a/wenet/utils/common.py +++ b/wenet/utils/common.py @@ -321,6 +321,19 @@ def mask_to_bias(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: return mask +def get_nested_attribute(obj, attr_path): + if isinstance(obj, torch.nn.parallel.DistributedDataParallel): + obj = obj.module + attributes = attr_path.split('.') + for attr in attributes: + obj = getattr(obj, attr) + return obj + + +def lrs_to_str(lrs: List): + return " ".join(["{:.4e}".format(lr) for lr in lrs]) + + class StepTimer: """Utility class for measuring steps/second.""" diff --git a/wenet/utils/executor.py b/wenet/utils/executor.py index 11bc24706..45f2739e2 100644 --- a/wenet/utils/executor.py +++ b/wenet/utils/executor.py @@ -31,7 +31,7 @@ class Executor: def __init__(self, global_step: int = 0): - self.step = global_step + self.step = global_step + 1 self.train_step_timer = None self.cv_step_timer = None @@ -85,9 +85,12 @@ def train(self, model, optimizer, scheduler, train_data_loader, info_dict = update_parameter_and_lr(model, optimizer, scheduler, scaler, info_dict) + # write training: tensorboard && log + log_per_step(writer, info_dict, timer=self.train_step_timer) save_interval = info_dict.get('save_interval', sys.maxsize) - if self.step % save_interval == 0 and self.step != 0 \ - and (batch_idx + 1) % info_dict["accum_grad"] == 0: + if (self.step + + 1) % save_interval == 0 and self.step != 0 and ( + batch_idx + 1) % info_dict["accum_grad"] == 0: import torch.distributed as dist # Ensure all ranks start CV at the same time in step mode dist.barrier() @@ -100,13 +103,14 @@ def train(self, model, optimizer, scheduler, train_data_loader, loss_dict, "save_time": datetime.datetime.now().strftime('%d/%m/%Y %H:%M:%S'), - "lr": - optimizer.param_groups[0]['lr'] + "lrs": + [group['lr'] for group in optimizer.param_groups] }) save_model(model, info_dict) + # write final cv: tensorboard + log_per_step(writer, info_dict) # Ensure all ranks start Train at the same time in step mode dist.barrier() - log_per_step(writer, info_dict, timer=self.train_step_timer) self.step += 1 if (batch_idx + 1) % info_dict["accum_grad"] == 0 else 0 @@ -143,7 +147,7 @@ def cv(self, model, cv_data_loader, configs): loss_value = loss_value.item() loss_dict[loss_name] = loss_dict.get(loss_name, 0) + \ loss_value * num_utts - + # write cv: log log_per_step(writer=None, info_dict=info_dict, timer=self.cv_step_timer) diff --git a/wenet/utils/scheduler.py b/wenet/utils/scheduler.py index 6a78bb6c7..170e4fd1d 100644 --- a/wenet/utils/scheduler.py +++ b/wenet/utils/scheduler.py @@ -15,7 +15,7 @@ # Modified from ESPnet(https://github.com/espnet/espnet) # NeMo(https://github.com/NVIDIA/NeMo) -from typing import Union +from typing import List, Union import math import warnings @@ -43,11 +43,10 @@ class WarmupLR(_LRScheduler): def __init__( self, optimizer: torch.optim.Optimizer, - warmup_steps: Union[int, float] = 25000, + warmup_steps: Union[int, float, List[Union[int, float]]] = 25000, last_epoch: int = -1, ): self.warmup_steps = warmup_steps - # __init__() must be invoked before setting field # because step() is also invoked in __init__() super().__init__(optimizer, last_epoch) @@ -57,14 +56,21 @@ def __repr__(self): def get_lr(self): step_num = self.last_epoch + 1 - if self.warmup_steps == 0: - return [lr * step_num**-0.5 for lr in self.base_lrs] - else: - return [ - lr * self.warmup_steps**0.5 * - min(step_num**-0.5, step_num * self.warmup_steps**-1.5) - for lr in self.base_lrs - ] + warmup_steps = self.warmup_steps + if not isinstance(warmup_steps, List): + warmup_steps = [self.warmup_steps] * len(self.base_lrs) + + def initlr_fn(lr): + return lr * step_num**-0.5 + + def warmuplr_fn(lr, warmup_step): + return lr * warmup_step**0.5 * min(step_num**-0.5, + step_num * warmup_step**-1.5) + + return [ + initlr_fn(lr) if warmup_steps[i] == 0 else warmuplr_fn( + lr, warmup_steps[i]) for (i, lr) in enumerate(self.base_lrs) + ] def set_step(self, step: int): self.last_epoch = step diff --git a/wenet/utils/train_utils.py b/wenet/utils/train_utils.py index 06bb71b02..1d20a8980 100644 --- a/wenet/utils/train_utils.py +++ b/wenet/utils/train_utils.py @@ -15,7 +15,7 @@ from contextlib import nullcontext import copy -from typing import Optional +from typing import List, Optional import deepspeed import json @@ -41,10 +41,10 @@ convert_zero_checkpoint_to_fp32_state_dict) from wenet.dataset.dataset import Dataset from wenet.utils.checkpoint import save_checkpoint +from wenet.utils.common import StepTimer, get_nested_attribute, lrs_to_str from wenet.utils.fsdp_utils import (check_gradient_checkpoint, fsdp_save_model, apply_fsdp_checkpointing, wenet_fsdp_wrap_policy) -from wenet.utils.common import StepTimer from wenet.utils.scheduler import WarmupLR, NoamHoldAnnealing from wenet.utils.ctc_utils import get_blank_id @@ -439,10 +439,38 @@ def wrap_cuda_model(args, model, configs=None): def init_optimizer_and_scheduler(args, configs, model): + groups = [] + lr = configs['optim_conf'].get('lr') + if isinstance(lr, List): + assert configs['scheduler'] == 'warmuplr' + modules_m = configs['optim_conf']['modules'] + assert isinstance(modules_m, List) + assert len(modules_m) + 1 == len(lr) + special_param_ids = set() + rest_params = [] + for (i, m_str) in enumerate(modules_m): + sub_module = get_nested_attribute(model, m_str) + subs_params = [] + for _, sub_params in sub_module.named_parameters(): + subs_params.append(sub_params) + special_param_ids.add(id(sub_params)) + groups.append({'params': subs_params, 'lr': lr[i]}) + # other model's parameters + for _, param in model.named_parameters(): + if id(param) not in special_param_ids: + rest_params.append(param) + groups.append({'params': rest_params, 'lr': lr[-1]}) + + params = groups if len(groups) > 0 else model.parameters() + optim_conf = copy.deepcopy(configs['optim_conf']) + if 'modules' in optim_conf: + del optim_conf['modules'] + if isinstance(lr, List): + optim_conf['lr'] = lr[-1] if configs['optim'] == 'adam': - optimizer = optim.Adam(model.parameters(), **configs['optim_conf']) + optimizer = optim.Adam(params, **optim_conf) elif configs['optim'] == 'adamw': - optimizer = optim.AdamW(model.parameters(), **configs['optim_conf']) + optimizer = optim.AdamW(params, **optim_conf) else: raise ValueError("unknown optimizer: " + configs['optim']) @@ -704,7 +732,7 @@ def update_parameter_and_lr(model, optimizer, scheduler, scaler, info_dict): scheduler.step() grad_norm = grad_norm.item() - info_dict["lr"] = optimizer.param_groups[0]['lr'] + info_dict["lrs"] = [group['lr'] for group in optimizer.param_groups] info_dict["grad_norm"] = grad_norm return info_dict @@ -719,28 +747,36 @@ def log_per_step(writer, info_dict, timer: Optional[StepTimer] = None): train_engine = info_dict.get("train_engine", "torch_ddp") accum_grad = info_dict.get('accum_grad', 1) if tag != "CV" else 1 log_interval = info_dict.get('log_interval', 10) - lr = info_dict.get("lr", 0.0) + lrs = info_dict.get("lrs", [0.0]) is_gradient_accumulation_boundary = info_dict.get( "is_gradient_accumulation_boundary", False) rank = int(os.environ.get('RANK', 0)) - + # TRAIN Tensorboard if tag == "TRAIN" and rank == 0 and writer is not None: if (train_engine == "deepspeed" and is_gradient_accumulation_boundary ) or (train_engine in ["torch_ddp", "torch_fsdp"] and (batch_idx + 1) % accum_grad == 0): writer.add_scalar('train/train_loss', - loss_dict['loss'] * accum_grad, step + 1) - writer.add_scalar('train/grad_norm', info_dict['grad_norm'], - step + 1) + loss_dict['loss'] * accum_grad, step) + writer.add_scalar('train/grad_norm', info_dict['grad_norm'], step) for name, value in loss_dict.items(): if name != 'loss' and value is not None: - writer.add_scalar('train/{}'.format(name), value, step + 1) + writer.add_scalar('train/{}'.format(name), value, step) + # lr + for i, lr in enumerate(lrs): + writer.add_scalar('train/lr_{}'.format(i), lr, step) + # CV Tensorboard elif "step_" in tag and rank == 0 and writer is not None: - writer.add_scalar('global_step/lr', lr, step + 1) for name, value in loss_dict.items(): - writer.add_scalar('global_step/{}'.format(name), value, step + 1) - + writer.add_scalar('cv/{}'.format(name), value, step) + logging.info( + 'Epoch {} Step {} CV info lr {} cv_loss {} rank {} acc {}'.format( + epoch, step + 1, lrs_to_str(lrs), loss_dict["loss"], rank, + loss_dict["acc"])) + return + + # TRAIN & CV, Shell log (stdout) if (batch_idx + 1) % log_interval == 0: log_str = '{} | '.format(tag) if timer is not None: @@ -757,16 +793,25 @@ def log_per_step(writer, info_dict, timer: Optional[StepTimer] = None): if name != 'loss' and value is not None: log_str += '{} {:.6f} '.format(name, value) if tag == "TRAIN": - log_str += 'lr {:.8f} grad_norm {:.6f} rank {}'.format( - lr, info_dict['grad_norm'], rank) + log_str += 'lr {} grad_norm {:.6f} rank {}'.format( + lrs_to_str(lrs), info_dict['grad_norm'], rank) logging.debug(log_str) def log_per_epoch(writer, info_dict): epoch = info_dict["epoch"] loss_dict = info_dict["loss_dict"] + lrs = info_dict['lrs'] + rank = int(os.environ.get('RANK', 0)) + step = info_dict["step"] + logging.info( + 'Epoch {} Step {} CV info lr {} cv_loss {} rank {} acc {}'.format( + epoch, step, lrs_to_str(lrs), loss_dict["loss"], rank, + loss_dict["acc"])) + if int(os.environ.get('RANK', 0)) == 0: - writer.add_scalar('epoch/lr', info_dict["lr"], epoch) + for i, lr in enumerate(info_dict["lrs"]): + writer.add_scalar('epoch/lr_{}'.format(i), lr, epoch) for name, value in loss_dict.items(): writer.add_scalar('epoch/{}'.format(name), value, epoch)