This repository has been archived by the owner on Jul 21, 2020. It is now read-only.
forked from yandexdataschool/Practical_RL
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbasic_model_tf.py
240 lines (196 loc) · 9.43 KB
/
basic_model_tf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
import tensorflow as tf
import keras.layers as L
# This code implements a single-GRU seq2seq model. You will have to improve it later in the assignment.
# Note 1: when using several recurrent layers TF can mixed up the weights of different recurrent layers.
# In that case, make sure you both create AND use each rnn/gru/lstm/custom layer in a unique variable scope
# e.g. with tf.variable_scope("first_lstm"): new_cell, new_out = self.lstm_1(...)
# with tf.variable_scope("second_lstm"): new_cell2, new_out2 = self.lstm_2(...)
# Note 2: everything you need for decoding should be stored in model state (output list of both encode and decode)
# e.g. for attention, you should store all encoder sequence and input mask
# there in addition to lstm/gru states.
class BasicTranslationModel:
def __init__(self, name, inp_voc, out_voc,
emb_size, hid_size,):
self.name = name
self.inp_voc = inp_voc
self.out_voc = out_voc
with tf.variable_scope(name):
self.emb_inp = L.Embedding(len(inp_voc), emb_size)
self.emb_out = L.Embedding(len(out_voc), emb_size)
self.enc0 = tf.nn.rnn_cell.GRUCell(hid_size)
self.dec_start = L.Dense(hid_size)
self.dec0 = tf.nn.rnn_cell.GRUCell(hid_size)
self.logits = L.Dense(len(out_voc))
# run on dummy output to .build all layers (and therefore create
# weights)
inp = tf.placeholder('int32', [None, None])
out = tf.placeholder('int32', [None, None])
h0 = self.encode(inp)
h1 = self.decode(h0, out[:, 0])
# h2 = self.decode(h1,out[:,1]) etc.
self.weights = tf.get_collection(
tf.GraphKeys.TRAINABLE_VARIABLES, scope=name)
def encode(self, inp, **flags):
"""
Takes symbolic input sequence, computes initial state
:param inp: matrix of input tokens [batch, time]
:return: a list of initial decoder state tensors
"""
inp_lengths = infer_length(inp, self.inp_voc.eos_ix)
inp_emb = self.emb_inp(inp)
_, enc_last = tf.nn.dynamic_rnn(
self.enc0, inp_emb,
sequence_length=inp_lengths,
dtype=inp_emb.dtype)
dec_start = self.dec_start(enc_last)
return [dec_start]
def decode(self, prev_state, prev_tokens, **flags):
"""
Takes previous decoder state and tokens, returns new state and logits
:param prev_state: a list of previous decoder state tensors
:param prev_tokens: previous output tokens, an int vector of [batch_size]
:return: a list of next decoder state tensors, a tensor of logits [batch,n_tokens]
"""
[prev_dec] = prev_state
prev_emb = self.emb_out(prev_tokens[:, None])[:, 0]
new_dec_out, new_dec_state = self.dec0(prev_emb, prev_dec)
output_logits = self.logits(new_dec_out)
return [new_dec_state], output_logits
def symbolic_score(self, inp, out, eps=1e-30, **flags):
"""
Takes symbolic int32 matrices of hebrew words and their english translations.
Computes the log-probabilities of all possible english characters given english prefices and hebrew word.
:param inp: input sequence, int32 matrix of shape [batch,time]
:param out: output sequence, int32 matrix of shape [batch,time]
:return: log-probabilities of all possible english characters of shape [bath,time,n_tokens]
NOTE: log-probabilities time axis is synchronized with out
In other words, logp are probabilities of __current__ output at each tick, not the next one
therefore you can get likelihood as logprobas * tf.one_hot(out,n_tokens)
"""
first_state = self.encode(inp, **flags)
batch_size = tf.shape(inp)[0]
bos = tf.fill([batch_size], self.out_voc.bos_ix)
first_logits = tf.log(tf.one_hot(bos, len(self.out_voc)) + eps)
def step(blob, y_prev):
h_prev = blob[:-1]
h_new, logits = self.decode(h_prev, y_prev, **flags)
return list(h_new) + [logits]
results = tf.scan(step, initializer=list(first_state) + [first_logits],
elems=tf.transpose(out))
# gather state and logits, each of shape [time,batch,...]
states_seq, logits_seq = results[:-1], results[-1]
# add initial state and logits
logits_seq = tf.concat((first_logits[None], logits_seq), axis=0)
# convert from [time,batch,...] to [batch,time,...]
logits_seq = tf.transpose(logits_seq, [1, 0, 2])
return tf.nn.log_softmax(logits_seq)
def symbolic_translate(
self,
inp,
greedy=False,
max_len=None,
eps=1e-30,
**flags):
"""
takes symbolic int32 matrix of hebrew words, produces output tokens sampled
from the model and output log-probabilities for all possible tokens at each tick.
:param inp: input sequence, int32 matrix of shape [batch,time]
:param greedy: if greedy, takes token with highest probablity at each tick.
Otherwise samples proportionally to probability.
:param max_len: max length of output, defaults to 2 * input length
:return: output tokens int32[batch,time] and
log-probabilities of all tokens at each tick, [batch,time,n_tokens]
"""
first_state = self.encode(inp, **flags)
batch_size = tf.shape(inp)[0]
bos = tf.fill([batch_size], self.out_voc.bos_ix)
first_logits = tf.log(tf.one_hot(bos, len(self.out_voc)) + eps)
max_len = tf.reduce_max(tf.shape(inp)[1]) * 2
def step(blob, t):
h_prev, y_prev = blob[:-2], blob[-1]
h_new, logits = self.decode(h_prev, y_prev, **flags)
y_new = (
tf.argmax(logits, axis=-1) if greedy
else tf.multinomial(logits, 1)[:, 0]
)
return list(h_new) + [logits, tf.cast(y_new, y_prev.dtype)]
results = tf.scan(
step,
initializer=list(first_state) + [first_logits, bos],
elems=[tf.range(max_len)],
)
# gather state, logits and outs, each of shape [time,batch,...]
states_seq, logits_seq, out_seq = (
results[:-2], results[-2], results[-1]
)
# add initial state, logits and out
logits_seq = tf.concat((first_logits[None], logits_seq), axis=0)
out_seq = tf.concat((bos[None], out_seq), axis=0)
states_seq = [
tf.concat((init[None], states), axis=0)
for init, states in zip(first_state, states_seq)
]
# convert from [time,batch,...] to [batch,time,...]
logits_seq = tf.transpose(logits_seq, [1, 0, 2])
out_seq = tf.transpose(out_seq)
states_seq = [
tf.transpose(states, [1, 0] + list(range(2, states.shape.ndims)))
for states in states_seq
]
return out_seq, tf.nn.log_softmax(logits_seq)
### Utility functions ###
def initialize_uninitialized(sess=None):
"""
Initialize unitialized variables, doesn't affect those already initialized
:param sess: in which session to initialize stuff. Defaults to tf.get_default_session()
"""
sess = sess or tf.get_default_session()
global_vars = tf.global_variables()
is_not_initialized = sess.run(
[tf.is_variable_initialized(var) for var in global_vars]
)
not_initialized_vars = [
v for (v, f)
in zip(global_vars, is_not_initialized)
if not f
]
if len(not_initialized_vars):
sess.run(tf.variables_initializer(not_initialized_vars))
def infer_length(seq, eos_ix, time_major=False, dtype=tf.int32):
"""
compute length given output indices and eos code
:param seq: tf matrix [time,batch] if time_major else [batch,time]
:param eos_ix: integer index of end-of-sentence token
:returns: lengths, int32 vector of shape [batch]
"""
axis = 0 if time_major else 1
is_eos = tf.cast(tf.equal(seq, eos_ix), dtype)
count_eos = tf.cumsum(is_eos, axis=axis, exclusive=True)
lengths = tf.reduce_sum(tf.cast(tf.equal(count_eos, 0), dtype), axis=axis)
return lengths
def infer_mask(seq, eos_ix, time_major=False, dtype=tf.float32):
"""
compute mask given output indices and eos code
:param seq: tf matrix [time,batch] if time_major else [batch,time]
:param eos_ix: integer index of end-of-sentence token
:returns: mask, float32 matrix with '0's and '1's of same shape as seq
"""
axis = 0 if time_major else 1
lengths = infer_length(seq, eos_ix, time_major=time_major)
mask = tf.sequence_mask(lengths, maxlen=tf.shape(seq)[axis], dtype=dtype)
if time_major:
mask = tf.transpose(mask)
return mask
def select_values_over_last_axis(values, indices):
"""
Auxiliary function to select logits corresponding to chosen tokens.
:param values: logits for all actions: float32[batch,tick,action]
:param indices: action ids int32[batch,tick]
:returns: values selected for the given actions: float[batch,tick]
"""
assert values.shape.ndims == 3 and indices.shape.ndims == 2
batch_size, seq_len = tf.shape(indices)[0], tf.shape(indices)[1]
batch_i = tf.tile(tf.range(0, batch_size)[:, None], [1, seq_len])
time_i = tf.tile(tf.range(0, seq_len)[None, :], [batch_size, 1])
indices_nd = tf.stack([batch_i, time_i, indices], axis=-1)
return tf.gather_nd(values, indices_nd)