-
Notifications
You must be signed in to change notification settings - Fork 788
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
c4e2a39
commit 385c0e7
Showing
7 changed files
with
2,929 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
## 简介 | ||
|
||
### 任务说明 | ||
机器翻译的输入一般是源语言的句子。但在很多实际系统中,比如语音识别系统的输出或者基于拼音的文字输入,源语言句子一般包含很多同音字错误, 这会导致翻译出现很多意想不到的错误。由于可以同时获得发音信息,我们提出了一种在输入端加入发音信息,进而在模型的嵌入层 | ||
融合文字信息和发音信息的翻译方法,大大提高了翻译模型对同音字错误的抵抗能力。 | ||
|
||
文章地址:https://arxiv.org/abs/1810.06729 | ||
|
||
### 效果说明 | ||
|
||
我们使用LDC Chinese-to-English数据集训练。中文词典用的是[DaCiDian](https://github.com/aishell-foundation/DaCiDian)。 在newstest2006上进行评测,效果如下所示: | ||
|
||
| beta=0 | beta=0.50 | beta=0.85 | beta=0.95 | | ||
|-|-|-|-| | ||
| 47.96 | 48.71 | 48.85 | 48.46 | | ||
|
||
beta代表发音信息的权重。这表明,即使将绝大部分权重放在发音信息上,翻译的效果依然很好。与此同时,翻译系统对同音字错误的抵抗力大大提高。 | ||
|
||
|
||
## 安装说明 | ||
|
||
1. paddle安装 | ||
|
||
本项目依赖于 PaddlePaddle Fluid 1.3.1 及以上版本,请参考 [安装指南](http://www.paddlepaddle.org/#quick-start) 进行安装 | ||
|
||
2. 环境依赖 | ||
|
||
请参考PaddlePaddle[安装说明](http://paddlepaddle.org/documentation/docs/zh/1.3/beginners_guide/install/index_cn.html)部分的内容 | ||
|
||
|
||
|
||
## 如何训练 | ||
|
||
1. 数据格式 | ||
|
||
数据格式和[Paddle机器翻译](https://github.com/PaddlePaddle/models/tree/develop/PaddleNLP/neural_machine_translation/transformer)的格式一致。为了获得输入句子的发音信息,需要额外提供源语言的发音基本单元和发音的词典。 | ||
|
||
A) 发音基本单元文件 | ||
|
||
中文的发音基本单元是拼音,将所有的拼音放在一个文件,类似: | ||
|
||
<unk> | ||
|
||
bo | ||
|
||
li | ||
|
||
。。。 | ||
|
||
B)发音词典 | ||
|
||
根据DaCiDian,对bpe后的源语言中的token赋予一个或者几个发音,类似: | ||
|
||
▁玻利维亚 bo li wei ya | ||
|
||
▁举行 ju xing | ||
|
||
▁总统 zong tong | ||
|
||
▁与 yu | ||
|
||
巴斯 ba si | ||
|
||
▁这个 zhei ge|zhe ge | ||
|
||
。。。 | ||
|
||
2. 训练模型 | ||
|
||
数据准备完成后,可以使用 `train.py` 脚本进行训练。例子如下: | ||
|
||
```sh | ||
python train.py \ | ||
--src_vocab_fpath nist_data/vocab_all.28000 \ | ||
--trg_vocab_fpath nist_data/vocab_all.28000 \ | ||
--train_file_pattern nist_data/nist_train.txt \ | ||
--phoneme_vocab_fpath nist_data/zh_pinyins.txt \ | ||
--lexicon_fpath nist_data/zh_lexicon.txt \ | ||
--batch_size 2048 \ | ||
--use_token_batch True \ | ||
--sort_type pool \ | ||
--pool_size 200000 \ | ||
--use_py_reader False \ | ||
--enable_ce False \ | ||
--fetch_steps 1 \ | ||
pass_num 100 \ | ||
learning_rate 2.0 \ | ||
warmup_steps 8000 \ | ||
beta2 0.997 \ | ||
d_model 512 \ | ||
d_inner_hid 2048 \ | ||
n_head 8 \ | ||
weight_sharing True \ | ||
max_length 256 \ | ||
save_freq 10000 \ | ||
beta 0.85 \ | ||
model_dir pinyin_models_beta085 \ | ||
ckpt_dir pinyin_ckpts_beta085 | ||
``` | ||
|
||
上述命令中设置了源语言词典文件路径(`src_vocab_fpath`)、目标语言词典文件路径(`trg_vocab_fpath`)、训练数据文件(`train_file_pattern`,支持通配符), 发音单元文件路径(`phoneme_vocab_fpath`), 发音词典路径(`lexicon_fpath`)等数据相关的参数和构造 batch 方式(`use_token_batch` 指定了数据按照 token 数目或者 sequence 数目组成 batch)等 reader 相关的参数。有关这些参数更详细的信息可以通过执行以下命令查看: | ||
|
||
```sh | ||
python train.py --help | ||
``` | ||
|
||
更多模型训练相关的参数则在 `config.py` 中的 `ModelHyperParams` 和 `TrainTaskConfig` 内定义;`ModelHyperParams` 定义了 embedding 维度等模型超参数,`TrainTaskConfig` 定义了 warmup 步数等训练需要的参数。这些参数默认使用了 Transformer 论文中 base model 的配置,如需调整可以在该脚本中进行修改。另外这些参数同样可在执行训练脚本的命令行中设置,传入的配置会合并并覆盖 `config.py` 中的配置. | ||
|
||
注意,如训练时更改了模型配置,使用 `infer.py` 预测时需要使用对应相同的模型配置;另外,训练时默认使用所有 GPU,可以通过 `CUDA_VISIBLE_DEVICES` 环境变量来设置使用指定的 GPU。 | ||
|
||
## 如何预测 | ||
|
||
使用以上提供的数据和模型,可以按照以下代码进行预测,翻译结果将打印到标准输出: | ||
|
||
```sh | ||
python infer.py \ | ||
--src_vocab_fpath nist_data/vocab_all.28000 \ | ||
--trg_vocab_fpath nist_data/vocab_all.28000 \ | ||
--test_file_pattern nist_data/nist_test.txt \ | ||
--phoneme_vocab_fpath nist_data/zh_pinyins.txt \ | ||
--lexicon_fpath nist_data/zh_lexicon.txt \ | ||
--batch_size 32 \ | ||
model_path pinyin_models_beta085/iter_200000.infer.model \ | ||
beam_size 5 \ | ||
max_out_len 255 \ | ||
beta 0.85 | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
class TrainTaskConfig(object): | ||
# support both CPU and GPU now. | ||
use_gpu = True | ||
# the epoch number to train. | ||
pass_num = 30 | ||
# the number of sequences contained in a mini-batch. | ||
# deprecated, set batch_size in args. | ||
batch_size = 32 | ||
# the hyper parameters for Adam optimizer. | ||
# This static learning_rate will be multiplied to the LearningRateScheduler | ||
# derived learning rate the to get the final learning rate. | ||
learning_rate = 2.0 | ||
beta1 = 0.9 | ||
beta2 = 0.997 | ||
eps = 1e-9 | ||
# the parameters for learning rate scheduling. | ||
warmup_steps = 8000 | ||
# the weight used to mix up the ground-truth distribution and the fixed | ||
# uniform distribution in label smoothing when training. | ||
# Set this as zero if label smoothing is not wanted. | ||
label_smooth_eps = 0.1 | ||
# the directory for saving trained models. | ||
model_dir = "trained_models" | ||
# the directory for saving checkpoints. | ||
ckpt_dir = "trained_ckpts" | ||
# the directory for loading checkpoint. | ||
# If provided, continue training from the checkpoint. | ||
ckpt_path = None | ||
# the parameter to initialize the learning rate scheduler. | ||
# It should be provided if use checkpoints, since the checkpoint doesn't | ||
# include the training step counter currently. | ||
start_step = 0 | ||
# the frequency to save trained models. | ||
save_freq = 10000 | ||
|
||
|
||
class InferTaskConfig(object): | ||
use_gpu = True | ||
# the number of examples in one run for sequence generation. | ||
batch_size = 10 | ||
# the parameters for beam search. | ||
beam_size = 5 | ||
max_out_len = 256 | ||
# the number of decoded sentences to output. | ||
n_best = 1 | ||
# the flags indicating whether to output the special tokens. | ||
output_bos = False | ||
output_eos = False | ||
output_unk = True | ||
# the directory for loading the trained model. | ||
model_path = "trained_models/pass_1.infer.model" | ||
|
||
|
||
class ModelHyperParams(object): | ||
# These following five vocabularies related configurations will be set | ||
# automatically according to the passed vocabulary path and special tokens. | ||
# size of source word dictionary. | ||
src_vocab_size = 10000 | ||
# size of target word dictionay | ||
trg_vocab_size = 10000 | ||
# size of phone dictionary | ||
phone_vocab_size = 1000 | ||
# ratio of phoneme embeddings | ||
beta = 0.0 | ||
# index for <bos> token | ||
bos_idx = 0 | ||
# index for <eos> token | ||
eos_idx = 1 | ||
# index for <unk> token | ||
unk_idx = 2 | ||
# index for <unk> in phonemes | ||
phone_pad_idx = 0 | ||
# max length of sequences deciding the size of position encoding table. | ||
max_length = 256 | ||
# the dimension for word embeddings, which is also the last dimension of | ||
# the input and output of multi-head attention, position-wise feed-forward | ||
# networks, encoder and decoder. | ||
d_model = 512 | ||
# size of the hidden layer in position-wise feed-forward networks. | ||
d_inner_hid = 2048 | ||
# the dimension that keys are projected to for dot-product attention. | ||
d_key = 64 | ||
# the dimension that values are projected to for dot-product attention. | ||
d_value = 64 | ||
# number of head used in multi-head attention. | ||
n_head = 8 | ||
# number of sub-layers to be stacked in the encoder and decoder. | ||
n_layer = 6 | ||
# dropout rates of different modules. | ||
prepostprocess_dropout = 0.1 | ||
attention_dropout = 0.1 | ||
relu_dropout = 0.1 | ||
# to process before each sub-layer | ||
preprocess_cmd = "n" # layer normalization | ||
# to process after each sub-layer | ||
postprocess_cmd = "da" # dropout + residual connection | ||
# random seed used in dropout for CE. | ||
dropout_seed = None | ||
# the flag indicating whether to share embedding and softmax weights. | ||
# vocabularies in source and target should be same for weight sharing. | ||
weight_sharing = True | ||
|
||
|
||
def merge_cfg_from_list(cfg_list, g_cfgs): | ||
""" | ||
Set the above global configurations using the cfg_list. | ||
""" | ||
assert len(cfg_list) % 2 == 0 | ||
for key, value in zip(cfg_list[0::2], cfg_list[1::2]): | ||
for g_cfg in g_cfgs: | ||
if hasattr(g_cfg, key): | ||
try: | ||
value = eval(value) | ||
except Exception: # for file path | ||
pass | ||
setattr(g_cfg, key, value) | ||
break |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# The placeholder for batch_size in compile time. Must be -1 currently to be | ||
# consistent with some ops' infer-shape output in compile time, such as the | ||
# sequence_expand op used in beamsearch decoder. | ||
batch_size = -1 | ||
# The placeholder for squence length in compile time. | ||
seq_len = 256 | ||
# The placeholder for phoneme sequence length in comiple time. | ||
phone_len = 16 | ||
# The placeholder for head number in compile time. | ||
n_head = 8 | ||
# The placeholder for model dim in compile time. | ||
d_model = 512 | ||
# Here list the data shapes and data types of all inputs. | ||
# The shapes here act as placeholder and are set to pass the infer-shape in | ||
# compile time. | ||
input_descs = { | ||
# The actual data shape of src_word is: | ||
# [batch_size, max_src_len_in_batch, 1] | ||
"src_word": [(batch_size, seq_len, 1), "int64", 2], | ||
# The actual data shape of src_pos is: | ||
# [batch_size, max_src_len_in_batch, 1] | ||
"src_pos": [(batch_size, seq_len, 1), "int64"], | ||
# This input is used to remove attention weights on paddings in the | ||
# encoder. | ||
# The actual data shape of src_slf_attn_bias is: | ||
# [batch_size, n_head, max_src_len_in_batch, max_src_len_in_batch] | ||
"src_slf_attn_bias": [(batch_size, n_head, seq_len, seq_len), "float32"], | ||
"src_phone": [(batch_size, seq_len, phone_len, 1), "int64"], | ||
"src_phone_mask": [(batch_size, seq_len, phone_len), "int64"], | ||
# The actual data shape of trg_word is: | ||
# [batch_size, max_trg_len_in_batch, 1] | ||
"trg_word": [(batch_size, seq_len, 1), "int64", | ||
2], # lod_level is only used in fast decoder. | ||
# The actual data shape of trg_pos is: | ||
# [batch_size, max_trg_len_in_batch, 1] | ||
"trg_pos": [(batch_size, seq_len, 1), "int64"], | ||
# This input is used to remove attention weights on paddings and | ||
# subsequent words in the decoder. | ||
# The actual data shape of trg_slf_attn_bias is: | ||
# [batch_size, n_head, max_trg_len_in_batch, max_trg_len_in_batch] | ||
"trg_slf_attn_bias": [(batch_size, n_head, seq_len, seq_len), "float32"], | ||
# This input is used to remove attention weights on paddings of the source | ||
# input in the encoder-decoder attention. | ||
# The actual data shape of trg_src_attn_bias is: | ||
# [batch_size, n_head, max_trg_len_in_batch, max_src_len_in_batch] | ||
"trg_src_attn_bias": [(batch_size, n_head, seq_len, seq_len), "float32"], | ||
# This input is used in independent decoder program for inference. | ||
# The actual data shape of enc_output is: | ||
# [batch_size, max_src_len_in_batch, d_model] | ||
"enc_output": [(batch_size, seq_len, d_model), "float32"], | ||
# The actual data shape of label_word is: | ||
# [batch_size * max_trg_len_in_batch, 1] | ||
"lbl_word": [(batch_size * seq_len, 1), "int64"], | ||
# This input is used to mask out the loss of paddding tokens. | ||
# The actual data shape of label_weight is: | ||
# [batch_size * max_trg_len_in_batch, 1] | ||
"lbl_weight": [(batch_size * seq_len, 1), "float32"], | ||
# This input is used in beam-search decoder. | ||
"init_score": [(batch_size, 1), "float32", 2], | ||
# This input is used in beam-search decoder for the first gather | ||
# (cell states updation) | ||
"init_idx": [(batch_size, ), "int32"], | ||
} | ||
|
||
# Names of word embedding table which might be reused for weight sharing. | ||
word_emb_param_names = ( | ||
"src_word_emb_table", | ||
"trg_word_emb_table", ) | ||
|
||
phone_emb_param_name = "phone_emb_table" | ||
|
||
# Names of position encoding table which will be initialized externally. | ||
pos_enc_param_names = ( | ||
"src_pos_enc_table", | ||
"trg_pos_enc_table", ) | ||
# separated inputs for different usages. | ||
encoder_data_input_fields = ( | ||
"src_word", | ||
"src_pos", | ||
"src_slf_attn_bias", | ||
"src_phone", | ||
"src_phone_mask", ) | ||
decoder_data_input_fields = ( | ||
"trg_word", | ||
"trg_pos", | ||
"trg_slf_attn_bias", | ||
"trg_src_attn_bias", | ||
"enc_output", ) | ||
label_data_input_fields = ( | ||
"lbl_word", | ||
"lbl_weight", ) | ||
# In fast decoder, trg_pos (only containing the current time step) is generated | ||
# by ops and trg_slf_attn_bias is not needed. | ||
fast_decoder_data_input_fields = ( | ||
"trg_word", | ||
"init_score", | ||
"init_idx", | ||
"trg_src_attn_bias", ) | ||
|
||
# Set seed for CE | ||
dropout_seed = None |
Oops, something went wrong.