-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathgrad.py
76 lines (66 loc) · 2.71 KB
/
grad.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
#start#
import os, glob, sys, time, torch
from torch.optim import lr_scheduler
# from torch.cuda import amp
torch.set_printoptions(precision=3)
class GradUtil(object):
def __init__(self, model, loss='ce', lr=0.01, wd=2e-4, root='.'):
self.path_checkpoint = os.path.join(root, 'super_params.tar')
if not os.path.exists(root):
os.makedirs(root)
self.lossName = loss
self.criterion = get_loss(loss)
params = filter(lambda p:p.requires_grad, model.parameters())
self.optimizer = RAdamW(params=params, lr=lr, weight_decay=2e-4)
self.scheduler = ReduceLR(name=loss, optimizer=self.optimizer,
mode='min', factor=0.7, patience=2,
verbose=True, threshold=0.0001, threshold_mode='rel',
cooldown=2, min_lr=1e-5, eps=1e-9)
def isLrLowest(self, thresh=1e-5):
return self.optimizer.param_groups[0]['lr']<thresh
coff_ds = 0.5
def calcGradient(self, criterion, outs, true, fov=None):
lossSum = 0#torch.autograd.Variable(torch.tensor(0, dtype=torch.float32), requires_grad=True)
if isinstance(outs, (list, tuple)):
# ratio = 1/(1+len(outs))
for i in range(len(outs)-1,0,-1):#第一个元素尺寸最大
# print('输出形状:', outs[i].shape, true.shape)
# true = torch.nn.functional.interpolate(true, size=outs[i].shape[-2:], mode='nearest')
loss = criterion(outs[i], true)#, fov
lossSum = lossSum + loss*self.coff_ds
outs = outs[0]
# print(outs.shape, true.shape)
lossSum = lossSum + criterion(outs, true)#, fov
return lossSum
def backward_seg(self, pred, true, fov=None, model=None, requires_grad=True, losInit=[]):
self.optimizer.zero_grad()
costList = []
#torch.autograd.Variable(torch.tensor(0, dtype=torch.float32), requires_grad=True)
los = self.calcGradient(self.criterion, pred, true, fov)
costList.append(los)
self.total_loss += los.item()
del pred, true, los
if isinstance(losInit, list) and len(losInit)>0:#hasattr(losInit, 'item'):#not isinstance(losInit, int):
costList.extend(losInit)
losSum = sum(costList)
losStr = ','.join(['{:.4f}'.format(los.item()) for los in costList])
if requires_grad:
losSum.backward()#梯度归一化
#梯度裁剪
# nn.utils.clip_grad_value_(model.parameters(), clip_value=1.1)#clip_value=1.1
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=20, norm_type=2)#(最大范数,L2)
self.optimizer.step()
return losSum.item(), losStr
total_loss = 0
def update_scheduler(self, i=0):
logStr = '\r{:03}# '.format(i)
# losSum = 0
logStr += '{}={:.4f},'.format(self.lossName, self.total_loss)
print(logStr, end='')
# self.callBackEarlyStopping(los=losSum)
if isinstance(self.scheduler, ReduceLR):
self.scheduler.step(self.total_loss)
else:
self.scheduler.step()
self.total_loss = 0
#end#