Skip to content

Commit

Permalink
added option for numpy padding in FlexiblePadding layer
Browse files Browse the repository at this point in the history
  • Loading branch information
grantbuster committed Oct 24, 2024
1 parent c5b0cdb commit 6f7b7b7
Showing 1 changed file with 20 additions and 5 deletions.
25 changes: 20 additions & 5 deletions phygnn/layers/custom_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class FlexiblePadding(tf.keras.layers.Layer):
"""Class to perform padding on tensors
"""

def __init__(self, paddings, mode='REFLECT'):
def __init__(self, paddings, mode='REFLECT', option='tf'):
"""
Parameters
----------
Expand All @@ -21,13 +21,29 @@ def __init__(self, paddings, mode='REFLECT'):
rank of the tensor and elements give the number
of leading and trailing pads
mode : str
tf.pad() padding mode. Can be REFLECT, CONSTANT,
tf.pad() / np.pad() padding mode. Can be REFLECT, CONSTANT,
or SYMMETRIC
option : str
Option for TensorFlow padding ("tf") or numpy ("np"). Default is tf
for tensorflow training. We have observed silent failures of
tf.pad() with larger array sizes, so "np" might be preferable at
inference time on large chunks.
"""
super().__init__()
self.paddings = tf.constant(paddings)
self.rank = len(paddings)
self.mode = mode
self.mode = mode.lower()
self.option = option.lower()

if self.option == 'tf':
self._pad_fun = tf.pad
elif self.option == 'np':
self._pad_fun = np.pad
else:
msg = ('FlexiblePadding option must be "tf" or "np" but '
f'received: {self.option}')
logger.error(msg)
raise KeyError(msg)

def compute_output_shape(self, input_shape):
"""Computes output shape after padding
Expand Down Expand Up @@ -62,8 +78,7 @@ def call(self, x):
by compute_output_shape
"""
return tf.pad(x, self.paddings,
mode=self.mode)
return self._pad_fun(x, self.paddings, mode=self.mode)


class ExpandDims(tf.keras.layers.Layer):
Expand Down

0 comments on commit 6f7b7b7

Please sign in to comment.