Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merging various fixes for Colab, Cloud TPU, TPU Pod, ... #247

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
10 changes: 6 additions & 4 deletions data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
"<eop>" : 8,
}

VOCAB_SIZE = 32000
UNK_ID = special_symbols["<unk>"]
CLS_ID = special_symbols["<cls>"]
SEP_ID = special_symbols["<sep>"]
Expand Down Expand Up @@ -188,7 +187,7 @@ def create_data(_):
# Create and dump corpus_info from task 0
if FLAGS.task == 0:
corpus_info = {
"vocab_size": VOCAB_SIZE,
"vocab_size": FLAGS.n_token,
"bsz_per_host": FLAGS.bsz_per_host,
"num_core_per_host": FLAGS.num_core_per_host,
"seq_len": FLAGS.seq_len,
Expand Down Expand Up @@ -762,6 +761,8 @@ def parser(record):
def get_input_fn(
tfrecord_dir,
split,
task,
pass_id,
bsz_per_host,
seq_len,
reuse_len,
Expand All @@ -778,7 +779,7 @@ def get_input_fn(

# Merge all record infos into a single one
record_glob_base = format_filename(
prefix="record_info-{}-*".format(split),
prefix="record_info-{}-{}-{}".format(split, task, pass_id),
bsz_per_host=bsz_per_host,
seq_len=seq_len,
bi_data=bi_data,
Expand Down Expand Up @@ -883,7 +884,8 @@ def input_fn(params):
flags.DEFINE_integer("reuse_len", 256,
help="Number of token that can be reused as memory. "
"Could be half of `seq_len`.")
flags.DEFINE_bool("uncased", True, help="Use uncased inputs or not.")
flags.DEFINE_bool("uncased", False, help="Use uncased inputs or not.")
flags.DEFINE_integer("n_token", 32000, help="Vocab size")
flags.DEFINE_bool("bi_data", True,
help="whether to create bidirectional data")
flags.DEFINE_integer("mask_alpha", default=6,
Expand Down
3 changes: 2 additions & 1 deletion model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def configure_tpu(FLAGS):
if FLAGS.use_tpu:
tpu_cluster = tf.contrib.cluster_resolver.TPUClusterResolver(
FLAGS.tpu, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
master = tpu_cluster.get_master()
master = None
else:
tpu_cluster = None
master = FLAGS.master
Expand All @@ -42,6 +42,7 @@ def configure_tpu(FLAGS):

per_host_input = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2
run_config = tf.contrib.tpu.RunConfig(
cluster=tpu_cluster,
master=master,
model_dir=FLAGS.model_dir,
session_config=session_config,
Expand Down
2 changes: 1 addition & 1 deletion modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def relative_positional_encoding(qlen, klen, d_model, clamp_len, attn_type,

if bsz is not None:
# With bi_data, the batch size should be divisible by 2.
assert bsz%2 == 0
tf.debugging.assert_equal(bsz % 2, 0)
fwd_pos_emb = positional_embedding(fwd_pos_seq, inv_freq, bsz//2)
bwd_pos_emb = positional_embedding(bwd_pos_seq, inv_freq, bsz//2)
else:
Expand Down
2 changes: 1 addition & 1 deletion run_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

import sentencepiece as spm

from data_utils import SEP_ID, VOCAB_SIZE, CLS_ID
from data_utils import SEP_ID, CLS_ID
import model_utils
import function_builder
from classifier_utils import PaddingInputExample
Expand Down
2 changes: 1 addition & 1 deletion run_race.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import tensorflow as tf
import sentencepiece as spm

from data_utils import SEP_ID, VOCAB_SIZE, CLS_ID
from data_utils import SEP_ID, CLS_ID
import model_utils
import function_builder
from classifier_utils import PaddingInputExample
Expand Down
2 changes: 1 addition & 1 deletion run_squad.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import function_builder
import model_utils
import squad_utils
from data_utils import SEP_ID, CLS_ID, VOCAB_SIZE
from data_utils import SEP_ID, CLS_ID

SPIECE_UNDERLINE = u'▁'

Expand Down
9 changes: 8 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,12 @@
flags.DEFINE_float("init_range", default=0.1,
help="Initialization std when init is uniform.")

# TFRecord Path
flags.DEFINE_integer("pass_id", 0, help="ID of the current pass."
"Different passes sample different negative segment.")
flags.DEFINE_integer("task", 0, help="The Task ID. This value is used when "
"using multiple workers to identify each worker.")

FLAGS = flags.FLAGS


Expand Down Expand Up @@ -226,6 +232,8 @@ def get_input_fn(split):
input_fn, record_info_dict = data_utils.get_input_fn(
tfrecord_dir=FLAGS.record_info_dir,
split=split,
task=FLAGS.task,
pass_id=FLAGS.pass_id,
bsz_per_host=batch_size // FLAGS.num_hosts,
seq_len=FLAGS.seq_len,
reuse_len=FLAGS.reuse_len,
Expand All @@ -251,7 +259,6 @@ def main(unused_argv):
assert FLAGS.seq_len > 0
assert FLAGS.perm_size > 0

FLAGS.n_token = data_utils.VOCAB_SIZE
tf.logging.info("n_token {}".format(FLAGS.n_token))

if not tf.gfile.Exists(FLAGS.model_dir):
Expand Down
11 changes: 8 additions & 3 deletions train_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,11 @@
flags.DEFINE_float("init_range", default=0.1,
help="Initialization std when init is uniform.")

# TFRecord Path
flags.DEFINE_integer("pass_id", 0, help="ID of the current pass."
"Different passes sample different negative segment.")
flags.DEFINE_integer("task", 0, help="The Task ID. This value is used when "
"using multiple workers to identify each worker.")

FLAGS = flags.FLAGS

Expand Down Expand Up @@ -186,6 +191,8 @@ def train(ps_device):
train_input_fn, record_info_dict = data_utils.get_input_fn(
tfrecord_dir=FLAGS.record_info_dir,
split="train",
task=FLAGS.task,
pass_id=FLAGS.pass_id,
bsz_per_host=FLAGS.train_batch_size,
seq_len=FLAGS.seq_len,
reuse_len=FLAGS.reuse_len,
Expand Down Expand Up @@ -293,7 +300,7 @@ def train(ps_device):
total_loss += loss_np

if curr_step > 0 and curr_step % FLAGS.iterations == 0:
curr_loss = total_loss / (curr_step - prev_step)
curr_loss = total_loss / FLAGS.iterations
tf.logging.info("[{}] | gnorm {:.2f} lr {:8.6f} "
"| loss {:.2f} | pplx {:>7.2f}, bpc {:>7.4f}".format(
curr_step, fetched[-3], fetched[-2],
Expand All @@ -314,8 +321,6 @@ def main(unused_argv):

tf.logging.set_verbosity(tf.logging.INFO)

# Get corpus info
FLAGS.n_token = data_utils.VOCAB_SIZE
tf.logging.info("n_token {}".format(FLAGS.n_token))

if not tf.gfile.Exists(FLAGS.model_dir):
Expand Down