-
Notifications
You must be signed in to change notification settings - Fork 8
/
translate.py
executable file
·128 lines (104 loc) · 4.32 KB
/
translate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
from __future__ import division
from builtins import bytes
import os
import argparse
import math
import codecs
import torch
import onmt
import onmt.IO
import opts
from itertools import takewhile, count
try:
from itertools import zip_longest
except ImportError:
from itertools import izip_longest as zip_longest
parser = argparse.ArgumentParser(description='translate.py')
opts.add_md_help_argument(parser)
opts.translate_opts(parser)
opt = parser.parse_args()
if opt.batch_size != 1:
print("WARNING: -batch_size isn't supported currently, "
"we set it to 1 for now!")
opt.batch_size = 1
def report_score(name, score_total, words_total):
print("%s AVG SCORE: %.4f, %s PPL: %.4f" % (
name, score_total / words_total,
name, math.exp(-score_total/words_total)))
def get_src_words(src_indices, index2str):
words = []
raw_words = (index2str[i] for i in src_indices)
words = takewhile(lambda w: w != onmt.IO.PAD_WORD, raw_words)
return " ".join(words)
def main():
dummy_parser = argparse.ArgumentParser(description='train.py')
opts.model_opts(dummy_parser)
dummy_opt = dummy_parser.parse_known_args([])[0]
opt.cuda = opt.gpu > -1
if opt.cuda:
torch.cuda.set_device(opt.gpu)
print(opt, dummy_opt)
translator = onmt.Translator(opt, dummy_opt.__dict__)
out_file = codecs.open(opt.output, 'w', 'utf-8')
pred_score_total, pred_words_total = 0, 0
gold_score_total, gold_words_total = 0, 0
if opt.dump_beam != "":
import json
translator.initBeamAccum()
data = onmt.IO.ONMTDataset(opt.src, opt.tgt, translator.fields, None)
test_data = onmt.IO.OrderedIterator(
dataset=data, device=opt.gpu,
batch_size=opt.batch_size, train=False, sort=False,
shuffle=False)
counter = count(1)
for batch in test_data:
pred_batch, gold_batch, pred_scores, gold_scores, attn, src \
= translator.translate(batch, data)
pred_score_total += sum(score[0] for score in pred_scores)
pred_words_total += sum(len(x[0]) for x in pred_batch)
if opt.tgt:
gold_score_total += sum(gold_scores)
gold_words_total += sum(len(x) for x in batch.tgt[1:])
# z_batch: an iterator over the predictions, their scores,
# the gold sentence, its score, and the source sentence for each
# sentence in the batch. It has to be zip_longest instead of
# plain-old zip because the gold_batch has length 0 if the target
# is not included.
z_batch = zip_longest(
pred_batch, gold_batch,
pred_scores, gold_scores,
(sent.squeeze(1) for sent in src.split(1, dim=1)))
for pred_sents, gold_sent, pred_score, gold_score, src_sent in z_batch:
n_best_preds = [" ".join(pred) for pred in pred_sents[:opt.n_best]]
out_file.write('\n'.join(n_best_preds))
out_file.write('\n')
out_file.flush()
if opt.verbose:
sent_number = next(counter)
words = get_src_words(
src_sent, translator.fields["src"].vocab.itos)
os.write(1, bytes('\nSENT %d: %s\n' %
(sent_number, words), 'UTF-8'))
best_pred = n_best_preds[0]
best_score = pred_score[0]
os.write(1, bytes('PRED %d: %s\n' %
(sent_number, best_pred), 'UTF-8'))
print("PRED SCORE: %.4f" % best_score)
if opt.tgt:
tgt_sent = ' '.join(gold_sent)
os.write(1, bytes('GOLD %d: %s\n' %
(sent_number, tgt_sent), 'UTF-8'))
print("GOLD SCORE: %.4f" % gold_score)
if len(n_best_preds) > 1:
print('\nBEST HYP:')
for score, sent in zip(pred_score, n_best_preds):
os.write(1, bytes("[%.4f] %s\n" % (score, sent),
'UTF-8'))
report_score('PRED', pred_score_total, pred_words_total)
if opt.tgt:
report_score('GOLD', gold_score_total, gold_words_total)
if opt.dump_beam:
json.dump(translator.beam_accum,
codecs.open(opt.dump_beam, 'w', 'utf-8'))
if __name__ == "__main__":
main()