From 30e9d677cd5fde57de519f6d336537a430aba0aa Mon Sep 17 00:00:00 2001 From: MananPatel6902 Date: Fri, 28 Jun 2024 15:40:03 +0530 Subject: [PATCH 1/3] hubconf.py to streamline model loading processes and improve configuration management. --- hubconf.py | 84 +++++++++++++++++++++++------------------------------- 1 file changed, 36 insertions(+), 48 deletions(-) diff --git a/hubconf.py b/hubconf.py index 8ea6a85..655b3a5 100644 --- a/hubconf.py +++ b/hubconf.py @@ -1,69 +1,57 @@ -dependencies = ['torch', 'torchaudio', 'numpy', 'vocos', 'safetensors'] - import logging import os from pathlib import Path from safetensors import safe_open - import torch from inference import Mars5TTS, InferenceConfig -ar_url = "https://github.com/Camb-ai/MARS5-TTS/releases/download/v0.3/mars5_en_checkpoints_ar-2000000.pt" -nar_url = "https://github.com/Camb-ai/MARS5-TTS/releases/download/v0.3/mars5_en_checkpoints_nar-1980000.pt" - -ar_sf_url = "https://github.com/Camb-ai/MARS5-TTS/releases/download/v0.3/mars5_en_checkpoints_ar-2000000.safetensors" -nar_sf_url = "https://github.com/Camb-ai/MARS5-TTS/releases/download/v0.3/mars5_en_checkpoints_nar-1980000.safetensors" - -def mars5_english(pretrained=True, progress=True, device=None, ckpt_format='safetensors', - ar_path=None, nar_path=None) -> Mars5TTS: - """ Load mars5 english model on `device`, optionally show `progress`. """ - if device is None: device = 'cuda' if torch.cuda.is_available() else 'cpu' - - assert ckpt_format in ['safetensors', 'pt'], "checkpoint format must be 'safetensors' or 'pt'" - - logging.info(f"Using device: {device}") - if pretrained == False: raise AssertionError('Only pretrained model currently supported.') - logging.info("Loading AR checkpoint...") - - if ar_path is None: - if ckpt_format == 'safetensors': - ar_ckpt = _load_safetensors_ckpt(ar_sf_url, progress=progress) - elif ckpt_format == 'pt': - ar_ckpt = torch.hub.load_state_dict_from_url( - ar_url, progress=progress, check_hash=False, map_location='cpu' - ) - else: ar_ckpt = torch.load(str(ar_path), map_location='cpu') - - logging.info("Loading NAR checkpoint...") - if nar_path is None: - if ckpt_format == 'safetensors': - nar_ckpt = _load_safetensors_ckpt(nar_sf_url, progress=progress) - elif ckpt_format == 'pt': - nar_ckpt = torch.hub.load_state_dict_from_url( - nar_url, progress=progress, check_hash=False, map_location='cpu' - ) - else: nar_ckpt = torch.load(str(nar_path), map_location='cpu') - logging.info("Initializing modules...") - mars5 = Mars5TTS(ar_ckpt, nar_ckpt, device=device) - return mars5, InferenceConfig +dependencies = ['torch', 'torchaudio', 'numpy', 'vocos', 'safetensors'] +# Centralized checkpoint URLs for easy management and updates +CHECKPOINT_URLS = { + "ar": "https://github.com/Camb-ai/MARS5-TTS/releases/download/v0.3/mars5_en_checkpoints_ar-2000000.pt", + "nar": "https://github.com/Camb-ai/MARS5-TTS/releases/download/v0.3/mars5_en_checkpoints_nar-1980000.pt", + "ar_sf": "https://github.com/Camb-ai/MARS5-TTS/releases/download/v0.3/mars5_en_checkpoints_ar-2000000.safetensors", + "nar_sf": "https://github.com/Camb-ai/MARS5-TTS/releases/download/v0.3/mars5_en_checkpoints_nar-1980000.safetensors" +} -def _load_safetensors_ckpt(url, progress): - """ Loads checkpoint from a safetensors file """ +def load_checkpoint(url, progress=True, ckpt_format='pt'): + """ Helper function to download and load a checkpoint, reducing duplication """ hub_dir = torch.hub.get_dir() model_dir = os.path.join(hub_dir, 'checkpoints') os.makedirs(model_dir, exist_ok=True) parts = torch.hub.urlparse(url) filename = os.path.basename(parts.path) cached_file = os.path.join(model_dir, filename) + if not os.path.exists(cached_file): - # download it torch.hub.download_url_to_file(url, cached_file, None, progress=progress) - # load checkpoint + + if ckpt_format == 'safetensors': + return _load_safetensors_ckpt(cached_file) + else: + return torch.load(cached_file, map_location='cpu') + +def _load_safetensors_ckpt(file_path): + """ Loads a safetensors checkpoint file """ ckpt = {} - with safe_open(cached_file, framework='pt', device='cpu') as f: + with safe_open(file_path, framework='pt', device='cpu') as f: metadata = f.metadata() ckpt['vocab'] = {'texttok.model': metadata['texttok.model'], 'speechtok.model': metadata['speechtok.model']} - ckpt['model'] = {} - for k in f.keys(): ckpt['model'][k] = f.get_tensor(k) + ckpt['model'] = {k: f.get_tensor(k) for k in f.keys()} return ckpt + +def mars5_english(pretrained=True, progress=True, device=None, ckpt_format='safetensors', ar_path=None, nar_path=None): + """ Load Mars5 English model on `device`, optionally show `progress`. """ + if device is None: + device = 'cuda' if torch.cuda.is_available() else 'cpu' + logging.info(f"Using device: {device}") + + if not pretrained: + raise ValueError('Only pretrained models are currently supported.') + + ar_ckpt = load_checkpoint(CHECKPOINT_URLS[f'ar_{ckpt_format}'], progress, ckpt_format) if ar_path is None else torch.load(ar_path, map_location='cpu') + nar_ckpt = load_checkpoint(CHECKPOINT_URLS[f'nar_{ckpt_format}'], progress, ckpt_format) if nar_path is None else torch.load(nar_path, map_location='cpu') + + logging.info("Initializing models...") + return Mars5TTS(ar_ckpt, nar_ckpt, device=device), InferenceConfig From e83524b4fbdad980bae7670ae008c6ec1b0660ee Mon Sep 17 00:00:00 2001 From: MananPatel6902 Date: Mon, 1 Jul 2024 14:31:29 +0530 Subject: [PATCH 2/3] Made changes which commented --- hubconf.py | 33 ++++++++++++++++++++++++++++----- 1 file changed, 28 insertions(+), 5 deletions(-) diff --git a/hubconf.py b/hubconf.py index 655b3a5..0bf1555 100644 --- a/hubconf.py +++ b/hubconf.py @@ -1,3 +1,5 @@ +dependencies = ['torch', 'torchaudio', 'numpy', 'vocos', 'safetensors'] + import logging import os from pathlib import Path @@ -5,7 +7,6 @@ import torch from inference import Mars5TTS, InferenceConfig -dependencies = ['torch', 'torchaudio', 'numpy', 'vocos', 'safetensors'] # Centralized checkpoint URLs for easy management and updates CHECKPOINT_URLS = { @@ -42,7 +43,11 @@ def _load_safetensors_ckpt(file_path): return ckpt def mars5_english(pretrained=True, progress=True, device=None, ckpt_format='safetensors', ar_path=None, nar_path=None): - """ Load Mars5 English model on `device`, optionally show `progress`. """ + + # Load Mars5 English model on `device`, optionally showing progress. + # This function also handles user-provided paths for model checkpoints, + # supporting both .pt and .safetensors formats. + if device is None: device = 'cuda' if torch.cuda.is_available() else 'cpu' logging.info(f"Using device: {device}") @@ -50,8 +55,26 @@ def mars5_english(pretrained=True, progress=True, device=None, ckpt_format='safe if not pretrained: raise ValueError('Only pretrained models are currently supported.') - ar_ckpt = load_checkpoint(CHECKPOINT_URLS[f'ar_{ckpt_format}'], progress, ckpt_format) if ar_path is None else torch.load(ar_path, map_location='cpu') - nar_ckpt = load_checkpoint(CHECKPOINT_URLS[f'nar_{ckpt_format}'], progress, ckpt_format) if nar_path is None else torch.load(nar_path, map_location='cpu') + # Determine the format of the checkpoint based on the file extension if paths are provided + if ar_path is not None: + if ar_path.endswith('.pt'): + ar_ckpt = load_checkpoint(None, progress, 'pt', ar_path) + elif ar_path.endswith('.safetensors'): + ar_ckpt = load_checkpoint(None, progress, 'safetensors', ar_path) + else: + raise NotImplementedError("Unsupported file format for ar_path. Please provide a .pt or .safetensors file.") + else: + ar_ckpt = load_checkpoint(CHECKPOINT_URLS[f'ar_{ckpt_format}'], progress, ckpt_format) + + if nar_path is not None: + if nar_path.endswith('.pt'): + nar_ckpt = load_checkpoint(None, progress, 'pt', nar_path) + elif nar_path.endswith('.safetensors'): + nar_ckpt = load_checkpoint(None, progress, 'safetensors', nar_path) + else: + raise NotImplementedError("Unsupported file format for nar_path. Please provide a .pt or .safetensors file.") + else: + nar_ckpt = load_checkpoint(CHECKPOINT_URLS[f'nar_{ckpt_format}'], progress, ckpt_format) logging.info("Initializing models...") - return Mars5TTS(ar_ckpt, nar_ckpt, device=device), InferenceConfig + return Mars5TTS(ar_ckpt, nar_ckpt, device=device), InferenceConfig \ No newline at end of file From c30a7aa40a8c63ade8b22e1e18f9c95761e5f7ee Mon Sep 17 00:00:00 2001 From: MananPatel6902 Date: Fri, 12 Jul 2024 20:30:19 +0530 Subject: [PATCH 3/3] update with latest urls from the hubconf.py --- hubconf.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/hubconf.py b/hubconf.py index 0bf1555..583d33a 100644 --- a/hubconf.py +++ b/hubconf.py @@ -10,9 +10,9 @@ # Centralized checkpoint URLs for easy management and updates CHECKPOINT_URLS = { - "ar": "https://github.com/Camb-ai/MARS5-TTS/releases/download/v0.3/mars5_en_checkpoints_ar-2000000.pt", + "ar": "https://github.com/Camb-ai/MARS5-TTS/releases/download/v0.4/mars5_en_checkpoints_ar-3000000.pt", "nar": "https://github.com/Camb-ai/MARS5-TTS/releases/download/v0.3/mars5_en_checkpoints_nar-1980000.pt", - "ar_sf": "https://github.com/Camb-ai/MARS5-TTS/releases/download/v0.3/mars5_en_checkpoints_ar-2000000.safetensors", + "ar_sf": "https://github.com/Camb-ai/MARS5-TTS/releases/download/v0.4/mars5_en_checkpoints_ar-3000000.safetensors", "nar_sf": "https://github.com/Camb-ai/MARS5-TTS/releases/download/v0.3/mars5_en_checkpoints_nar-1980000.safetensors" } @@ -42,11 +42,14 @@ def _load_safetensors_ckpt(file_path): ckpt['model'] = {k: f.get_tensor(k) for k in f.keys()} return ckpt + +# Load Mars5 English model on `device`, optionally showing progress. +# This function also handles user-provided path for model checkpoints, +# supporting both .pt and .safetensors formats. + def mars5_english(pretrained=True, progress=True, device=None, ckpt_format='safetensors', ar_path=None, nar_path=None): - # Load Mars5 English model on `device`, optionally showing progress. - # This function also handles user-provided paths for model checkpoints, - # supporting both .pt and .safetensors formats. + if device is None: device = 'cuda' if torch.cuda.is_available() else 'cpu' @@ -77,4 +80,5 @@ def mars5_english(pretrained=True, progress=True, device=None, ckpt_format='safe nar_ckpt = load_checkpoint(CHECKPOINT_URLS[f'nar_{ckpt_format}'], progress, ckpt_format) logging.info("Initializing models...") - return Mars5TTS(ar_ckpt, nar_ckpt, device=device), InferenceConfig \ No newline at end of file + return Mars5TTS(ar_ckpt, nar_ckpt, device=device), InferenceConfig +