Skip to content

Commit

Permalink
refactored a gaussian kernel initializer instead of a gaussian averag…
Browse files Browse the repository at this point in the history
…e pooling layer for easier integration with tensorflow models
  • Loading branch information
grantbuster committed May 6, 2024
1 parent eadb42b commit 61eb6a4
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 119 deletions.
138 changes: 45 additions & 93 deletions phygnn/layers/custom_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,99 +130,6 @@ def call(self, x):
return tf.tile(x, self._mult)


class GaussianAveragePooling2D(tf.keras.layers.Layer):
"""Custom layer to implement tensorflow average pooling layer but with a
gaussian kernel. This is basically a gaussian smoothing layer with a fixed
convolution window that limits the area of effect"""

def __init__(self, pool_size, strides=None, padding='valid', sigma=1):
"""
Parameters
----------
pool_size: integer
Pooling window size. This sets the number of pixels in each
dimension that will be averaged into an output pixel. Only one
integer is specified, the same window length will be used for both
dimensions. For example, if ``pool_size=2`` and ``strides=2`` then
the output dimension will be half of the input.
strides: Integer, tuple of 2 integers, or None.
Strides values. If None, it will default to `pool_size`.
padding: One of `"valid"` or `"same"` (case-insensitive).
`"valid"` means no padding. `"same"` results in padding evenly to
the left/right or up/down of the input such that output has the
same height/width dimension as the input.
sigma : float
Sigma parameter for gaussian distribution
"""

super().__init__()
assert isinstance(pool_size, int), 'pool_size must be int!'
self._pool_size = pool_size
self._strides = strides
self._padding = padding.upper()
self._sigma = sigma
self._kernel = None

@staticmethod
def _make_2D_gaussian_kernel(edge_len, sigma=1.):
"""Creates 2D gaussian kernel with side length `edge_len` and a sigma
of `sigma`
Parameters
----------
edge_len : int
Edge size of the kernel
sigma : float
Sigma parameter for gaussian distribution
Returns
-------
kernel : np.ndarray
2D kernel with shape (edge_len, edge_len)
"""
ax = np.linspace(-(edge_len - 1) / 2., (edge_len - 1) / 2., edge_len)
gauss = np.exp(-0.5 * np.square(ax) / np.square(sigma))
kernel = np.outer(gauss, gauss)
kernel = kernel / np.sum(kernel)
return kernel.astype(np.float32)

def build(self, input_shape):
"""Custom implementation of the tf layer build method.
Sets the shape of the gaussian kernel
Parameters
----------
input_shape : tuple
Shape tuple of the input
"""
target_shape = (self._pool_size, self._pool_size, 1, input_shape[-1])
self._kernel = self._make_2D_gaussian_kernel(self._pool_size,
self._sigma)
self._kernel = [self._kernel for _ in range(input_shape[-1])]
self._kernel = np.dstack(self._kernel)
self._kernel = np.expand_dims(self._kernel, 2)
assert self._kernel.shape == target_shape
self._kernel = tf.convert_to_tensor(self._kernel, dtype=tf.float32)

def call(self, x):
"""Operates on x with the specified function
Parameters
----------
x : tf.Tensor
Input tensor
Returns
-------
x : tf.Tensor
Output tensor operated on by the specified function
"""
out = tf.nn.convolution(x, self._kernel, strides=self._strides,
padding=self._padding)
return out


class GaussianNoiseAxis(tf.keras.layers.Layer):
"""Layer to apply random noise along a given axis."""

Expand Down Expand Up @@ -282,6 +189,51 @@ def call(self, x):
return x * rand_tensor


class GaussianKernelInit2D(tf.keras.initializers.Initializer):
"""Convolutional kernel initializer that creates a symmetric 2D array with
a gaussian distribution. This can be used with Conv2D as a gaussian average
pooling layer if trainable=False
"""

