Skip to content

Commit

Permalink
Xlgeng fix (#2018)
Browse files Browse the repository at this point in the history
* [fix] 修复utils/common.py中pad_list未考虑time维度后可跟其他维度的情况

* [fix] 修复utils/common.py中pad_list未考虑time维度后可跟其他维度的情况(#2007)

* [fix] 修复jit报错,初步判断该爆错由`*(xs[0].shape[1:])`代码表示的动态张量引起,现修改common.py/pad_list的注释, 暂时不考虑time维度后可跟其他维度, 先让代码恢复可运行状态 (issue #2015)

* [fix] 完全修复jit报错,在jit要求条件下实现time维度后可跟其他维度(#2015)

* [fix] 完全修复jit报错,在jit要求条件下实现time维度后可跟其他维度(#2015)

* [fix] 完全修复jit报错,在jit要求条件下实现time维度后可跟其他维度(#2015)
  • Loading branch information
gengxuelong authored Sep 19, 2023
1 parent ee73151 commit fd186f3
Showing 1 changed file with 13 additions and 2 deletions.
15 changes: 13 additions & 2 deletions wenet/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,19 @@ def pad_list(xs: List[torch.Tensor], pad_value: int):
"""
max_len = max([len(item) for item in xs])
batchs = len(xs)
pad_res = torch.zeros(batchs, max_len, *(xs[0].shape[1:]),
dtype=xs[0].dtype, device=xs[0].device)
ndim = xs[0].ndim
if ndim == 1:
pad_res = torch.zeros(batchs, max_len,
dtype=xs[0].dtype, device=xs[0].device)
elif ndim == 2:
pad_res = torch.zeros(batchs, max_len, xs[0].shape[1],
dtype=xs[0].dtype, device=xs[0].device)
elif ndim == 3:
pad_res = torch.zeros(batchs, max_len, xs[0].shape[1],
xs[0].shape[2], dtype=xs[0].dtype,
device=xs[0].device)
else:
raise ValueError(f"Unsupported ndim: {ndim}")
pad_res.fill_(pad_value)
for i in range(batchs):
pad_res[i, :len(xs[i])] = xs[i]
Expand Down

0 comments on commit fd186f3

Please sign in to comment.