diff --git a/examples/keras_io/generative/text_generation_gpt.py b/examples/keras_io/generative/text_generation_gpt.py index 1565fe8f3a7..5693d0b1100 100644 --- a/examples/keras_io/generative/text_generation_gpt.py +++ b/examples/keras_io/generative/text_generation_gpt.py @@ -49,18 +49,18 @@ # Data BATCH_SIZE = 64 -SEQ_LEN = 128 -MIN_TRAINING_SEQ_LEN = 450 +MIN_STRING_LEN = 512 # Strings shorter than this will be discarded +SEQ_LEN = 128 # Length of training sequences, in tokens # Model EMBED_DIM = 256 -FEED_FORWARD_DIM = 256 +FEED_FORWARD_DIM = 128 NUM_HEADS = 3 NUM_LAYERS = 2 VOCAB_SIZE = 5000 # Limits parameters in model. # Training -EPOCHS = 6 +EPOCHS = 5 # Inference NUM_TOKENS_TO_GENERATE = 80 @@ -82,7 +82,7 @@ # Load simplebooks-92 train set and filter out short lines. raw_train_ds = ( tf_data.TextLineDataset(dir + "simplebooks-92-raw/train.txt") - .filter(lambda x: tf_strings.length(x) > MIN_TRAINING_SEQ_LEN) + .filter(lambda x: tf_strings.length(x) > MIN_STRING_LEN) .batch(BATCH_SIZE) .shuffle(buffer_size=256) ) @@ -90,7 +90,7 @@ # Load simplebooks-92 validation set and filter out short lines. raw_val_ds = ( tf_data.TextLineDataset(dir + "simplebooks-92-raw/valid.txt") - .filter(lambda x: tf_strings.length(x) > MIN_TRAINING_SEQ_LEN) + .filter(lambda x: tf_strings.length(x) > MIN_STRING_LEN) .batch(BATCH_SIZE) ) @@ -214,7 +214,7 @@ def preprocess(inputs): Now that we have our model, let's train it with the `fit()` method. """ -model.fit(train_ds, validation_data=val_ds, verbose=2, epochs=EPOCHS) +model.fit(train_ds, validation_data=val_ds, epochs=EPOCHS) """ ## Inference