From 2f6d38b4b4af6758e1e90a57fdf7de4ad3859f22 Mon Sep 17 00:00:00 2001 From: Emily Sheng Date: Sat, 5 Dec 2020 00:02:39 -0800 Subject: [PATCH] Removing hardcoded vocab size --- src/create_adv_token.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/create_adv_token.py b/src/create_adv_token.py index f29d63e..d1c20cb 100644 --- a/src/create_adv_token.py +++ b/src/create_adv_token.py @@ -22,18 +22,18 @@ def extract_grad_hook(module, grad_in, grad_out): # Returns the wordpiece embedding weight matrix. -def get_embedding_weight(language_model): +def get_embedding_weight(language_model, vocab_size): for module in language_model.modules(): if isinstance(module, torch.nn.Embedding): - if module.weight.shape[0] == 50257: # Only add a hook to wordpiece embeddings, not position embeddings. + if module.weight.shape[0] == vocab_size: # Only add a hook to wordpiece embeddings, not position embeddings. return module.weight.detach() # Add hooks for embeddings. -def add_hooks(language_model): +def add_hooks(language_model, vocab_size): for module in language_model.modules(): if isinstance(module, torch.nn.Embedding): - if module.weight.shape[0] == 50257: # Only add a hook to wordpiece embeddings, not position. + if module.weight.shape[0] == vocab_size: # Only add a hook to wordpiece embeddings, not position. module.weight.requires_grad = True module.register_backward_hook(extract_grad_hook) @@ -287,13 +287,13 @@ def run_model(): model = AutoModelWithLMHead.from_pretrained(params.model_name_or_path) tokenizer = AutoTokenizer.from_pretrained( params.tokenizer_name if params.tokenizer_name else params.model_name_or_path) + total_vocab_size = len(tokenizer) model.eval() model.to(device) - add_hooks(model) # add gradient hooks to embeddings - embedding_weight = get_embedding_weight(model) # save the word embedding matrix + add_hooks(model, total_vocab_size) # add gradient hooks to embeddings + embedding_weight = get_embedding_weight(model, total_vocab_size) # save the word embedding matrix - total_vocab_size = 50257 # total number of subword pieces in the GPT-2 model enc_trigger_init = tokenizer.encode('The ' + params.trigger_init)[1:] trigger_init_len = len(enc_trigger_init) old_num_trigger_tokens = params.num_trigger_tokens