def __init__(self, stdev=1):
"""
Parameters
----------
stdev : float
Standard deviation of the gaussian distribution defining the kernel
values
"""
self.stdev = stdev

def __call__(self, shape, dtype=tf.float32):
"""
Parameters
---------
shape : tuple
Shape of the input tensor, typically (y, x, n_features, n_obs)
dtype : None | tf.DType
Tensorflow datatype e.g., tf.float32
Returns
-------
kernel : tf.Tensor
Kernel tensor of shape (y, x, n_features, n_obs) for use in a
Conv2D layer.
"""

ax = np.linspace(-(shape[0] - 1) / 2., (shape[0] - 1) / 2., shape[0])
kernel = np.exp(-0.5 * np.square(ax) / np.square(self.stdev))
kernel = np.outer(kernel, kernel)
kernel = kernel / np.sum(kernel)

kernel = np.expand_dims(kernel, (2, 3))
kernel = np.repeat(kernel, shape[2], axis=2)
kernel = np.repeat(kernel, shape[3], axis=3)

kernel = tf.convert_to_tensor(kernel, dtype=dtype)
return kernel


class FlattenAxis(tf.keras.layers.Layer):
"""Layer to flatten an axis from a 5D spatiotemporal Tensor into axis-0
observations."""
Expand Down
52 changes: 26 additions & 26 deletions tests/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
SpatioTemporalExpansion,
TileLayer,
FunctionalLayer,
GaussianAveragePooling2D,
GaussianKernelInit2D,
)
from phygnn.layers.handlers import HiddenLayers, Layers

Expand Down Expand Up @@ -447,28 +447,28 @@ def test_functional_layer():
assert "must be one of" in str(excinfo.value)


def test_gaussian_pool():
"""Test the gaussian average pooling layer"""

layer = GaussianAveragePooling2D(pool_size=2, strides=2,
padding='valid', sigma=1)
x = np.zeros((2, 6, 6, 2))
x[:, 1, 1, :] = 1
x[:, 2, 2, :] = 1
y = layer(x)
print(y[0, :, :, 0])
assert y.shape == (2, 3, 3, 2)
assert (y[0, :, :, 0] == y[-1, :, :, -1]).numpy().all()
assert (y[:, 0, 0, :].numpy() == 0.25).all()
assert (y[:, 1, 1, :].numpy() == 0.25).all()

layer = GaussianAveragePooling2D(pool_size=3, strides=1,
padding='valid', sigma=1)
x = np.zeros((2, 6, 6, 2))
x[:, 2, 2, :] = 2
y = layer(x)
assert y.shape == (2, 4, 4, 2)
assert (y[0, :, :, 0] == y[-1, :, :, -1]).numpy().all()
assert y[0, 1, 1, 0].numpy() == y.numpy().max()
assert (y[0, -1, 0, 0].numpy() == 0).all()
assert (y[0, :, -1, 0].numpy() == 0).all()
def test_gaussian_kernel():
"""Test the gaussian kernel initializer for gaussian average pooling"""

kernels = []
biases = []
for stdev in [1, 2]:
kinit = GaussianKernelInit2D(stdev=stdev)
layer = tf.keras.layers.Conv2D(filters=16, kernel_size=5, strides=1,
padding='valid',
kernel_initializer=kinit)
_ = layer(np.ones((24, 100, 100, 35)))
kernel = layer.weights[0].numpy()
bias = layer.weights[1].numpy()
kernels.append(kernel)
biases.append(bias)

assert (kernel[:, :, 0, 0] == kernel[:, :, -1, -1]).all()
assert kernel[:, :, 0, 0].sum() == 1
assert (bias == 0).all()
assert kernel[2, 2, 0, 0] == kernel.max()
assert kernel[0, 0, 0, 0] == kernel.min()
assert kernel[-1, -1, 0, 0] == kernel.min()

assert kernels[1].max() < kernels[0].max()
assert kernels[1].min() > kernels[0].min()

0 comments on commit 61eb6a4

Please sign in to comment.