Skip to content

Commit

Permalink
Replaced normalization_enabled with normalization_threshold, set safe…
Browse files Browse the repository at this point in the history
…r default of 0.9 to prevent clipping
  • Loading branch information
beveradb committed Jan 9, 2024
1 parent 2ba4fa5 commit 330dc92
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 34 deletions.
41 changes: 23 additions & 18 deletions audio_separator/separator/separator.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __init__(
secondary_stem_path=None,
output_format="WAV",
output_subtype=None,
normalization_enabled=False,
normalization_threshold=0.9,
denoise_enabled=False,
output_single_stem=None,
invert_using_spec=False,
Expand Down Expand Up @@ -57,7 +57,7 @@ def __init__(
# Filter out noisy warnings from PyTorch for users who don't care about them
if log_level > logging.DEBUG:
warnings.filterwarnings("ignore")

package_version = pkg_resources.get_distribution("audio-separator").version

self.logger.info(f"Separator version {package_version} instantiating with output_dir: {output_dir}, output_format: {output_format}")
Expand All @@ -79,11 +79,10 @@ def __init__(
if self.output_subtype is None and output_format == "WAV":
self.output_subtype = "PCM_16"

self.normalization_enabled = normalization_enabled
if self.normalization_enabled:
self.logger.debug(f"Normalization enabled, waveform will be normalized to max amplitude of 1.0 to avoid clipping.")
else:
self.logger.debug(f"Normalization disabled, waveform will not be normalized.")
self.normalization_threshold = normalization_threshold
self.logger.debug(
f"Normalization threshold set to {normalization_threshold}, waveform will lowered to this max amplitude to avoid clipping."
)

self.denoise_enabled = denoise_enabled
if self.denoise_enabled:
Expand Down Expand Up @@ -124,36 +123,37 @@ def setup_inferencing_device(self):
self.logger.info(f"Operating System: {os_name} {os_version}")

system_info = platform.uname()
self.logger.info(f"System: {system_info.system} Node: {system_info.node} Release: {system_info.release} Machine: {system_info.machine} Proc: {system_info.processor}")
self.logger.info(
f"System: {system_info.system} Node: {system_info.node} Release: {system_info.release} Machine: {system_info.machine} Proc: {system_info.processor}"
)

python_version = platform.python_version()
self.logger.info(f"Python Version: {python_version}")

onnxruntime_gpu_package = self.get_package_distribution("onnxruntime-gpu")
if onnxruntime_gpu_package is not None:
self.logger.info(f"ONNX Runtime GPU package installed with version: {onnxruntime_gpu_package.version}")

onnxruntime_silicon_package = self.get_package_distribution("onnxruntime-silicon")
if onnxruntime_silicon_package is not None:
self.logger.info(f"ONNX Runtime Silicon package installed with version: {onnxruntime_silicon_package.version}")

onnxruntime_cpu_package = self.get_package_distribution("onnxruntime")
if onnxruntime_cpu_package is not None:
self.logger.info(f"ONNX Runtime CPU package installed with version: {onnxruntime_cpu_package.version}")

torch_package = self.get_package_distribution("torch")
if torch_package is not None:
self.logger.info(f"Torch package installed with version: {torch_package.version}")

torchvision_package = self.get_package_distribution("torchvision")
if torchvision_package is not None:
self.logger.info(f"Torchvision package installed with version: {torchvision_package.version}")

torchaudio_package = self.get_package_distribution("torchaudio")
if torchaudio_package is not None:
self.logger.info(f"Torchaudio package installed with version: {torchaudio_package.version}")


ort_device = ort.get_device()
ort_providers = ort.get_available_providers()

Expand Down Expand Up @@ -336,14 +336,17 @@ def separate(self, audio_file_path):

self.primary_source = None
self.secondary_source = None

self.audio_file_path = audio_file_path
self.audio_file_base = os.path.splitext(os.path.basename(audio_file_path))[0]

# Prepare the mix for processing
self.logger.debug("Preparing mix...")
mix = self.prepare_mix(self.audio_file_path)

self.logger.debug("Normalizing mix before demixing...")
mix = spec_utils.normalize(self.logger, wave=mix, max_peak=self.normalization_threshold)

# Start the demixing process
source = self.demix(mix)

Expand All @@ -356,7 +359,7 @@ def separate(self, audio_file_path):
# Normalize and transpose the primary source if it's not already an array
if not isinstance(self.primary_source, np.ndarray):
self.logger.debug("Normalizing primary source...")
self.primary_source = spec_utils.normalize(self.logger, source, self.normalization_enabled).T
self.primary_source = spec_utils.normalize(self.logger, wave=source, max_peak=self.normalization_threshold).T

# Process the secondary source if not already an array
if not isinstance(self.secondary_source, np.ndarray):
Expand Down Expand Up @@ -391,7 +394,9 @@ def separate(self, audio_file_path):
)
if not isinstance(self.primary_source, np.ndarray):
self.primary_source = source.T
self.primary_source_map = self.final_process(self.primary_stem_path, self.primary_source, self.model_primary_stem, self.sample_rate)
self.primary_source_map = self.final_process(
self.primary_stem_path, self.primary_source, self.model_primary_stem, self.sample_rate
)
output_files.append(self.primary_stem_path)

# Clear GPU cache to free up memory
Expand Down Expand Up @@ -422,7 +427,7 @@ def separate(self, audio_file_path):
def write_audio(self, stem_path: str, stem_source, sample_rate, stem_name=None):
self.logger.debug(f"Entering write_audio with stem_name: {stem_name} and stem_path: {stem_path}")

stem_source = spec_utils.normalize(self.logger, stem_source, self.normalization_enabled)
stem_source = spec_utils.normalize(self.logger, wave=stem_source, max_peak=self.normalization_threshold)

# Check if the numpy array is empty or contains very low values
if np.max(np.abs(stem_source)) < 1e-6:
Expand Down Expand Up @@ -648,9 +653,9 @@ def demix(self, mix, is_match_mix=False):
result[..., start:end] += tar_waves[..., : end - start]

# Normalizes the results by the divider to account for overlap.
self.logger.debug("Normalizing result by dividing result by divider.")
tar_waves = result / divider
tar_waves_.append(tar_waves)
self.logger.debug("Result normalized by divider.")

# Reshapes the results to match the original dimensions.
tar_waves_ = np.vstack(tar_waves_)[:, :, self.trim : -self.trim]
Expand Down
25 changes: 15 additions & 10 deletions audio_separator/separator/spec_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,18 +107,23 @@ def run_thread(**kwargs):
return spec


def normalize(logger: logging.Logger, wave, is_normalize=False):
"""Save output music files"""
def normalize(logger: logging.Logger, wave, max_peak=1.0):
"""Normalize audio waveform to a specified peak value.
Args:
logger (logging.Logger): Logger for debugging information.
wave (array-like): Audio waveform.
max_peak (float): Maximum peak value for normalization.
Returns:
array-like: Normalized or original waveform.
"""
maxv = np.abs(wave).max()
if maxv > 1.0:
logger.debug(f"Normalization Set {is_normalize}: Input above threshold for clipping. Max:{maxv}")
if is_normalize:
logger.debug(f"The result was normalized.")
wave /= maxv
else:
logger.debug(f"The result was not normalized.")
if maxv > max_peak:
logger.debug(f"Maximum peak amplitude above clipping threshold, normalizing from {maxv} to max peak {max_peak}.")
wave *= max_peak / maxv
else:
logger.debug(f"Normalization Set {is_normalize}: Input not above threshold for clipping. Max:{maxv}")
logger.debug(f"Maximum peak amplitude not above clipping threshold, no need to normalize: {maxv}")

return wave

Expand Down
10 changes: 5 additions & 5 deletions audio_separator/utils/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,10 @@ def main():
)

parser.add_argument(
"--normalize",
type=lambda x: (str(x).lower() == "true"),
default=False,
help="Optional: enable or disable normalization during separation (default: %(default)s). Example: --normalize=True",
"--normalization_threshold",
type=float,
default=0.9,
help="Optional: max peak amplitude to normalize input and output audio to (default: %(default)s). Example: --normalization_threshold=0.7",
)

parser.add_argument(
Expand Down Expand Up @@ -134,7 +134,7 @@ def main():
output_dir=args.output_dir,
output_format=args.output_format,
denoise_enabled=args.denoise,
normalization_enabled=args.normalize,
normalization_threshold=args.normalization_threshold,
output_single_stem=args.single_stem,
invert_using_spec=args.invert_spect,
sample_rate=args.sample_rate,
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"

[tool.poetry]
name = "audio-separator"
version = "0.12.3"
version = "0.13.0"
description = "Easy to use vocal separation, using MDX-Net models from UVR trained by @Anjok07"
authors = ["Andrew Beveridge <[email protected]>"]
license = "MIT"
Expand Down

0 comments on commit 330dc92

Please sign in to comment.