From 6e2941ee8efb4230946be98ffba32c27b5e808f2 Mon Sep 17 00:00:00 2001 From: xingchensong Date: Fri, 8 Dec 2023 19:55:14 +0800 Subject: [PATCH] refactor(yaml): try to pass unittest --- wenet/ctl_model/asr_model_ctl.py | 14 +++- wenet/k2/model.py | 24 +++++-- wenet/transducer/transducer.py | 14 +++- wenet/utils/init_model.py | 13 ++-- wenet/utils/module_utils.py | 120 ------------------------------- 5 files changed, 47 insertions(+), 138 deletions(-) delete mode 100644 wenet/utils/module_utils.py diff --git a/wenet/ctl_model/asr_model_ctl.py b/wenet/ctl_model/asr_model_ctl.py index de0d4d7259..c6f02ecc33 100644 --- a/wenet/ctl_model/asr_model_ctl.py +++ b/wenet/ctl_model/asr_model_ctl.py @@ -48,11 +48,19 @@ def __init__( logit_temp: float = 0.1, n_negatives: int = 0, ctl_weight: float = 1, + special_tokens: dict = None, ): assert 0.0 <= ctc_weight <= 1.0, ctc_weight - super().__init__(vocab_size, encoder, decoder, ctc, ctc_weight, - ignore_id, reverse_weight, lsm_weight, - length_normalized_loss) + super().__init__(vocab_size, + encoder, + decoder, + ctc, + ctc_weight, + ignore_id, + reverse_weight, + lsm_weight, + length_normalized_loss, + special_tokens=special_tokens) # For CTL Loss self.n_negatives = n_negatives diff --git a/wenet/k2/model.py b/wenet/k2/model.py index 271e450da2..bbc580cdc3 100644 --- a/wenet/k2/model.py +++ b/wenet/k2/model.py @@ -38,19 +38,29 @@ def __init__( lsm_weight: float = 0.0, length_normalized_loss: bool = False, lfmmi_dir: str = '', + special_tokens: dict = None, ): - super().__init__(vocab_size, encoder, decoder, ctc, ctc_weight, - ignore_id, reverse_weight, lsm_weight, - length_normalized_loss) + super().__init__(vocab_size, + encoder, + decoder, + ctc, + ctc_weight, + ignore_id, + reverse_weight, + lsm_weight, + length_normalized_loss, + special_tokens=special_tokens) self.lfmmi_dir = lfmmi_dir if self.lfmmi_dir != '': self.load_lfmmi_resource() @torch.jit.ignore(drop=True) - def _forward_ctc(self, encoder_out: torch.Tensor, - encoder_mask: torch.Tensor, text: torch.Tensor, - text_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - loss_ctc, ctc_probs = self._calc_lfmmi_loss(encoder_out, encoder_mask, text) + def _forward_ctc( + self, encoder_out: torch.Tensor, encoder_mask: torch.Tensor, + text: torch.Tensor, + text_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + loss_ctc, ctc_probs = self._calc_lfmmi_loss(encoder_out, encoder_mask, + text) return loss_ctc, ctc_probs @torch.jit.ignore(drop=True) diff --git a/wenet/transducer/transducer.py b/wenet/transducer/transducer.py index 95bd961e25..ed1e58a35d 100644 --- a/wenet/transducer/transducer.py +++ b/wenet/transducer/transducer.py @@ -41,11 +41,19 @@ def __init__( warmup_steps: float = 25000, lm_only_scale: float = 0.25, am_only_scale: float = 0.0, + special_tokens: dict = None, ) -> None: assert attention_weight + ctc_weight + transducer_weight == 1.0 - super().__init__(vocab_size, encoder, attention_decoder, ctc, - ctc_weight, ignore_id, reverse_weight, lsm_weight, - length_normalized_loss) + super().__init__(vocab_size, + encoder, + attention_decoder, + ctc, + ctc_weight, + ignore_id, + reverse_weight, + lsm_weight, + length_normalized_loss, + special_tokens=special_tokens) self.blank = blank self.transducer_weight = transducer_weight diff --git a/wenet/utils/init_model.py b/wenet/utils/init_model.py index 3ed52df74f..3a50eb05a5 100644 --- a/wenet/utils/init_model.py +++ b/wenet/utils/init_model.py @@ -122,6 +122,7 @@ def init_model(args, configs): attention_decoder=decoder, joint=joint, ctc=ctc, + special_tokens=configs['tokenizer_conf']['special_tokens'], **configs['model_conf']) elif configs['model'] == 'paraformer': """ NOTE(Mddct): support fintune paraformer, if there is a need for @@ -134,11 +135,13 @@ def init_model(args, configs): print(configs) return model, configs else: - model = WENET_MODEL_CLASSES[configs['model']](vocab_size=vocab_size, - encoder=encoder, - decoder=decoder, - ctc=ctc, - **configs['model_conf']) + model = WENET_MODEL_CLASSES[configs['model']]( + vocab_size=vocab_size, + encoder=encoder, + decoder=decoder, + ctc=ctc, + special_tokens=configs['tokenizer_conf']['special_tokens'], + **configs['model_conf']) # If specify checkpoint, load some info from checkpoint if args.checkpoint is not None: diff --git a/wenet/utils/module_utils.py b/wenet/utils/module_utils.py deleted file mode 100644 index 2e7053a605..0000000000 --- a/wenet/utils/module_utils.py +++ /dev/null @@ -1,120 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Copyright [2023-11-28] -import torch - -from wenet.transformer.swish import Swish -from wenet.transformer.subsampling import ( - LinearNoSubsampling, - EmbedinigNoSubsampling, - Conv1dSubsampling2, - Conv2dSubsampling4, - Conv2dSubsampling6, - Conv2dSubsampling8, -) -from wenet.efficient_conformer.subsampling import Conv2dSubsampling2 -from wenet.squeezeformer.subsampling import DepthwiseConv2dSubsampling4 -from wenet.transformer.embedding import (PositionalEncoding, - RelPositionalEncoding, - WhisperPositionalEncoding, - LearnablePositionalEncoding, - NoPositionalEncoding) -from wenet.transformer.attention import (MultiHeadedAttention, - RelPositionMultiHeadedAttention) -from wenet.efficient_conformer.attention import GroupedRelPositionMultiHeadedAttention -from wenet.transformer.encoder import ConformerEncoder, TransformerEncoder -from wenet.branchformer.encoder import BranchformerEncoder -from wenet.e_branchformer.encoder import EBranchformerEncoder -from wenet.squeezeformer.encoder import SqueezeformerEncoder -from wenet.efficient_conformer.encoder import EfficientConformerEncoder -from wenet.ctl_model.encoder import DualConformerEncoder -from wenet.ctl_model.encoder import DualTransformerEncoder -from wenet.transformer.decoder import BiTransformerDecoder, TransformerDecoder -from wenet.transformer.ctc import CTC -from wenet.transformer.asr_model import ASRModel -from wenet.ctl_model.asr_model_ctl import CTLModel -from wenet.whisper.whisper import Whisper -from wenet.transducer.predictor import (ConvPredictor, EmbeddingPredictor, - RNNPredictor) -from wenet.transducer.joint import TransducerJoint -from wenet.k2.model import K2Model -from wenet.transducer.transducer import Transducer - -WENET_ACTIVATION_CLASSES = { - "hardtanh": torch.nn.Hardtanh, - "tanh": torch.nn.Tanh, - "relu": torch.nn.ReLU, - "selu": torch.nn.SELU, - "swish": getattr(torch.nn, "SiLU", Swish), - "gelu": torch.nn.GELU, -} - -WENET_RNN_CLASSES = { - "rnn": torch.nn.RNN, - "lstm": torch.nn.LSTM, - "gru": torch.nn.GRU, -} - -WENET_SUBSAMPLE_CLASSES = { - "linear": LinearNoSubsampling, - "embed": EmbedinigNoSubsampling, - "conv1d2": Conv1dSubsampling2, - "conv2d2": Conv2dSubsampling2, - "conv2d": Conv2dSubsampling4, - "dwconv2d4": DepthwiseConv2dSubsampling4, - "conv2d6": Conv2dSubsampling6, - "conv2d8": Conv2dSubsampling8, -} - -WENET_EMB_CLASSES = { - "embed": PositionalEncoding, - "abs_pos": PositionalEncoding, - "rel_pos": RelPositionalEncoding, - "no_pos": NoPositionalEncoding, - "abs_pos_whisper": WhisperPositionalEncoding, - "embed_learnable_pe": LearnablePositionalEncoding, -} - -WENET_ATTENTION_CLASSES = { - "selfattn": MultiHeadedAttention, - "rel_selfattn": RelPositionMultiHeadedAttention, - "grouped_rel_selfattn": GroupedRelPositionMultiHeadedAttention, -} - -WENET_ENCODER_CLASSES = { - "transformer": TransformerEncoder, - "conformer": ConformerEncoder, - "squeezeformer": SqueezeformerEncoder, - "efficientConformer": EfficientConformerEncoder, - "branchformer": BranchformerEncoder, - "e_branchformer": EBranchformerEncoder, - "dual_transformer": DualTransformerEncoder, - "dual_conformer": DualConformerEncoder, -} - -WENET_DECODER_CLASSES = { - "transformer": TransformerDecoder, - "bi_transformer": BiTransformerDecoder, -} - -WENET_CTC_CLASSES = { - "ctc": CTC, -} - -WENET_PREDICTOR_CLASSES = { - "rnn": RNNPredictor, - "embedding": EmbeddingPredictor, - "conv": ConvPredictor, -} - -WENET_JOINT_CLASSES = { - "transducerjoint": TransducerJoint, -} - -WENET_MODEL_CLASSES = { - "asrmodel": ASRModel, - "ctlmodel": CTLModel, - "whisper": Whisper, - "k2model": K2Model, - "transducer": Transducer, -}