Skip to content

Commit

Permalink
Adding a method to use validation datasets while training.
Browse files Browse the repository at this point in the history
  • Loading branch information
gugarosa committed Aug 20, 2019
1 parent 75d1008 commit de8b68b
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 28 deletions.
2 changes: 1 addition & 1 deletion examples/applications/audio_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
rnn = RNN(vocab_size=d.vocab_size, hidden_size=128, learning_rate=0.001)

# Training the network
rnn.train(dataset=d, batch_size=128, epochs=5)
rnn.train(train=d, batch_size=128, epochs=5)

# Generating new notes
gen_notes = rnn.generate_text(
Expand Down
2 changes: 1 addition & 1 deletion examples/applications/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
rnn = RNN(vocab_size=d.vocab_size, hidden_size=64, learning_rate=0.001)

# Training the network
rnn.train(dataset=d, batch_size=128, epochs=100)
rnn.train(train=d, batch_size=128, epochs=100)

# Predicting using the same input (just for checking what is has learnt)
preds = rnn.predict(d.X)
Expand Down
6 changes: 1 addition & 5 deletions examples/neurals/train_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,13 @@
# Loading a text
sentences = l.load_txt('data/text/chapter1_harry.txt')

# Defining a predition input
start_text = 'Mr. Dursley'

# Creates a pre-processing pipeline
pipe = p.pipeline(
p.tokenize_to_char
)

# Applying pre-processing pipeline to sentences and start token
tokens = pipe(sentences)
start_token = pipe(start_text)

# Creating a OneHot dataset
d = OneHot(tokens, max_length=10)
Expand All @@ -25,4 +21,4 @@
rnn = RNN(vocab_size=d.vocab_size, hidden_size=64, learning_rate=0.001)

# Training the network
rnn.train(dataset=d, batch_size=128, epochs=100)
rnn.train(train=d, batch_size=128, epochs=100)
71 changes: 56 additions & 15 deletions nalp/core/neural.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@
class Neural(tf.keras.Model):
"""A Neural class is responsible for holding vital information when defining a
neural network.
Note that some methods have to be redefined when using its childs.
"""

def __init__(self):
"""Initialization method.
Note that basic variables shared by all childs should be declared here.
"""
Expand Down Expand Up @@ -108,16 +108,17 @@ def step(self, X_batch, Y_batch):
zip(gradients, self.trainable_variables))

# Update the loss metric state
self.loss_metric.update_state(loss)
self.train_loss.update_state(loss)

# Update the accuracy metric state
self.accuracy_metric.update_state(Y_batch, preds)
self.train_accuracy.update_state(Y_batch, preds)

def train(self, dataset, batch_size=1, epochs=100):
def train(self, train, validation=None, batch_size=1, epochs=100):
"""Trains a model.
Args:
dataset (Dataset): A Dataset object containing already encoded data (X, Y).
train (Dataset): A training Dataset object containing already encoded data (X, Y).
validation (Dataset): A validation Dataset object containing already encoded data (X, Y).
batch_size (int): The maximum size for each training batch.
epochs (int): The maximum number of training epochs.
Expand All @@ -126,22 +127,62 @@ def train(self, dataset, batch_size=1, epochs=100):
logger.info(f'Model ready to be trained for: {epochs} epochs.')
logger.info(f'Batch size: {batch_size}.')

# Creating batches to further feed the network
batches = dataset.create_batches(dataset.X, dataset.Y, batch_size)
# Creating training batches to further feed the network
train_batches = train.create_batches(train.X, train.Y, batch_size)

# Checks if there is a validation set
if validation:
# Creating validation batches to further feed the network
val_batches = validation.create_batches(
validation.X, validation.Y, batch_size)

# Iterate through all epochs
for epoch in range(epochs):
# Resetting states to further append losses and accuracies
self.loss_metric.reset_states()
self.accuracy_metric.reset_states()
self.train_loss.reset_states()
self.train_accuracy.reset_states()
self.val_loss.reset_states()
self.val_accuracy.reset_states()

# Iterate through all possible batches, dependending on batch size
for X_batch, Y_batch in batches:
# Iterate through all possible training batches, dependending on batch size
for X_train, Y_train in train_batches:
# Performs the optimization step
self.step(X_batch, Y_batch)
self.step(X_train, Y_train)

logger.debug(
f'Epoch: {epoch+1}/{epochs} | Loss: {self.loss_metric.result().numpy():.4f} | Accuracy: {self.accuracy_metric.result().numpy():.4f}')
f'Epoch: {epoch+1}/{epochs} | Loss: {self.train_loss.result().numpy():.4f} | Accuracy: {self.train_accuracy.result().numpy():.4f}')

# Checks if there is a validation set
if validation:
# Iterate through all possible batches, dependending on batch size
for X_val, Y_val in val_batches:
# Tests the network
self.test(X_val, Y_val)

logger.debug(
f'Val Loss: {self.val_loss.result().numpy():.4f} | Val Accuracy: {self.val_accuracy.result().numpy():.4f}\n')

@tf.function
def test(self, X_batch, Y_batch):
"""Performs a single batch testing.
Args:
X_batch (tf.Tensor): A tensor containing the inputs batch.
Y_batch (tf.Tensor): A tensor containing the inputs' labels batch.
"""

# Calculate the predictions based on inputs
preds = self(X_batch)

# Calculate the loss
loss = self.loss(Y_batch, preds)

# Update the testing loss metric state
self.val_loss.update_state(loss)

# Update the testing accuracy metric state
self.val_accuracy.update_state(Y_batch, preds)

@tf.function
def predict(self, X):
Expand All @@ -159,4 +200,4 @@ def predict(self, X):
# Performs the forward pass
preds = self(X)

return preds
return preds
19 changes: 13 additions & 6 deletions nalp/neurals/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,15 +127,22 @@ def _build_metrics(self):
"""

# Defining accuracy metric
self.accuracy_metric = tf.metrics.CategoricalAccuracy(
name='accuracy_metric')
# Defining training accuracy metric
self.train_accuracy = tf.metrics.CategoricalAccuracy(
name='train_accuracy')

# Defining loss metric
self.loss_metric = tf.metrics.Mean(name='loss_metric')
# Defining training loss metric
self.train_loss = tf.metrics.Mean(name='train_loss')

# Defining validation accuracy metric
self.val_accuracy = tf.metrics.CategoricalAccuracy(
name='val_accuracy')

# Defining validation loss metric
self.val_loss = tf.metrics.Mean(name='val_loss')

logger.debug(
f'Accuracy: {self.loss_metric} | Mean Loss: {self.loss_metric}.')
f'Train Accuracy: {self.train_accuracy} | Train Loss: {self.train_loss} | Val Accuracy: {self.val_accuracy} | Val Loss: {self.val_loss}.')

@tf.function
def call(self, x):
Expand Down

0 comments on commit de8b68b

Please sign in to comment.