Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(yaml): Config ctc/cmvn/tokenizer in train.yaml #2205

Merged
merged 26 commits into from
Dec 13, 2023
Merged

Conversation

xingchensong
Copy link
Member

@xingchensong xingchensong commented Dec 8, 2023

  1. 将模型相关的配置从train args移动到yaml,比如ctc、cmvn、tokenizer
  2. 将sos设定为恒定=2,而不是vocab_size - 1,特殊化处理,此时不同字典可以share相同的特殊token

image

TODO (current PR)

  • 验证可以训练
  • 验证可以解码

TODO (next PR)

@Mddct
Copy link
Collaborator

Mddct commented Dec 8, 2023

great work

@Mddct Mddct self-requested a review December 8, 2023 12:42
Mddct
Mddct previously approved these changes Dec 8, 2023
@Mddct
Copy link
Collaborator

Mddct commented Dec 8, 2023

TODO

@Mddct Mddct requested a review from robin1001 December 8, 2023 13:04
Copy link
Collaborator

@robin1001 robin1001 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

周哥,这个改动比较大,功能也很杂,强烈建议分开提 PR,每个 PR 只做一个功能,来自 Google 软件实践指南中的建议。

@xingchensong
Copy link
Member Author

记一个TODO,后面runtime也要针对special id做对应的修改,比如blanckid不是0,sos不是vocabsize-1等

@xingchensong
Copy link
Member Author

模型转换脚本(sos从vocabsize-1变成2)

import torch                                                                                                                                                                                                                  [3/528]

old_state = torch.load('/mnt/d/BaiduSyncdisk/downloads/ckpt/20210601_u2++_conformer_exp_aishell/final.pt')
new_state = {}
change_list = ['decoder.left_decoder.output_layer.weight',
               'decoder.left_decoder.output_layer.bias',
               'decoder.left_decoder.embed.0.weight',
               'decoder.right_decoder.output_layer.weight',
               'decoder.right_decoder.output_layer.bias',
               'decoder.right_decoder.embed.0.weight',
               'ctc.ctc_lo.weight',
               'ctc.ctc_lo.bias']
for key in old_state.keys():
    if key in change_list:
        print("processing {}, {}".format(key, old_state[key].size()))
        tensor = old_state[key]
        new_tensor = torch.zeros_like(tensor)
        if len(tensor.size()) == 2:  # weight
            new_tensor[:2, :] = tensor[:2, :]
            new_tensor[2, :] = tensor[-1, :]
            new_tensor[3:, :] = tensor[2:-1, :]
        elif len(tensor.size()) == 1:  # bias
            new_tensor[:2] = tensor[:2]
            new_tensor[2] = tensor[-1]
            new_tensor[3:] = tensor[2:-1]
        else:
            raise NotImplementedError
        new_state[key] = new_tensor
    elif "concat_linear" in key:
        continue
    else:
        new_state[key] = old_state[key]

torch.save(new_state, "/mnt/d/BaiduSyncdisk/downloads/ckpt/20210601_u2++_conformer_exp_aishell/final.sos2.pt")

转换后可以成功解码
image

@xingchensong
Copy link
Member Author

解码结果一致

image

image

Mddct
Mddct previously approved these changes Dec 12, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants