Skip to content

Commit

Permalink
Modified gpool layer to have trainable distribution parameter sigma
Browse files Browse the repository at this point in the history
  • Loading branch information
grantbuster committed Sep 13, 2024
1 parent 2499daa commit 11a0d86
Showing 1 changed file with 47 additions and 34 deletions.
81 changes: 47 additions & 34 deletions phygnn/layers/custom_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ class GaussianAveragePooling2D(tf.keras.layers.Layer):
convolution window that limits the area of effect"""

def __init__(self, pool_size, strides=None, padding='valid', sigma=1,
**kwargs):
trainable=True, **kwargs):
"""
Parameters
----------
Expand All @@ -154,47 +154,56 @@ def __init__(self, pool_size, strides=None, padding='valid', sigma=1,
same height/width dimension as the input.
sigma : float
Sigma parameter for gaussian distribution
trainable : bool
Flag for whether sigma is trainable weight or not.
kwargs : dict
Extra kwargs for tf.keras.layers.Layer
"""

super().__init__(**kwargs)
assert isinstance(pool_size, int), 'pool_size must be int!'
self._pool_size = pool_size
self._strides = strides
self._padding = padding.upper()
self._sigma = sigma

target_shape = (self._pool_size, self._pool_size, 1, 1)
self._kernel = self._make_2D_gaussian_kernel(self._pool_size,
self._sigma)
self._kernel = np.expand_dims(self._kernel, -1)
self._kernel = np.expand_dims(self._kernel, -1)
assert self._kernel.shape == target_shape
self._kernel = tf.convert_to_tensor(self._kernel, dtype=tf.float32)
self.pool_size = pool_size
self.strides = strides
self.padding = padding.upper()
self.trainable = trainable
self.sigma = sigma

@staticmethod
def _make_2D_gaussian_kernel(edge_len, sigma=1.):
"""Creates 2D gaussian kernel with side length `edge_len` and a sigma
of `sigma`
def build(self, input_shape):
"""Custom implementation of the tf layer build method.
Initializes the trainable sigma variable
Parameters
----------
edge_len : int
Edge size of the kernel
sigma : float
Sigma parameter for gaussian distribution
input_shape : tuple
Shape tuple of the input
"""
if not any(self.weights):
init = tf.keras.initializers.Constant(value=self.sigma)
self.sigma = self.add_weight("sigma", shape=[1],
trainable=self.trainable,
dtype=tf.float32,
initializer=init)

def make_kernel(self):
"""Creates 2D gaussian kernel with side length `self.pool_size` and a
sigma of `sigma`
Returns
-------
kernel : np.ndarray
2D kernel with shape (edge_len, edge_len)
2D kernel with shape (self.pool_size, self.pool_size)
"""
ax = np.linspace(-(edge_len - 1) / 2., (edge_len - 1) / 2., edge_len)
gauss = np.exp(-0.5 * np.square(ax) / np.square(sigma))
kernel = np.outer(gauss, gauss)
kernel = kernel / np.sum(kernel)
return kernel.astype(np.float32)
ax = tf.linspace(-(self.pool_size - 1) / 2.,
(self.pool_size - 1) / 2.,
self.pool_size)
gauss = tf.math.exp(-0.5 * tf.math.square(ax)
/ tf.math.square(self.sigma))
kernel = tf.expand_dims(gauss, 0) * tf.expand_dims(gauss, -1)
kernel = kernel / tf.math.reduce_sum(kernel)
kernel = tf.expand_dims(kernel, -1)
kernel = tf.expand_dims(kernel, -1)
return kernel

def get_config(self):
"""Implementation of get_config method from tf.keras.layers.Layer for
Expand All @@ -206,10 +215,11 @@ def get_config(self):
"""
config = super().get_config().copy()
config.update({
'pool_size': self._pool_size,
'strides': self._strides,
'padding': self._padding,
'sigma': self._sigma,
'pool_size': self.pool_size,
'strides': self.strides,
'padding': self.padding,
'trainable': self.trainable,
'sigma': float(self.sigma),
})
return config

Expand All @@ -226,12 +236,15 @@ def call(self, x):
x : tf.Tensor
Output tensor operated on by the specified function
"""

kernel = self.make_kernel()

out = []
for idf in range(x.shape[-1]):
fslice = slice(idf, idf + 1)
iout = tf.nn.convolution(x[..., fslice], self._kernel,
strides=self._strides,
padding=self._padding)
iout = tf.nn.convolution(x[..., fslice], kernel,
strides=self.strides,
padding=self.padding)
out.append(iout)
out = tf.concat(out, -1, name='concat')
return out
Expand Down

0 comments on commit 11a0d86

Please sign in to comment.