Skip to content

Commit

Permalink
🥀 Correct hifigan resblock logic.
Browse files Browse the repository at this point in the history
  • Loading branch information
dathudeptrai committed Feb 28, 2021
1 parent b7b0288 commit 9a107d9
Showing 1 changed file with 38 additions and 16 deletions.
54 changes: 38 additions & 16 deletions tensorflow_tts/models/hifigan.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def __init__(
self._apply_weightnorm(self.blocks_1)
self._apply_weightnorm(self.blocks_2)

def call(self, x):
def call(self, x, training=False):
"""Calculate forward propagation.
Args:
x (Tensor): Input tensor (B, T, C).
Expand Down Expand Up @@ -116,6 +116,23 @@ def _apply_weightnorm(self, list_layers):
pass


class TFMultiHifiResBlock(tf.keras.layers.Layer):
"""Tensorflow Multi Hifigan resblock 1 module."""

def __init__(self, list_resblock, **kwargs):
super().__init__(**kwargs)
self.list_resblock = list_resblock

def call(self, x, training=False):
xs = None
for resblock in self.list_resblock:
if xs is None:
xs = resblock(x, training=training)
else:
xs += resblock(x, training=training)
return xs / len(self.list_resblock)


class TFHifiGANGenerator(tf.keras.Model):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
Expand Down Expand Up @@ -158,21 +175,26 @@ def __init__(self, config, **kwargs):
),
]

# ad residual stack layer
for j in range(config.stacks):
layers += [
TFHifiResBlock(
kernel_size=config.stack_kernel_size[j],
filters=config.filters // (2 ** (i + 1)),
dilation_rate=config.stack_dilation_rate[j],
use_bias=config.use_bias,
nonlinear_activation=config.nonlinear_activation,
nonlinear_activation_params=config.nonlinear_activation_params,
is_weight_norm=config.is_weight_norm,
initializer_seed=config.initializer_seed,
name="hifigan_resblock_._{}._._{}".format(i, j),
)
]
# add residual stack layer
layers += [
TFMultiHifiResBlock(
list_resblock=[
TFHifiResBlock(
kernel_size=config.stack_kernel_size[j],
filters=config.filters // (2 ** (i + 1)),
dilation_rate=config.stack_dilation_rate[j],
use_bias=config.use_bias,
nonlinear_activation=config.nonlinear_activation,
nonlinear_activation_params=config.nonlinear_activation_params,
is_weight_norm=config.is_weight_norm,
initializer_seed=config.initializer_seed,
name="hifigan_resblock_._{}".format(j),
)
for j in range(config.stacks)
],
name="multi_hifigan_resblock_._{}".format(i),
)
]
# add final layer
layers += [
getattr(tf.keras.layers, config.nonlinear_activation)(
Expand Down

0 comments on commit 9a107d9

Please sign in to comment.