Skip to content

Commit

Permalink
Removing unnecessary utils
Browse files Browse the repository at this point in the history
  • Loading branch information
ewsheng committed Nov 15, 2020
1 parent 23d6506 commit 4dfeeef
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 167 deletions.
19 changes: 14 additions & 5 deletions src/create_adv_token.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,20 @@
from transformers import AutoModelWithLMHead, AutoTokenizer, GPT2Tokenizer, GPT2LMHeadModel
import attacks
import constants
import utils
import collections
import heapq
import string


# hook used in add_hooks()
extracted_grads = []


def extract_grad_hook(module, grad_in, grad_out):
global extracted_grads
extracted_grads.append(grad_out[0])


# Returns the wordpiece embedding weight matrix.
def get_embedding_weight(language_model):
for module in language_model.modules():
Expand All @@ -27,7 +35,7 @@ def add_hooks(language_model):
if isinstance(module, torch.nn.Embedding):
if module.weight.shape[0] == 50257: # Only add a hook to wordpiece embeddings, not position.
module.weight.requires_grad = True
module.register_backward_hook(utils.extract_grad_hook)
module.register_backward_hook(extract_grad_hook)


# Gets the loss of the target_tokens using the triggers as the context.
Expand Down Expand Up @@ -214,6 +222,7 @@ def keep_candidate_token(candidate):


def run_model():
global extracted_grads

parser = argparse.ArgumentParser()
parser.add_argument('--neg_sample_file', default='', help='File of negative regard target samples.')
Expand Down Expand Up @@ -599,7 +608,7 @@ def run_model():
print(tokenizer.decode(trigger_tokens), trigger_tokens)

model.zero_grad()
utils.extracted_grads = [] # Each element is (batch_size, sample_length, 768_embed_dim).
extracted_grads = [] # Each element is (batch_size, sample_length, 768_embed_dim).
loss_types = [] # Order of `add` and `sub` loss types.
demo_types = [] # Order of `neg` or `pos` demographic types.
for idx, (typ, demo_type, target_tokens) in enumerate(all_items):
Expand Down Expand Up @@ -649,7 +658,7 @@ def run_model():
add_indices = [i for i, loss_type in enumerate(loss_types) if loss_type == 'add']
add_extracted_grads = []
for i in add_indices:
extracted_grad = utils.extracted_grads[i]
extracted_grad = extracted_grads[i]
if params.use_weighted_neg and demo_types[i] == 'neg': # Amplify neg associations.
extracted_grad *= 2
add_extracted_grads.append(extracted_grad)
Expand All @@ -663,7 +672,7 @@ def run_model():
sub_indices = [i for i, loss_type in enumerate(loss_types) if loss_type == 'sub']
sub_extracted_grads = []
for i in sub_indices:
extracted_grad = utils.extracted_grads[i]
extracted_grad = extracted_grads[i]
if params.use_weighted_neg and demo_types[i] == 'neg': # Amplify neg associations.
extracted_grad *= 2
sub_extracted_grads.append(extracted_grad)
Expand Down
162 changes: 0 additions & 162 deletions src/utils.py

This file was deleted.

0 comments on commit 4dfeeef

Please sign in to comment.