-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathgohrnet.py
38 lines (36 loc) · 1.78 KB
/
gohrnet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from keras.models import Model
from keras.layers import Dense, Conv1D, Input, Reshape, Permute, Add, Flatten, BatchNormalization, Activation
from keras import backend as K
from keras.regularizers import l2
#make residual tower of convolutional blocks
def make_model(input_size, num_filters=32, num_outputs=1, d1=64, d2=64, word_size=16, ks=3,depth=1, reg_param=10**-5, final_activation='sigmoid'):
#Input and preprocessing layers
inp = Input(shape=(input_size,));
rs = Reshape((input_size//word_size, word_size))(inp);
perm = Permute((2,1))(rs);
#add a single residual layer that will expand the data to num_filters channels
#this is a bit-sliced layer
conv0 = Conv1D(num_filters, kernel_size=1, padding='same', kernel_regularizer=l2(reg_param))(perm);
conv0 = BatchNormalization()(conv0);
conv0 = Activation('relu')(conv0);
#add residual blocks
shortcut = conv0;
for i in range(depth):
conv1 = Conv1D(num_filters, kernel_size=ks, padding='same', kernel_regularizer=l2(reg_param))(shortcut);
conv1 = BatchNormalization()(conv1);
conv1 = Activation('relu')(conv1);
conv2 = Conv1D(num_filters, kernel_size=ks, padding='same',kernel_regularizer=l2(reg_param))(conv1);
conv2 = BatchNormalization()(conv2);
conv2 = Activation('relu')(conv2);
shortcut = Add()([shortcut, conv2]);
#add prediction head
flat1 = Flatten()(shortcut);
dense1 = Dense(d1,kernel_regularizer=l2(reg_param))(flat1);
dense1 = BatchNormalization()(dense1);
dense1 = Activation('relu')(dense1);
dense2 = Dense(d2, kernel_regularizer=l2(reg_param))(dense1);
dense2 = BatchNormalization()(dense2);
dense2 = Activation('relu')(dense2);
out = Dense(num_outputs, activation=final_activation, kernel_regularizer=l2(reg_param))(dense2);
model = Model(inputs=inp, outputs=out);
return(model);