diff --git a/TTS/tts/models/xtts.py b/TTS/tts/models/xtts.py index ababf28dd4..53dd806582 100644 --- a/TTS/tts/models/xtts.py +++ b/TTS/tts/models/xtts.py @@ -599,11 +599,12 @@ def inference( mode="linear" ).transpose(1, 2) + gpt_latents_list.append(gpt_latents.cpu()) wavs.append(self.hifigan_decoder(gpt_latents, g=speaker_embedding).cpu().squeeze()) return { "wav": torch.cat(wavs, dim=0).numpy(), - "gpt_latents": torch.cat(gpt_latents_list, dim=0).numpy(), + "gpt_latents": torch.cat(gpt_latents_list, dim=1).numpy(), "speaker_embedding": speaker_embedding, }