-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathscan_locker.py
101 lines (83 loc) · 2.81 KB
/
scan_locker.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
'''
Date: 2024-10-01 13:49:57
LastEditors: Jiaqi Gu && [email protected]
LastEditTime: 2024-10-01 13:52:33
FilePath: /ONN_Reliable/scan_locker.py
'''
import argparse
import os
import pickle
from copy import deepcopy
import torch
from pyutils.config import configs
from pyutils.general import ensure_dir
from pyutils.torch_train import load_model, set_torch_deterministic
from core.builder import make_criterion, make_dataloader, make_model
from core.models.attack_defense.post_locker import smart_locker
from core.models.layers.gemm_conv2d import GemmConv2d
from core.models.layers.gemm_linear import GemmLinear
def reset_model(model):
load_model(
model,
configs.checkpoint.restore_checkpoint,
ignore_size_mismatch=int(configs.checkpoint.no_linear),
)
def generate_statistics(model, criterion, device, validation_loader, eta: float):
locker = smart_locker(
model=model, criterion=criterion, cluster_method="normal", device=device
)
L_K, W_K, G_size = locker.smart_locking(
eta=eta,
val_loader=validation_loader,
)
model.calculate_signature(G_size=G_size)
locker.calculate_mem_ov()
return L_K, W_K, G_size
def scan_locker(model, validation_loader, criterion, eta: float = 0.0):
model_copy = deepcopy(model)
set_torch_deterministic(configs.noise.random_state + eta * 100)
L_K, W_K, G_size = generate_statistics(
model=model_copy,
validation_loader=validation_loader,
criterion=criterion,
eta=eta,
device=device,
)
folder = (
f"./EXP_data/Locker/{configs.model.name}/sens-aware" # + Customized Locking
)
ensure_dir(folder)
f_save = open(
os.path.join(folder, f"{configs.quantize.N_bits}_bit_NoO_grad_LK_{eta}.pkl"),
"wb",
)
pickle.dump(L_K, f_save)
f_save = open(
os.path.join(folder, f"{configs.quantize.N_bits}_bit_NoO_grad_WK_{eta}.pkl"),
"wb",
)
pickle.dump(W_K, f_save)
f_save = open(
os.path.join(folder, f"{configs.quantize.N_bits}_bit_NoO_grad_G_{eta}.pkl"),
"wb",
)
pickle.dump(G_size, f_save)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("config", metavar="FILE", help="config file")
args, opts = parser.parse_known_args()
configs.load(args.config, recursive=True)
device = torch.device("cuda")
_, validation_loader = make_dataloader()
criterion = make_criterion().to(device)
model = make_model(device=device)
reset_model(model)
for name, module in model.named_modules():
if isinstance(module, (GemmConv2d, GemmLinear)):
module.weight_quantizer.to_two_com()
scan_locker(
model=model,
validation_loader=validation_loader,
eta=configs.defense.eta,
criterion=criterion,
)