Skip to content

Commit

Permalink
minor fno update
Browse files Browse the repository at this point in the history
  • Loading branch information
bnb32 committed Nov 19, 2023
1 parent 1df0d28 commit ef61e16
Showing 1 changed file with 13 additions and 5 deletions.
18 changes: 13 additions & 5 deletions phygnn/layers/custom_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,6 +606,8 @@ class FNO(tf.keras.layers.Layer):
----------
1. FourCastNet: A Global Data-driven High-resolution Weather Model using
Adaptive Fourier Neural Operators. http://arxiv.org/abs/2202.11214
2. Adaptive Fourier Neural Operators: Efficient Token Mixers for
Transformers. http://arxiv.org/abs/2111.13587
"""

def __init__(self, filters, sparsity_threshold=0.5, activation='relu'):
Expand All @@ -616,9 +618,9 @@ def __init__(self, filters, sparsity_threshold=0.5, activation='relu'):
Number of dense connections in the FNO block.
sparsity_threshold : float
Parameter to control sparsity and shrinkage in the softshrink
activation function.
activation function following the MLP layers.
activation : str
Activation function used in the dense layer of the FNO block.
Activation function used in MLP layers.
"""

super().__init__()
Expand All @@ -628,7 +630,8 @@ def __init__(self, filters, sparsity_threshold=0.5, activation='relu'):
self._mlp_layers = None
self._activation = activation
self._n_channels = None
self._dense_units = None
self._perms_in = None
self._perms_out = None
self._lambd = sparsity_threshold

def _softshrink(self, x):
Expand Down Expand Up @@ -684,6 +687,12 @@ def build(self, input_shape):
tf.keras.layers.Dense(self._filters, activation=self._activation),
tf.keras.layers.Dense(self._n_channels)]

def _mlp_block(self, x):
"""Run mlp layers on input"""
for layer in self._mlp_layers:
x = layer(x)
return x

def call(self, x):
"""Call the custom FourierNeuralOperator layer
Expand All @@ -700,8 +709,7 @@ def call(self, x):
"""
t_in = x
x = self._fft(x)
for layer in self._mlp_layers:
x = layer(x)
x = self._mlp_block(x)
x = self._softshrink(x)
x = self._ifft(x)
x = tf.cast(x, dtype=t_in.dtype)
Expand Down

0 comments on commit ef61e16

Please sign in to comment.