diff --git a/wenet/paraformer/layers.py b/wenet/paraformer/layers.py index b1cd362a0..d17280d8a 100644 --- a/wenet/paraformer/layers.py +++ b/wenet/paraformer/layers.py @@ -290,8 +290,12 @@ def forward_layers_checkpointed(self, xs: torch.Tensor, for layer in self.encoders0: xs, _, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad) for layer in self.encoders: - xs, _, _, _ = ckpt.checkpoint(layer.__call__, xs, chunk_masks, - pos_emb, mask_pad, use_reentrant=False) + xs, _, _, _ = ckpt.checkpoint(layer.__call__, + xs, + chunk_masks, + pos_emb, + mask_pad, + use_reentrant=False) return xs @@ -480,8 +484,12 @@ def forward_layers_checkpointed(self, x: torch.Tensor, if i == 0: x, _, _, _ = layer(x, tgt_mask, memory, memory_mask) else: - x, _, _, _ = ckpt.checkpoint(layer.__call__, x, tgt_mask, - memory, memory_mask, use_reentrant=False) + x, _, _, _ = ckpt.checkpoint(layer.__call__, + x, + tgt_mask, + memory, + memory_mask, + use_reentrant=False) for layer in self.decoders3: x = layer(x) return x diff --git a/wenet/transformer/decoder.py b/wenet/transformer/decoder.py index 00df599f2..ba31edffc 100644 --- a/wenet/transformer/decoder.py +++ b/wenet/transformer/decoder.py @@ -15,6 +15,7 @@ """Decoder definition.""" from typing import Dict, Tuple, List, Optional +import os import torch import torch.utils.checkpoint as ckpt import logging @@ -214,7 +215,11 @@ def forward_layers_checkpointed(self, x: torch.Tensor, memory_mask: torch.Tensor) -> torch.Tensor: for layer in self.decoders: x, tgt_mask, memory, memory_mask = ckpt.checkpoint( - layer.__call__, x, tgt_mask, memory, memory_mask, + layer.__call__, + x, + tgt_mask, + memory, + memory_mask, use_reentrant=False) return x @@ -278,14 +283,17 @@ def forward_one_step( def tie_or_clone_weights(self, jit_mode: bool = True): """Tie or clone module weights (between word_emb and output_layer) depending of whether we are using TorchScript or not""" + rank = int(os.environ.get('RANK', 0)) if not self.use_output_layer: return if jit_mode: - logging.info("clone emb.weight to output.weight") + if rank == 0: + logging.info("clone emb.weight to output.weight") self.output_layer.weight = torch.nn.Parameter( self.embed[0].weight.clone()) else: - logging.info("tie emb.weight with output.weight") + if rank == 0: + logging.info("tie emb.weight with output.weight") self.output_layer.weight = self.embed[0].weight if getattr(self.output_layer, "bias", None) is not None: diff --git a/wenet/transformer/encoder.py b/wenet/transformer/encoder.py index 83ea684fe..9cfd260ea 100644 --- a/wenet/transformer/encoder.py +++ b/wenet/transformer/encoder.py @@ -190,9 +190,12 @@ def forward_layers_checkpointed(self, xs: torch.Tensor, pos_emb: torch.Tensor, mask_pad: torch.Tensor) -> torch.Tensor: for layer in self.encoders: - xs, chunk_masks, _, _ = ckpt.checkpoint(layer.__call__, xs, - chunk_masks, pos_emb, - mask_pad, use_reentrant=False) + xs, chunk_masks, _, _ = ckpt.checkpoint(layer.__call__, + xs, + chunk_masks, + pos_emb, + mask_pad, + use_reentrant=False) return xs def forward_chunk( diff --git a/wenet/utils/checkpoint.py b/wenet/utils/checkpoint.py index 9071a9bec..8a2dfba61 100644 --- a/wenet/utils/checkpoint.py +++ b/wenet/utils/checkpoint.py @@ -24,14 +24,17 @@ def load_checkpoint(model: torch.nn.Module, path: str) -> dict: - logging.info('Checkpoint: loading from checkpoint %s' % path) + rank = int(os.environ.get('RANK', 0)) + logging.info('[Rank {}] Checkpoint: loading from checkpoint {}'.format( + rank, path)) checkpoint = torch.load(path, map_location='cpu', mmap=True) missing_keys, unexpected_keys = model.load_state_dict(checkpoint, strict=False) - for key in missing_keys: - logging.info("missing tensor: {}".format(key)) - for key in unexpected_keys: - logging.info("unexpected tensor: {}".format(key)) + if rank == 0: + for key in missing_keys: + logging.info("missing tensor: {}".format(key)) + for key in unexpected_keys: + logging.info("unexpected tensor: {}".format(key)) info_path = re.sub('.pt$', '.yaml', path) configs = {} if os.path.exists(info_path): @@ -41,7 +44,9 @@ def load_checkpoint(model: torch.nn.Module, path: str) -> dict: def save_state_dict_and_infos(state_dict, path: str, infos=None): - logging.info('Checkpoint: save to checkpoint %s' % path) + rank = int(os.environ.get('RANK', 0)) + logging.info('[Rank {}] Checkpoint: save to checkpoint {}'.format( + rank, path)) torch.save(state_dict, path) info_path = re.sub('.pt$', '.yaml', path) if infos is None: @@ -67,6 +72,7 @@ def save_checkpoint(model: torch.nn.Module, path: str, infos=None): def filter_modules(model_state_dict, modules): + rank = int(os.environ.get('RANK', 0)) new_mods = [] incorrect_mods = [] mods_model = model_state_dict.keys() @@ -75,7 +81,7 @@ def filter_modules(model_state_dict, modules): new_mods += [mod] else: incorrect_mods += [mod] - if incorrect_mods: + if incorrect_mods and rank == 0: logging.warning( "module(s) %s don't match or (partially match) " "available modules in model.", diff --git a/wenet/utils/init_model.py b/wenet/utils/init_model.py index b278b6dd9..309b4d67f 100644 --- a/wenet/utils/init_model.py +++ b/wenet/utils/init_model.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os import torch from wenet.finetune.lora.utils import mark_only_lora_as_trainable @@ -180,7 +181,8 @@ def init_model(args, configs): if hasattr(args, 'only_optimize_lora') and args.only_optimize_lora: mark_only_lora_as_trainable(model, bias='lora_only') - print(configs) + if int(os.environ.get('RANK', 0)) == 0: + print(configs) # Tie emb.weight to decoder.output_layer.weight if model.decoder.tie_word_embedding: