-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathmain_build-dict.py
69 lines (51 loc) · 2.26 KB
/
main_build-dict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import os
from parser import cfg
import PatchAttack.TextureDict_builder as TD_builder
from PatchAttack.PatchAttack_config import configure_PA
import PatchAttack.AdvPatchDict_builder as AP_builder
from PatchAttack import utils
import torchvision.models as Models
torch_cuda = cfg.torch_cuda
def main():
configure_PA(cfg.tdict_dir, cfg.t_labels)
if cfg.dict == 'Texture':
if cfg.t_data == 'ImageNet':
# ImageNet data agent for texture generation
DA = utils.data_agent(
ImageNet_train_dir=cfg.ImageNet_train_dir,
ImageNet_val_dir=cfg.ImageNet_val_dir,
data_name='ImageNet',
train_transform=utils.data_agent.process_PIL
)
TD_builder.build(DA, cfg.t_labels)
elif cfg.t_data == 'custom':
# custom dataset requirement:
# attribute -- targets: list consisting of the labels (int)
# methods -- __getitem__(): return image (torch.Tensor), label (int)
#custom_dataset = ...
#DA = TD_builder.custom_data_agent(custom_dataset)
#TD_builder.build(DA, cfg.t_labels)
assert False, 'Please see the commented requirements to build custom data agent'
elif cfg.dict == 'AdvPatch':
# model
model = getattr(Models, cfg.arch.lower()+str(cfg.depth))(
pretrained=True
).cuda(torch_cuda).eval()
if cfg.t_data == 'ImageNet':
# ImageNet data agent for AdvPatch generation
DA = utils.data_agent(
ImageNet_train_dir=cfg.ImageNet_train_dir,
ImageNet_val_dir=cfg.ImageNet_val_dir,
data_name='ImageNet',
train_transform=utils.data_agent.process_PIL,
)
AP_builder.build(model, cfg.t_labels, DA)
elif cfg.t_data == 'custom':
# custom dataset requirement:
# methods -- __getitem__(): return image (torch.Tensor), label (int)
#custom_dataset = ...
#DA = AP_builder.custom_data_agent(custom_dataset)
#AP_builder.build(model, cfg.t_labels, DA)
assert False, 'Please see the commented requirements to build custom data agent'
if __name__ == '__main__':
main()