diff --git a/TTS/tts/models/xtts.py b/TTS/tts/models/xtts.py index 83812f377f..8e9d6bd382 100644 --- a/TTS/tts/models/xtts.py +++ b/TTS/tts/models/xtts.py @@ -756,11 +756,13 @@ def load_checkpoint( model_path = checkpoint_path or os.path.join(checkpoint_dir, "model.pth") vocab_path = vocab_path or os.path.join(checkpoint_dir, "vocab.json") - speaker_file_path = speaker_file_path or os.path.join(checkpoint_dir, "speakers_xtts.pth") + + if speaker_file_path is None and checkpoint_dir is not None: + speaker_file_path = os.path.join(checkpoint_dir, "speakers_xtts.pth") self.language_manager = LanguageManager(config) self.speaker_manager = None - if os.path.exists(speaker_file_path): + if speaker_file_path is not None and os.path.exists(speaker_file_path): self.speaker_manager = SpeakerManager(speaker_file_path) if os.path.exists(vocab_path):