diff --git a/audio_separator/separator/separator.py b/audio_separator/separator/separator.py index 0345b04..f0f69f6 100644 --- a/audio_separator/separator/separator.py +++ b/audio_separator/separator/separator.py @@ -29,7 +29,7 @@ def __init__( secondary_stem_path=None, output_format="WAV", output_subtype=None, - normalization_enabled=False, + normalization_threshold=0.9, denoise_enabled=False, output_single_stem=None, invert_using_spec=False, @@ -57,7 +57,7 @@ def __init__( # Filter out noisy warnings from PyTorch for users who don't care about them if log_level > logging.DEBUG: warnings.filterwarnings("ignore") - + package_version = pkg_resources.get_distribution("audio-separator").version self.logger.info(f"Separator version {package_version} instantiating with output_dir: {output_dir}, output_format: {output_format}") @@ -79,11 +79,10 @@ def __init__( if self.output_subtype is None and output_format == "WAV": self.output_subtype = "PCM_16" - self.normalization_enabled = normalization_enabled - if self.normalization_enabled: - self.logger.debug(f"Normalization enabled, waveform will be normalized to max amplitude of 1.0 to avoid clipping.") - else: - self.logger.debug(f"Normalization disabled, waveform will not be normalized.") + self.normalization_threshold = normalization_threshold + self.logger.debug( + f"Normalization threshold set to {normalization_threshold}, waveform will lowered to this max amplitude to avoid clipping." + ) self.denoise_enabled = denoise_enabled if self.denoise_enabled: @@ -124,7 +123,9 @@ def setup_inferencing_device(self): self.logger.info(f"Operating System: {os_name} {os_version}") system_info = platform.uname() - self.logger.info(f"System: {system_info.system} Node: {system_info.node} Release: {system_info.release} Machine: {system_info.machine} Proc: {system_info.processor}") + self.logger.info( + f"System: {system_info.system} Node: {system_info.node} Release: {system_info.release} Machine: {system_info.machine} Proc: {system_info.processor}" + ) python_version = platform.python_version() self.logger.info(f"Python Version: {python_version}") @@ -132,11 +133,11 @@ def setup_inferencing_device(self): onnxruntime_gpu_package = self.get_package_distribution("onnxruntime-gpu") if onnxruntime_gpu_package is not None: self.logger.info(f"ONNX Runtime GPU package installed with version: {onnxruntime_gpu_package.version}") - + onnxruntime_silicon_package = self.get_package_distribution("onnxruntime-silicon") if onnxruntime_silicon_package is not None: self.logger.info(f"ONNX Runtime Silicon package installed with version: {onnxruntime_silicon_package.version}") - + onnxruntime_cpu_package = self.get_package_distribution("onnxruntime") if onnxruntime_cpu_package is not None: self.logger.info(f"ONNX Runtime CPU package installed with version: {onnxruntime_cpu_package.version}") @@ -144,16 +145,15 @@ def setup_inferencing_device(self): torch_package = self.get_package_distribution("torch") if torch_package is not None: self.logger.info(f"Torch package installed with version: {torch_package.version}") - + torchvision_package = self.get_package_distribution("torchvision") if torchvision_package is not None: self.logger.info(f"Torchvision package installed with version: {torchvision_package.version}") - + torchaudio_package = self.get_package_distribution("torchaudio") if torchaudio_package is not None: self.logger.info(f"Torchaudio package installed with version: {torchaudio_package.version}") - ort_device = ort.get_device() ort_providers = ort.get_available_providers() @@ -336,7 +336,7 @@ def separate(self, audio_file_path): self.primary_source = None self.secondary_source = None - + self.audio_file_path = audio_file_path self.audio_file_base = os.path.splitext(os.path.basename(audio_file_path))[0] @@ -344,6 +344,9 @@ def separate(self, audio_file_path): self.logger.debug("Preparing mix...") mix = self.prepare_mix(self.audio_file_path) + self.logger.debug("Normalizing mix before demixing...") + mix = spec_utils.normalize(self.logger, wave=mix, max_peak=self.normalization_threshold) + # Start the demixing process source = self.demix(mix) @@ -356,7 +359,7 @@ def separate(self, audio_file_path): # Normalize and transpose the primary source if it's not already an array if not isinstance(self.primary_source, np.ndarray): self.logger.debug("Normalizing primary source...") - self.primary_source = spec_utils.normalize(self.logger, source, self.normalization_enabled).T + self.primary_source = spec_utils.normalize(self.logger, wave=source, max_peak=self.normalization_threshold).T # Process the secondary source if not already an array if not isinstance(self.secondary_source, np.ndarray): @@ -391,7 +394,9 @@ def separate(self, audio_file_path): ) if not isinstance(self.primary_source, np.ndarray): self.primary_source = source.T - self.primary_source_map = self.final_process(self.primary_stem_path, self.primary_source, self.model_primary_stem, self.sample_rate) + self.primary_source_map = self.final_process( + self.primary_stem_path, self.primary_source, self.model_primary_stem, self.sample_rate + ) output_files.append(self.primary_stem_path) # Clear GPU cache to free up memory @@ -422,7 +427,7 @@ def separate(self, audio_file_path): def write_audio(self, stem_path: str, stem_source, sample_rate, stem_name=None): self.logger.debug(f"Entering write_audio with stem_name: {stem_name} and stem_path: {stem_path}") - stem_source = spec_utils.normalize(self.logger, stem_source, self.normalization_enabled) + stem_source = spec_utils.normalize(self.logger, wave=stem_source, max_peak=self.normalization_threshold) # Check if the numpy array is empty or contains very low values if np.max(np.abs(stem_source)) < 1e-6: @@ -648,9 +653,9 @@ def demix(self, mix, is_match_mix=False): result[..., start:end] += tar_waves[..., : end - start] # Normalizes the results by the divider to account for overlap. + self.logger.debug("Normalizing result by dividing result by divider.") tar_waves = result / divider tar_waves_.append(tar_waves) - self.logger.debug("Result normalized by divider.") # Reshapes the results to match the original dimensions. tar_waves_ = np.vstack(tar_waves_)[:, :, self.trim : -self.trim] diff --git a/audio_separator/separator/spec_utils.py b/audio_separator/separator/spec_utils.py index 0018075..8fb68f2 100644 --- a/audio_separator/separator/spec_utils.py +++ b/audio_separator/separator/spec_utils.py @@ -107,18 +107,23 @@ def run_thread(**kwargs): return spec -def normalize(logger: logging.Logger, wave, is_normalize=False): - """Save output music files""" +def normalize(logger: logging.Logger, wave, max_peak=1.0): + """Normalize audio waveform to a specified peak value. + + Args: + logger (logging.Logger): Logger for debugging information. + wave (array-like): Audio waveform. + max_peak (float): Maximum peak value for normalization. + + Returns: + array-like: Normalized or original waveform. + """ maxv = np.abs(wave).max() - if maxv > 1.0: - logger.debug(f"Normalization Set {is_normalize}: Input above threshold for clipping. Max:{maxv}") - if is_normalize: - logger.debug(f"The result was normalized.") - wave /= maxv - else: - logger.debug(f"The result was not normalized.") + if maxv > max_peak: + logger.debug(f"Maximum peak amplitude above clipping threshold, normalizing from {maxv} to max peak {max_peak}.") + wave *= max_peak / maxv else: - logger.debug(f"Normalization Set {is_normalize}: Input not above threshold for clipping. Max:{maxv}") + logger.debug(f"Maximum peak amplitude not above clipping threshold, no need to normalize: {maxv}") return wave diff --git a/audio_separator/utils/cli.py b/audio_separator/utils/cli.py index b8b8248..da6aa22 100755 --- a/audio_separator/utils/cli.py +++ b/audio_separator/utils/cli.py @@ -59,10 +59,10 @@ def main(): ) parser.add_argument( - "--normalize", - type=lambda x: (str(x).lower() == "true"), - default=False, - help="Optional: enable or disable normalization during separation (default: %(default)s). Example: --normalize=True", + "--normalization_threshold", + type=float, + default=0.9, + help="Optional: max peak amplitude to normalize input and output audio to (default: %(default)s). Example: --normalization_threshold=0.7", ) parser.add_argument( @@ -134,7 +134,7 @@ def main(): output_dir=args.output_dir, output_format=args.output_format, denoise_enabled=args.denoise, - normalization_enabled=args.normalize, + normalization_threshold=args.normalization_threshold, output_single_stem=args.single_stem, invert_using_spec=args.invert_spect, sample_rate=args.sample_rate, diff --git a/pyproject.toml b/pyproject.toml index eb177a8..fe31559 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api" [tool.poetry] name = "audio-separator" -version = "0.12.3" +version = "0.13.0" description = "Easy to use vocal separation, using MDX-Net models from UVR trained by @Anjok07" authors = ["Andrew Beveridge "] license = "MIT"