Skip to content

Commit

Permalink
fix resampling and model loading minor errors
Browse files Browse the repository at this point in the history
  • Loading branch information
genisplaja committed Nov 28, 2024
1 parent 0c7806d commit d2e65ef
Show file tree
Hide file tree
Showing 10 changed files with 80 additions and 23 deletions.
13 changes: 10 additions & 3 deletions compiam/melody/pattern/sancara_search/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,9 +180,16 @@ def load_model(self, model_path, conf_path, spec_path):
setattr(self, tp, v)

self.model = self._build_model()
self.model.load_state_dict(
torch.load(model_path, weights_only=True), strict=False
)
try:
self.model.load_state_dict(
torch.load(model_path, weights_only=True, map_location=self.device),
strict=False
)
except:
self.model.load_state_dict(
torch.load(model_path, map_location=self.device),
strict=False,
)
self.trained = True

def download_model(self, model_path=None, force_overwrite=False):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,10 @@ def load_model(self, model_path):
"""Load pre-trained model weights."""
if not os.path.exists(model_path):
self.download_model(model_path) # Downloading model weights
self.model.load_state_dict(torch.load(model_path, weights_only=True))
try: # Loading model weights
self.model.load_state_dict(torch.load(model_path, weights_only=True, map_location=self.device))
except:
self.model.load_state_dict(torch.load(model_path, map_location=self.device))
self.model_path = model_path
self.trained = True

Expand Down
6 changes: 4 additions & 2 deletions compiam/melody/pitch_extraction/melodia.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from compiam.utils.pitch import normalisation, resampling
from compiam.io import write_csv
from compiam.utils import get_logger
from compiam.utils import get_logger, stereo_to_mono

logger = get_logger(__name__)

Expand Down Expand Up @@ -87,12 +87,14 @@ def extract(self, input_data, input_sr=44100, out_step=None):
filename=input_data, sampleRate=self.sample_rate
)()
elif isinstance(input_data, np.ndarray):
input_data = stereo_to_mono(input_data)
# Apply Eqloudness filter
logger.warning(
f"Resampling... (input sampling rate is {input_sr}Hz, make sure this is correct)"
)
resample_audio = estd.Resample(
inputSampleRate=input_sr, outputSampleRate=self.sample_rate
)()
)
input_data = resample_audio(input_data)
audio = estd.EqualLoudness(signal=input_data)()
else:
Expand Down
22 changes: 13 additions & 9 deletions compiam/melody/raga_recognition/deepsrgm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import librosa

import numpy as np

Expand All @@ -8,7 +9,7 @@
ModelNotTrainedError,
DatasetNotLoadedError,
)
from compiam.utils import get_logger, WORKDIR
from compiam.utils import get_logger, stereo_to_mono, WORKDIR
from compiam.utils.download import download_remote_model

logger = get_logger(__name__)
Expand Down Expand Up @@ -122,7 +123,10 @@ def load_model(self, model_path, rnn="lstm"):
self.model = self._build_model(rnn="gru")

