This code is based on Fairseq v0.10.2
- PyTorch version >= 1.9.0
- python version >= 3.6
In order to get the Inter-group relation, you should first get a raw data file. And install stanfordnlpCoreNLP
software according to the steps of https://github.com/stanfordnlp/CoreNLP
Then run the following script:
bash preprocess_group.sh
python3 -u train.py data-bin/$data_dir
--distributed-world-size 8 -s src -t tgt
--task dp_tree_group_phrase_translation
--arch phrase_transformer_t2t_wmt_en_de
--optimizer adam --clip-norm 0.0
--adam-betas '(0.9, 0.997)'
--lr-scheduler inverse_sqrt --warmup-init-lr 1e-07 --warmup-updates 16000
--lr 0.002 --min-lr 1e-09
--criterion label_smoothed_cross_entropy --label-smoothing 0.1
--max-tokens 4096
--update-freq 2
--max-epoch 30
--attention-dropout 0.1 -- relu-dropout 0.1
--no-progress-bar
--log-interval 100
--ddp-backend no_c10d
--seed 1
--fp16
--save-dir $save_dir
--keep-last-epochs 10
python3 -u train.py data-bin/$data_dir
--distributed-world-size 8 -s src -t tgt
--task dp_tree_group_phrase_translation
--arch phrase_transformer_t2t_wmt_en_de
--share-all-embeddings
--optimizer adam --clip-norm 0.0
--adam-betas '(0.9, 0.997)'
--lr-scheduler inverse_sqrt --warmup-init-lr 1e-07 --warmup-updates 8000
--lr 0.002 --min-lr 1e-09
--weight-decay 0.0001
--criterion label_smoothed_cross_entropy --label-smoothing 0.1
--max-tokens 4096
--update-freq 4
--max-epoch 30
--dropout 0.1 --attention-dropout 0.1 -- relu-dropout 0.1
--truncate-source --skip-invalid-size-inputs-valid-test --max-source-positions 500
--no-progress-bar
--log-interval 100
--ddp-backend no_c10d
--seed 1
--fp16
--save-dir $save_dir
--keep-last-epochs 10
python3 generate.py \
data-bin/wmt-en2de \
--task dp_tree_group_phrase_translation
--path $model_dir/$checkpoint \
--gen-subset test \
--batch-size 64 \
--beam 4 \
--lenpen 0.6 \
--output hypo.txt \
--quiet \
--remove-bpe
We use pyrouge as the scoring script.
python3 generate.py \
data-bin/$data_dir \
--path $model_dir/$checkpoint \
--gen-subset test \
--truncate-source \
--batch-size 32 \
--lenpen 2.0 \
--min-len 55 \
--max-len-b 140 \
--max-source-positions 500 \
--beam 4 \
--no-repeat-ngram-size 3 \
--remove-bpe
python3 get_rouge.py --decodes_filename $model_dir/hypo.sorted.tok --targets_filename cnndm.test.target.tok