-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathchanges.txt
63 lines (48 loc) · 2.24 KB
/
changes.txt
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
Git https://github.com/anonymous-xmepai/MEND-Ours/commit/3993d6e339f8f92b2fb71f06a40b0031f434ada4
mend/mend/losses.py
remove veiw for other models (keep for bloom) in both lines
line 6 - 7
pre = pre.to(torch.float32).view(-1, 1)
post = post.to(torch.float32).view(-1, 1)
mend/mend/data_classes/fever.py
change view to view(-1, 1) in both lines for models other than bloom
lines 89 - 90
batches["predictions"] = torch.tensor(predictions).long().view(-1)
batches["labels"] = torch.tensor(labels).long().view(-1)
mend/mend/config/model/bloom-560m.yaml
added this file
mend/mend/data/fever/bloom_560m.bin
added fine-tuned weights
mend/mend/utils.py
Change allow_unused=False to True for bloom
line 77
def safe_backward(loss, parameters, accumulate=1, allow_unused=True):
Changed for better caching
line 26 - 30
def scr():
if os.path.exists(f"/home/{getpass.getuser()}/mend/mend/cache-ssd"):
scr_dir = f"/home/{getpass.getuser()}/mend/mend/cache-ssd/"
else:
scr_dir = f"/home/{getpass.getuser()}/mend/mend/cache/"
For GPT - Sequence Classification Model
mend/mennd/run.py
line 46 - 49
Added lines
import transformers
if isinstance(model, transformers.GPT2ForSequenceClassification):
model.config.pad_token_id = model.config.eos_token_id
line 59 - 64
Added lines
if config.tests:
train_set = BinaryAugmentedKILT(tokenizer, f"{base_dir}/data/{config.train_set}", config)
val_set = BinaryAugmentedKILT(tokenizer, f"{base_dir}/data/{config.val_set}", config)
else:
train_set = BinaryAugmentedKILT(tokenizer, f"{base_dir}/data/fever/fever-train-kilt.jsonl", config)
val_set = BinaryAugmentedKILT(tokenizer, f"{base_dir}/data/fever/fever-dev-kilt.jsonl", config)
mend/mend/algs/mend.py
Added `or isinstance(self.model, transformers.GPT2ForSequenceClassification)` in if statement in lines
line 127 and 216
mend/mend/data_classes/fever.py
line 22
Added
self.tokenizer.pad_token = tokenizer.eos_token