diff --git a/compiam/melody/pattern/sancara_search/__init__.py b/compiam/melody/pattern/sancara_search/__init__.py index f3ee2f61..39f88206 100644 --- a/compiam/melody/pattern/sancara_search/__init__.py +++ b/compiam/melody/pattern/sancara_search/__init__.py @@ -180,9 +180,16 @@ def load_model(self, model_path, conf_path, spec_path): setattr(self, tp, v) self.model = self._build_model() - self.model.load_state_dict( - torch.load(model_path, weights_only=True), strict=False - ) + try: + self.model.load_state_dict( + torch.load(model_path, weights_only=True, map_location=self.device), + strict=False + ) + except: + self.model.load_state_dict( + torch.load(model_path, map_location=self.device), + strict=False, + ) self.trained = True def download_model(self, model_path=None, force_overwrite=False): diff --git a/compiam/melody/pitch_extraction/ftaresnet_carnatic/__init__.py b/compiam/melody/pitch_extraction/ftaresnet_carnatic/__init__.py index eb34be79..68c094aa 100644 --- a/compiam/melody/pitch_extraction/ftaresnet_carnatic/__init__.py +++ b/compiam/melody/pitch_extraction/ftaresnet_carnatic/__init__.py @@ -80,7 +80,10 @@ def load_model(self, model_path): """Load pre-trained model weights.""" if not os.path.exists(model_path): self.download_model(model_path) # Downloading model weights - self.model.load_state_dict(torch.load(model_path, weights_only=True)) + try: # Loading model weights + self.model.load_state_dict(torch.load(model_path, weights_only=True, map_location=self.device)) + except: + self.model.load_state_dict(torch.load(model_path, map_location=self.device)) self.model_path = model_path self.trained = True diff --git a/compiam/melody/pitch_extraction/melodia.py b/compiam/melody/pitch_extraction/melodia.py index 247a85f9..9bebb128 100644 --- a/compiam/melody/pitch_extraction/melodia.py +++ b/compiam/melody/pitch_extraction/melodia.py @@ -4,7 +4,7 @@ from compiam.utils.pitch import normalisation, resampling from compiam.io import write_csv -from compiam.utils import get_logger +from compiam.utils import get_logger, stereo_to_mono logger = get_logger(__name__) @@ -87,12 +87,14 @@ def extract(self, input_data, input_sr=44100, out_step=None): filename=input_data, sampleRate=self.sample_rate )() elif isinstance(input_data, np.ndarray): + input_data = stereo_to_mono(input_data) + # Apply Eqloudness filter logger.warning( f"Resampling... (input sampling rate is {input_sr}Hz, make sure this is correct)" ) resample_audio = estd.Resample( inputSampleRate=input_sr, outputSampleRate=self.sample_rate - )() + ) input_data = resample_audio(input_data) audio = estd.EqualLoudness(signal=input_data)() else: diff --git a/compiam/melody/raga_recognition/deepsrgm/__init__.py b/compiam/melody/raga_recognition/deepsrgm/__init__.py index 2323fde3..1d79ea7c 100644 --- a/compiam/melody/raga_recognition/deepsrgm/__init__.py +++ b/compiam/melody/raga_recognition/deepsrgm/__init__.py @@ -1,4 +1,5 @@ import os +import librosa import numpy as np @@ -8,7 +9,7 @@ ModelNotTrainedError, DatasetNotLoadedError, ) -from compiam.utils import get_logger, WORKDIR +from compiam.utils import get_logger, stereo_to_mono, WORKDIR from compiam.utils.download import download_remote_model logger = get_logger(__name__) @@ -122,7 +123,10 @@ def load_model(self, model_path, rnn="lstm"): self.model = self._build_model(rnn="gru") self.model_path = model_path - weights = torch.load(model_path, weights_only=True, map_location=self.device) + try: + weights = torch.load(model_path, weights_only=True, map_location=self.device) + except: + weights = torch.load(model_path, map_location=self.device) new_weights = weights.copy() keys_to_fix = [ ".weight_ih_l0", @@ -232,20 +236,20 @@ def get_features( "Install compIAM with essentia support: pip install 'compiam[essentia]'" ) + # Loading and resampling audio if isinstance(input_data, str): if not os.path.exists(input_data): raise FileNotFoundError("Target audio not found.") - audio = estd.MonoLoader( - filename=input_data, sampleRate=self.sample_rate - )() + audio, _ = librosa.load(input_data, sr=self.sample_rate) elif isinstance(input_data, np.ndarray): + input_data = stereo_to_mono(input_data) logger.warning( - "Resampling... (input sampling rate is {input_sr}Hz, make sure this is correct)" + f"Resampling... (input sampling rate is assumed {input_sr}Hz, \ + make sure this is correct and change input_sr otherwise)" ) - resampling = estd.Resample( - inputSampleRate=input_sr, outputSampleRate=self.sample_rate + audio = librosa.resample( + input_data, orig_sr=input_sr, target_sr=self.sample_rate ) - audio = resampling(input_data) else: raise ValueError("Input must be path to audio signal or an audio array") diff --git a/compiam/melody/tonic_identification/tonic_multipitch.py b/compiam/melody/tonic_identification/tonic_multipitch.py index 49a6d429..2c3ed099 100644 --- a/compiam/melody/tonic_identification/tonic_multipitch.py +++ b/compiam/melody/tonic_identification/tonic_multipitch.py @@ -2,7 +2,7 @@ import numpy as np -from compiam.utils import get_logger +from compiam.utils import get_logger, stereo_to_mono logger = get_logger(__name__) @@ -65,12 +65,16 @@ def extract(self, input_data, input_sr=44100): raise FileNotFoundError("Target audio not found.") audio = estd.MonoLoader(filename=input_data, sampleRate=self.sample_rate)() elif isinstance(input_data, np.ndarray): + if len(input_data.shape) == 2: + input_data = stereo_to_mono(input_data) + if len(input_data.shape) > 2: + raise ValueError("Input must be an unbatched audio signal") logger.warning( f"Resampling... (input sampling rate is {input_sr}Hz, make sure this is correct)" ) resampling = estd.Resample( inputSampleRate=input_sr, outputSampleRate=self.sample_rate - )() + ) audio = resampling(input_data) else: raise ValueError("Input must be path to audio signal or an audio array") diff --git a/compiam/separation/music_source_separation/mixer_model/__init__.py b/compiam/separation/music_source_separation/mixer_model/__init__.py index b44f748e..b6d74ef1 100644 --- a/compiam/separation/music_source_separation/mixer_model/__init__.py +++ b/compiam/separation/music_source_separation/mixer_model/__init__.py @@ -83,7 +83,11 @@ def _build_model(self): def load_model(self, model_path): if not os.path.exists(model_path): self.download_model(model_path) # Downloading model weights - self.model.load_state_dict(torch.load(model_path, weights_only=True)) + try: + weights = torch.load(model_path, weights_only=True, map_location=self.device) + except: + weights = torch.load(model_path, map_location=self.device) + self.model.load_state_dict(weights) self.model_path = model_path self.trained = True diff --git a/compiam/structure/segmentation/dhrupad_bandish_segmentation/__init__.py b/compiam/structure/segmentation/dhrupad_bandish_segmentation/__init__.py index f67c7347..9debf911 100644 --- a/compiam/structure/segmentation/dhrupad_bandish_segmentation/__init__.py +++ b/compiam/structure/segmentation/dhrupad_bandish_segmentation/__init__.py @@ -176,16 +176,20 @@ def load_model(self, model_path): self.download_model(model_path) self.model = self._build_model() - self.model.load_state_dict( - torch.load(model_path, weights_only=True, map_location=self.device) - ) + try: + self.model.load_state_dict( + torch.load(model_path, weights_only=True, map_location=self.device) + ) + except: + self.model.load_state_dict( + torch.load(model_path, map_location=self.device) + ) self.model.eval() self.loaded_model_path = model_path self.trained = True def download_model(self, model_path=None, force_overwrite=False): """Download pre-trained model.""" - print("modelpathhh", model_path) download_path = ( os.sep + os.path.join(*model_path.split(os.sep)[:-4]) if model_path is not None diff --git a/compiam/utils/__init__.py b/compiam/utils/__init__.py index 8a06ac3b..fa49733c 100644 --- a/compiam/utils/__init__.py +++ b/compiam/utils/__init__.py @@ -4,7 +4,6 @@ import pathlib import pickle import difflib -import librosa import IPython.display as ipd import numpy as np @@ -172,3 +171,21 @@ def add_center_to_mask(mask): num_one = 0 indices = [] return mask + + +def stereo_to_mono(audio): + """Assuming numpy array as input""" + if len(audio.shape) == 2: + # Put channels first + if audio.shape[0] > audio.shape[1]: + audio = audio.T + # If stereo, average the channels + if audio.shape[0] == 2: + audio = np.mean(audio, axis=0) + if audio.shape[0] == 1: + audio = np.squeeze(audio, axis=0) + if audio.shape[0] > 2: + raise ValueError("Expected mono or stereo audio, got multi-channel audio") + if len(audio.shape) > 2: + raise ValueError("Input must be an unbatched audio signal") + return audio diff --git a/tests/melody/test_deepsrgm.py b/tests/melody/test_deepsrgm.py index 69c2769e..e01dc73f 100644 --- a/tests/melody/test_deepsrgm.py +++ b/tests/melody/test_deepsrgm.py @@ -43,6 +43,10 @@ def _get_features(): feat = deepsrgm.get_features( os.path.join(TESTDIR, "resources", "melody", "pitch_test.wav") ) + feat_1 = deepsrgm.get_features(np.zeros(44100)) + feat_2 = deepsrgm.get_features( + os.path.join(TESTDIR, "resources", "melody", "pitch_test.wav") + ) @pytest.mark.torch diff --git a/tests/melody/test_essentia_extractors.py b/tests/melody/test_essentia_extractors.py index ef75a61f..b2db2904 100644 --- a/tests/melody/test_essentia_extractors.py +++ b/tests/melody/test_essentia_extractors.py @@ -15,6 +15,9 @@ def _predict_normalized_pitch(): pitch = melodia.extract( os.path.join(TESTDIR, "resources", "melody", "pitch_test.wav") ) + pitch_2 = melodia.extract(np.zeros(44100)) + pitch_3 = melodia.extract(np.zeros(2, 44100)) # Testing input array + pitch_4 = melodia.extract(np.zeros(44100, 2)) # Testing input array assert isinstance(pitch, np.ndarray) assert np.shape(pitch) == (699, 2) @@ -67,6 +70,11 @@ def _predict_normalized_pitch(): tonic = tonic_multipitch.extract( os.path.join(TESTDIR, "resources", "melody", "pitch_test.wav") ) + tonic_2 = tonic_multipitch.extract(np.zeros(44100)) # Testing input array + tonic_3 = tonic_multipitch.extract(np.zeros(2, 44100)) # Testing input array + tonic_4 = tonic_multipitch.extract(np.zeros(44100, 2)) # Testing input array + + assert isinstance(tonic, float) assert tonic == 157.64892578125