Skip to content

Commit

Permalink
[log] avoid reduntant logging (#2493)
Browse files Browse the repository at this point in the history
  • Loading branch information
xingchensong authored Apr 18, 2024
1 parent 5b4e2f4 commit c415f6c
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 18 deletions.
16 changes: 12 additions & 4 deletions wenet/paraformer/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,8 +290,12 @@ def forward_layers_checkpointed(self, xs: torch.Tensor,
for layer in self.encoders0:
xs, _, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
for layer in self.encoders:
xs, _, _, _ = ckpt.checkpoint(layer.__call__, xs, chunk_masks,
pos_emb, mask_pad, use_reentrant=False)
xs, _, _, _ = ckpt.checkpoint(layer.__call__,
xs,
chunk_masks,
pos_emb,
mask_pad,
use_reentrant=False)
return xs


Expand Down Expand Up @@ -480,8 +484,12 @@ def forward_layers_checkpointed(self, x: torch.Tensor,
if i == 0:
x, _, _, _ = layer(x, tgt_mask, memory, memory_mask)
else:
x, _, _, _ = ckpt.checkpoint(layer.__call__, x, tgt_mask,
memory, memory_mask, use_reentrant=False)
x, _, _, _ = ckpt.checkpoint(layer.__call__,
x,
tgt_mask,
memory,
memory_mask,
use_reentrant=False)
for layer in self.decoders3:
x = layer(x)
return x
14 changes: 11 additions & 3 deletions wenet/transformer/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Decoder definition."""
from typing import Dict, Tuple, List, Optional

import os
import torch
import torch.utils.checkpoint as ckpt
import logging
Expand Down Expand Up @@ -214,7 +215,11 @@ def forward_layers_checkpointed(self, x: torch.Tensor,
memory_mask: torch.Tensor) -> torch.Tensor:
for layer in self.decoders:
x, tgt_mask, memory, memory_mask = ckpt.checkpoint(
layer.__call__, x, tgt_mask, memory, memory_mask,
layer.__call__,
x,
tgt_mask,
memory,
memory_mask,
use_reentrant=False)
return x

Expand Down Expand Up @@ -278,14 +283,17 @@ def forward_one_step(
def tie_or_clone_weights(self, jit_mode: bool = True):
"""Tie or clone module weights (between word_emb and output_layer)
depending of whether we are using TorchScript or not"""
rank = int(os.environ.get('RANK', 0))
if not self.use_output_layer:
return
if jit_mode:
logging.info("clone emb.weight to output.weight")
if rank == 0:
logging.info("clone emb.weight to output.weight")
self.output_layer.weight = torch.nn.Parameter(
self.embed[0].weight.clone())
else:
logging.info("tie emb.weight with output.weight")
if rank == 0:
logging.info("tie emb.weight with output.weight")
self.output_layer.weight = self.embed[0].weight

if getattr(self.output_layer, "bias", None) is not None:
Expand Down
9 changes: 6 additions & 3 deletions wenet/transformer/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,9 +190,12 @@ def forward_layers_checkpointed(self, xs: torch.Tensor,
pos_emb: torch.Tensor,
mask_pad: torch.Tensor) -> torch.Tensor:
for layer in self.encoders:
xs, chunk_masks, _, _ = ckpt.checkpoint(layer.__call__, xs,
chunk_masks, pos_emb,
mask_pad, use_reentrant=False)
xs, chunk_masks, _, _ = ckpt.checkpoint(layer.__call__,
xs,
chunk_masks,
pos_emb,
mask_pad,
use_reentrant=False)
return xs

def forward_chunk(
Expand Down
20 changes: 13 additions & 7 deletions wenet/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,17 @@


def load_checkpoint(model: torch.nn.Module, path: str) -> dict:
logging.info('Checkpoint: loading from checkpoint %s' % path)
rank = int(os.environ.get('RANK', 0))
logging.info('[Rank {}] Checkpoint: loading from checkpoint {}'.format(
rank, path))
checkpoint = torch.load(path, map_location='cpu', mmap=True)
missing_keys, unexpected_keys = model.load_state_dict(checkpoint,
strict=False)
for key in missing_keys:
logging.info("missing tensor: {}".format(key))
for key in unexpected_keys:
logging.info("unexpected tensor: {}".format(key))
if rank == 0:
for key in missing_keys:
logging.info("missing tensor: {}".format(key))
for key in unexpected_keys:
logging.info("unexpected tensor: {}".format(key))
info_path = re.sub('.pt$', '.yaml', path)
configs = {}
if os.path.exists(info_path):
Expand All @@ -41,7 +44,9 @@ def load_checkpoint(model: torch.nn.Module, path: str) -> dict:


def save_state_dict_and_infos(state_dict, path: str, infos=None):
logging.info('Checkpoint: save to checkpoint %s' % path)
rank = int(os.environ.get('RANK', 0))
logging.info('[Rank {}] Checkpoint: save to checkpoint {}'.format(
rank, path))
torch.save(state_dict, path)
info_path = re.sub('.pt$', '.yaml', path)
if infos is None:
Expand All @@ -67,6 +72,7 @@ def save_checkpoint(model: torch.nn.Module, path: str, infos=None):


def filter_modules(model_state_dict, modules):
rank = int(os.environ.get('RANK', 0))
new_mods = []
incorrect_mods = []
mods_model = model_state_dict.keys()
Expand All @@ -75,7 +81,7 @@ def filter_modules(model_state_dict, modules):
new_mods += [mod]
else:
incorrect_mods += [mod]
if incorrect_mods:
if incorrect_mods and rank == 0:
logging.warning(
"module(s) %s don't match or (partially match) "
"available modules in model.",
Expand Down
4 changes: 3 additions & 1 deletion wenet/utils/init_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import torch

from wenet.finetune.lora.utils import mark_only_lora_as_trainable
Expand Down Expand Up @@ -180,7 +181,8 @@ def init_model(args, configs):
if hasattr(args, 'only_optimize_lora') and args.only_optimize_lora:
mark_only_lora_as_trainable(model, bias='lora_only')

print(configs)
if int(os.environ.get('RANK', 0)) == 0:
print(configs)

# Tie emb.weight to decoder.output_layer.weight
if model.decoder.tie_word_embedding:
Expand Down

0 comments on commit c415f6c

Please sign in to comment.