Skip to content

Commit

Permalink
Sampling text is now based under a multinomial distribution instead o…
Browse files Browse the repository at this point in the history
…f random guessing.
  • Loading branch information
gugarosa committed Feb 18, 2019
1 parent 31723b9 commit 0772942
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 6 deletions.
2 changes: 1 addition & 1 deletion examples/neurals/training_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,5 +44,5 @@
print(''.join(pred_text))

# Generating new text
gen_text = rnn.generate_text(dataset=d, start_text=pred_input, length=100)
gen_text = rnn.generate_text(dataset=d, start_text=pred_input, length=100, temperature=0.2)
print(''.join(gen_text))
38 changes: 33 additions & 5 deletions nalp/neurals/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,14 +325,45 @@ def predict(self, input_batch, model_path=None, probability=1):

return predict

def generate_text(self, dataset, start_text='', length=1, model_path=None):
def _sample_from_multinomial(self, probs, temperature):
""" Samples an vocabulary index from a multinomial distribution.
Args:
probs (np.array): An array of probabilites from 'tf.nn.softmax'.
temperature (float): The amount of diversity to include when sampling.
Returns:
The index of sampled character or word.
"""

# Converting to float64 to avoid multinomial distribution erros
probs = np.asarray(probs).astype('float64')

# Then, we calculate the log of probs, divide by temperature and apply
# exponential
exp_probs = np.exp(np.log(probs) / temperature)

# Finally, we normalize it
norm_probs = exp_probs / np.sum(exp_probs)

# Sampling from multinomial distribution
dist_probs = np.random.multinomial(1, norm_probs, 1)

# The predicted index will be the argmax of the distribution
pred_idx = np.argmax(dist_probs)

return pred_idx

def generate_text(self, dataset, start_text='', length=1, temperature=1.0, model_path=None):
"""Generates a maximum length of new text based on the probability of next char
ocurring.
Args:
dataset (OneHot): A OneHot object.
start_text (str): The initial text for generating new text.
length (int): Maximum amount of generated text.
temperature (float): The amount of diversity to include when sampling.
model_path (str): If needed, will load a different model from the previously trained.
Returns:
Expand Down Expand Up @@ -374,10 +405,7 @@ def generate_text(self, dataset, start_text='', length=1, model_path=None):
predict = sess.run([self.predictor_prob], feed_dict={self.x: seed})

# Chooses a index based on the predictions probability distribution
pred_idx = np.random.choice(
range(dataset.vocab_size),
p=predict[0][-1]
)
pred_idx = self._sample_from_multinomial(predict[0][-1], temperature)

# Removing first indexated token
tokens_idx = np.delete(tokens_idx, 0, 0)
Expand Down

0 comments on commit 0772942

Please sign in to comment.