Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bnb/fno #45

Merged
merged 6 commits into from
Nov 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 120 additions & 0 deletions phygnn/layers/custom_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,6 +597,126 @@ def call(self, x):
return x


class FNO(tf.keras.layers.Layer):
"""Custom layer for fourier neural operator block

Note that this is only set up to take a channels-last input

References
----------
1. FourCastNet: A Global Data-driven High-resolution Weather Model using
Adaptive Fourier Neural Operators. http://arxiv.org/abs/2202.11214
2. Adaptive Fourier Neural Operators: Efficient Token Mixers for
Transformers. http://arxiv.org/abs/2111.13587
"""

def __init__(self, filters, sparsity_threshold=0.5, activation='relu'):
"""
Parameters
----------
filters : int
Number of dense connections in the FNO block.
sparsity_threshold : float
Parameter to control sparsity and shrinkage in the softshrink
activation function following the MLP layers.
activation : str
Activation function used in MLP layers.
"""

super().__init__()
self._filters = filters
self._fft_layer = None
self._ifft_layer = None
self._mlp_layers = None
self._activation = activation
self._n_channels = None
self._perms_in = None
self._perms_out = None
self._lambd = sparsity_threshold

def _softshrink(self, x):
"""Softshrink activation function

https://pytorch.org/docs/stable/generated/torch.nn.Softshrink.html
"""
values_below_lower = tf.where(x < -self._lambd, x + self._lambd, 0)
values_above_upper = tf.where(self._lambd < x, x - self._lambd, 0)
return values_below_lower + values_above_upper

def _fft(self, x):
"""Apply needed transpositions and fft operation."""
x = tf.transpose(x, perm=self._perms_in)
x = self._fft_layer(tf.cast(x, tf.complex64))
x = tf.transpose(x, perm=self._perms_out)
return x

def _ifft(self, x):
"""Apply needed transpositions and ifft operation."""
x = tf.transpose(x, perm=self._perms_in)
x = self._ifft_layer(tf.cast(x, tf.complex64))
x = tf.transpose(x, perm=self._perms_out)
return x

def build(self, input_shape):
"""Build the FNO layer based on an input shape

Parameters
----------
input_shape : tuple
Shape tuple of the input tensor
"""
self._n_channels = input_shape[-1]
dims = list(range(len(input_shape)))
self._perms_in = [dims[-1], *dims[:-1]]
self._perms_out = [*dims[1:], dims[0]]

if len(input_shape) == 4:
self._fft_layer = tf.signal.fft2d
self._ifft_layer = tf.signal.ifft2d
elif len(input_shape) == 5:
self._fft_layer = tf.signal.fft3d
self._ifft_layer = tf.signal.ifft3d
else:
msg = ('FNO layer can only accept 4D or 5D data '
'for image or video input but received input shape: {}'
.format(input_shape))
logger.error(msg)
raise RuntimeError(msg)

self._mlp_layers = [
tf.keras.layers.Dense(self._filters, activation=self._activation),
tf.keras.layers.Dense(self._n_channels)]

def _mlp_block(self, x):
"""Run mlp layers on input"""
for layer in self._mlp_layers:
x = layer(x)
return x

def call(self, x):
"""Call the custom FourierNeuralOperator layer

Parameters
----------
x : tf.Tensor
Input tensor.

Returns
-------
x : tf.Tensor
Output tensor, this is the FNO weights added to the original input
tensor.
"""
t_in = x
x = self._fft(x)
x = self._mlp_block(x)
x = self._softshrink(x)
x = self._ifft(x)
x = tf.cast(x, dtype=t_in.dtype)

return x + t_in


class Sup3rAdder(tf.keras.layers.Layer):
"""Layer to add high-resolution data to a sup3r model in the middle of a
super resolution forward pass."""
Expand Down
52 changes: 44 additions & 8 deletions tests/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@
import pytest
import tensorflow as tf

from phygnn.layers.custom_layers import (SkipConnection,
SpatioTemporalExpansion,
FlattenAxis,
ExpandDims,
TileLayer,
GaussianNoiseAxis)
from phygnn.layers.handlers import Layers, HiddenLayers
from phygnn.layers.custom_layers import (
ExpandDims,
FlattenAxis,
GaussianNoiseAxis,
SkipConnection,
SpatioTemporalExpansion,
TileLayer,
)
from phygnn.layers.handlers import HiddenLayers, Layers


@pytest.mark.parametrize(
Expand Down Expand Up @@ -208,7 +210,7 @@ def test_temporal_depth_to_time(t_mult, s_mult, t_roll):
n_filters = 2 * s_mult**2 * t_mult
shape = (1, 4, 4, 3, n_filters)
n = np.product(shape)
x = np.arange(n).reshape((shape))
x = np.arange(n).reshape(shape)
y = layer(x)
assert y.shape[0] == x.shape[0]
assert y.shape[1] == s_mult * x.shape[1]
Expand Down Expand Up @@ -387,3 +389,37 @@ def test_squeeze_excite_3d():
x = layer(x)
with pytest.raises(tf.errors.InvalidArgumentError):
tf.assert_equal(x_in, x)


def test_fno_2d():
"""Test the FNO layer with 2D data (4D tensor input)"""
hidden_layers = [
{'class': 'FNO', 'filters': 8, 'sparsity_threshold': 0.01,
'activation': 'relu'}]
layers = HiddenLayers(hidden_layers)
assert len(layers.layers) == 1

x = np.random.normal(0, 1, size=(1, 4, 4, 3))

for layer in layers:
x_in = x
x = layer(x)
with pytest.raises(tf.errors.InvalidArgumentError):
tf.assert_equal(x_in, x)


def test_fno_3d():
"""Test the FNO layer with 3D data (5D tensor input)"""
hidden_layers = [
{'class': 'FNO', 'filters': 8, 'sparsity_threshold': 0.01,
'activation': 'relu'}]
layers = HiddenLayers(hidden_layers)
assert len(layers.layers) == 1

x = np.random.normal(0, 1, size=(1, 4, 4, 6, 3))

for layer in layers:
x_in = x
x = layer(x)
with pytest.raises(tf.errors.InvalidArgumentError):
tf.assert_equal(x_in, x)
Loading