diff --git a/wenet/transformer/attention.py b/wenet/transformer/attention.py index 43a7c8091..b8550bccf 100644 --- a/wenet/transformer/attention.py +++ b/wenet/transformer/attention.py @@ -204,16 +204,31 @@ def _update_kv_and_cache( new_cache = (k, v) # for multi query or multi group attention if self.h_kv != self.h and self.h_kv != 1: - k = torch.repeat_interleave( - k, - self.h // self.h_kv, - dim=-3, - ) - v = torch.repeat_interleave( - v, - self.h // self.h_kv, - dim=-3, - ) + # NOTE: onnxruntime issues: + # https://github.com/wenet-e2e/wenet/issues/2517 + # k = torch.repeat_interleave( + # k, + # self.h // self.h_kv, + # dim=-3, + # ) + # v = torch.repeat_interleave( + # v, + # self.h // self.h_kv, + # dim=-3, + # ) + n_repeat = self.h // self.h_kv + k_shape = k.size() + k = k.unsqueeze(-3).expand( + k_shape[:-2] + torch.Size([n_repeat]) + + k_shape[-2:]).reshape(k_shape[:-3] + + torch.Size([self.h_kv * n_repeat]) + + k_shape[-2:]) + v_shape = v.size() + v = v.unsqueeze(-3).expand( + v_shape[:-2] + torch.Size([n_repeat]) + + v_shape[-2:]).reshape(v_shape[:-3] + + torch.Size([self.h_kv * n_repeat]) + + v_shape[-2:]) return k, v, new_cache def forward(