-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathembedding.py
127 lines (114 loc) · 4.62 KB
/
embedding.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import operator
import string
import nltk
import numpy as np
import collections
import multiprocessing
from utils import docstream, countwords
class EmbeddingModel(object):
"""
Base class for models that define word embeddings.
"""
stopwords = nltk.corpus.stopwords.words('english')
tokenizer = nltk.load('tokenizers/punkt/english.pickle')
min_len = 4 # minimum sentence length
@staticmethod
def softmax(z):
return np.exp(z) / np.sum(np.exp(z), axis=0)
@staticmethod
def sigmoid(z):
return 1.0/(1+np.exp(-z))
def get_sents(self, doc):
sen_list = self.tokenizer.tokenize(doc)
sen_list = [s.replace('\n', ' ') for s in sen_list]
sen_list = [s.translate(None, string.punctuation) for s in sen_list]
sen_list = [s.translate(None, '1234567890') for s in sen_list]
sen_list = [s for s in sen_list if len(s.split()) >= self.min_len]
sen_list = [[w.lower() for w in s.split()] for s in sen_list]
sen_list = [[w for w in s if w in self.vocab] for s in sen_list]
return sen_list
def get_context(self, pos, sen):
context = []
for i in range(self.win_size):
if pos+i+1 < len(sen):
context.append(sen[pos+i+1])
if pos-i-1 >= 0:
context.append(sen[pos-i-1])
return list(set(context))
def get_binvec(self, context):
binvec = np.zeros(len(self.vocab))
for word in context:
binvec += self.get_onehot(word)
return binvec
def get_onehot(self, word):
index = self.word_indices[word]
onehot = np.zeros(len(self.vocab))
onehot[index] = 1
return onehot
def get_synonyms(self, word):
probe = self.word_vecs[self.word_indices[word], :]
self.rank_words(np.dot(self.word_vecs, probe))
def indices_to_words(self, v):
indices = np.where(v!=0)[0]
words = []
for key, val in self.word_indices.iteritems():
if val in inds:
words.append(key)
return words
def rank_words(self, scores):
rank = zip(range(len(self.vocab)), scores)
rank = sorted(rank, key=operator.itemgetter(1), reverse=True)
top_words = [(self.vocab[x[0]],x[1]) for x in rank[:10]]
print ''
for word in top_words[:5]:
print word[0], word[1]
def data(self, size, model):
counter = 0
for doclist in docstream():
for doc in doclist:
if counter >= size:
raise StopIteration()
counter += 1
sen_list = self.get_sents(doc)
for sen in sen_list:
sen = [w for w in sen if w not in self.stopwords]
if len(sen) < 4:
continue
xs = np.zeros((len(self.vocab),len(sen)))
ys = np.zeros((len(self.vocab),len(sen)))
if model == 'cbow':
for _ in range(len(sen)):
context = self.get_context(_,sen)
xs[:,_] = self.get_binvec(context)
ys[:,_] = self.get_onehot(sen[_])
yield xs, ys
elif model == 'skipgram':
for _ in range(len(sen)):
context = self.get_context(_,sen)
xs[:,_] = self.get_onehot(sen[_])
ys[:,_] = self.get_binvec(context)
yield xs, ys
def ns_data(self, size, model):
counter = 0
for doclist in docstream():
for doc in doclist:
if counter >= size:
raise StopIteration()
counter += 1
sen_list = self.get_sents(doc)
for sen in sen_list:
sen = [w for w in sen if w not in self.stopwords]
if len(sen) < 4:
continue
if model == 'cbow':
for _ in range(len(sen)):
context = self.get_context(_,sen)
x = self.get_binvec(context)
y = self.get_onehot(sen[_])
yield x, y
elif model == 'skipgram':
for _ in range(len(sen)):
context = self.get_context(_,sen)
x = self.get_onehot(sen[_])
y = self.get_binvec(context)
yield x, y