Skip to content

Commit

Permalink
[wenet] support word embedding for encoder
Browse files Browse the repository at this point in the history
  • Loading branch information
robin1001 committed Sep 12, 2023
1 parent 7529289 commit bfcc09b
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 18 deletions.
3 changes: 3 additions & 0 deletions wenet/transformer/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from wenet.transformer.subsampling import Conv2dSubsampling4
from wenet.transformer.subsampling import Conv2dSubsampling6
from wenet.transformer.subsampling import Conv2dSubsampling8
from wenet.transformer.subsampling import EmbedinigNoSubsampling
from wenet.transformer.subsampling import LinearNoSubsampling
from wenet.utils.common import get_activation
from wenet.utils.mask import make_pad_mask
Expand Down Expand Up @@ -104,6 +105,8 @@ def __init__(
subsampling_class = Conv2dSubsampling6
elif input_layer == "conv2d8":
subsampling_class = Conv2dSubsampling8
elif input_layer == "embed":
subsampling_class = EmbedinigNoSubsampling
else:
raise ValueError("unknown input_layer: " + input_layer)

Expand Down
73 changes: 55 additions & 18 deletions wenet/transformer/subsampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from ESPnet(https://github.com/espnet/espnet)


"""Subsampling layer definition."""

from typing import Tuple, Union
Expand All @@ -22,6 +20,7 @@


class BaseSubsampling(torch.nn.Module):

def __init__(self):
super().__init__()
self.right_context = 0
Expand All @@ -32,6 +31,40 @@ def position_encoding(self, offset: Union[int, torch.Tensor],
return self.pos_enc.position_encoding(offset, size)


class EmbedinigNoSubsampling(BaseSubsampling):
"""Embedding input without subsampling
"""

def __init__(self, idim: int, odim: int, dropout_rate: float,
pos_enc_class: torch.nn.Module):
super().__init__()
self.embed = torch.nn.Embedding(idim, odim)
self.pos_enc = pos_enc_class

def forward(
self,
x: torch.Tensor,
x_mask: torch.Tensor,
offset: Union[int, torch.Tensor] = 0
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Input x.
Args:
x (torch.Tensor): Input tensor (#batch, time, idim).
x_mask (torch.Tensor): Input mask (#batch, 1, time).
Returns:
torch.Tensor: linear input tensor (#batch, time', odim),
where time' = time .
torch.Tensor: linear input mask (#batch, 1, time'),
where time' = time .
"""
x = self.embed(x)
x, pos_emb = self.pos_enc(x, offset)
return x, pos_emb, x_mask


class LinearNoSubsampling(BaseSubsampling):
"""Linear transform the input without subsampling
Expand All @@ -41,6 +74,7 @@ class LinearNoSubsampling(BaseSubsampling):
dropout_rate (float): Dropout rate.
"""

def __init__(self, idim: int, odim: int, dropout_rate: float,
pos_enc_class: torch.nn.Module):
"""Construct an linear object."""
Expand All @@ -55,10 +89,10 @@ def __init__(self, idim: int, odim: int, dropout_rate: float,
self.subsampling_rate = 1

def forward(
self,
x: torch.Tensor,
x_mask: torch.Tensor,
offset: Union[int, torch.Tensor] = 0
self,
x: torch.Tensor,
x_mask: torch.Tensor,
offset: Union[int, torch.Tensor] = 0
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Input x.
Expand Down Expand Up @@ -87,6 +121,7 @@ class Conv2dSubsampling4(BaseSubsampling):
dropout_rate (float): Dropout rate.
"""

def __init__(self, idim: int, odim: int, dropout_rate: float,
pos_enc_class: torch.nn.Module):
"""Construct an Conv2dSubsampling4 object."""
Expand All @@ -107,10 +142,10 @@ def __init__(self, idim: int, odim: int, dropout_rate: float,
self.right_context = 6

def forward(
self,
x: torch.Tensor,
x_mask: torch.Tensor,
offset: Union[int, torch.Tensor] = 0
self,
x: torch.Tensor,
x_mask: torch.Tensor,
offset: Union[int, torch.Tensor] = 0
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Subsample x.
Expand Down Expand Up @@ -142,6 +177,7 @@ class Conv2dSubsampling6(BaseSubsampling):
dropout_rate (float): Dropout rate.
pos_enc (torch.nn.Module): Custom position encoding layer.
"""

def __init__(self, idim: int, odim: int, dropout_rate: float,
pos_enc_class: torch.nn.Module):
"""Construct an Conv2dSubsampling6 object."""
Expand All @@ -160,10 +196,10 @@ def __init__(self, idim: int, odim: int, dropout_rate: float,
self.right_context = 10

def forward(
self,
x: torch.Tensor,
x_mask: torch.Tensor,
offset: Union[int, torch.Tensor] = 0
self,
x: torch.Tensor,
x_mask: torch.Tensor,
offset: Union[int, torch.Tensor] = 0
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Subsample x.
Args:
Expand Down Expand Up @@ -194,6 +230,7 @@ class Conv2dSubsampling8(BaseSubsampling):
dropout_rate (float): Dropout rate.
"""

def __init__(self, idim: int, odim: int, dropout_rate: float,
pos_enc_class: torch.nn.Module):
"""Construct an Conv2dSubsampling8 object."""
Expand All @@ -214,10 +251,10 @@ def __init__(self, idim: int, odim: int, dropout_rate: float,
self.right_context = 14

def forward(
self,
x: torch.Tensor,
x_mask: torch.Tensor,
offset: Union[int, torch.Tensor] = 0
self,
x: torch.Tensor,
x_mask: torch.Tensor,
offset: Union[int, torch.Tensor] = 0
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Subsample x.
Expand Down

0 comments on commit bfcc09b

Please sign in to comment.