Skip to content

Commit

Permalink
irfft has no gradient so -> ifft
Browse files Browse the repository at this point in the history
  • Loading branch information
bnb32 committed Nov 18, 2023
1 parent e6d3849 commit 1c791fa
Showing 1 changed file with 8 additions and 9 deletions.
17 changes: 8 additions & 9 deletions phygnn/layers/custom_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,14 +648,14 @@ def build(self, input_shape):
input_shape : tuple
Shape tuple of the input tensor
"""
self._n_channels = input_shape[-1] // 2 + 1
self._n_channels = input_shape[-1]

if len(input_shape) == 4:
self._fft_layer = tf.signal.rfft2d
self._ifft_layer = tf.signal.irfft2d
self._fft_layer = tf.signal.fft2d
self._ifft_layer = tf.signal.ifft2d
elif len(input_shape) == 5:
self._fft_layer = tf.signal.rfft3d
self._ifft_layer = tf.signal.irfft3d
self._fft_layer = tf.signal.fft3d
self._ifft_layer = tf.signal.ifft3d
else:
msg = ('FNO layer can only accept 4D or 5D data '
'for image or video input but received input shape: {}'
Expand All @@ -681,14 +681,13 @@ def call(self, x):
Output tensor, this is the FNO weights added to the original input
tensor.
"""

t_in = x
x = self._fft_layer(x)
x = self._fft_layer(tf.cast(x, tf.complex64))
for layer in self._mlp_layers:
x = layer(x)
x = self._softshrink(x, lambd=self._sparsity_threshold)
x = tf.cast(x, dtype=tf.complex64)
x = self._ifft_layer(x)
x = self._ifft_layer(tf.cast(x, tf.complex64))
x = tf.cast(x, dtype=t_in.dtype)

return x + t_in

Expand Down

0 comments on commit 1c791fa

Please sign in to comment.