From 6a4f6f4314ff723fcd4540a849ba73c4af3b81a9 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Fri, 17 Nov 2023 17:44:20 -0700 Subject: [PATCH 1/6] fno layer --- phygnn/layers/custom_layers.py | 95 ++++++++++++++++++++++++++++++++++ 1 file changed, 95 insertions(+) diff --git a/phygnn/layers/custom_layers.py b/phygnn/layers/custom_layers.py index d9b03ff..8e2789a 100644 --- a/phygnn/layers/custom_layers.py +++ b/phygnn/layers/custom_layers.py @@ -597,6 +597,101 @@ def call(self, x): return x +class FNO(tf.keras.layers.Layer): + """Custom layer for fourier neural operator block + + Note that this is only set up to take a channels-last input + + References + ---------- + 1. FourCastNet: A Global Data-driven High-resolution Weather Model using + Adaptive Fourier Neural Operators. http://arxiv.org/abs/2202.11214 + """ + + def __init__(self, ratio=16, sparsity_threshold=0.5): + """ + Parameters + ---------- + ratio : int + Number of channels/filters divided by the number of + dense connections in the FNO block. + sparsity_threshold : float + Parameter to control sparsity and shrinkage in the softshrink + activation function. + """ + + super().__init__() + self._ratio = ratio + self.fft_layer = None + self.ifft_layer = None + self.mlp_layers = None + self._n_channels = None + self._dense_units = None + self.sparsity_threshold = sparsity_threshold + + def softshrink(self, x, lambd=0.5): + """Softshrink activation function + + https://pytorch.org/docs/stable/generated/torch.nn.Softshrink.html + """ + x = tf.convert_to_tensor(x) + values_below_lower = tf.where(x < -lambd, x + lambd, 0) + values_above_upper = tf.where(lambd < x, x - lambd, 0) + return values_below_lower + values_above_upper + + def build(self, input_shape): + """Build the FNO layer based on an input shape + + Parameters + ---------- + input_shape : tuple + Shape tuple of the input tensor + """ + self._n_channels = input_shape[-1] + self._dense_units = int(np.ceil(self._n_channels / self._ratio)) + + if len(input_shape) == 4: + self.fft_layer = tf.signal.fft2d + self.ifft_layer = tf.signal.ifft2d + elif len(input_shape) == 5: + self.fft_layer = tf.signal.fft3d + self.ifft_layer = tf.signal.ifft3d + else: + msg = ('FourierNeuralOperator layer can only accept 4D or 5D data ' + 'for image or video input but received input shape: {}' + .format(input_shape)) + logger.error(msg) + raise RuntimeError(msg) + + self.mlp_layers = [ + tf.keras.layers.Dense(self._dense_units, activation='relu'), + tf.keras.layers.Dense(self._n_channels)] + + def call(self, x): + """Call the custom FourierNeuralOperator layer + + Parameters + ---------- + x : tf.Tensor + Input tensor. + + Returns + ------- + x : tf.Tensor + Output tensor, this is the FNO weights added to the original input + tensor. + """ + + t_in = x + x = self.fft_layer(x) + for layer in self.mlp_layers: + x = layer(x) + x = self.softshrink(x, lambd=self.sparsity_threshold) + x = self.ifft_layer(x) + + return x + t_in + + class Sup3rAdder(tf.keras.layers.Layer): """Layer to add high-resolution data to a sup3r model in the middle of a super resolution forward pass.""" From e6d3849ecb3b57987a856843d2e3e6e9bca9a38e Mon Sep 17 00:00:00 2001 From: bnb32 Date: Fri, 17 Nov 2023 18:44:54 -0700 Subject: [PATCH 2/6] fno shape fixes --- phygnn/layers/custom_layers.py | 49 +++++++++++++++++----------------- 1 file changed, 25 insertions(+), 24 deletions(-) diff --git a/phygnn/layers/custom_layers.py b/phygnn/layers/custom_layers.py index 8e2789a..8af6938 100644 --- a/phygnn/layers/custom_layers.py +++ b/phygnn/layers/custom_layers.py @@ -608,33 +608,34 @@ class FNO(tf.keras.layers.Layer): Adaptive Fourier Neural Operators. http://arxiv.org/abs/2202.11214 """ - def __init__(self, ratio=16, sparsity_threshold=0.5): + def __init__(self, filters, sparsity_threshold=0.5, activation='relu'): """ Parameters ---------- - ratio : int - Number of channels/filters divided by the number of - dense connections in the FNO block. + filters : int + Number of dense connections in the FNO block. sparsity_threshold : float Parameter to control sparsity and shrinkage in the softshrink activation function. + activation : str + Activation function used in the dense layer of the FNO block. """ super().__init__() - self._ratio = ratio - self.fft_layer = None - self.ifft_layer = None - self.mlp_layers = None + self._filters = filters + self._fft_layer = None + self._ifft_layer = None + self._mlp_layers = None + self._activation = activation self._n_channels = None self._dense_units = None - self.sparsity_threshold = sparsity_threshold + self._sparsity_threshold = sparsity_threshold - def softshrink(self, x, lambd=0.5): + def _softshrink(self, x, lambd=0.5): """Softshrink activation function https://pytorch.org/docs/stable/generated/torch.nn.Softshrink.html """ - x = tf.convert_to_tensor(x) values_below_lower = tf.where(x < -lambd, x + lambd, 0) values_above_upper = tf.where(lambd < x, x - lambd, 0) return values_below_lower + values_above_upper @@ -647,24 +648,23 @@ def build(self, input_shape): input_shape : tuple Shape tuple of the input tensor """ - self._n_channels = input_shape[-1] - self._dense_units = int(np.ceil(self._n_channels / self._ratio)) + self._n_channels = input_shape[-1] // 2 + 1 if len(input_shape) == 4: - self.fft_layer = tf.signal.fft2d - self.ifft_layer = tf.signal.ifft2d + self._fft_layer = tf.signal.rfft2d + self._ifft_layer = tf.signal.irfft2d elif len(input_shape) == 5: - self.fft_layer = tf.signal.fft3d - self.ifft_layer = tf.signal.ifft3d + self._fft_layer = tf.signal.rfft3d + self._ifft_layer = tf.signal.irfft3d else: - msg = ('FourierNeuralOperator layer can only accept 4D or 5D data ' + msg = ('FNO layer can only accept 4D or 5D data ' 'for image or video input but received input shape: {}' .format(input_shape)) logger.error(msg) raise RuntimeError(msg) - self.mlp_layers = [ - tf.keras.layers.Dense(self._dense_units, activation='relu'), + self._mlp_layers = [ + tf.keras.layers.Dense(self._filters, activation=self._activation), tf.keras.layers.Dense(self._n_channels)] def call(self, x): @@ -683,11 +683,12 @@ def call(self, x): """ t_in = x - x = self.fft_layer(x) - for layer in self.mlp_layers: + x = self._fft_layer(x) + for layer in self._mlp_layers: x = layer(x) - x = self.softshrink(x, lambd=self.sparsity_threshold) - x = self.ifft_layer(x) + x = self._softshrink(x, lambd=self._sparsity_threshold) + x = tf.cast(x, dtype=tf.complex64) + x = self._ifft_layer(x) return x + t_in From 1c791fa3549abe4bbedebbd42a2056af123a0dc1 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Fri, 17 Nov 2023 20:19:46 -0700 Subject: [PATCH 3/6] irfft has no gradient so -> ifft --- phygnn/layers/custom_layers.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/phygnn/layers/custom_layers.py b/phygnn/layers/custom_layers.py index 8af6938..a985bbc 100644 --- a/phygnn/layers/custom_layers.py +++ b/phygnn/layers/custom_layers.py @@ -648,14 +648,14 @@ def build(self, input_shape): input_shape : tuple Shape tuple of the input tensor """ - self._n_channels = input_shape[-1] // 2 + 1 + self._n_channels = input_shape[-1] if len(input_shape) == 4: - self._fft_layer = tf.signal.rfft2d - self._ifft_layer = tf.signal.irfft2d + self._fft_layer = tf.signal.fft2d + self._ifft_layer = tf.signal.ifft2d elif len(input_shape) == 5: - self._fft_layer = tf.signal.rfft3d - self._ifft_layer = tf.signal.irfft3d + self._fft_layer = tf.signal.fft3d + self._ifft_layer = tf.signal.ifft3d else: msg = ('FNO layer can only accept 4D or 5D data ' 'for image or video input but received input shape: {}' @@ -681,14 +681,13 @@ def call(self, x): Output tensor, this is the FNO weights added to the original input tensor. """ - t_in = x - x = self._fft_layer(x) + x = self._fft_layer(tf.cast(x, tf.complex64)) for layer in self._mlp_layers: x = layer(x) x = self._softshrink(x, lambd=self._sparsity_threshold) - x = tf.cast(x, dtype=tf.complex64) - x = self._ifft_layer(x) + x = self._ifft_layer(tf.cast(x, tf.complex64)) + x = tf.cast(x, dtype=t_in.dtype) return x + t_in From 1df0d28bcc5adbff999e8eeb8514193d078e08a5 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Sat, 18 Nov 2023 10:14:53 -0700 Subject: [PATCH 4/6] added transposes to fft the correct dims --- phygnn/layers/custom_layers.py | 31 ++++++++++++++++++++++++------- 1 file changed, 24 insertions(+), 7 deletions(-) diff --git a/phygnn/layers/custom_layers.py b/phygnn/layers/custom_layers.py index a985bbc..6f7f32a 100644 --- a/phygnn/layers/custom_layers.py +++ b/phygnn/layers/custom_layers.py @@ -629,17 +629,31 @@ def __init__(self, filters, sparsity_threshold=0.5, activation='relu'): self._activation = activation self._n_channels = None self._dense_units = None - self._sparsity_threshold = sparsity_threshold + self._lambd = sparsity_threshold - def _softshrink(self, x, lambd=0.5): + def _softshrink(self, x): """Softshrink activation function https://pytorch.org/docs/stable/generated/torch.nn.Softshrink.html """ - values_below_lower = tf.where(x < -lambd, x + lambd, 0) - values_above_upper = tf.where(lambd < x, x - lambd, 0) + values_below_lower = tf.where(x < -self._lambd, x + self._lambd, 0) + values_above_upper = tf.where(self._lambd < x, x - self._lambd, 0) return values_below_lower + values_above_upper + def _fft(self, x): + """Apply needed transpositions and fft operation.""" + x = tf.transpose(x, perm=self._perms_in) + x = self._fft_layer(tf.cast(x, tf.complex64)) + x = tf.transpose(x, perm=self._perms_out) + return x + + def _ifft(self, x): + """Apply needed transpositions and ifft operation.""" + x = tf.transpose(x, perm=self._perms_in) + x = self._ifft_layer(tf.cast(x, tf.complex64)) + x = tf.transpose(x, perm=self._perms_out) + return x + def build(self, input_shape): """Build the FNO layer based on an input shape @@ -649,6 +663,9 @@ def build(self, input_shape): Shape tuple of the input tensor """ self._n_channels = input_shape[-1] + dims = list(range(len(input_shape))) + self._perms_in = [dims[-1], *dims[:-1]] + self._perms_out = [*dims[1:], dims[0]] if len(input_shape) == 4: self._fft_layer = tf.signal.fft2d @@ -682,11 +699,11 @@ def call(self, x): tensor. """ t_in = x - x = self._fft_layer(tf.cast(x, tf.complex64)) + x = self._fft(x) for layer in self._mlp_layers: x = layer(x) - x = self._softshrink(x, lambd=self._sparsity_threshold) - x = self._ifft_layer(tf.cast(x, tf.complex64)) + x = self._softshrink(x) + x = self._ifft(x) x = tf.cast(x, dtype=t_in.dtype) return x + t_in From ef61e16d0f805800bbd98f6388591be33ce80666 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Sun, 19 Nov 2023 09:14:21 -0700 Subject: [PATCH 5/6] minor fno update --- phygnn/layers/custom_layers.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/phygnn/layers/custom_layers.py b/phygnn/layers/custom_layers.py index 6f7f32a..e7358ca 100644 --- a/phygnn/layers/custom_layers.py +++ b/phygnn/layers/custom_layers.py @@ -606,6 +606,8 @@ class FNO(tf.keras.layers.Layer): ---------- 1. FourCastNet: A Global Data-driven High-resolution Weather Model using Adaptive Fourier Neural Operators. http://arxiv.org/abs/2202.11214 + 2. Adaptive Fourier Neural Operators: Efficient Token Mixers for + Transformers. http://arxiv.org/abs/2111.13587 """ def __init__(self, filters, sparsity_threshold=0.5, activation='relu'): @@ -616,9 +618,9 @@ def __init__(self, filters, sparsity_threshold=0.5, activation='relu'): Number of dense connections in the FNO block. sparsity_threshold : float Parameter to control sparsity and shrinkage in the softshrink - activation function. + activation function following the MLP layers. activation : str - Activation function used in the dense layer of the FNO block. + Activation function used in MLP layers. """ super().__init__() @@ -628,7 +630,8 @@ def __init__(self, filters, sparsity_threshold=0.5, activation='relu'): self._mlp_layers = None self._activation = activation self._n_channels = None - self._dense_units = None + self._perms_in = None + self._perms_out = None self._lambd = sparsity_threshold def _softshrink(self, x): @@ -684,6 +687,12 @@ def build(self, input_shape): tf.keras.layers.Dense(self._filters, activation=self._activation), tf.keras.layers.Dense(self._n_channels)] + def _mlp_block(self, x): + """Run mlp layers on input""" + for layer in self._mlp_layers: + x = layer(x) + return x + def call(self, x): """Call the custom FourierNeuralOperator layer @@ -700,8 +709,7 @@ def call(self, x): """ t_in = x x = self._fft(x) - for layer in self._mlp_layers: - x = layer(x) + x = self._mlp_block(x) x = self._softshrink(x) x = self._ifft(x) x = tf.cast(x, dtype=t_in.dtype) From c422acfd89664066d287b74f64634fdd54d22d22 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Mon, 20 Nov 2023 08:46:51 -0700 Subject: [PATCH 6/6] cant have codecov going down --- tests/test_layers.py | 52 +++++++++++++++++++++++++++++++++++++------- 1 file changed, 44 insertions(+), 8 deletions(-) diff --git a/tests/test_layers.py b/tests/test_layers.py index 64a7a31..54e4ded 100644 --- a/tests/test_layers.py +++ b/tests/test_layers.py @@ -5,13 +5,15 @@ import pytest import tensorflow as tf -from phygnn.layers.custom_layers import (SkipConnection, - SpatioTemporalExpansion, - FlattenAxis, - ExpandDims, - TileLayer, - GaussianNoiseAxis) -from phygnn.layers.handlers import Layers, HiddenLayers +from phygnn.layers.custom_layers import ( + ExpandDims, + FlattenAxis, + GaussianNoiseAxis, + SkipConnection, + SpatioTemporalExpansion, + TileLayer, +) +from phygnn.layers.handlers import HiddenLayers, Layers @pytest.mark.parametrize( @@ -208,7 +210,7 @@ def test_temporal_depth_to_time(t_mult, s_mult, t_roll): n_filters = 2 * s_mult**2 * t_mult shape = (1, 4, 4, 3, n_filters) n = np.product(shape) - x = np.arange(n).reshape((shape)) + x = np.arange(n).reshape(shape) y = layer(x) assert y.shape[0] == x.shape[0] assert y.shape[1] == s_mult * x.shape[1] @@ -387,3 +389,37 @@ def test_squeeze_excite_3d(): x = layer(x) with pytest.raises(tf.errors.InvalidArgumentError): tf.assert_equal(x_in, x) + + +def test_fno_2d(): + """Test the FNO layer with 2D data (4D tensor input)""" + hidden_layers = [ + {'class': 'FNO', 'filters': 8, 'sparsity_threshold': 0.01, + 'activation': 'relu'}] + layers = HiddenLayers(hidden_layers) + assert len(layers.layers) == 1 + + x = np.random.normal(0, 1, size=(1, 4, 4, 3)) + + for layer in layers: + x_in = x + x = layer(x) + with pytest.raises(tf.errors.InvalidArgumentError): + tf.assert_equal(x_in, x) + + +def test_fno_3d(): + """Test the FNO layer with 3D data (5D tensor input)""" + hidden_layers = [ + {'class': 'FNO', 'filters': 8, 'sparsity_threshold': 0.01, + 'activation': 'relu'}] + layers = HiddenLayers(hidden_layers) + assert len(layers.layers) == 1 + + x = np.random.normal(0, 1, size=(1, 4, 4, 6, 3)) + + for layer in layers: + x_in = x + x = layer(x) + with pytest.raises(tf.errors.InvalidArgumentError): + tf.assert_equal(x_in, x)