-
Notifications
You must be signed in to change notification settings - Fork 0
/
unet.py
61 lines (48 loc) · 2.58 KB
/
unet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
from tensorflow.keras.layers import Input, concatenate, Dense, Flatten, BatchNormalization, Dropout, ReLU, Conv2D, Reshape, MaxPooling2D, Conv2DTranspose
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.applications.vgg16 import VGG16, preprocess_input
import tensorflow as tf
from tensorflow.keras.regularizers import l2
input_shape = (256,256,1)
Nkeypoints = 15
def model(input_shape):
H = input_shape[0]
W = input_shape[1]
def downsample_block(x, block_num, n_filters, pooling_on=True):
x = Conv2D(n_filters, kernel_size=(3, 3), strides=1, padding='same', activation='relu',
name="Block" + str(block_num) + "_Conv1")(x)
x = Conv2D(n_filters, kernel_size=(3, 3), strides=1, padding='same', activation='relu',
name="Block" + str(block_num) + "_Conv2")(x)
skip = x
if pooling_on is True:
x = MaxPooling2D(pool_size=(2, 2), strides=2, padding='valid', name="Block" + str(block_num) + "_Pool1")(x)
return x, skip
def upsample_block(x, skip, block_num, n_filters):
x = Conv2DTranspose(n_filters, kernel_size=(2, 2), strides=2, padding='valid', activation='relu',
name="Block" + str(block_num) + "_ConvT1")(x)
x = concatenate([x, skip], axis=-1, name="Block" + str(block_num) + "_Concat1")
x = Conv2D(n_filters, kernel_size=(3, 3), strides=1, padding='same', activation='relu',
name="Block" + str(block_num) + "_Conv1")(x)
x = Conv2D(n_filters, kernel_size=(3, 3), strides=1, padding='same', activation='relu',
name="Block" + str(block_num) + "_Conv2")(x)
return x
input = Input(input_shape, name="Input")
# downsampling
x, skip1 = downsample_block(input, 1, 64)
x, skip2 = downsample_block(x, 2, 128)
x, skip3 = downsample_block(x, 3, 256)
x, skip4 = downsample_block(x, 4, 512)
x, _ = downsample_block(x, 5, 1024, pooling_on=False)
# upsampling
x = upsample_block(x, skip4, 6, 512)
x = upsample_block(x, skip3, 7, 256)
x = upsample_block(x, skip2, 8, 128)
x = upsample_block(x, skip1, 9, 64)
output = Conv2D(15, kernel_size=(1, 1), strides=1, padding='valid', activation='linear', name="output")(x)
#output = Conv2D(68, kernel_size=(1, 1), strides=1, padding='valid', activation='sigmoid', name="output")(x)
#output = Reshape(target_shape=(H*W*Nkeypoints,1))(output)#add and check
model = Model(inputs=input, outputs=output, name="Output")
model.summary()
return model
if __name__=="__main__":
model(input_shape)