Skip to content

Commit

Permalink
Removing hardcoded vocab size
Browse files Browse the repository at this point in the history
  • Loading branch information
ewsheng committed Dec 5, 2020
1 parent 3214c7a commit 2f6d38b
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions src/create_adv_token.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,18 @@ def extract_grad_hook(module, grad_in, grad_out):


# Returns the wordpiece embedding weight matrix.
def get_embedding_weight(language_model):
def get_embedding_weight(language_model, vocab_size):
for module in language_model.modules():
if isinstance(module, torch.nn.Embedding):
if module.weight.shape[0] == 50257: # Only add a hook to wordpiece embeddings, not position embeddings.
if module.weight.shape[0] == vocab_size: # Only add a hook to wordpiece embeddings, not position embeddings.
return module.weight.detach()


# Add hooks for embeddings.
def add_hooks(language_model):
def add_hooks(language_model, vocab_size):
for module in language_model.modules():
if isinstance(module, torch.nn.Embedding):
if module.weight.shape[0] == 50257: # Only add a hook to wordpiece embeddings, not position.
if module.weight.shape[0] == vocab_size: # Only add a hook to wordpiece embeddings, not position.
module.weight.requires_grad = True
module.register_backward_hook(extract_grad_hook)

Expand Down Expand Up @@ -287,13 +287,13 @@ def run_model():
model = AutoModelWithLMHead.from_pretrained(params.model_name_or_path)
tokenizer = AutoTokenizer.from_pretrained(
params.tokenizer_name if params.tokenizer_name else params.model_name_or_path)
total_vocab_size = len(tokenizer)
model.eval()
model.to(device)

add_hooks(model) # add gradient hooks to embeddings
embedding_weight = get_embedding_weight(model) # save the word embedding matrix
add_hooks(model, total_vocab_size) # add gradient hooks to embeddings
embedding_weight = get_embedding_weight(model, total_vocab_size) # save the word embedding matrix

total_vocab_size = 50257 # total number of subword pieces in the GPT-2 model
enc_trigger_init = tokenizer.encode('The ' + params.trigger_init)[1:]
trigger_init_len = len(enc_trigger_init)
old_num_trigger_tokens = params.num_trigger_tokens
Expand Down

0 comments on commit 2f6d38b

Please sign in to comment.