-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
51 lines (39 loc) · 1.21 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
#J Doran 2019
import json
import nltk
import nltk.lm
nltk.download('stopwords')
nltk.download('punkt')
from nltk.lm.preprocessing import padded_everygram_pipeline
from nltk.tokenize import TweetTokenizer
from nltk.lm import MLE
fp = open(input("Location of message.json: "))
json = json.load(fp)
messages = json["messages"]
ps = [p["name"] for p in json["participants"]]
name = input("Which participant (" + " or ".join(ps) + "): ")
n = int(input("n-gram size (probably 2 or 3): "))
maxLength = 20 #int(input("Maximum sentence length: "))
tkn = TweetTokenizer()
sent_detector = nltk.data.load('tokenizers/punkt/english.pickle')
corpus = []
for m in messages:
if m["sender_name"] == name:
corpus += [tkn.tokenize(m["content"])]
train, vocab = padded_everygram_pipeline(n,corpus)
lm = MLE(n)
lm.fit(train, vocabulary_text=vocab)
def clean(wordList):
out = ""
for word in wordList:
if word != '</s>' and word != '<s>':
out += " "
out += word
else:
if word == '</s>':
break
return out
print("Finished training! ('exit' to quit)")
previous = ""
while (previous != "exit"):
previous = input(clean(lm.generate(maxLength, text_seed=['<s>'])))