Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[refactor] use model_class to unify module init #2216

Merged
merged 9 commits into from
Dec 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion examples/aishell/rnnt/conf/conformer_rnnt.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ encoder_conf:
pos_enc_layer_type: 'rel_pos'
selfattention_layer_type: 'rel_selfattn'


joint: transducer_joint
joint_conf:
enc_output_size: 256
pred_output_size: 256
Expand Down Expand Up @@ -50,6 +50,7 @@ decoder_conf:
src_attention_dropout_rate: 0.1

# hybrid transducer+ctc+attention
model: transducer
model_conf:
transducer_weight: 0.75
ctc_weight: 0.1
Expand Down
3 changes: 2 additions & 1 deletion examples/aishell/rnnt/conf/conformer_u2pp_rnnt.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ encoder_conf:
cnn_module_norm: 'layer_norm' # using nn.LayerNorm makes model converge faster
use_dynamic_left_chunk: false


joint: transducer_joint
joint_conf:
enc_output_size: 256
pred_output_size: 256
Expand Down Expand Up @@ -54,6 +54,7 @@ decoder_conf:
src_attention_dropout_rate: 0.1

# hybrid transducer+ctc+attention
model: transducer
model_conf:
transducer_weight: 0.75
ctc_weight: 0.1
Expand Down
3 changes: 2 additions & 1 deletion examples/aishell/rnnt/conf/example_embedding_predictor.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ encoder_conf:
pos_enc_layer_type: 'rel_pos'
selfattention_layer_type: 'rel_selfattn'


joint: transducer_joint
joint_conf:
enc_output_size: 256
pred_output_size: 320
Expand Down Expand Up @@ -46,6 +46,7 @@ decoder_conf:
src_attention_dropout_rate: 0.1

# hybrid transducer+ctc+attention
model: transducer
model_conf:
transducer_weight: 0.4
ctc_weight: 0.2
Expand Down
4 changes: 2 additions & 2 deletions examples/aishell/s0/conf/train_unified_conformer_ctl.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# network architecture
# encoder related
encoder: conformer
encoder: dual_conformer
encoder_conf:
output_size: 256 # dimension of attention
attention_heads: 4
Expand Down Expand Up @@ -32,8 +32,8 @@ decoder_conf:
self_attention_dropout_rate: 0.0
src_attention_dropout_rate: 0.0

ctlmodel: true
# hybrid CTC/attention
model: ctl_model
model_conf:
ctc_weight: 0.3
lsm_weight: 0.1 # label smoothing option
Expand Down
6 changes: 4 additions & 2 deletions examples/aishell/s0/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -304,8 +304,10 @@ if [ ${stage} -le 9 ] && [ ${stop_stage} -ge 9 ]; then
# 9.1 Build token level bigram fst for LF-MMI training
tools/k2/prepare_mmi.sh data/train/ data/dev data/local/lfmmi

# 9.2 Run LF-MMI training from stage 4, with below new args
# --lfmmi_dir data/local/lfmmi
# 9.2 Run LF-MMI training from stage 4, modify below args in train.yaml
# model: k2_model
# model_conf:
# lfmmi_dir data/local/lfmmi

# 9.3 Run HLG decode from stage 8.2
fi
5 changes: 4 additions & 1 deletion examples/aishell2/rnnt/conf/conformer_rnnt.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@ encoder_conf:
pos_enc_layer_type: 'rel_pos'
selfattention_layer_type: 'rel_selfattn'


joint: transducer_joint
joint_conf:
enc_output_size: 256
pred_output_size: 256
join_dim: 512
prejoin_linear: True
postjoin_linear: false
Expand Down Expand Up @@ -48,6 +50,7 @@ decoder_conf:
src_attention_dropout_rate: 0.1

# hybrid transducer+ctc+attention
model: transducer
model_conf:
transducer_weight: 0.75
ctc_weight: 0.1
Expand Down
5 changes: 4 additions & 1 deletion examples/aishell2/rnnt/conf/conformer_u2pp_rnnt.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@ encoder_conf:
cnn_module_norm: 'layer_norm' # using nn.LayerNorm makes model converge faster
use_dynamic_left_chunk: false


joint: transducer_joint
joint_conf:
enc_output_size: 256
pred_output_size: 256
join_dim: 512
prejoin_linear: True
postjoin_linear: false
Expand Down Expand Up @@ -52,6 +54,7 @@ decoder_conf:
src_attention_dropout_rate: 0.1

# hybrid transducer+ctc+attention
model: transducer
model_conf:
transducer_weight: 0.75
ctc_weight: 0.1
Expand Down
5 changes: 4 additions & 1 deletion examples/librispeech/rnnt/conf/conformer_rnnt.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@ encoder_conf:
pos_enc_layer_type: 'rel_pos'
selfattention_layer_type: 'rel_selfattn'


joint: transducer_joint
joint_conf:
enc_output_size: 256
pred_output_size: 256
join_dim: 512
prejoin_linear: True
postjoin_linear: false
Expand Down Expand Up @@ -48,6 +50,7 @@ decoder_conf:
src_attention_dropout_rate: 0.1

# hybrid transducer+ctc+attention
model: transducer
model_conf:
transducer_weight: 0.75
ctc_weight: 0.1
Expand Down
34 changes: 34 additions & 0 deletions test/wenet/utils/test_init_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright [2023-12-10] <[email protected], Xingchen Song>

import glob
import yaml

from wenet.utils.init_model import init_model


class DummyArguments:
jit = False
enc_init = None
checkpoint = None


def test_init_model():
configs = glob.glob("examples/*/*/conf/*.yaml")
args = DummyArguments()
for c in configs:
with open(c, 'r') as fin:
config = yaml.load(fin, Loader=yaml.FullLoader)
if 'fbank_conf' in config['dataset_conf']:
input_dim = config['dataset_conf']['fbank_conf']['num_mel_bins']
elif 'log_mel_spectrogram_conf' in config['dataset_conf']:
input_dim = config['dataset_conf']['log_mel_spectrogram_conf'][
'num_mel_bins']
else:
input_dim = config['dataset_conf']['mfcc_conf']['num_mel_bins']
config['input_dim'] = input_dim
# TODO(xcsong): fix vocab_size
config['output_dim'] = 3000
print("checking {} {}".format(c, config))
init_model(args, config)
2 changes: 1 addition & 1 deletion test/wenet/whisper/test_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ def test_model(model, audio_path):

# 6. Forward wenet.decoder
wenet_tokens, _ = add_whisper_tokens(
configs['model_conf']['special_tokens'],
configs['tokenizer_conf']['special_tokens'],
torch.tensor([dummy_tokens], dtype=torch.long),
ignore_id=-1,
task=task,
Expand Down
Loading
Loading