Skip to content

Commit

Permalink
Update text generation example
Browse files Browse the repository at this point in the history
  • Loading branch information
fchollet committed Nov 11, 2023
1 parent 66bb075 commit 1a480f4
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions examples/keras_io/generative/text_generation_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,18 +49,18 @@

# Data
BATCH_SIZE = 64
SEQ_LEN = 128
MIN_TRAINING_SEQ_LEN = 450
MIN_STRING_LEN = 512 # Strings shorter than this will be discarded
SEQ_LEN = 128 # Length of training sequences, in tokens

# Model
EMBED_DIM = 256
FEED_FORWARD_DIM = 256
FEED_FORWARD_DIM = 128
NUM_HEADS = 3
NUM_LAYERS = 2
VOCAB_SIZE = 5000 # Limits parameters in model.

# Training
EPOCHS = 6
EPOCHS = 5

# Inference
NUM_TOKENS_TO_GENERATE = 80
Expand All @@ -82,15 +82,15 @@
# Load simplebooks-92 train set and filter out short lines.
raw_train_ds = (
tf_data.TextLineDataset(dir + "simplebooks-92-raw/train.txt")
.filter(lambda x: tf_strings.length(x) > MIN_TRAINING_SEQ_LEN)
.filter(lambda x: tf_strings.length(x) > MIN_STRING_LEN)
.batch(BATCH_SIZE)
.shuffle(buffer_size=256)
)

# Load simplebooks-92 validation set and filter out short lines.
raw_val_ds = (
tf_data.TextLineDataset(dir + "simplebooks-92-raw/valid.txt")
.filter(lambda x: tf_strings.length(x) > MIN_TRAINING_SEQ_LEN)
.filter(lambda x: tf_strings.length(x) > MIN_STRING_LEN)
.batch(BATCH_SIZE)
)

Expand Down Expand Up @@ -214,7 +214,7 @@ def preprocess(inputs):
Now that we have our model, let's train it with the `fit()` method.
"""

model.fit(train_ds, validation_data=val_ds, verbose=2, epochs=EPOCHS)
model.fit(train_ds, validation_data=val_ds, epochs=EPOCHS)

"""
## Inference
Expand Down

0 comments on commit 1a480f4

Please sign in to comment.