Skip to content

Commit

Permalink
Remove layernorm and peephole connections
Browse files Browse the repository at this point in the history
  • Loading branch information
carlthome authored Jan 4, 2022
1 parent 77e0e88 commit 1547958
Showing 1 changed file with 6 additions and 34 deletions.
40 changes: 6 additions & 34 deletions cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,12 @@ class ConvLSTMCell(tf.nn.rnn_cell.RNNCell):
Xingjian, S. H. I., et al. "Convolutional LSTM network: A machine learning approach for precipitation nowcasting." Advances in Neural Information Processing Systems. 2015.
"""

def __init__(self, shape, filters, kernel, forget_bias=1.0, activation=tf.tanh, normalize=True, peephole=True, data_format='channels_last', reuse=None):
def __init__(self, shape, filters, kernel, forget_bias=1.0, activation=tf.tanh, data_format='channels_last', reuse=None):
super(ConvLSTMCell, self).__init__(_reuse=reuse)
self._kernel = kernel
self._filters = filters
self._forget_bias = forget_bias
self._activation = activation
self._normalize = normalize
self._peephole = peephole
if data_format == 'channels_last':
self._size = tf.TensorShape(shape + [self._filters])
self._feature_axis = self._size.ndims
Expand Down Expand Up @@ -42,30 +40,13 @@ def call(self, x, state):
m = 4 * self._filters if self._filters > 1 else 4
W = tf.get_variable('kernel', self._kernel + [n, m])
y = tf.nn.convolution(x, W, 'SAME', data_format=self._data_format)
if not self._normalize:
y += tf.get_variable('bias', [m], initializer=tf.zeros_initializer())
y += tf.get_variable('bias', [m], initializer=tf.zeros_initializer())
j, i, f, o = tf.split(y, 4, axis=self._feature_axis)

if self._peephole:
i += tf.get_variable('W_ci', c.shape[1:]) * c
f += tf.get_variable('W_cf', c.shape[1:]) * c

if self._normalize:
j = tf.contrib.layers.layer_norm(j)
i = tf.contrib.layers.layer_norm(i)
f = tf.contrib.layers.layer_norm(f)

f = tf.sigmoid(f + self._forget_bias)
i = tf.sigmoid(i)
c = c * f + i * self._activation(j)

if self._peephole:
o += tf.get_variable('W_co', c.shape[1:]) * c

if self._normalize:
o = tf.contrib.layers.layer_norm(o)
c = tf.contrib.layers.layer_norm(c)

o = tf.sigmoid(o)
h = o * self._activation(c)

Expand All @@ -77,12 +58,11 @@ def call(self, x, state):
class ConvGRUCell(tf.nn.rnn_cell.RNNCell):
"""A GRU cell with convolutions instead of multiplications."""

def __init__(self, shape, filters, kernel, activation=tf.tanh, normalize=True, data_format='channels_last', reuse=None):
def __init__(self, shape, filters, kernel, activation=tf.tanh, data_format='channels_last', reuse=None):
super(ConvGRUCell, self).__init__(_reuse=reuse)
self._filters = filters
self._kernel = kernel
self._activation = activation
self._normalize = normalize
if data_format == 'channels_last':
self._size = tf.TensorShape(shape + [self._filters])
self._feature_axis = self._size.ndims
Expand Down Expand Up @@ -111,13 +91,8 @@ def call(self, x, h):
m = 2 * self._filters if self._filters > 1 else 2
W = tf.get_variable('kernel', self._kernel + [n, m])
y = tf.nn.convolution(inputs, W, 'SAME', data_format=self._data_format)
if self._normalize:
r, u = tf.split(y, 2, axis=self._feature_axis)
r = tf.contrib.layers.layer_norm(r)
u = tf.contrib.layers.layer_norm(u)
else:
y += tf.get_variable('bias', [m], initializer=tf.ones_initializer())
r, u = tf.split(y, 2, axis=self._feature_axis)
y += tf.get_variable('bias', [m], initializer=tf.ones_initializer())
r, u = tf.split(y, 2, axis=self._feature_axis)
r, u = tf.sigmoid(r), tf.sigmoid(u)

with tf.variable_scope('candidate'):
Expand All @@ -126,10 +101,7 @@ def call(self, x, h):
m = self._filters
W = tf.get_variable('kernel', self._kernel + [n, m])
y = tf.nn.convolution(inputs, W, 'SAME', data_format=self._data_format)
if self._normalize:
y = tf.contrib.layers.layer_norm(y)
else:
y += tf.get_variable('bias', [m], initializer=tf.zeros_initializer())
y += tf.get_variable('bias', [m], initializer=tf.zeros_initializer())
h = u * h + (1 - u) * self._activation(y)

return h, h

0 comments on commit 1547958

Please sign in to comment.