Skip to content

Commit

Permalink
minor: Tiny fixes here and there
Browse files Browse the repository at this point in the history
major: initial GUI interface
  • Loading branch information
draj committed Oct 6, 2021
1 parent b334db5 commit d3843f1
Show file tree
Hide file tree
Showing 7 changed files with 190 additions and 88 deletions.
146 changes: 75 additions & 71 deletions common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,23 @@ def shard_files_mono(files, args):
print("File for language", lang, "has been sharded.")
sys.stdout.flush()

def shard_files_mono_lm(files, args):
"""This method shards files into N parts containing the same number of lines. Each shard will go to a different GPU which may even be located on another machine. This method is run when the 'shard_files' argument is passed."""
print("Sharding files into", args.world_size, "parts")
for lang in files:
infile = open(files[lang]).readlines()
num_lines = len(infile)
lines_per_shard = math.ceil(num_lines/args.world_size)
print("For language:",lang," the total number of lines are:", num_lines, "and number of lines per shard are:", lines_per_shard)
for shard_id in range(args.world_size):
outfile = open(files[lang]+"."+"%02d" % shard_id, "w")
for line in infile[shard_id*lines_per_shard:(shard_id+1)*lines_per_shard]:
outfile.write(line)
outfile.flush()
outfile.close()
print("File for language", lang, "has been sharded.")
sys.stdout.flush()

def shard_files_bi(files, args):
"""This method shards files into N parts containing the same number of lines. Each shard will go to a different GPU which may even be located on another machine. This method is run when the 'shard_files' argument is passed."""
print("Sharding files into", args.world_size, "parts")
Expand Down Expand Up @@ -570,7 +587,7 @@ def generate_batches_monolingual_masked(tok, args, files, rank):
yield input_ids, input_masks, decoder_input_ids, labels


def generate_batches_lm(tok, args, files, rank): ## Address compatibilities of the meta tokens when using official models
def generate_batches_lm(tok, args, rank, files): ## Address compatibilities of the meta tokens when using official models
"""Generates the source, target and source attention masks for denoising. Long sequences are truncated and short sequences are ignored."""

batch_count = 0
Expand All @@ -590,50 +607,61 @@ def generate_batches_lm(tok, args, files, rank): ## Address compatibilities of t
probs = [probs_temp[lang] for lang in language_list]
num_langs = len(language_list)
language_indices = list(range(num_langs))
has_rem=False
while batch_count != args.num_batches:
curr_batch_count = 0
input_batch = []
label_batch = []
batch_count += 1
max_sent_len = 0
prev_max_sent_len = 0
start = time.time()
sents_in_batch = 0
while True:
language_idx = random.choices(language_indices, probs)[0]
sentence = next(language_file_dict[language_list[language_idx]]).strip()
lang = "<2"+language_list[language_idx]+">"
sentence_split = sentence.split(" ")
sent_len = len(sentence_split)
if sent_len < 1:
continue
if sent_len > args.max_length: ## Initial truncation
sentence_split = sentence_split[:args.max_length]
sentence = " ".join(sentence_split)
sent_len = args.max_length
iids = tok(lang + " " + sentence, add_special_tokens=False, return_tensors="pt").input_ids
curr_sent_len = len(iids[0])
if curr_sent_len > max_sent_len:
prev_max_sent_len = max_sent_len
max_sent_len = curr_sent_len
if not has_rem:
language_idx = random.choices(language_indices, probs)[0]
sentence = next(language_file_dict[language_list[language_idx]]).strip()
lang = "<2"+language_list[language_idx]+">"
sentence_split = sentence.split(" ")
sent_len = len(sentence_split)
if sent_len < 1:
continue
if args.train_with_meta and not has_rem and random.random() <= 0.2: ## Use the first part of the document only 20% of the time.
randidx = 0
sentence_split_curr = sentence_split[randidx:randidx+args.max_length]
sentence_curr=" ".join(sentence_split)

if args.use_official_pretrained:
input_batch.append(sentence_curr)
else:
input_batch.append(lang + " " + sentence_curr)

sents_in_batch += 1
if sents_in_batch == args.batch_size: ## We will drop this sentence for now. It may be used in a future iteration.
has_rem=True
break

