diff --git a/tensorflow_tts/models/hifigan.py b/tensorflow_tts/models/hifigan.py index 155c508f..0647a3e5 100644 --- a/tensorflow_tts/models/hifigan.py +++ b/tensorflow_tts/models/hifigan.py @@ -88,7 +88,7 @@ def __init__( self._apply_weightnorm(self.blocks_1) self._apply_weightnorm(self.blocks_2) - def call(self, x): + def call(self, x, training=False): """Calculate forward propagation. Args: x (Tensor): Input tensor (B, T, C). @@ -116,6 +116,23 @@ def _apply_weightnorm(self, list_layers): pass +class TFMultiHifiResBlock(tf.keras.layers.Layer): + """Tensorflow Multi Hifigan resblock 1 module.""" + + def __init__(self, list_resblock, **kwargs): + super().__init__(**kwargs) + self.list_resblock = list_resblock + + def call(self, x, training=False): + xs = None + for resblock in self.list_resblock: + if xs is None: + xs = resblock(x, training=training) + else: + xs += resblock(x, training=training) + return xs / len(self.list_resblock) + + class TFHifiGANGenerator(tf.keras.Model): def __init__(self, config, **kwargs): super().__init__(**kwargs) @@ -158,21 +175,26 @@ def __init__(self, config, **kwargs): ), ] - # ad residual stack layer - for j in range(config.stacks): - layers += [ - TFHifiResBlock( - kernel_size=config.stack_kernel_size[j], - filters=config.filters // (2 ** (i + 1)), - dilation_rate=config.stack_dilation_rate[j], - use_bias=config.use_bias, - nonlinear_activation=config.nonlinear_activation, - nonlinear_activation_params=config.nonlinear_activation_params, - is_weight_norm=config.is_weight_norm, - initializer_seed=config.initializer_seed, - name="hifigan_resblock_._{}._._{}".format(i, j), - ) - ] + # add residual stack layer + layers += [ + TFMultiHifiResBlock( + list_resblock=[ + TFHifiResBlock( + kernel_size=config.stack_kernel_size[j], + filters=config.filters // (2 ** (i + 1)), + dilation_rate=config.stack_dilation_rate[j], + use_bias=config.use_bias, + nonlinear_activation=config.nonlinear_activation, + nonlinear_activation_params=config.nonlinear_activation_params, + is_weight_norm=config.is_weight_norm, + initializer_seed=config.initializer_seed, + name="hifigan_resblock_._{}".format(j), + ) + for j in range(config.stacks) + ], + name="multi_hifigan_resblock_._{}".format(i), + ) + ] # add final layer layers += [ getattr(tf.keras.layers, config.nonlinear_activation)(