-
Notifications
You must be signed in to change notification settings - Fork 2
/
optimizer.py
273 lines (221 loc) · 8.26 KB
/
optimizer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from logging import getLogger
from typing import Dict, Any
import itertools
logger = getLogger(__name__)
def optimize_v1(result: Dict[str, Any], bits: float) -> None:
"""
Optimizes the quantization settings to achieve a target bitrate while minimizing the maximum quantization error.
Args:
result: The result dictionary containing the quantization error for each qparams.
bits: The target number of bits per weight.
"""
eps = 0.0001
numel = 0
max_rfn = 0.0
for _, layer in result.items():
numel += layer["numel"]
for option in layer["options"]:
max_rfn = max(max_rfn, option["err"])
# max_rfn -= eps
min_rfn = 0
best_rfn = 10000.0
target_bpw = bits
# Binary search for combination of settings that minimizes max rfn_error while
invalid = False
min_diff = 0.00001
while max_rfn - min_rfn > min_diff or invalid:
target_rfn = (min_rfn + max_rfn) / 2
invalid = False
current_total_bits = 0
for layer in result.values():
best_option = None
best_bpw = 10000.0
for option in layer["options"]:
if option["bpw"] < best_bpw and option["err"] <= target_rfn:
best_bpw = option["bpw"]
best_option = option
layer["best_option_max"] = best_option
if best_option is None:
invalid = True
break
current_total_bits += int(layer["best_option_max"]["total_bits"])
current_bpw = current_total_bits / numel
if not invalid:
logger.info(f" -- rfn max: {target_rfn:2.5f} bpw: {current_bpw:2.5f}")
else:
logger.info(f" -- rfn max: {target_rfn:2.5f} (not possible)")
if current_bpw <= target_bpw and not invalid:
best_rfn = min(best_rfn, target_rfn)
max_rfn = target_rfn
else:
min_rfn = target_rfn
max_rfn += eps
# We've found the smallest error that can be met by _all_ layers while staying below the set no. bits.
# Now select a minimum target to allow some layers to use more accurate settings if we didn't meet the
# target bitrate
max_rfn = max(target_rfn, best_rfn)
min_rfn = 0
min_diff = 0.00001
while max_rfn - min_rfn > min_diff:
target_rfn = (min_rfn + max_rfn) / 2
invalid = False
current_total_bits = 0
for layer in result.values():
best_option = None
best_rfn = 10000.0
for option in layer["options"]:
if best_rfn > option["err"] >= target_rfn and option[
"err"] < layer["best_option_max"]["err"]:
best_rfn = option["err"]
best_option = option
if best_option is None:
layer["best_option"] = layer["best_option_max"]
else:
layer["best_option"] = best_option
current_total_bits += int(layer["best_option"]["total_bits"])
current_bpw = current_total_bits / numel
logger.info(f" -- rfn min: {target_rfn:2.5f} bpw: {current_bpw:2.5f}")
if current_bpw <= target_bpw:
max_rfn = target_rfn
else:
min_rfn = target_rfn
def optimize_v2(result: Dict[str, Any], target_bpw: float) -> None:
error_norm = 2.4
max_step_size = 2
numel = 0
for _, layer in result.items():
numel += layer["numel"]
weight_budget = numel * target_bpw
# Compile options
def fn(x):
return 1 - ((1 - x) ** error_norm)
weights = []
values = []
params = []
for _, layer in result.items():
v = [fn(option["accuracy"]) for option in layer["options"]]
w = [option["total_bits"] for option in layer["options"]]
weights.append(w)
values.append(v)
params.append(layer["options"])
# Sort options by weight, eliminate strictly worse options
logger.info(" -- Pruning...")
for i in range(len(weights)):
combined = sorted(zip(weights[i], values[i], params[i]))
w_, v_, p_ = zip(*combined)
w_ = list(w_)
v_ = list(v_)
p_ = list(p_)
j = 1
while j < len(v_):
if v_[j] <= v_[j - 1]:
w_.pop(j)
v_.pop(j)
p_.pop(j)
else:
j += 1
weights[i] = w_
values[i] = v_
params[i] = p_
# Quick and dirty iterative solver
logger.info(" -- Solving...")
f_solution = [0] * len(weights)
weight = 0
value = 1
for w, v in zip(weights, values):
weight += w[0]
value *= v[0]
while True:
min_idx = -1
min_value = float("inf")
for i in range(len(f_solution)):
s = f_solution[i]
if values[i][s] < min_value:
if s < len(weights[i]) - 1:
added_w = weights[i][s + 1] - weights[i][s]
if added_w + weight <= weight_budget:
min_idx = i
min_value = values[i][s]
if min_idx == -1:
break
s = f_solution[min_idx]
weight += weights[min_idx][s + 1] - weights[min_idx][s]
value *= values[min_idx][s + 1] / values[min_idx][s]
f_solution[min_idx] += 1
bpw = weight / numel
logger.info(f" -- Score: {value:.8f} bpw: {bpw:.4f}")
def improve(solution, s_weight, hold=None):
if hold is None:
hold = []
best_idx = -1
best_ratio = 0
best_add_w = 0
best_add_v = 0
for idx in range(len(solution)):
if idx in hold:
continue
si = solution[idx]
if si == len(weights[idx]) - 1:
continue
add_w = weights[idx][si + 1] - weights[idx][si]
if s_weight + add_w > weight_budget:
continue
add_v = values[idx][si + 1] / values[idx][si]
ratio = add_v / add_w
if ratio > best_ratio:
best_ratio = ratio
best_idx = idx
best_add_w = add_w
best_add_v = add_v
return best_idx, best_add_w, best_add_v
best_value = value
prev_best_value = value
step_size = 1
while True:
for i, j in itertools.permutations(range(len(f_solution)), 2):
t_solution = f_solution.copy()
t_solution[i] = max(t_solution[i] - step_size, 0)
t_solution[j] = max(t_solution[j] - step_size, 0)
t_weight = 0
t_value = 1
for k, idx in enumerate(t_solution):
t_weight += weights[k][idx]
t_value *= values[k][idx]
while True:
b_idx, b_add_w, b_add_v = improve(t_solution, t_weight, [i, j])
if b_idx == -1:
break
t_solution[b_idx] += 1
t_weight += b_add_w
t_value *= b_add_v
if t_value > best_value:
f_solution = t_solution
best_value = t_value
break
if best_value == prev_best_value:
step_size += 1
if step_size > max_step_size:
break
continue
bpw = t_weight / numel
logger.info(f" -- Score: {best_value:.8f} bpw: {bpw:.4f}")
prev_best_value = best_value
# Save strategy
logger.info(" -- Quantization strategy:")
for i, (name, layer) in enumerate(result.items()):
param = params[i][f_solution[i]]
layer["best_option"] = param
bpw = param["total_bits"] / layer["numel"]
err = 1 - param["accuracy"]
logger.info(f" -- {name:50} {bpw:1.4f} bpw - exp. error: {err:1.8f}")