diff --git a/trax/data/tf_inputs.py b/trax/data/tf_inputs.py index 38670d9c4..49ab2fc1a 100644 --- a/trax/data/tf_inputs.py +++ b/trax/data/tf_inputs.py @@ -186,7 +186,7 @@ def append_targets(example): # Skip a random fraction at the beginning of the stream. The skip is # essential for synchronous highly-parallel training to avoid multiple # replicas reading the same data in lock-step. - dataset = dataset.skip(random.randint(0, _MAX_SKIP_EXAMPLES)) + dataset = dataset.skip(random.randint(0, int(_MAX_SKIP_EXAMPLES))) dataset = preprocess_fn(dataset, training) dataset = dataset.shuffle(shuffle_buffer_size) return dataset.prefetch(8)