Skip to content

Commit

Permalink
[refactor] simplify code and keep API consisstent
Browse files Browse the repository at this point in the history
  • Loading branch information
xingchensong committed Dec 8, 2023
1 parent 7bece3a commit b4f20a0
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 8 deletions.
2 changes: 2 additions & 0 deletions examples/aishell/rnnt/conf/conformer_rnnt.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ encoder_conf:


joint_conf:
enc_output_size: 256
pred_output_size: 256
join_dim: 512
prejoin_linear: True
postjoin_linear: false
Expand Down
2 changes: 2 additions & 0 deletions examples/aishell/rnnt/conf/conformer_u2pp_rnnt.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ encoder_conf:


joint_conf:
enc_output_size: 256
pred_output_size: 256
join_dim: 512
prejoin_linear: True
postjoin_linear: false
Expand Down
3 changes: 3 additions & 0 deletions examples/aishell/rnnt/conf/example_embedding_predictor.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ encoder_conf:


joint_conf:
enc_output_size: 256
pred_output_size: 320
join_dim: 320
prejoin_linear: true
postjoin_linear: false
Expand All @@ -26,6 +28,7 @@ joint_conf:
predictor: embedding
predictor_conf:
embed_size: 320
output_size: 320
embed_dropout: 0.1
n_head: 4
history_size: 5
Expand Down
4 changes: 4 additions & 0 deletions wenet/transducer/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ class EmbeddingPredictor(PredictorBase):
def __init__(self,
voca_size: int,
embed_size: int,
output_size: int,
embed_dropout: float,
n_head: int,
history_size: int = 2,
Expand All @@ -226,6 +227,7 @@ def __init__(self,
layer_norm_epsilon: float = 1e-5) -> None:

super().__init__()
assert output_size == embed_size
# multi head
self.num_heads = n_head
self.embed_size = embed_size
Expand Down Expand Up @@ -379,13 +381,15 @@ class ConvPredictor(PredictorBase):
def __init__(self,
voca_size: int,
embed_size: int,
output_size: int,
embed_dropout: float,
history_size: int = 2,
activation: str = "relu",
bias: bool = False,
layer_norm_epsilon: float = 1e-5) -> None:
super().__init__()

assert embed_size == output_size
assert history_size >= 0
self.embed_size = embed_size
self.context_size = history_size + 1
Expand Down
8 changes: 0 additions & 8 deletions wenet/utils/init_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,19 +112,11 @@ def init_model(args, configs):
elif predictor_type == 'embedding':
predictor = EmbeddingPredictor(vocab_size,
**configs['predictor_conf'])
configs['predictor_conf']['output_size'] = configs[
'predictor_conf']['embed_size']
elif predictor_type == 'conv':
predictor = ConvPredictor(vocab_size, **configs['predictor_conf'])
configs['predictor_conf']['output_size'] = configs[
'predictor_conf']['embed_size']
else:
raise NotImplementedError(
"only rnn, embedding and conv type support now")
configs['joint_conf']['enc_output_size'] = configs['encoder_conf'][
'output_size']
configs['joint_conf']['pred_output_size'] = configs['predictor_conf'][
'output_size']
joint = TransducerJoint(vocab_size, **configs['joint_conf'])
model = Transducer(vocab_size=vocab_size,
blank=0,
Expand Down

0 comments on commit b4f20a0

Please sign in to comment.