has_rem=False
randidx = random.randint(0,max(sent_len-args.max_length,0))
sentence_split = sentence_split[randidx:randidx+args.max_length]
sentence=" ".join(sentence_split)

potential_batch_count = max_sent_len*(sents_in_batch+1)
if potential_batch_count > args.batch_size: ## We will drop this sentence for now. It may be used in a future iteration.
max_sent_len = prev_max_sent_len
break
input_batch.append(lang + " " + sentence)
label_batch.append(sentence + " </s>")
if args.use_official_pretrained:
input_batch.append(sentence)
else:
input_batch.append(lang + " " + sentence)

sents_in_batch += 1
if sents_in_batch == args.batch_size: ## We will drop this sentence for now. It may be used in a future iteration.
break

if len(encoder_input_batch) == 0:
print("Zero size batch due to an abnormal example. Skipping empty batch.")
continue
input_ids = tok(input_batch, add_special_tokens=False, return_tensors="pt", padding=True, max_length=max_sent_len).input_ids
if args.use_official_pretrained:
input_ids = tok(input_batch, return_tensors="pt", padding=True).input_ids
else:
input_ids = tok(input_batch, add_special_tokens=False, return_tensors="pt", padding=True).input_ids
if args.hard_truncate_length > 0 and len(input_ids[0]) > args.hard_truncate_length: ## Truncate again if we exceed the maximum sequence length.
input_ids = input_ids[:,:args.hard_truncate_length]
labels = tok(label_batch, add_special_tokens=False, return_tensors="pt", padding=True, max_length=max_sent_len).input_ids
if args.hard_truncate_length > 0 and len(labels[0]) > args.hard_truncate_length: ## Truncate again if we exceed the maximum sequence length.
labels = labels[:,:args.hard_truncate_length]
labels = input_ids[:,1:]
input_ids = input_ids[:,:-1]
end = time.time()
yield input_ids, labels

Expand Down Expand Up @@ -1283,48 +1311,21 @@ def generate_batches_for_decoding_lm(tok, args):
"""Generates the source sentences for the test set."""
src_file = open(args.test_src)
lang = args.lang
curr_batch_count = 0
input_batch = []
max_sent_len = 0
lang = slang if args.use_official_pretrained else "<2"+slang+">"
lang = lang if args.use_official_pretrained else "<2"+lang+">"

for line in src_file:
start = time.time()
sent = line.strip()
sent_split = sent.split(" ")
sent_len = len(sent_split)
if sent_len > args.max_length: ## Initial truncation
sent_split = sent_split[:args.max_length]
sent = " ".join(sent_split)
sent_len = args.max_length

iids = tok(lang + " " + src_sent, add_special_tokens=False, return_tensors="pt").input_ids
curr_sent_len = len(iids[0])

if curr_sent_len > max_sent_len:
max_sent_len = curr_sent_len
if args.use_official_pretrained:
input_ids = tok([sent], return_tensors="pt", padding=True).input_ids
else:
input_ids = tok([lang + " " + sent], add_special_tokens=False, return_tensors="pt", padding=True).input_ids

input_batch.append(lang + " " + src_sent)
end = time.time()

yield input_ids

curr_batch_count += 1
if curr_batch_count == args.batch_size:
input_ids = tok(input_batch, add_special_tokens=False, return_tensors="pt", padding=True, max_length=max_sent_len).input_ids
if args.hard_truncate_length > 0 and len(input_ids[0]) > args.hard_truncate_length:
input_ids = input_ids[:,:args.hard_truncate_length]
end = time.time()

yield input_ids, input_masks

curr_batch_count = 0
input_batch = []
max_sent_len = 0

if len(input_batch) != 0:
input_ids = tok(input_batch, add_special_tokens=False, return_tensors="pt", padding=True, max_length=max_sent_len).input_ids
if args.hard_truncate_length > 0 and len(input_ids[0]) > args.hard_truncate_length:
input_ids = input_ids[:,:args.hard_truncate_length]
yield input_ids, input_masks