self.model_path = model_path
weights = torch.load(model_path, weights_only=True, map_location=self.device)
try:
weights = torch.load(model_path, weights_only=True, map_location=self.device)
except:
weights = torch.load(model_path, map_location=self.device)
new_weights = weights.copy()
keys_to_fix = [
".weight_ih_l0",
Expand Down Expand Up @@ -232,20 +236,20 @@ def get_features(
"Install compIAM with essentia support: pip install 'compiam[essentia]'"
)

# Loading and resampling audio
if isinstance(input_data, str):
if not os.path.exists(input_data):
raise FileNotFoundError("Target audio not found.")
audio = estd.MonoLoader(
filename=input_data, sampleRate=self.sample_rate
)()
audio, _ = librosa.load(input_data, sr=self.sample_rate)
elif isinstance(input_data, np.ndarray):
input_data = stereo_to_mono(input_data)
logger.warning(
"Resampling... (input sampling rate is {input_sr}Hz, make sure this is correct)"
f"Resampling... (input sampling rate is assumed {input_sr}Hz, \
make sure this is correct and change input_sr otherwise)"
)
resampling = estd.Resample(
inputSampleRate=input_sr, outputSampleRate=self.sample_rate
audio = librosa.resample(
input_data, orig_sr=input_sr, target_sr=self.sample_rate
)
audio = resampling(input_data)
else:
raise ValueError("Input must be path to audio signal or an audio array")

Expand Down
8 changes: 6 additions & 2 deletions compiam/melody/tonic_identification/tonic_multipitch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import numpy as np

from compiam.utils import get_logger
from compiam.utils import get_logger, stereo_to_mono

logger = get_logger(__name__)

Expand Down Expand Up @@ -65,12 +65,16 @@ def extract(self, input_data, input_sr=44100):
raise FileNotFoundError("Target audio not found.")
audio = estd.MonoLoader(filename=input_data, sampleRate=self.sample_rate)()
elif isinstance(input_data, np.ndarray):
if len(input_data.shape) == 2:
input_data = stereo_to_mono(input_data)
if len(input_data.shape) > 2:
raise ValueError("Input must be an unbatched audio signal")
logger.warning(
f"Resampling... (input sampling rate is {input_sr}Hz, make sure this is correct)"
)
resampling = estd.Resample(
inputSampleRate=input_sr, outputSampleRate=self.sample_rate
)()
)
audio = resampling(input_data)
else:
raise ValueError("Input must be path to audio signal or an audio array")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,11 @@ def _build_model(self):
def load_model(self, model_path):
if not os.path.exists(model_path):
self.download_model(model_path) # Downloading model weights
self.model.load_state_dict(torch.load(model_path, weights_only=True))
try:
weights = torch.load(model_path, weights_only=True, map_location=self.device)
except:
weights = torch.load(model_path, map_location=self.device)
self.model.load_state_dict(weights)
self.model_path = model_path
self.trained = True

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,16 +176,20 @@ def load_model(self, model_path):
self.download_model(model_path)

self.model = self._build_model()
self.model.load_state_dict(
torch.load(model_path, weights_only=True, map_location=self.device)
)
try:
self.model.load_state_dict(
torch.load(model_path, weights_only=True, map_location=self.device)
)
except:
self.model.load_state_dict(
torch.load(model_path, map_location=self.device)
)
self.model.eval()
self.loaded_model_path = model_path
self.trained = True

def download_model(self, model_path=None, force_overwrite=False):
"""Download pre-trained model."""
print("modelpathhh", model_path)
download_path = (
os.sep + os.path.join(*model_path.split(os.sep)[:-4])
if model_path is not None
Expand Down
19 changes: 18 additions & 1 deletion compiam/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import pathlib
import pickle
import difflib
import librosa

import IPython.display as ipd
import numpy as np
Expand Down Expand Up @@ -172,3 +171,21 @@ def add_center_to_mask(mask):
num_one = 0
indices = []
return mask


def stereo_to_mono(audio):
"""Assuming numpy array as input"""
if len(audio.shape) == 2:
# Put channels first
if audio.shape[0] > audio.shape[1]:
audio = audio.T
# If stereo, average the channels
if audio.shape[0] == 2:
audio = np.mean(audio, axis=0)
if audio.shape[0] == 1:
audio = np.squeeze(audio, axis=0)
if audio.shape[0] > 2:
raise ValueError("Expected mono or stereo audio, got multi-channel audio")
if len(audio.shape) > 2:
raise ValueError("Input must be an unbatched audio signal")
return audio
4 changes: 4 additions & 0 deletions tests/melody/test_deepsrgm.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ def _get_features():
feat = deepsrgm.get_features(
os.path.join(TESTDIR, "resources", "melody", "pitch_test.wav")
)
feat_1 = deepsrgm.get_features(np.zeros(44100))
feat_2 = deepsrgm.get_features(
os.path.join(TESTDIR, "resources", "melody", "pitch_test.wav")
)


@pytest.mark.torch
Expand Down
8 changes: 8 additions & 0 deletions tests/melody/test_essentia_extractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ def _predict_normalized_pitch():
pitch = melodia.extract(
os.path.join(TESTDIR, "resources", "melody", "pitch_test.wav")
)
pitch_2 = melodia.extract(np.zeros(44100))
pitch_3 = melodia.extract(np.zeros(2, 44100)) # Testing input array
pitch_4 = melodia.extract(np.zeros(44100, 2)) # Testing input array

assert isinstance(pitch, np.ndarray)
assert np.shape(pitch) == (699, 2)
Expand Down Expand Up @@ -67,6 +70,11 @@ def _predict_normalized_pitch():
tonic = tonic_multipitch.extract(
os.path.join(TESTDIR, "resources", "melody", "pitch_test.wav")
)
tonic_2 = tonic_multipitch.extract(np.zeros(44100)) # Testing input array
tonic_3 = tonic_multipitch.extract(np.zeros(2, 44100)) # Testing input array
tonic_4 = tonic_multipitch.extract(np.zeros(44100, 2)) # Testing input array



assert isinstance(tonic, float)
assert tonic == 157.64892578125
Expand Down

0 comments on commit d2e65ef

Please sign in to comment.