forked from kmeng01/rome
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrome_main.py
177 lines (143 loc) · 5.56 KB
/
rome_main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
from copy import deepcopy
from typing import Dict, List, Tuple
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from util import nethook
from util.generate import generate_fast
from .compute_u import compute_u
from .compute_v import compute_v
from .rome_hparams import ROMEHyperParams
CONTEXT_TEMPLATES_CACHE = None
def apply_rome_to_model(
model: AutoModelForCausalLM,
tok: AutoTokenizer,
requests: List[Dict],
hparams: ROMEHyperParams,
copy=False,
return_orig_weights=False,
) -> Tuple[AutoModelForCausalLM, List[str]]:
"""
Returns a model with the desired changes.
:param copy: If true, will preserve the original model while creating a new one to edit.
Note that you are responsible for deallocating the new model's memory to avoid leaks.
:return: (1) the updated model, (2) an original copy of the weights that changed
"""
if copy:
model = deepcopy(model)
weights_copy = {}
for i, request in enumerate(requests):
deltas = execute_rome(model, tok, request, hparams)
with torch.no_grad():
for w_name, (delta_u, delta_v) in deltas.items():
upd_matrix = delta_u.unsqueeze(1) @ delta_v.unsqueeze(0)
w = nethook.get_parameter(model, w_name)
upd_matrix = upd_matrix_match_shape(upd_matrix, w.shape)
if return_orig_weights and w_name not in weights_copy:
assert i == 0
weights_copy[w_name] = w.detach().clone()
w[...] += upd_matrix
print(f"New weights successfully inserted into {list(deltas.keys())}")
return model, weights_copy
def execute_rome(
model: AutoModelForCausalLM,
tok: AutoTokenizer,
request: Dict,
hparams: ROMEHyperParams,
) -> Dict[str, Tuple[torch.Tensor]]:
"""
Executes the ROME update algorithm for the specified update at the specified layer
Invariant: model at beginning of function == model at end of function
"""
# Update target and print info
request = deepcopy(request)
if request["target_new"]["str"][0] != " ":
# Space required for correct tokenization
request["target_new"]["str"] = " " + request["target_new"]["str"]
print(
f"Executing ROME algorithm for the update: "
f"[{request['prompt'].format(request['subject'])}] -> [{request['target_new']['str']}]"
)
# Retrieve weights that user desires to change
weights = {
f"{hparams.rewrite_module_tmp.format(layer)}.weight": nethook.get_parameter(
model, f"{hparams.rewrite_module_tmp.format(layer)}.weight"
)
for layer in hparams.layers
}
# Save old weights for future restoration
weights_copy = {k: v.detach().clone() for k, v in weights.items()}
# Update loop: sequentially intervene at each specified layer
deltas = {}
for layer in sorted(hparams.layers):
# Compute rank-1 update matrix
left_vector: torch.Tensor = compute_u(
model,
tok,
request,
hparams,
layer,
get_context_templates(model, tok, hparams.context_template_length_params),
)
print("Left vector shape:", left_vector.shape)
right_vector: torch.Tensor = compute_v(
model,
tok,
request,
hparams,
layer,
left_vector,
get_context_templates(model, tok, hparams.context_template_length_params),
)
print("Right vector shape:", right_vector.shape)
with torch.no_grad():
# Determine correct transposition of delta matrix
weight_name = f"{hparams.rewrite_module_tmp.format(layer)}.weight"
upd_matrix = left_vector.unsqueeze(1) @ right_vector.unsqueeze(0)
upd_matrix = upd_matrix_match_shape(upd_matrix, weights[weight_name].shape)
# Update model weights and record desired changes in `delta` variable
weights[weight_name][...] += upd_matrix
deltas[weight_name] = (
left_vector.detach(),
right_vector.detach(),
)
# Restore state of original model
with torch.no_grad():
for k, v in weights.items():
v[...] = weights_copy[k]
print(f"Deltas successfully computed for {list(weights.keys())}")
return deltas
def upd_matrix_match_shape(matrix: torch.Tensor, shape: torch.Size) -> torch.Tensor:
"""
GPT-2 and GPT-J have transposed weight representations.
Returns a matrix that matches the desired shape, else raises a ValueError
"""
if matrix.shape == shape:
return matrix
elif matrix.T.shape == shape:
return matrix.T
else:
raise ValueError(
"Update matrix computed by ROME does not match original weight shape. "
"Check for bugs in the code?"
)
def get_context_templates(model, tok, length_params):
global CONTEXT_TEMPLATES_CACHE
if CONTEXT_TEMPLATES_CACHE is None:
CONTEXT_TEMPLATES_CACHE = ["{}"] + [
x + ". {}"
for x in sum(
(
generate_fast(
model,
tok,
["<|endoftext|>"],
n_gen_per_prompt=n_gen,
max_out_len=length,
)
for length, n_gen in length_params
),
[],
)
]
print(f"Cached context templates {CONTEXT_TEMPLATES_CACHE}")
return CONTEXT_TEMPLATES_CACHE