-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlayers.py
132 lines (120 loc) · 4.16 KB
/
layers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import tensorflow as tf
from tensorflow.keras import layers
class DenseBlock(layers.Layer):
"""Densebock used in generator
Argument:
dense_growth_channels: how many channels grow after one
conv, leakyrelu, concatenate
"""
def __init__(self, input_channels, dense_growth_channels, scale):
super(DenseBlock, self).__init__()
self.input_channels = input_channels
self.dense_growth_channels = dense_growth_channels
self.scale = scale
self.conv_1 = tf.keras.layers.Conv2D(
filters = self.dense_growth_channels,
kernel_size = 3,
strides = 1,
padding = "same"
)
self.concatenate_1 = tf.keras.layers.Concatenate()
self.conv_2 = tf.keras.layers.Conv2D(
filters = self.dense_growth_channels,
kernel_size = 3,
strides = 1,
padding = "same"
)
self.concatenate_2 = tf.keras.layers.Concatenate()
self.conv_3 = tf.keras.layers.Conv2D(
filters = self.dense_growth_channels,
kernel_size = 3,
strides = 1,
padding = "same"
)
self.concatenate_3 = tf.keras.layers.Concatenate()
self.conv_4 = tf.keras.layers.Conv2D(
filters = self.dense_growth_channels,
kernel_size = 3,
strides = 1,
padding = "same"
)
self.concatenate_4 = tf.keras.layers.Concatenate()
self.conv_5 = tf.keras.layers.Conv2D(
filters = self.input_channels,
kernel_size = 3,
strides = 1,
padding = "same"
)
self.residual_scale = tf.keras.layers.Lambda(lambda x: x * scale)
self.add = tf.keras.layers.Add()
def call(self, inputs):
x_0 = inputs
x = self.conv_1(inputs)
x = tf.nn.leaky_relu(x)
x = x_1 = self.concatenate_1([x_0, x])
x = self.conv_2(x)
x = tf.nn.leaky_relu(x)
x = x_2 = self.concatenate_2([x_1, x_0, x])
x = self.conv_3(x)
x = tf.nn.leaky_relu(x)
x = x_3 = self.concatenate_3([x_2, x_1, x_0, x])
x = self.conv_4(x)
x = tf.nn.leaky_relu(x)
x = x_4 = self.concatenate_4([x_3, x_2, x_1, x_0, x])
x = self.conv_5(x)
x = self.residual_scale(x)
x = self.add([x, x_0])
return x
class ConvBlock(tf.keras.layers.Layer):
"""Convolution layer used in discrimanator
Argument:
filters: num of filters used in conv.
strides: strides.
bn: bool, if use bn in conv block.
"""
def __init__(self, filters, strides = 1, bn = True):
super(ConvBlock, self).__init__()
self. has_bn = True
self.conv = tf.keras.layers.Conv2D(
filters = filters,
kernel_size = 3,
strides = strides,
padding = "same"
)
self.bn = tf.keras.layers.BatchNormalization(momentum = 0.8)
def call(self, inputs):
x = self.conv(inputs)
x = tf.nn.leaky_relu(x)
if self.has_bn == True:
x = self.bn(x)
return x
class ResidualBlock(tf.keras.layers.Layer):
def __init__(self, filters = 128, **kwargs):
"""
The residual block in DSen2.
Same as paper at all.
"""
super(ResidualBlock, self).__init__(**kwargs)
self.conv1 = layers.Conv2D(
filters = filters,
kernel_size = 3,
strides = 1,
padding = "same"
)
self.conv2 = layers.Conv2D(
filters = filters,
kernel_size = 3,
strides = 1,
padding = "same"
)
# 0.1 is same as the paper
self.residual_scale = layers.Lambda(lambda x: x * 0.1)
self.add = layers.Add()
def call(self, inputs):
x = inputs
x = self.conv1(x)
x = tf.nn.relu(x)
x = self.conv2(x)
x = self.residual_scale(x)
x = self.add([inputs, x])
return x