Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support RL training #1

Merged
merged 4 commits into from
Aug 21, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 117 additions & 19 deletions build_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,12 @@
#
import json

import sequencing as sq
import numpy
import tensorflow as tf

import sequencing as sq
from sequencing import TIME_MAJOR, MODE
from sequencing.utils.metrics import Delta_BLEU


def optimistic_restore(session, save_file):
Expand Down Expand Up @@ -41,22 +44,93 @@ def optimistic_restore(session, save_file):


def cross_entropy_sequence_loss(logits, targets, sequence_length):
with tf.name_scope("cross_entropy_sequence_loss"):
losses = tf.nn.sparse_softmax_cross_entropy_with_logits(
with tf.name_scope('cross_entropy_sequence_loss'):
total_length = tf.to_float(tf.reduce_sum(sequence_length))

entropy_losses = tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=logits, labels=targets)

# Mask out the losses we don't care about
loss_mask = tf.sequence_mask(
tf.to_int32(sequence_length), tf.to_int32(tf.shape(targets)[0]))
losses = losses * tf.transpose(tf.to_float(loss_mask), [1, 0])
loss_mask = tf.transpose(tf.to_float(loss_mask), [1, 0])

losses = entropy_losses * loss_mask
# losses.shape: T * B
# sequence_length: B
total_loss_avg = tf.reduce_sum(losses) / total_length

return total_loss_avg


def rl_sequence_loss(logits, targets, sequence_length, baseline_states, reward):
# reward: T * B
with tf.name_scope('rl_sequence_loss'):
total_length = tf.to_float(tf.reduce_sum(sequence_length))

entropy_losses = tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=logits, labels=targets)

# Mask out the losses we don't care about
loss_mask = tf.sequence_mask(
tf.to_int32(sequence_length), tf.to_int32(tf.shape(targets)[0]))
loss_mask = tf.transpose(tf.to_float(loss_mask), [1, 0])

reward_predicted = tf.contrib.layers.fully_connected(baseline_states, 1,
activation_fn=None)
reward_predicted = tf.squeeze(reward_predicted)

reward_losses = tf.pow(reward_predicted - reward, 2)

reward_loss_rmse = tf.sqrt(tf.reduce_sum(reward_losses * loss_mask) /
total_length)

return losses
reward_entropy_losses = (reward - tf.stop_gradient(reward_predicted)) \
* entropy_losses * loss_mask

# Calculate the average log perplexity in each batch
total_loss_avg = tf.reduce_sum(
reward_entropy_losses) / total_length + reward_loss_rmse

# the first reward predict is total reward
return total_loss_avg, \
tf.reduce_sum(entropy_losses * loss_mask) / total_length, \
tf.reduce_mean(tf.slice(reward_predicted, [0, 0], [1, -1]))


def _py_func(predict_target_ids, ground_truth_ids, eos_id):
n = 4 # 4-gram
delta = True # delta future reward
batch_size = predict_target_ids.shape[1]
length = numpy.zeros(batch_size, dtype=numpy.int32)
reward = numpy.zeros_like(predict_target_ids, dtype=numpy.float32)

for i in range(batch_size):
p_id = predict_target_ids[:, i].tolist()
p_len = p_id.index(eos_id) + 1 if eos_id in p_id else len(p_id)
length[i] = p_len
p_id = p_id[:p_len]

t_id = ground_truth_ids[:, i].tolist()
t_len = t_id.index(eos_id) + 1 if eos_id in t_id else len(t_id)
t_id = t_id[:t_len]

bleu_scores = Delta_BLEU(p_id, t_id, n)
reward_i = bleu_scores[:, n - 1].copy()

if delta:
reward_i[1:] = reward_i[1:] - reward_i[:-1]
reward[:p_len, i] = reward_i[::-1].cumsum()[::-1]
else:
reward[:p_len, i] = reward_i[-1]

return reward, length


