-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
blockformer #1504
Open
LeonWlw
wants to merge
7
commits into
wenet-e2e:main
Choose a base branch
from
LeonWlw:ml_blockformer
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
blockformer #1504
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
bcc219a
add se_layer for blockformer
LeonWlw 9fdb662
add se_layer for blockformer
LeonWlw f85d5ef
fix formatting issues
LeonWlw 3987af4
Merge branch 'wenet-e2e:main' into ml_blockformer
LeonWlw 058cbfe
Merge branch 'ml_blockformer' of github.com:LeonWlw/wenet into ml_blo…
LeonWlw 6e6f759
fix formatting issues
LeonWlw 89e17d4
add aishell results
LeonWlw File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
# network architecture | ||
# encoder related | ||
encoder: conformer | ||
encoder_conf: | ||
output_size: 256 # dimension of attention | ||
attention_heads: 4 | ||
linear_units: 2048 # the number of units of position-wise feed forward | ||
num_blocks: 12 # the number of encoder blocks | ||
dropout_rate: 0.1 | ||
positional_dropout_rate: 0.1 | ||
attention_dropout_rate: 0.0 | ||
input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8 | ||
normalize_before: true | ||
cnn_module_kernel: 15 | ||
use_cnn_module: True | ||
activation_type: 'swish' | ||
pos_enc_layer_type: 'rel_pos' | ||
selfattention_layer_type: 'rel_selfattn' | ||
use_se_module: true | ||
se_module_channel: 12 # the same number with encoder blocks | ||
|
||
# decoder related | ||
decoder: transformer | ||
decoder_conf: | ||
attention_heads: 4 | ||
linear_units: 2048 | ||
num_blocks: 6 | ||
dropout_rate: 0.1 | ||
positional_dropout_rate: 0.1 | ||
input_layer: 'rel_embed' | ||
self_attention_dropout_rate: 0.0 | ||
src_attention_dropout_rate: 0.0 | ||
use_se_module: true | ||
se_module_channel: 6 # the same number with decoder blocks | ||
|
||
# hybrid CTC/attention | ||
model_conf: | ||
ctc_weight: 0.3 | ||
lsm_weight: 0.1 # label smoothing option | ||
length_normalized_loss: false | ||
|
||
dataset_conf: | ||
filter_conf: | ||
max_length: 40960 | ||
min_length: 0 | ||
token_max_length: 200 | ||
token_min_length: 1 | ||
resample_conf: | ||
resample_rate: 16000 | ||
speed_perturb: true | ||
fbank_conf: | ||
num_mel_bins: 80 | ||
frame_shift: 10 | ||
frame_length: 25 | ||
dither: 0.1 | ||
spec_aug: true | ||
spec_aug_conf: | ||
num_t_mask: 2 | ||
num_f_mask: 2 | ||
max_t: 50 | ||
max_f: 10 | ||
shuffle: true | ||
shuffle_conf: | ||
shuffle_size: 1500 | ||
sort: false | ||
sort_conf: | ||
sort_size: 500 # sort_size should be less than shuffle_size | ||
batch_conf: | ||
batch_type: 'static' # static or dynamic | ||
batch_size: 16 | ||
|
||
grad_clip: 5 | ||
accum_grad: 4 | ||
max_epoch: 360 | ||
log_interval: 100 | ||
|
||
optim: adam | ||
optim_conf: | ||
lr: 0.002 | ||
scheduler: warmuplr # pytorch v1.1.0+ required | ||
scheduler_conf: | ||
warmup_steps: 50000 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
# Copyright (c) 2022 Mininglamp Com (Liuwei Wei, Xiaoming Ren) | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
"""Squeeze-and-Excitation layer definition.""" | ||
|
||
import torch | ||
|
||
|
||
class SELayer(torch.nn.Module): | ||
def __init__(self, channel: int, reduction: int = 1): | ||
super().__init__() | ||
self.avg_pool = torch.nn.AdaptiveAvgPool2d(1) | ||
self.fc = torch.nn.Sequential( | ||
torch.nn.Linear(channel, channel // reduction, bias=False), | ||
torch.nn.ReLU(inplace=True), | ||
torch.nn.Linear(channel // reduction, channel, bias=False), | ||
torch.nn.Sigmoid() | ||
) | ||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
b, c, _, _ = x.size() | ||
y = self.avg_pool(x).view(b, c) | ||
y = self.fc(y).view(b, c, 1, 1) | ||
return x * y.expand_as(x) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
avg_pool over T and D dim should consider pad_mask ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for your remind, we will update pad_mask to the code and retrain it .