-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgenerate.py
48 lines (36 loc) · 1.03 KB
/
generate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import fitlog
fitlog.debug()
from config import C
from model.generate import generate as model_generate
import torch as tc
from dataloader import chinese_tokenizer
import pdb
import pickle
with open(C.model_save , "rb") as fil:
model = pickle.load(fil)
#model = tc.load(C.model_save)
if isinstance(model , tc.nn.DataParallel):
model = model.module
model = model.eval().cuda(C.gpus[0])
def generate_from_sents(model , sent , return_index = False):
sent = sent.replace(" " , "").replace("\t" , "").replace("\n" , "")
sent = chinese_tokenizer(sent)
sent = [model.vocab.to_index(w) for w in sent]
x = tc.LongTensor(sent).cuda(C.gpus[0])
y = model_generate(model , x)
if return_index:
return y
y = [model.vocab.to_word(w) for w in list(y)]
return "".join(y)
if __name__ == "__main__":
if C.gene_input:
print (generate_from_sents(model , C.gene_input))
else:
while True:
sent = input(">>")
if sent == "q":
break
#sent = """晚霞,山川,云"""
print ()
print (generate_from_sents(model , sent))
print ()