def plot_attention(data, X_label=None, Y_label=None, num_layers=None, num_heads=None, file_name=None, plot_title=None):
'''
Expand Down Expand Up @@ -1369,9 +1370,12 @@ def plot_attention(data, X_label=None, Y_label=None, num_layers=None, num_heads=
plt.close(fig) # close the figure


def generate_batches_monolingual_masked_or_bilingual(tok, args, rank, files, train_files, ctr):
def generate_batches_monolingual_masked_or_bilingual(tok, args, rank, files, train_files):
"""This will return masked monolingual or bilingual batches according to a fixed ratio."""
if args.bilingual_train_frequency != -1 and ctr % args.bilingual_train_frequency == 0:
return generate_batches_bilingual(tok, args, train_files, rank)
else:
return generate_batches_monolingual_masked(tok, args, files, rank)
bilingual_generator = generate_batches_bilingual(tok, args, train_files, rank)
monolingual_generator = generate_batches_monolingual_masked(tok, args, files, rank)
while True:
if args.bilingual_train_frequency != 0.0 and random.random() <= args.bilingual_train_frequency:
yield next(bilingual_generator), True
else:
yield next(monolingual_generator), False
11 changes: 7 additions & 4 deletions decode_nmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def model_create_load_decode(gpu, args):
if "mbart" in args.model_path:
model = MBartForConditionalGeneration.from_pretrained(args.model_path) ## This is only to avoid having to specify the hyperparams manually assuming you fine-tuned an official model. If you know the hyperparams then dont use this.
elif "bart" in args.model_path:
model = BartForConditionalGeneration.from_pretrained(args.model_path) ## This is only to avoid having to specify the hyperparams manually assuming you fine-tuned an official model. If you know the hyperparams then dont use this.
model = BartForConditionalGeneration.from_pretrained(args.model_path, force_bos_token_to_be_generated=True) ## This is only to avoid having to specify the hyperparams manually assuming you fine-tuned an official model. If you know the hyperparams then dont use this.
else:
config = MBartConfig(vocab_size=len(tok), encoder_layers=args.encoder_layers, decoder_layers=args.decoder_layers, dropout=args.dropout, attention_dropout=args.attention_dropout, activation_dropout=args.activation_dropout, encoder_attention_heads=args.encoder_attention_heads, decoder_attention_heads=args.decoder_attention_heads, encoder_ffn_dim=args.encoder_ffn_dim, decoder_ffn_dim=args.decoder_ffn_dim, d_model=args.d_model, no_embed_norm=args.no_embed_norm, scale_embedding=args.scale_embedding, pad_token_id=tok.pad_token_id, eos_token_id=tok(["</s>"], add_special_tokens=False).input_ids[0][0], bos_token_id=tok(["<s>"], add_special_tokens=False).input_ids[0][0], encoder_tying_config=args.encoder_tying_config, decoder_tying_config=args.decoder_tying_config, multilayer_softmaxing=args.multilayer_softmaxing, wait_k=args.wait_k, additional_source_wait_k=args.additional_source_wait_k, unidirectional_encoder=args.unidirectional_encoder, multi_source=args.multi_source, multi_source_method=args.multi_source_method, softmax_temperature=args.softmax_temperature, temperature_calibration=args.temperature_calibration, no_scale_attention_embedding=args.no_scale_attention_embedding, positional_encodings=args.positional_encodings) ## Configuration.
model = MBartForConditionalGeneration(config)
Expand All @@ -118,7 +118,7 @@ def model_create_load_decode(gpu, args):
else:
if args.use_official_pretrained and args.locally_fine_tuned_model_path is not None: ## If we want to decode a locally fine-tuned version of an official model.
args.model_path = args.locally_fine_tuned_model_path
map_location = {'cuda:%d' % 0: 'cuda:%d' % rank}
map_location = {'cuda:%d' % 0: 'cuda:%d' % gpu}
checkpoint_dict = torch.load(args.model_path, map_location=map_location)
if type(checkpoint_dict) == dict:
model.load_state_dict(remap_embeddings_eliminate_components_and_eliminate_mismatches(model.state_dict(), remap_layers(checkpoint_dict['model'], 4, args), args), strict=True if (args.remap_encoder == "" and args.remap_decoder == "" and not args.eliminate_encoder_before_initialization and not args.eliminate_decoder_before_initialization and not args.eliminate_embeddings_before_initialization) else False) ## Modification needed if we want to load a partial model trained using multilayer softmaxing.
Expand All @@ -143,9 +143,12 @@ def model_create_load_decode(gpu, args):
with torch.no_grad():
translations = model.module.generate(input_ids.to(gpu), use_cache=True, num_beams=args.beam_size, max_length=int((len(input_ids[0])*args.max_decode_length_multiplier) if args.max_decode_length_multiplier > 0 else -args.max_decode_length_multiplier), min_length=int((len(input_ids[0])*args.min_decode_length_multiplier) if args.min_decode_length_multiplier > 0 else -args.min_decode_length_multiplier), early_stopping=True, attention_mask=input_masks.to(gpu), pad_token_id=tok.pad_token_id, eos_token_id=tok(["</s>"], add_special_tokens=False).input_ids[0][0], decoder_start_token_id=tok([args.tlang if args.use_official_pretrained else "<2"+args.tlang+">"], add_special_tokens=False).input_ids[0][0], bos_token_id=tok(["<s>"], add_special_tokens=False).input_ids[0][0], length_penalty=args.length_penalty, repetition_penalty=args.repetition_penalty, encoder_no_repeat_ngram_size=args.encoder_no_repeat_ngram_size, no_repeat_ngram_size=args.no_repeat_ngram_size, num_return_sequences=args.beam_size if args.return_all_sequences else 1, additional_input_ids=input_ids_parent.to(gpu) if args.multi_source else None, additional_input_ids_mask=input_masks_parent.to(gpu) if args.multi_source else None) ## We translate the batch.
print(len(input_ids), "in and", len(translations), "out")
if args.return_all_sequences:
input_ids = input_ids.repeat(args.beam_size,1)
for input_id, translation in zip(input_ids, translations):
translation = tok.decode(translation, skip_special_tokens=args.no_skip_special_tokens, clean_up_tokenization_spaces=False)
input_id = tok.decode(input_id, skip_special_tokens=args.no_skip_special_tokens, clean_up_tokenization_spaces=False) ### Get the raw sentences.
# print(input_id, " ### ", translation)
outf.write(translation+"\n")
outf.flush()
hyp.append(translation)
Expand Down Expand Up @@ -425,8 +428,8 @@ def run_demo():
help='Should we return all beam sequences?')
parser.add_argument('--no_skip_special_tokens', action='store_false',
help='Should we return outputs without special tokens? We may need this to deal with situations where the user specified control tokens must be in the output.')
parser.add_argument('--multilayer_softmaxing', action='store_true',
help='Should we apply a softmax for each decoder layer? Unsupported for distillation. Only for vanilla training.')
parser.add_argument('--multilayer_softmaxing', default=None,
help='Should we apply a softmax for each decoder layer? Unsupported for distillation. Only for vanilla training. You have to specify a comma separated list of the intermediate layers which you want to softmax. These go from 0 for the embedding layer to L-2 for the penultimate layer.')
parser.add_argument('--remap_encoder', default='', type=str,
help='This indicates the remappings for the layer. Example: 1-2,2-4,3-6. The plan is to use these remappings to cut down the model prior to decoding or training. Suppose we have a 6 layer model but we only want to utilize the 2nd, 4th and 6th layer then we will copy the content of the 2nd, 4th and 6th layers to the 1st, 2nd and 3rd layer and delete the former layers from the parameter dictionary. This counts as layer pruning. IMPORTANT NOTE: Ensure that you specify ALL child layer indices you wish mapped. For example if you want 1-2,2-1,3-3 you MUST NOT skip the 3-3 part else it will be deleted from the model dictionary and will be randomly initialized. The loading mechanism is not strict so it will ignore missing or non matching keys. ADDITIONAL NOTE: Load a checkpoint with only the model and not the optimizer to prevent failure as we are not sure if remapping optimizers and learning rate schedulers make sense or not.')
parser.add_argument('--remap_decoder', default='', type=str,
Expand Down
Loading

0 comments on commit d3843f1

Please sign in to comment.