-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtranslation_test.py
67 lines (60 loc) · 2.18 KB
/
translation_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
# Beam Search Module
import argparse
import numpy as n
from dataset import *
from modules import *
from tqdm import tqdm
k = 5 # Beam size
if __name__ == "__main__":
argparser = argparse.ArgumentParser("testing translation model")
argparser.add_argument("--gpu", default=-1, help="gpu id")
argparser.add_argument("--N", default=6, type=int, help="num of layers")
argparser.add_argument("--dataset", default="multi30k", help="dataset")
argparser.add_argument("--batch", default=64, help="batch size")
argparser.add_argument(
"--universal", action="store_true", help="use universal transformer"
)
argparser.add_argument(
"--checkpoint", type=int, help="checkpoint: you must specify it"
)
argparser.add_argument(
"--print", action="store_true", help="whether to print translated text"
)
args = argparser.parse_args()
args_filter = ["batch", "gpu", "print"]
exp_setting = "-".join(
"{}".format(v) for k, v in vars(args).items() if k not in args_filter
)
device = "cpu" if args.gpu == -1 else "cuda:{}".format(args.gpu)
dataset = get_dataset(args.dataset)
V = dataset.vocab_size
dim_model = 512
fpred = open("pred.txt", "w")
fref = open("ref.txt", "w")
graph_pool = GraphPool()
model = make_model(V, V, N=args.N, dim_model=dim_model)
with open("checkpoints/{}.pkl".format(exp_setting), "rb") as f:
model.load_state_dict(
th.load(f, map_location=lambda storage, loc: storage)
)
model = model.to(device)
model.eval()
test_iter = dataset(
graph_pool, mode="test", batch_size=args.batch, device=device, k=k
)
for i, g in enumerate(test_iter):
with th.no_grad():
output = model.infer(
g, dataset.MAX_LENGTH, dataset.eos_id, k, alpha=0.6
)
for line in dataset.get_sequence(output):
if args.print:
print(line)
print(line, file=fpred)
for line in dataset.tgt["test"]:
print(line.strip(), file=fref)
fpred.close()
fref.close()
os.system(r"bash scripts/bleu.sh pred.txt ref.txt")
os.remove("pred.txt")
os.remove("ref.txt")