From 452af90959b33ed30813ad708a50d4d2da70eef7 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Fri, 17 Nov 2023 15:41:20 -0700 Subject: [PATCH] fno layer --- phygnn/layers/custom_layers.py | 101 ++++++++++++++++++++++++++++++++- 1 file changed, 98 insertions(+), 3 deletions(-) diff --git a/phygnn/layers/custom_layers.py b/phygnn/layers/custom_layers.py index d9b03ff..883b553 100644 --- a/phygnn/layers/custom_layers.py +++ b/phygnn/layers/custom_layers.py @@ -478,7 +478,8 @@ def __init__(self, name): Unique string identifier of the skip connection. The skip endpoint should have the same name. """ - super().__init__(name=name) + super().__init__() + self._name = name self._cache = None def call(self, x): @@ -597,6 +598,100 @@ def call(self, x): return x +class FourierNeuralOperator(tf.keras.layers.Layer): + """Custom layer for fourier neural operator block + + Note that this is only set up to take a channels-last input + + References + ---------- + 1. FourCastNet: A Global Data-driven High-resolution Weather Model using + Adaptive Fourier Neural Operators. http://arxiv.org/abs/2202.11214 + """ + + def __init__(self, ratio=16, sparsity_threshold=0.5): + """ + Parameters + ---------- + ratio : int + Number of channels/filters divided by the number of + dense connections in the FNO block. + sparsity_threshold : float + Parameter to control sparsity and shrinkage in the softshrink + activation function. + """ + + super().__init__() + self._ratio = ratio + self.fft_layer = None + self.ifft_layer = None + self.mlp_layers = None + self.sparsity_threshold = sparsity_threshold + + def softshrink(self, x, lambd=0.5): + """Softshrink activation function + + https://pytorch.org/docs/stable/generated/torch.nn.Softshrink.html + """ + x = tf.convert_to_tensor(x) + values_below_lower = tf.where(x < -lambd, x + lambd, 0) + values_above_upper = tf.where(lambd < x, x - lambd, 0) + return values_below_lower + values_above_upper + + def build(self, input_shape): + """Build the FNO layer based on an input shape + + Parameters + ---------- + input_shape : tuple + Shape tuple of the input tensor + """ + + self._n_channels = input_shape[-1] + self._dense_units = int(np.ceil(self._n_channels / self._ratio)) + + if len(input_shape) == 4: + self.fft_layer = tf.signal.fft2d + self.ifft_layer = tf.signal.ifft2d + elif len(input_shape) == 5: + self.fft_layer = tf.signal.fft3d + self.ifft_layer = tf.signal.ifft3d + else: + msg = ('FourierNeuralOperator layer can only accept 4D or 5D data ' + 'for image or video input but received input shape: {}' + .format(input_shape)) + logger.error(msg) + raise RuntimeError(msg) + + self.mlp_layers = [ + tf.keras.layers.Dense(self._dense_units, activation='relu'), + tf.keras.layers.Dense(self._n_channels)] + + def call(self, x): + """Call the custom FourierNeuralOperator layer + + Parameters + ---------- + x : tf.Tensor + Input tensor. + + Returns + ------- + x : tf.Tensor + Output tensor, this is the FNO weights added to the original input + tensor. + """ + + t_in = x + x = self.fft_layer(x) + for layer in self.mlp_layers: + x = layer(x) + x = self.softshrink(x, lambd=self.sparsity_threshold) + x = self.ifft_layer(x) + + return x + t_in + + class Sup3rAdder(tf.keras.layers.Layer): """Layer to add high-resolution data to a sup3r model in the middle of a super resolution forward pass.""" @@ -609,7 +704,7 @@ def __init__(self, name=None): Unique str identifier of the adder layer. Usually the name of the hi-resolution feature used in the addition. """ - super().__init__(name=name) + self.name = name def call(self, x, hi_res_adder): """Adds hi-resolution data to the input tensor x in the middle of a @@ -644,7 +739,7 @@ def __init__(self, name=None): Unique str identifier for the concat layer. Usually the name of the hi-resolution feature used in the concatenation. """ - super().__init__(name=name) + self.name = name def call(self, x, hi_res_feature): """Concatenates a hi-resolution feature to the input tensor x in the