From b4f20a07b560b3d33e480e335f6cceb1f78f93f8 Mon Sep 17 00:00:00 2001 From: xingchensong Date: Fri, 8 Dec 2023 22:42:33 +0800 Subject: [PATCH] [refactor] simplify code and keep API consisstent --- examples/aishell/rnnt/conf/conformer_rnnt.yaml | 2 ++ examples/aishell/rnnt/conf/conformer_u2pp_rnnt.yaml | 2 ++ .../aishell/rnnt/conf/example_embedding_predictor.yaml | 3 +++ wenet/transducer/predictor.py | 4 ++++ wenet/utils/init_model.py | 8 -------- 5 files changed, 11 insertions(+), 8 deletions(-) diff --git a/examples/aishell/rnnt/conf/conformer_rnnt.yaml b/examples/aishell/rnnt/conf/conformer_rnnt.yaml index aeab0b180..3af76ed78 100644 --- a/examples/aishell/rnnt/conf/conformer_rnnt.yaml +++ b/examples/aishell/rnnt/conf/conformer_rnnt.yaml @@ -19,6 +19,8 @@ encoder_conf: joint_conf: + enc_output_size: 256 + pred_output_size: 256 join_dim: 512 prejoin_linear: True postjoin_linear: false diff --git a/examples/aishell/rnnt/conf/conformer_u2pp_rnnt.yaml b/examples/aishell/rnnt/conf/conformer_u2pp_rnnt.yaml index 28a80d5f7..3481f20b3 100644 --- a/examples/aishell/rnnt/conf/conformer_u2pp_rnnt.yaml +++ b/examples/aishell/rnnt/conf/conformer_u2pp_rnnt.yaml @@ -23,6 +23,8 @@ encoder_conf: joint_conf: + enc_output_size: 256 + pred_output_size: 256 join_dim: 512 prejoin_linear: True postjoin_linear: false diff --git a/examples/aishell/rnnt/conf/example_embedding_predictor.yaml b/examples/aishell/rnnt/conf/example_embedding_predictor.yaml index 6d15b2fc0..ce701b57c 100644 --- a/examples/aishell/rnnt/conf/example_embedding_predictor.yaml +++ b/examples/aishell/rnnt/conf/example_embedding_predictor.yaml @@ -17,6 +17,8 @@ encoder_conf: joint_conf: + enc_output_size: 256 + pred_output_size: 320 join_dim: 320 prejoin_linear: true postjoin_linear: false @@ -26,6 +28,7 @@ joint_conf: predictor: embedding predictor_conf: embed_size: 320 + output_size: 320 embed_dropout: 0.1 n_head: 4 history_size: 5 diff --git a/wenet/transducer/predictor.py b/wenet/transducer/predictor.py index 9f4b2aa23..6949aa0cf 100644 --- a/wenet/transducer/predictor.py +++ b/wenet/transducer/predictor.py @@ -218,6 +218,7 @@ class EmbeddingPredictor(PredictorBase): def __init__(self, voca_size: int, embed_size: int, + output_size: int, embed_dropout: float, n_head: int, history_size: int = 2, @@ -226,6 +227,7 @@ def __init__(self, layer_norm_epsilon: float = 1e-5) -> None: super().__init__() + assert output_size == embed_size # multi head self.num_heads = n_head self.embed_size = embed_size @@ -379,6 +381,7 @@ class ConvPredictor(PredictorBase): def __init__(self, voca_size: int, embed_size: int, + output_size: int, embed_dropout: float, history_size: int = 2, activation: str = "relu", @@ -386,6 +389,7 @@ def __init__(self, layer_norm_epsilon: float = 1e-5) -> None: super().__init__() + assert embed_size == output_size assert history_size >= 0 self.embed_size = embed_size self.context_size = history_size + 1 diff --git a/wenet/utils/init_model.py b/wenet/utils/init_model.py index adbead528..dc413eb95 100644 --- a/wenet/utils/init_model.py +++ b/wenet/utils/init_model.py @@ -112,19 +112,11 @@ def init_model(args, configs): elif predictor_type == 'embedding': predictor = EmbeddingPredictor(vocab_size, **configs['predictor_conf']) - configs['predictor_conf']['output_size'] = configs[ - 'predictor_conf']['embed_size'] elif predictor_type == 'conv': predictor = ConvPredictor(vocab_size, **configs['predictor_conf']) - configs['predictor_conf']['output_size'] = configs[ - 'predictor_conf']['embed_size'] else: raise NotImplementedError( "only rnn, embedding and conv type support now") - configs['joint_conf']['enc_output_size'] = configs['encoder_conf'][ - 'output_size'] - configs['joint_conf']['pred_output_size'] = configs['predictor_conf'][ - 'output_size'] joint = TransducerJoint(vocab_size, **configs['joint_conf']) model = Transducer(vocab_size=vocab_size, blank=0,