diff --git a/phygnn/layers/custom_layers.py b/phygnn/layers/custom_layers.py index 77cacd6..ac0e55f 100644 --- a/phygnn/layers/custom_layers.py +++ b/phygnn/layers/custom_layers.py @@ -177,12 +177,14 @@ def __init__(self, pool_size, strides=None, padding='valid', sigma=1, def _make_2D_gaussian_kernel(edge_len, sigma=1.): """Creates 2D gaussian kernel with side length `edge_len` and a sigma of `sigma` + Parameters ---------- edge_len : int Edge size of the kernel sigma : float Sigma parameter for gaussian distribution + Returns ------- kernel : np.ndarray @@ -213,10 +215,12 @@ def get_config(self): def call(self, x): """Operates on x with the specified function + Parameters ---------- x : tf.Tensor Input tensor + Returns ------- x : tf.Tensor @@ -240,11 +244,12 @@ def __init__(self, axis, mean=1, stddev=0.1): """ Parameters ---------- - axis : int - Axis to apply random noise across. All other axis will have the - same noise. For example, for a 5D spatiotemporal tensor with axis=3 - (the time axis), this layer will apply a single random number to - every unique index of axis=3. + axis : int | list | tuple + Axes to apply random noise across. All other axes will have the + same noise. For example, for a 5D spatiotemporal tensor with + axis=(1, 2, 3) (both spatial axes and the temporal axis), this + layer will apply a single random number to every unique index of + axis=(1, 2, 3). mean : float The mean of the normal distribution. stddev : float @@ -252,11 +257,18 @@ def __init__(self, axis, mean=1, stddev=0.1): """ super().__init__() - self._axis = axis - self._rand_shape = None + self.rank = None + self._axis = axis if isinstance(axis, (tuple, list)) else [axis] self._mean = tf.constant(mean, dtype=tf.dtypes.float32) self._stddev = tf.constant(stddev, dtype=tf.dtypes.float32) + def _get_rand_shape(self, x): + """Get shape of random noise along the specified axes.""" + shape = np.ones(len(x.shape), dtype=np.int32) + for ax in self._axis: + shape[ax] = x.shape[ax] + return tf.constant(shape, dtype=tf.dtypes.int32) + def build(self, input_shape): """Custom implementation of the tf layer build method. @@ -267,9 +279,7 @@ def build(self, input_shape): input_shape : tuple Shape tuple of the input """ - shape = np.ones(len(input_shape), dtype=np.int32) - shape[self._axis] = input_shape[self._axis] - self._rand_shape = tf.constant(shape, dtype=tf.dtypes.int32) + self.rank = len(input_shape) def call(self, x): """Calls the tile operation @@ -285,11 +295,11 @@ def call(self, x): Output tensor with noise applied to the requested axis. """ - rand_tensor = tf.random.normal(self._rand_shape, + rand_tensor = tf.random.normal(self._get_rand_shape(x), mean=self._mean, stddev=self._stddev, dtype=tf.dtypes.float32) - return x * rand_tensor + return x + rand_tensor class FlattenAxis(tf.keras.layers.Layer): @@ -351,7 +361,7 @@ def __init__(self, spatial_mult=1): """ Parameters ---------- - spatial_multiplier : int + spatial_mult : int Number of times to multiply the spatial dimensions. Note that the spatial expansion is an un-packing of the feature dimension. For example, if the input layer has shape (123, 5, 5, 16) with @@ -435,14 +445,14 @@ def __init__(self, spatial_mult=1, temporal_mult=1, """ Parameters ---------- - spatial_multiplier : int + spatial_mult : int Number of times to multiply the spatial dimensions. Note that the spatial expansion is an un-packing of the feature dimension. For example, if the input layer has shape (123, 5, 5, 24, 16) with multiplier=2 the output shape will be (123, 10, 10, 24, 4). The input feature dimension must be divisible by the spatial multiplier squared. - temporal_multiplier : int + temporal_mult : int Number of times to multiply the temporal dimension. For example, if the input layer has shape (123, 5, 5, 24, 2) with multiplier=2 the output shape will be (123, 5, 5, 48, 2). @@ -603,18 +613,17 @@ def call(self, x): if self._cache is None: self._cache = x return x + try: + out = tf.add(x, self._cache) + except Exception as e: + msg = ('Could not add SkipConnection "{}" data cache of ' + 'shape {} to input of shape {}.' + .format(self._name, self._cache.shape, x.shape)) + logger.error(msg) + raise RuntimeError(msg) from e else: - try: - out = tf.add(x, self._cache) - except Exception as e: - msg = ('Could not add SkipConnection "{}" data cache of ' - 'shape {} to input of shape {}.' - .format(self._name, self._cache.shape, x.shape)) - logger.error(msg) - raise RuntimeError(msg) from e - else: - self._cache = None - return out + self._cache = None + return out class SqueezeAndExcitation(tf.keras.layers.Layer): @@ -834,7 +843,8 @@ def __init__(self, name=None): """ super().__init__(name=name) - def call(self, x, hi_res_adder): + @staticmethod + def call(x, hi_res_adder): """Adds hi-resolution data to the input tensor x in the middle of a sup3r resolution network. @@ -869,7 +879,8 @@ def __init__(self, name=None): """ super().__init__(name=name) - def call(self, x, hi_res_feature): + @staticmethod + def call(x, hi_res_feature): """Concatenates a hi-resolution feature to the input tensor x in the middle of a sup3r resolution network. @@ -940,7 +951,8 @@ class SigLin(tf.keras.layers.Layer): y = x + 0.5 where x>=0.5 """ - def call(self, x): + @staticmethod + def call(x): """Operates on x with SigLin Parameters @@ -1002,8 +1014,7 @@ def build(self, input_shape): def _logt(self, x): if not self.inverse: return tf.math.log(x + self.adder) * self.scalar - else: - return tf.math.exp(x / self.scalar) - self.adder + return tf.math.exp(x / self.scalar) - self.adder def call(self, x): """Operates on x with (inverse) log transform @@ -1021,16 +1032,15 @@ def call(self, x): if self.idf is None: return self._logt(x) - else: - out = [] - for idf in range(x.shape[-1]): - if idf in self.idf: - out.append(self._logt(x[..., idf:idf + 1])) - else: - out.append(x[..., idf:idf + 1]) + out = [] + for idf in range(x.shape[-1]): + if idf in self.idf: + out.append(self._logt(x[..., idf:idf + 1])) + else: + out.append(x[..., idf:idf + 1]) - out = tf.concat(out, -1, name='concat') - return out + out = tf.concat(out, -1, name='concat') + return out class UnitConversion(tf.keras.layers.Layer):