-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel_creator.py
123 lines (100 loc) · 4.51 KB
/
model_creator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import os
import sys
import json
import torch.nn
import torchvision
import torch.backends
import src.model_creator as mcSrc
class ModelCreator(mcSrc.ModelCreator):
#{{{
def setup_optimiser(self, params, model):
if params.optimiser == 'sgd':
return torch.optim.SGD(model.parameters(), lr=params.lr, momentum=params.momentum, weight_decay=params.weight_decay)
elif params.optimiser == 'rmsprop':
return torch.optim.RMSprop(model.parameters(), lr=params.lr, momentum=params.momentum, weight_decay=params.weight_decay)
def read_model(self, params):
#{{{
if params.dataset == 'cifar10' :
import src.ar4414.pruning.models.cifar as models
num_classes = 10
scale = 224/32
elif params.dataset == 'cifar100' :
import src.ar4414.pruning.models.cifar as models
num_classes = 100
scale = 224/32
else :
import src.ar4414.pruning.models.imagenet as models
num_classes = 1000
scale = 1
print("Creating Model %s" % params.arch)
if 'resnet' in params.arch:
if 'cifar' in params.dataset:
model = models.__dict__[params.arch](num_classes=num_classes, depth=params.depth)
else:
model = models.__dict__['resnet{}'.format(params.depth)](pretrained=False, progress=False)
elif 'efficientnet' in params.arch:
#TODO: make all the parameters below configurable
model = models.__dict__[params.arch](num_classes=num_classes, scale=scale)
else:
model = models.__dict__[params.arch](num_classes=num_classes)
return model
#}}}
def load_pretrained(self, params, model):
#{{{
if params.resume or params.branch or params.entropy or params.pruneFilters or params.binSearch:
if params.pretrained is not None:
checkpoint = torch.load(params.pretrained)
model.load_state_dict(checkpoint)
else:
print("WARNING: Attempting to run something which would normally require a pretrained model without one - check model_creator.py")
elif params.fbsPruning:
device_id = params.gpu_list[0]
location = 'cuda:'+str(device_id)
checkpoint = torch.load(params.pretrained, map_location=location)
if 'E_g_x' in str(checkpoint.keys()):
checkpoint = {k.replace('module.','') : v for k,v in checkpoint.items()}
model.module.load_state_dict(checkpoint, initialise=False)
else:
checkpoint = {k.replace('module.','') : v for k,v in checkpoint.items()}
model.module.load_state_dict(checkpoint, initialise=True)
elif params.finetune or params.getGops:
device_id = params.gpu_list[0]
location = 'cuda:'+str(device_id)
checkpoint = torch.load(params.pretrained, map_location=location)
if params.getGops == True:
masks = [v for k,v in checkpoint.items() if 'mask' in k]
if masks != []:
print('Setting pruning masks')
model.module.set_masks(masks)
# model.module.load_state_dict(checkpoint)
model.load_state_dict(checkpoint)
elif params.evaluate or params.unprunedTestAcc:
device_id = params.gpu_list[0]
location = 'cuda:'+str(device_id)
checkpoint = torch.load(params.pretrained, map_location=location)
model.load_state_dict(checkpoint)
elif params.noFtChannelsPruned:
device_id = params.gpu_list[0]
location = 'cuda:'+str(device_id)
# get model from logs
logs = params.logs
with open(logs, 'r') as jFile:
logs = json.load(jFile)
preTrainedModel = logs[params.arch]['pre_ft_model']
checkpoint = torch.load(preTrainedModel, map_location=location)
model.load_state_dict(checkpoint)
torch.backends.cudnn.benchmark = True
print('Total params: %.2fM' % (sum(p.numel() for p in model.parameters())/1000000.0))
return model
#}}}
#}}}
def get_model_size(model):
#{{{
params = 0
for p in model.named_parameters():
paramsInLayer = 1
for dim in p[1].size():
paramsInLayer *= dim
params += (paramsInLayer * 4) / 1e6
return params
#}}}