Skip to content

Commit

Permalink
[transformer] add norm eps (#2397)
Browse files Browse the repository at this point in the history
  • Loading branch information
Mddct authored Mar 8, 2024
1 parent d01715a commit a93af33
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 37 deletions.
22 changes: 13 additions & 9 deletions wenet/transformer/convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,16 @@
class ConvolutionModule(nn.Module):
"""ConvolutionModule in Conformer model."""

def __init__(self,
channels: int,
kernel_size: int = 15,
activation: nn.Module = nn.ReLU(),
norm: str = "batch_norm",
causal: bool = False,
bias: bool = True):
def __init__(
self,
channels: int,
kernel_size: int = 15,
activation: nn.Module = nn.ReLU(),
norm: str = "batch_norm",
causal: bool = False,
bias: bool = True,
norm_eps: float = 1e-5,
):
"""Construct an ConvolutionModule object.
Args:
channels (int): The number of channels of conv layers.
Expand Down Expand Up @@ -73,10 +76,11 @@ def __init__(self,
assert norm in ['batch_norm', 'layer_norm', 'rms_norm']
if norm == "batch_norm":
self.use_layer_norm = False
self.norm = WENET_NORM_CLASSES['batch_norm'](channels)
self.norm = WENET_NORM_CLASSES['batch_norm'](channels,
eps=norm_eps)
else:
self.use_layer_norm = True
self.norm = WENET_NORM_CLASSES[norm](channels)
self.norm = WENET_NORM_CLASSES[norm](channels, eps=norm_eps)

self.pointwise_conv2 = nn.Conv1d(
channels,
Expand Down
17 changes: 14 additions & 3 deletions wenet/transformer/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def __init__(
use_sdpa: bool = False,
mlp_type: str = 'position_wise_feed_forward',
layer_norm_type: str = 'layer_norm',
norm_eps: float = 1e-5,
):
super().__init__()
attention_dim = encoder_output_size
Expand All @@ -98,7 +99,7 @@ def __init__(
assert layer_norm_type in ['layer_norm', 'rms_norm']
self.normalize_before = normalize_before
self.after_norm = WENET_NORM_CLASSES[layer_norm_type](attention_dim,
eps=1e-5)
eps=norm_eps)
self.use_output_layer = use_output_layer
if use_output_layer:
self.output_layer = torch.nn.Linear(attention_dim, vocab_size)
Expand All @@ -122,6 +123,8 @@ def __init__(
activation, mlp_bias),
dropout_rate,
normalize_before,
layer_norm_type,
norm_eps,
) for _ in range(self.num_blocks)
])

Expand Down Expand Up @@ -329,6 +332,8 @@ def __init__(
gradient_checkpointing: bool = False,
tie_word_embedding: bool = False,
use_sdpa: bool = False,
layer_norm_type: str = 'layer_norm',
norm_eps: float = 1e-5,
):

super().__init__()
Expand All @@ -352,7 +357,10 @@ def __init__(
value_bias=value_bias,
gradient_checkpointing=gradient_checkpointing,
tie_word_embedding=tie_word_embedding,
use_sdpa=use_sdpa)
use_sdpa=use_sdpa,
layer_norm_type=layer_norm_type,
norm_eps=norm_eps,
)

self.right_decoder = TransformerDecoder(
vocab_size,
Expand All @@ -373,7 +381,10 @@ def __init__(
mlp_bias=mlp_bias,
gradient_checkpointing=gradient_checkpointing,
tie_word_embedding=tie_word_embedding,
use_sdpa=use_sdpa)
use_sdpa=use_sdpa,
layer_norm_type=layer_norm_type,
norm_eps=norm_eps,
)

def forward(
self,
Expand Down
7 changes: 4 additions & 3 deletions wenet/transformer/decoder_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def __init__(
dropout_rate: float,
normalize_before: bool = True,
layer_norm_type: str = 'layer_norm',
norm_eps: float = 1e-5,
):
"""Construct an DecoderLayer object."""
super().__init__()
Expand All @@ -57,9 +58,9 @@ def __init__(
self.src_attn = src_attn
self.feed_forward = feed_forward
assert layer_norm_type in ['layer_norm', 'rms_norm']
self.norm1 = WENET_NORM_CLASSES[layer_norm_type](size, eps=1e-5)
self.norm2 = WENET_NORM_CLASSES[layer_norm_type](size, eps=1e-5)
self.norm3 = WENET_NORM_CLASSES[layer_norm_type](size, eps=1e-5)
self.norm1 = WENET_NORM_CLASSES[layer_norm_type](size, eps=norm_eps)
self.norm2 = WENET_NORM_CLASSES[layer_norm_type](size, eps=norm_eps)
self.norm3 = WENET_NORM_CLASSES[layer_norm_type](size, eps=norm_eps)
self.dropout = nn.Dropout(dropout_rate)
self.normalize_before = normalize_before

Expand Down
38 changes: 23 additions & 15 deletions wenet/transformer/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def __init__(
gradient_checkpointing: bool = False,
use_sdpa: bool = False,
layer_norm_type: str = 'layer_norm',
norm_eps: float = 1e-5,
):
"""
Args:
Expand Down Expand Up @@ -107,7 +108,7 @@ def __init__(
assert layer_norm_type in ['layer_norm', 'rms_norm']
self.normalize_before = normalize_before
self.after_norm = WENET_NORM_CLASSES[layer_norm_type](output_size,
eps=1e-5)
eps=norm_eps)
self.static_chunk_size = static_chunk_size
self.use_dynamic_chunk = use_dynamic_chunk
self.use_dynamic_left_chunk = use_dynamic_left_chunk
Expand Down Expand Up @@ -373,6 +374,7 @@ def __init__(
use_sdpa: bool = False,
mlp_type: str = 'position_wise_feed_forward',
layer_norm_type: str = 'layer_norm',
norm_eps: float = 1e-5,
):
""" Construct TransformerEncoder
Expand All @@ -384,22 +386,24 @@ def __init__(
input_layer, pos_enc_layer_type, normalize_before,
static_chunk_size, use_dynamic_chunk, global_cmvn,
use_dynamic_left_chunk, gradient_checkpointing,
use_sdpa, layer_norm_type)
use_sdpa, layer_norm_type, norm_eps)
activation = WENET_ACTIVATION_CLASSES[activation_type]()
mlp_class = WENET_MLP_CLASSES[mlp_type]
self.encoders = torch.nn.ModuleList([
TransformerEncoderLayer(output_size,
WENET_ATTENTION_CLASSES["selfattn"](
attention_heads, output_size,
attention_dropout_rate, query_bias,
key_bias, value_bias, use_sdpa),
mlp_class(output_size, linear_units,
dropout_rate, activation,
mlp_bias),
dropout_rate,
normalize_before,
layer_norm_type=layer_norm_type)
for _ in range(num_blocks)
TransformerEncoderLayer(
output_size,
WENET_ATTENTION_CLASSES["selfattn"](attention_heads,
output_size,
attention_dropout_rate,
query_bias, key_bias,
value_bias, use_sdpa),
mlp_class(output_size, linear_units, dropout_rate, activation,
mlp_bias),
dropout_rate,
normalize_before,
layer_norm_type=layer_norm_type,
norm_eps=norm_eps,
) for _ in range(num_blocks)
])


Expand Down Expand Up @@ -439,6 +443,8 @@ def __init__(
gradient_checkpointing: bool = False,
use_sdpa: bool = False,
mlp_type: str = 'position_wise_feed_forward',
layer_norm_type: str = 'layer_norm',
norm_eps: float = 1e-5,
):
"""Construct ConformerEncoder
Expand All @@ -463,7 +469,7 @@ def __init__(
input_layer, pos_enc_layer_type, normalize_before,
static_chunk_size, use_dynamic_chunk, global_cmvn,
use_dynamic_left_chunk, gradient_checkpointing,
use_sdpa)
use_sdpa, layer_norm_type, norm_eps)
activation = WENET_ACTIVATION_CLASSES[activation_type]()

# self-attention module definition
Expand Down Expand Up @@ -500,5 +506,7 @@ def __init__(
*convolution_layer_args) if use_cnn_module else None,
dropout_rate,
normalize_before,
layer_norm_type=layer_norm_type,
norm_eps=norm_eps,
) for _ in range(num_blocks)
])
16 changes: 9 additions & 7 deletions wenet/transformer/encoder_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,15 @@ def __init__(
dropout_rate: float,
normalize_before: bool = True,
layer_norm_type: str = 'layer_norm',
norm_eps: float = 1e-5,
):
"""Construct an EncoderLayer object."""
super().__init__()
self.self_attn = self_attn
self.feed_forward = feed_forward
assert layer_norm_type in ['layer_norm', 'rms_norm']
self.norm1 = WENET_NORM_CLASSES[layer_norm_type](size, eps=1e-5)
self.norm2 = WENET_NORM_CLASSES[layer_norm_type](size, eps=1e-5)
self.norm1 = WENET_NORM_CLASSES[layer_norm_type](size, eps=norm_eps)
self.norm2 = WENET_NORM_CLASSES[layer_norm_type](size, eps=norm_eps)
self.dropout = nn.Dropout(dropout_rate)
self.size = size
self.normalize_before = normalize_before
Expand Down Expand Up @@ -140,6 +141,7 @@ def __init__(
dropout_rate: float = 0.1,
normalize_before: bool = True,
layer_norm_type: str = 'layer_norm',
norm_eps: float = 1e-5,
):
"""Construct an EncoderLayer object."""
super().__init__()
Expand All @@ -149,20 +151,20 @@ def __init__(
self.feed_forward_macaron = feed_forward_macaron
self.conv_module = conv_module
self.norm_ff = WENET_NORM_CLASSES[layer_norm_type](
size, eps=1e-5) # for the FNN module
size, eps=norm_eps) # for the FNN module
self.norm_mha = WENET_NORM_CLASSES[layer_norm_type](
size, eps=1e-5) # for the MHA module
size, eps=norm_eps) # for the MHA module
if feed_forward_macaron is not None:
self.norm_ff_macaron = WENET_NORM_CLASSES[layer_norm_type](
size, eps=1e-5)
size, eps=norm_eps)
self.ff_scale = 0.5
else:
self.ff_scale = 1.0
if self.conv_module is not None:
self.norm_conv = WENET_NORM_CLASSES[layer_norm_type](
size, eps=1e-5) # for the CNN module
size, eps=norm_eps) # for the CNN module
self.norm_final = WENET_NORM_CLASSES[layer_norm_type](
size, eps=1e-5) # for the final output of the block
size, eps=norm_eps) # for the final output of the block
self.dropout = nn.Dropout(dropout_rate)
self.size = size
self.normalize_before = normalize_before
Expand Down

0 comments on commit a93af33

Please sign in to comment.