Skip to content

Commit

Permalink
[transformer] try to fix mga in onnxruntime (#2519)
Browse files Browse the repository at this point in the history
* [transformer] try to fix mga in onnxruntime

* fix v shape
  • Loading branch information
Mddct authored May 8, 2024
1 parent 2258c72 commit f2372ae
Showing 1 changed file with 25 additions and 10 deletions.
35 changes: 25 additions & 10 deletions wenet/transformer/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,16 +204,31 @@ def _update_kv_and_cache(
new_cache = (k, v)
# for multi query or multi group attention
if self.h_kv != self.h and self.h_kv != 1:
k = torch.repeat_interleave(
k,
self.h // self.h_kv,
dim=-3,
)
v = torch.repeat_interleave(
v,
self.h // self.h_kv,
dim=-3,
)
# NOTE: onnxruntime issues:
# https://github.com/wenet-e2e/wenet/issues/2517
# k = torch.repeat_interleave(
# k,
# self.h // self.h_kv,
# dim=-3,
# )
# v = torch.repeat_interleave(
# v,
# self.h // self.h_kv,
# dim=-3,
# )
n_repeat = self.h // self.h_kv
k_shape = k.size()
k = k.unsqueeze(-3).expand(
k_shape[:-2] + torch.Size([n_repeat]) +
k_shape[-2:]).reshape(k_shape[:-3] +
torch.Size([self.h_kv * n_repeat]) +
k_shape[-2:])
v_shape = v.size()
v = v.unsqueeze(-3).expand(
v_shape[:-2] + torch.Size([n_repeat]) +
v_shape[-2:]).reshape(v_shape[:-3] +
torch.Size([self.h_kv * n_repeat]) +
v_shape[-2:])
return k, v, new_cache

def forward(
Expand Down

0 comments on commit f2372ae

Please sign in to comment.