def build_attention_model(params, src_vocab, trg_vocab, source_ids,
source_seq_length, target_ids, target_seq_length,
beam_size=1, mode=MODE.TRAIN, teacher_rate=1.0,
max_step=100):
beam_size=1, mode=MODE.TRAIN,
teacher_rate=1.0, max_step=100):
"""
Build a model.

Expand All @@ -83,7 +157,7 @@ def build_attention_model(params, src_vocab, trg_vocab, source_ids,
:param source_ids: placeholder
:param source_seq_length: placeholder
:param target_ids: placeholder
:param target_ids: placeholder
:param target_seq_length: placeholder
:param beam_size: used in beam inference
:param mode:
:return:
Expand All @@ -105,14 +179,19 @@ def build_attention_model(params, src_vocab, trg_vocab, source_ids,
source_embedded = source_embedding_table(source_ids)

encoder = sq.StackBidirectionalRNNEncoder(encoder_params, name='stack_rnn',
mode=mode)
mode=mode)
encoded_representation = encoder.encode(source_embedded, source_seq_length)
attention_keys = encoded_representation.attention_keys
attention_values = encoded_representation.attention_values
attention_length = encoded_representation.attention_length

# feedback
if mode == MODE.TRAIN:
if mode == MODE.RL:
tf.logging.info('BUILDING RL TRAIN FEEDBACK......')
dynamical_batch_size = tf.shape(attention_keys)[1]
feedback = sq.RLTrainingFeedBack(trg_vocab, dynamical_batch_size,
max_step=max_step)
elif mode == MODE.TRAIN:
tf.logging.info('BUILDING TRAIN FEEDBACK WITH {} TEACHER_RATE'
'......'.format(teacher_rate))
feedback = sq.TrainingFeedBack(target_ids, target_seq_length,
Expand Down Expand Up @@ -168,20 +247,39 @@ def build_attention_model(params, src_vocab, trg_vocab, source_ids,
decoder_output, decoder_final_state = sq.dynamic_decode(decoder,
scope='decoder')

if mode != MODE.TRAIN:
# not training
if mode == MODE.EVAL or mode == MODE.INFER:
return decoder_output, decoder_final_state

# construct the loss
# bos is added in feedback
# so target_ids is predict_ids
if not TIME_MAJOR:
predict_ids = tf.transpose(target_ids, [1, 0])
ground_truth_ids = tf.transpose(target_ids, [1, 0])
else:
predict_ids = target_ids
ground_truth_ids = target_ids

# construct the loss
if mode == MODE.RL:
baseline_states = tf.stop_gradient(decoder_output.baseline_states)
predict_ids = tf.stop_gradient(decoder_output.predicted_ids)

losses = cross_entropy_sequence_loss(
logits=decoder_output.logits,
targets=predict_ids,
sequence_length=target_seq_length)
reward, sequence_length = tf.py_func(
func=_py_func,
inp=[predict_ids, ground_truth_ids, trg_vocab.eos_id],
Tout=[tf.float32, tf.int32],
name='reward')
sequence_length.set_shape((None,))
total_loss_avg, entropy_loss_avg, reward_predicted = rl_sequence_loss(
logits=decoder_output.logits,
targets=predict_ids,
sequence_length=sequence_length,
baseline_states=baseline_states,
reward=reward)
return decoder_output, total_loss_avg, entropy_loss_avg, reward_predicted
else:

return decoder_output, losses
total_loss_avg = cross_entropy_sequence_loss(
logits=decoder_output.logits,
targets=ground_truth_ids,
sequence_length=target_seq_length)
return decoder_output, total_loss_avg, total_loss_avg, tf.to_float(0.)
100 changes: 100 additions & 0 deletions config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
from collections import namedtuple

from build_inputs import build_vocab

TrainingConfigs = namedtuple('TrainingConfigs',
['src_vocab', 'trg_vocab', 'params',
'train_src_file', 'train_trg_file',
'test_src_file', 'test_trg_file',
'beam_size', 'batch_size',
'max_step', 'model_dir',
'lr_rate', 'rl_lr_rate',
'clip_gradient_norm', 'rl_clip_gradient_norm',
'train_steps'])


def get_config(config_name):
configs = [word2pos.__name__, en2zh.__name__]

if config_name in configs:
return eval(config_name)()
else:
raise Exception('Config not found')


def word2pos():
# load vocab
src_vocab = build_vocab('data/vocab.word', 256, ' ')
trg_vocab = build_vocab('data/vocab.tag', 32, ' ')

params = {'encoder': {'rnn_cell': {'state_size': 512,
'cell_name': 'BasicLSTMCell',
'num_layers': 1,
'input_keep_prob': 1.0,
'output_keep_prob': 1.0},
'attention_key_size': 256},
'decoder': {'rnn_cell': {'cell_name': 'BasicLSTMCell',
'state_size': 512,
'num_layers': 1,
'input_keep_prob': 1.0,
'output_keep_prob': 1.0},
'logits': {'input_keep_prob': 1.0}}}

configs = TrainingConfigs(
src_vocab=src_vocab,
trg_vocab=trg_vocab,
params=params,
train_src_file='data/train.word',
train_trg_file='data/train.tag',
test_src_file='data/test.word',
test_trg_file='data/test.tag',
beam_size=1,
batch_size=64,
max_step=100,
model_dir='models',
lr_rate=0.001,
rl_lr_rate=0.0001,
clip_gradient_norm=5.,
rl_clip_gradient_norm=1.,
train_steps=200000)

return configs


def en2zh():
# load vocab
src_vocab = build_vocab('data/vocab.en', 512, ' ')
trg_vocab = build_vocab('data/vocab.zh', 512, '')

params = {'encoder': {'rnn_cell': {'state_size': 1024,
'cell_name': 'BasicLSTMCell',
'num_layers': 1,
'input_keep_prob': 1.0,
'output_keep_prob': 1.0},
'attention_key_size': 512},
'decoder': {'rnn_cell': {'cell_name': 'BasicLSTMCell',
'state_size': 1024,
'num_layers': 1,
'input_keep_prob': 1.0,
'output_keep_prob': 1.0},
'logits': {'input_keep_prob': 1.0}}}

configs = TrainingConfigs(
src_vocab=src_vocab,
trg_vocab=trg_vocab,
params=params,
train_src_file='data/en.tok.shuf.filter',
train_trg_file='data/zh.tok.shuf.filter',
test_src_file='data/test.en',
test_trg_file='data/test.zh',
beam_size=5,
batch_size=128,
max_step=150,
model_dir='models',
lr_rate=0.0005,
rl_lr_rate=0.0001,
clip_gradient_norm=5.,
rl_clip_gradient_norm=1.,
train_steps=200000)

return configs
Loading