Skip to content

Commit

Permalink
Merge pull request #45 from NREL/bnb/fno
Browse files Browse the repository at this point in the history
Bnb/fno
  • Loading branch information
bnb32 authored Nov 21, 2023
2 parents 42c914a + c422acf commit a7ca20d
Show file tree
Hide file tree
Showing 2 changed files with 164 additions and 8 deletions.
120 changes: 120 additions & 0 deletions phygnn/layers/custom_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,6 +597,126 @@ def call(self, x):
return x


class FNO(tf.keras.layers.Layer):
"""Custom layer for fourier neural operator block
Note that this is only set up to take a channels-last input
References
----------
1. FourCastNet: A Global Data-driven High-resolution Weather Model using
Adaptive Fourier Neural Operators. http://arxiv.org/abs/2202.11214
2. Adaptive Fourier Neural Operators: Efficient Token Mixers for
Transformers. http://arxiv.org/abs/2111.13587
"""

def __init__(self, filters, sparsity_threshold=0.5, activation='relu'):
"""
Parameters
----------
filters : int
Number of dense connections in the FNO block.
sparsity_threshold : float
Parameter to control sparsity and shrinkage in the softshrink
activation function following the MLP layers.
activation : str
Activation function used in MLP layers.
"""

super().__init__()
self._filters = filters
self._fft_layer = None
self._ifft_layer = None
self._mlp_layers = None
self._activation = activation
self._n_channels = None
self._perms_in = None
self._perms_out = None
self._lambd = sparsity_threshold

def _softshrink(self, x):
"""Softshrink activation function
https://pytorch.org/docs/stable/generated/torch.nn.Softshrink.html
"""
values_below_lower = tf.where(x < -self._lambd, x + self._lambd, 0)
values_above_upper = tf.where(self._lambd < x, x - self._lambd, 0)
return values_below_lower + values_above_upper

def _fft(self, x):
"""Apply needed transpositions and fft operation."""
x = tf.transpose(x, perm=self._perms_in)
x = self._fft_layer(tf.cast(x, tf.complex64))
x = tf.transpose(x, perm=self._perms_out)
return x

def _ifft(self, x):
"""Apply needed transpositions and ifft operation."""
x = tf.transpose(x, perm=self._perms_in)
x = self._ifft_layer(tf.cast(x, tf.complex64))
x = tf.transpose(x, perm=self._perms_out)
return x

def build(self, input_shape):
"""Build the FNO layer based on an input shape
Parameters
----------
input_shape : tuple
Shape tuple of the input tensor
"""
self._n_channels = input_shape[-1]
dims = list(range(len(input_shape)))
self._perms_in = [dims[-1], *dims[:-1]]
self._perms_out = [*dims[1:], dims[0]]

if len(input_shape) == 4:
self._fft_layer = tf.signal.fft2d
self._ifft_layer = tf.signal.ifft2d
elif len(input_shape) == 5:
self._fft_layer = tf.signal.fft3d
self._ifft_layer = tf.signal.ifft3d
else:
msg = ('FNO layer can only accept 4D or 5D data '
'for image or video input but received input shape: {}'
.format(input_shape))
logger.error(msg)
raise RuntimeError(msg)

self._mlp_layers = [
tf.keras.layers.Dense(self._filters, activation=self._activation),
tf.keras.layers.Dense(self._n_channels)]

def _mlp_block(self, x):
"""Run mlp layers on input"""
for layer in self._mlp_layers:
x = layer(x)
return x

def call(self, x):
"""Call the custom FourierNeuralOperator layer
Parameters
----------
x : tf.Tensor
Input tensor.
Returns
-------
x : tf.Tensor
Output tensor, this is the FNO weights added to the original input
tensor.
"""
t_in = x
x = self._fft(x)
x = self._mlp_block(x)
x = self._softshrink(x)
x = self._ifft(x)
x = tf.cast(x, dtype=t_in.dtype)

return x + t_in


class Sup3rAdder(tf.keras.layers.Layer):
"""Layer to add high-resolution data to a sup3r model in the middle of a
super resolution forward pass."""
Expand Down
52 changes: 44 additions & 8 deletions tests/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@
import pytest
import tensorflow as tf

from phygnn.layers.custom_layers import (SkipConnection,
SpatioTemporalExpansion,
FlattenAxis,
ExpandDims,
TileLayer,
GaussianNoiseAxis)
from phygnn.layers.handlers import Layers, HiddenLayers
from phygnn.layers.custom_layers import (
ExpandDims,
FlattenAxis,
GaussianNoiseAxis,
SkipConnection,
SpatioTemporalExpansion,
TileLayer,
)
from phygnn.layers.handlers import HiddenLayers, Layers


@pytest.mark.parametrize(
Expand Down Expand Up @@ -208,7 +210,7 @@ def test_temporal_depth_to_time(t_mult, s_mult, t_roll):
n_filters = 2 * s_mult**2 * t_mult
shape = (1, 4, 4, 3, n_filters)
n = np.product(shape)
x = np.arange(n).reshape((shape))
x = np.arange(n).reshape(shape)
y = layer(x)
assert y.shape[0] == x.shape[0]
assert y.shape[1] == s_mult * x.shape[1]
Expand Down Expand Up @@ -387,3 +389,37 @@ def test_squeeze_excite_3d():
x = layer(x)
with pytest.raises(tf.errors.InvalidArgumentError):
tf.assert_equal(x_in, x)


def test_fno_2d():
"""Test the FNO layer with 2D data (4D tensor input)"""
hidden_layers = [
{'class': 'FNO', 'filters': 8, 'sparsity_threshold': 0.01,
'activation': 'relu'}]
layers = HiddenLayers(hidden_layers)
assert len(layers.layers) == 1

x = np.random.normal(0, 1, size=(1, 4, 4, 3))

for layer in layers:
x_in = x
x = layer(x)
with pytest.raises(tf.errors.InvalidArgumentError):
tf.assert_equal(x_in, x)


def test_fno_3d():
"""Test the FNO layer with 3D data (5D tensor input)"""
hidden_layers = [
{'class': 'FNO', 'filters': 8, 'sparsity_threshold': 0.01,
'activation': 'relu'}]
layers = HiddenLayers(hidden_layers)
assert len(layers.layers) == 1

x = np.random.normal(0, 1, size=(1, 4, 4, 6, 3))

for layer in layers:
x_in = x
x = layer(x)
with pytest.raises(tf.errors.InvalidArgumentError):
tf.assert_equal(x_in, x)

0 comments on commit a7ca20d

Please sign in to comment.