-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathunpool(1).py
39 lines (31 loc) · 1.33 KB
/
unpool(1).py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import tensorflow as tf
import numpy as np
a = tf.ones(shape=[1,6,6,1],dtype=tf.float32)
pool = tf.nn.max_pool(a, ksize=[1,2,2,1], strides=[1,2,2,1],padding='VALID')
def unpool2(pool, ksize, stride, padding = 'VALID'):
"""
simple unpool method
:param pool : the tensor to run unpool operation
:param ksize : integer
:param stride : integer
:return : the tensor after the unpool operation
"""
pool = tf.transpose(pool, perm=[0,3,1,2])
pool_shape = pool.shape.as_list()
if padding == 'VALID':
size = (pool_shape[2] - 1) * stride + ksize
# padding == 'SAME'
else:
size = pool_shape[2] * stride
unpool_shape = [pool_shape[0], pool_shape[1], size, size]
unpool = tf.Variable(np.zeros(unpool_shape), dtype=tf.float32)
for batch in range(pool_shape[0]):
for channel in range(pool_shape[1]):
for w in range(pool_shape[2]):
for h in range(pool_shape[3]):
diff_matrix = tf.sparse_tensor_to_dense(tf.SparseTensor(indices=[[batch,channel,w*stride,h*stride]],values=tf.expand_dims(pool[batch][channel][w][h],axis=0),dense_shape = [pool_shape[0],pool_shape[1],size,size]))
unpool = unpool + diff_matrix
unpool = tf.transpose(unpool, perm=[0,2,3,1])
return unpool
# examples
# unpool2(pool, 2, 2, 'VaLID')