forked from sseung0703/KD_methods_with_TF
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathop_util.py
executable file
·108 lines (96 loc) · 6.42 KB
/
op_util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import tensorflow as tf
from tensorflow.python.ops import control_flow_ops
def Optimizer_w_Distillation(class_loss, LR, epoch, init_epoch, global_step, Distillation):
with tf.variable_scope('Optimizer_w_Distillation'):
# get variables and update operations
variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
teacher_variables = tf.get_collection('Teacher')
variables = list(set(variables)-set(teacher_variables))
# make optimizer w/ learning rate scheduler
optimize = tf.train.MomentumOptimizer(LR, 0.9, use_nesterov=True)
if Distillation is None:
# training main-task
total_loss = class_loss + tf.add_n(tf.losses.get_regularization_losses())
tf.summary.scalar('loss/total_loss', total_loss)
gradients = optimize.compute_gradients(total_loss, var_list = variables)
elif Distillation == 'Soft_logits':
# multi-task learning with alpha
total_loss = class_loss*0.7 + tf.add_n(tf.losses.get_regularization_losses()) + tf.get_collection('dist')[0]*0.3
tf.summary.scalar('loss/total_loss', total_loss)
gradients = optimize.compute_gradients(total_loss, var_list = variables)
elif Distillation == 'AT' or Distillation == 'RKD':
# simple multi-task learning
total_loss = class_loss + tf.add_n(tf.losses.get_regularization_losses()) + tf.get_collection('dist')[0]
tf.summary.scalar('loss/total_loss', total_loss)
gradients = optimize.compute_gradients(total_loss, var_list = variables)
elif Distillation == 'FitNet' or Distillation == 'FSP':
# initialization and fine-tuning
# in initialization phase, weight decay have to be turn-off which is not trained by distillation
reg_loss = tf.add_n(tf.losses.get_regularization_losses())
distillation_loss = tf.get_collection('dist')[0]
cond = epoch < init_epoch
total_loss = tf.cond(cond, lambda : distillation_loss + reg_loss,
lambda : class_loss + reg_loss)
tf.summary.scalar('loss/total_loss', total_loss)
gradients = optimize.compute_gradients(class_loss, var_list = variables)
gradient_wdecay = optimize.compute_gradients(reg_loss, var_list = variables)
gradient_dist = optimize.compute_gradients(distillation_loss, var_list = variables)
with tf.variable_scope('clip_grad'):
for i, gc, gw, gd in zip(range(len(gradients)),gradients,gradient_wdecay,gradient_dist):
gw = 0. if gw[0] is None else gw[0]
if gd[0] != None:
gradients[i] = (tf.cond(cond, lambda : gw+gd[0], lambda : gw + gc[0]), gc[1])
else:
gradients[i] = (tf.cond(cond, lambda : tf.zeros_like(gc[0]), lambda : gw + gc[0]), gc[1])
elif Distillation == 'AB':
# initialization and fine-tuning
# in initialization phase, weight decay have to be turn-off which is not trained by distillation
reg_loss = tf.add_n(tf.losses.get_regularization_losses())
distillation_loss = tf.get_collection('dist')[0]
KD_loss = tf.get_collection('Logits_KD')[0]
cond = epoch < init_epoch
total_loss = tf.cond(cond, lambda: distillation_loss + reg_loss,
lambda: 0.7 * class_loss + 0.3 * KD_loss + reg_loss)
tf.summary.scalar('loss/total_loss', total_loss)
gradients = optimize.compute_gradients(0.7 * class_loss + 0.3 * KD_loss, var_list=variables)
gradient_wdecay = optimize.compute_gradients(reg_loss, var_list=variables)
gradient_dist = optimize.compute_gradients(distillation_loss, var_list=variables)
with tf.variable_scope('clip_grad'):
for i, gc, gw, gd in zip(range(len(gradients)), gradients, gradient_wdecay, gradient_dist):
gw = 0. if gw[0] is None else gw[0]
if gd[0] is None:
gradients[i] = (tf.cond(cond, lambda: tf.zeros_like(gc[0]), lambda: gw + gc[0]), gc[1])
elif gc[0] is None:
gradients[i] = (tf.cond(cond, lambda: gw + gd[0], lambda: tf.zeros_like(gd[0])), gd[1])
else:
gradients[i] = (tf.cond(cond, lambda: gw + gd[0], lambda: gw + gc[0]), gc[1])
elif Distillation == 'KD-SVD':
# multi-task learning w/ distillation gradients clipping
# distillation gradients are clipped by norm of main-task gradients
def sigmoid(x, k, d = 1):
s = 1/(1+tf.exp(-(x-k)/d))
s = tf.cond(tf.greater(s,1-1e-8),
lambda : 1.0, lambda : s)
return s
reg_loss = tf.add_n(tf.losses.get_regularization_losses())
distillation_loss = tf.get_collection('dist')[0]
total_loss = class_loss + reg_loss + distillation_loss
tf.summary.scalar('loss/total_loss', total_loss)
tf.summary.scalar('loss/distillation_loss', distillation_loss)
gradients = optimize.compute_gradients(class_loss, var_list = variables)
gradient_wdecay = optimize.compute_gradients(reg_loss, var_list = variables)
gradient_dist = optimize.compute_gradients(distillation_loss, var_list = variables)
with tf.variable_scope('clip_grad'):
for i, gc, gw, gd in zip(range(len(gradients)),gradients,gradient_wdecay,gradient_dist):
gw = 0. if gw[0] is None else gw[0]
if gd[0] != None:
norm = tf.sqrt(tf.reduce_sum(tf.square(gc[0])))*sigmoid(epoch, 0)
gradients[i] = (gc[0] + gw + tf.clip_by_norm(gd[0], norm), gc[1])
else:
gradients[i] = (gc[0] + gw, gc[1])
# merge update operators and make train operator
update_ops.append(optimize.apply_gradients(gradients, global_step=global_step))
update_op = tf.group(*update_ops)
train_op = control_flow_ops.with_dependencies([update_op], total_loss, name='train_op')
return train_op