Skip to content

Commit

Permalink
[wenet] add dropout for subsampling.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Mddct authored Sep 28, 2023
1 parent 6158f9e commit 915d27c
Showing 1 changed file with 14 additions and 7 deletions.
21 changes: 14 additions & 7 deletions wenet/transformer/subsampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,9 @@ def __init__(self, idim: int, odim: int, dropout_rate: float,
torch.nn.ReLU(),
)
self.out = torch.nn.Sequential(
torch.nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim))
torch.nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim),
torch.nn.Dropout(dropout_rate),
)
self.pos_enc = pos_enc_class
# The right context for every conv layer is computed by:
# (kernel_size - 1) * frame_rate_of_this_layer
Expand Down Expand Up @@ -188,8 +190,10 @@ def __init__(self, idim: int, odim: int, dropout_rate: float,
torch.nn.Conv2d(odim, odim, 5, 3),
torch.nn.ReLU(),
)
self.linear = torch.nn.Linear(odim * (((idim - 1) // 2 - 2) // 3),
odim)
self.out = torch.nn.Sequential(
torch.nn.Linear(odim * (((idim - 1) // 2 - 2) // 3), odim),
torch.nn.Dropout(dropout_rate),
)
self.pos_enc = pos_enc_class
# 10 = (3 - 1) * 1 + (5 - 1) * 2
self.subsampling_rate = 6
Expand All @@ -216,7 +220,7 @@ def forward(
x = x.unsqueeze(1) # (b, c, t, f)
x = self.conv(x)
b, c, t, f = x.size()
x = self.linear(x.transpose(1, 2).contiguous().view(b, t, c * f))
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
x, pos_emb = self.pos_enc(x, offset)
return x, pos_emb, x_mask[:, :, 2::2][:, :, 4::3]

Expand All @@ -243,8 +247,11 @@ def __init__(self, idim: int, odim: int, dropout_rate: float,
torch.nn.Conv2d(odim, odim, 3, 2),
torch.nn.ReLU(),
)
self.linear = torch.nn.Linear(
odim * ((((idim - 1) // 2 - 1) // 2 - 1) // 2), odim)
self.out = torch.nn.Sequential(
torch.nn.Linear(
odim * ((((idim - 1) // 2 - 1) // 2 - 1) // 2), odim),
torch.nn.Dropout(dropout_rate),
)
self.pos_enc = pos_enc_class
self.subsampling_rate = 8
# 14 = (3 - 1) * 1 + (3 - 1) * 2 + (3 - 1) * 4
Expand Down Expand Up @@ -272,6 +279,6 @@ def forward(
x = x.unsqueeze(1) # (b, c, t, f)
x = self.conv(x)
b, c, t, f = x.size()
x = self.linear(x.transpose(1, 2).contiguous().view(b, t, c * f))
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
x, pos_emb = self.pos_enc(x, offset)
return x, pos_emb, x_mask[:, :, 2::2][:, :, 2::2][:, :, 2::2]

0 comments on commit 915d27c

Please sign in to comment.