diff --git a/bin/TCNTempoDetector b/bin/TCNTempoDetector new file mode 100755 index 00000000..716aedda --- /dev/null +++ b/bin/TCNTempoDetector @@ -0,0 +1,107 @@ +#!/usr/bin/env python +# encoding: utf-8 +""" +TCNTempoDetector beat tracking algorithm. + +""" + +from __future__ import absolute_import, division, print_function + +import argparse + +import numpy as np + +from madmom.audio import SignalProcessor +from madmom.features import ActivationsProcessor +from madmom.features.beats import TCNBeatProcessor +from madmom.features.tempo import TCNTempoHistogramProcessor, TempoEstimationProcessor +from madmom.io import write_events, write_tempo +from madmom.processors import IOProcessor, io_arguments + + +def main(): + """TCNTempoDetector""" + + # define parser + p = argparse.ArgumentParser( + formatter_class=argparse.RawDescriptionHelpFormatter, description=''' + The TCNTempoDetector program detects the tempo in an audio file according + to the method described in: + + "Multi-Task learning of tempo and beat: learning one to improve the other" + Sebastian Böck, Matthew Davies and Peter Knees. + Proc. of the 20th International Society for Music Information Retrieval + Conference (ISMIR), 2019. + + This program can be run in 'single' file mode to process a single audio + file and write the detected beats to STDOUT or the given output file. + + $ TCNTempoDetector single INFILE [-o OUTFILE] + + If multiple audio files should be processed, the program can also be run + in 'batch' mode to save the detected beats to files with the given suffix. + + $ TCNTempoDetector batch [-o OUTPUT_DIR] [-s OUTPUT_SUFFIX] FILES + + If no output directory is given, the program writes the files with the + detected beats to the same location as the audio files. + + The 'pickle' mode can be used to store the used parameters to be able to + exactly reproduce experiments. + + ''') + # version + p.add_argument('--version', action='version', + version='TCNTempoDetector') + # input/output options + io_arguments(p, output_suffix='.bpm.txt', online=True) + ActivationsProcessor.add_arguments(p) + # signal processing arguments + SignalProcessor.add_arguments(p, norm=False, gain=0) + # tempo arguments + TempoEstimationProcessor.add_arguments(p, hist_smooth=15) + + # parse arguments + args = p.parse_args() + + # set immutable arguments + args.tasks = (1, ) + args.interpolate = True + args.method = None + args.act_smooth = None + + # print arguments + if args.verbose: + print(args) + + # input processor + if args.load: + # load the activations from file + in_processor = ActivationsProcessor(mode='r', **vars(args)) + else: + # use a TCN to predict beats and tempo + in_processor = TCNBeatProcessor(**vars(args)) + + # output processor + if args.save: + # save the TCN activations to file + out_processor = ActivationsProcessor(mode='w', **vars(args)) + else: + # extract the tempo histogram from the NN output + args.histogram_processor = TCNTempoHistogramProcessor(**vars(args)) + # estimate tempo + tempo_estimator = TempoEstimationProcessor(**vars(args)) + # output handler + output = write_tempo + # sequentially process them + out_processor = [tempo_estimator, output] + + # create an IOProcessor + processor = IOProcessor(in_processor, out_processor) + + # and call the processing function + args.func(processor, **vars(args)) + + +if __name__ == '__main__': + main() diff --git a/madmom/features/tempo.py b/madmom/features/tempo.py index aa7d4153..ffe2d547 100644 --- a/madmom/features/tempo.py +++ b/madmom/features/tempo.py @@ -11,6 +11,7 @@ import sys import warnings +from operator import itemgetter import numpy as np @@ -204,9 +205,9 @@ def detect_tempo(histogram, fps=None, interpolate=False): Histogram (tuple of 2 numpy arrays, the first giving the strengths of the bins and the second corresponding tempo/delay values). fps : float, optional - Frames per second. If 'None', the second element is interpreted as tempo - values. If set, the histogram's second element is interpreted as inter - beat intervals (IBIs) in frames with the given rate. + Frames per second. If 'None', the second element is interpreted as + tempo values. If set, the histogram's second element is interpreted as + inter beat intervals (IBIs) in frames with the given rate. interpolate : bool, optional Interpolate the histogram bins. @@ -285,8 +286,8 @@ def __init__(self, min_bpm, max_bpm, hist_buffer=HIST_BUFFER, fps=None, online=False, **kwargs): # pylint: disable=unused-argument super(TempoHistogramProcessor, self).__init__(online=online) - self.min_bpm = min_bpm - self.max_bpm = max_bpm + self.min_bpm = float(min_bpm) + self.max_bpm = float(max_bpm) self.hist_buffer = hist_buffer self.fps = fps if self.online: @@ -609,34 +610,86 @@ def process_online(self, activations, reset=True, **kwargs): return np.sum(bins, axis=0), self.intervals +class TCNTempoHistogramProcessor(TempoHistogramProcessor): + """ + Derive a tempo histogram from (multi-task) TCN output. + + Parameters + ---------- + min_bpm : float, optional + Minimum tempo to detect [bpm]. + max_bpm : float, optional + Maximum tempo to detect [bpm]. + + References + ---------- + .. [1] Sebastian Böck, Matthew Davies and Peter Knees, + "Multi-Task learning of tempo and beat: learning one to improve the + other", + Proceedings of the 20th International Society for Music Information + Retrieval Conference (ISMIR), 2019. + + """ + + def __init__(self, min_bpm=MIN_BPM, max_bpm=MAX_BPM, **kwargs): + # pylint: disable=unused-argument + super(TCNTempoHistogramProcessor, self).__init__( + min_bpm=min_bpm, max_bpm=max_bpm, **kwargs) + + def process(self, data, **kwargs): + """ + Extract tempo histogram from (multi-task) TCN output. + + Parameters + ---------- + data : numpy array or tuple of numpy arrays + Tempo-task (numpy array) or multi-task (tuple) output of TCN. + + Returns + ------- + histogram_bins : numpy array + Bins of tempo histogram, i.e. tempo strengths. + histogram_tempi : numpy array + Corresponding tempi [bpm]. + + """ + # if data is a tuple, tempo is usually last item of TCN output + if type(data) == tuple: + data = itemgetter(-1)(data) + # use a linear tempo range + tempi = np.arange(len(data)) + # determine tempo range to consider + min_idx = np.argmax(tempi >= self.min_bpm) + max_idx = np.argmin(tempi <= self.max_bpm) + # return only selected range + return data[min_idx:max_idx], tempi[min_idx:max_idx] + + class TempoEstimationProcessor(OnlineProcessor): """ Tempo Estimation Processor class. Parameters ---------- - method : {'comb', 'acf', 'dbn'} - Method used for tempo estimation. + method : {'comb', 'acf', 'dbn', None} + Method used for tempo histogram creation, e.g. from a beat + activation function or tempo classification layer. min_bpm : float, optional Minimum tempo to detect [bpm]. max_bpm : float, optional Maximum tempo to detect [bpm]. - act_smooth : float, optional (default: 0.14) + act_smooth : float, optional Smooth the activation function over `act_smooth` seconds. - hist_smooth : int, optional (default: 7) + hist_smooth : int, optional Smooth the tempo histogram over `hist_smooth` bins. alpha : float, optional Scaling factor for the comb filter. fps : float, optional Frames per second. histogram_processor : :class:`TempoHistogramProcessor`, optional - Processor used to create a tempo histogram. If 'None', a default - combfilter histogram processor will be created and used. + Processor used to create a tempo histogram. interpolate : bool, optional Interpolate tempo with quadratic interpolation. - kwargs : dict, optional - Keyword arguments passed to :class:`CombFilterTempoHistogramProcessor` - if no `histogram_processor` was given. Examples -------- @@ -670,8 +723,8 @@ def __init__(self, method=METHOD, min_bpm=MIN_BPM, max_bpm=MAX_BPM, if method is not None: warnings.warn( 'Usage of `method` is deprecated as of version 0.17. ' - 'Please use a dedicated `TempoHistogramProcessor` ' - 'before the `TempoEstimationProcessor` instead. ' + 'Please pass a dedicated `TempoHistogramProcessor` ' + 'instance as `histogram_processor`.' 'Functionality will be removed in version 0.19.') self.method = method self.act_smooth = act_smooth @@ -750,8 +803,8 @@ def process_offline(self, activations, **kwargs): if self.act_smooth is not None: act_smooth = int(round(self.fps * self.act_smooth)) activations = smooth_signal(activations, act_smooth) - # generate a histogram of beat intervals - histogram = self.interval_histogram(activations.astype(float)) + # generate tempo histogram from beat activations/TCN classification + histogram = self.histogram_processor(activations) # smooth the histogram histogram = smooth_histogram(histogram, self.hist_smooth) # detect the tempi and return them diff --git a/tests/data/activations/sample.beats_tcn_tempo.npz b/tests/data/activations/sample.beats_tcn_tempo.npz new file mode 100644 index 00000000..0d63112b Binary files /dev/null and b/tests/data/activations/sample.beats_tcn_tempo.npz differ diff --git a/tests/data/detections/sample.tcn_tempo_detector.txt b/tests/data/detections/sample.tcn_tempo_detector.txt new file mode 100644 index 00000000..7938ad4c --- /dev/null +++ b/tests/data/detections/sample.tcn_tempo_detector.txt @@ -0,0 +1 @@ +87.16 174.89 0.74 diff --git a/tests/test_bin.py b/tests/test_bin.py index acb0e684..9c6b71e4 100644 --- a/tests/test_bin.py +++ b/tests/test_bin.py @@ -1123,6 +1123,34 @@ def test_all_tempi(self): [68.97, 0.099], [82.19, 0.096]])) +class TestTCNTempoDetectorProgram(unittest.TestCase): + + def setUp(self): + self.bin = pj(program_path, "TCNTempoDetector") + self.activations = Activations( + pj(ACTIVATIONS_PATH, "sample.beats_tcn_tempo.npz")) + self.result = np.loadtxt( + pj(DETECTIONS_PATH, "sample.tcn_tempo_detector.txt")) + + def test_help(self): + self.assertTrue(run_help(self.bin)) + + def test_binary(self): + # save activations as binary file + run_save(self.bin, sample_file, tmp_act) + act = Activations(tmp_act) + self.assertTrue(np.allclose(act, self.activations, atol=1e-5)) + # reload from file + run_load(self.bin, tmp_act, tmp_result) + result = np.loadtxt(tmp_result) + self.assertTrue(np.allclose(result, self.result, atol=1e-5)) + + def test_run(self): + run_single(self.bin, sample_file, tmp_result) + result = np.loadtxt(tmp_result) + self.assertTrue(np.allclose(result, self.result, atol=1e-5)) + + # clean up def teardown_module(): os.unlink(tmp_act) diff --git a/tests/test_features_tempo.py b/tests/test_features_tempo.py index f0ed0a51..423acddb 100644 --- a/tests/test_features_tempo.py +++ b/tests/test_features_tempo.py @@ -10,6 +10,7 @@ import unittest from os.path import join as pj +from madmom.features import Activations from madmom.features.tempo import * from madmom.io import write_tempo, load_tempo from . import ACTIVATIONS_PATH @@ -36,6 +37,8 @@ DBN_TEMPI_ONLINE = [[176.470588, 0.580877380], [86.9565217, 0.244729904], [74.0740741, 0.127887992], [40.8163265, 0.0232523621], [250.000000, 0.0232523621]] +TCN_TEMPI = np.array([[87, 0.62103526], [175, 0.2131467], [58, 0.1607556], + [41, 0.00323059], [115, 0.0008726]]) HIST = interval_histogram_comb(act, 0.79, min_tau=24, max_tau=150) @@ -419,6 +422,54 @@ def test_process_online(self): self.assertTrue(np.allclose(np.median(hist), 0)) +class TestTCNTempoHistogramProcessorClass(unittest.TestCase): + + def setUp(self): + self.processor = TCNTempoHistogramProcessor(min_bpm=10, max_bpm=250) + self.act = Activations(pj(ACTIVATIONS_PATH, + "sample.beats_tcn_tempo.npz")) + + def test_types(self): + self.assertIsInstance(self.processor.min_bpm, float) + self.assertIsInstance(self.processor.max_bpm, float) + self.assertIsNone(self.processor.fps) + + def test_values(self): + self.assertTrue(self.processor.min_bpm == 10) + self.assertTrue(self.processor.max_bpm == 250) + self.assertTrue(np.sum(self.act) == 1) + + def test_process(self): + hist, tempi = self.processor(self.act) + self.assertTrue(np.allclose(tempi, np.arange(10, 251))) + self.assertTrue(np.allclose(hist.max(), 0.1326968)) + self.assertTrue(np.allclose(hist.min(), 5.05e-09)) + self.assertTrue(np.allclose(hist.argmax(), 77)) + self.assertTrue(np.allclose(hist.argmin(), 182)) + # hist sum is not 1, since we excluded tempi < 10 + self.assertTrue(np.allclose(np.sum(hist), 0.9999768)) + self.assertTrue(np.allclose(np.mean(hist), 0.0041492814)) + self.assertTrue(np.allclose(np.median(hist), 7.891e-06)) + + def test_tempo(self): + tempo_processor = TempoEstimationProcessor( + histogram_processor=self.processor, act_smooth=None) + tempi = tempo_processor(self.act) + self.assertTrue(tempi.shape == (14, 2)) + self.assertTrue(np.allclose(tempi[:, 0], + [87, 174, 58, 41, 115, 132, 100, + 32, 197, 246, 228, 23, 214, 14])) + self.assertTrue(np.allclose(np.sum(tempi[:, 1]), 1)) + + def test_tempo_hist_smooth(self): + tempo_processor = TempoEstimationProcessor( + histogram_processor=self.processor, act_smooth=None, + hist_smooth=15) + tempi = tempo_processor(self.act) + self.assertTrue(tempi.shape == (9, 2)) + self.assertTrue(np.allclose(tempi[:5], TCN_TEMPI, atol=0.01)) + + class TestWriteTempoFunction(unittest.TestCase): def setUp(self): diff --git a/tests/test_utils.py b/tests/test_utils.py index 5fdae039..d9b8becb 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -34,6 +34,7 @@ pj(ACTIVATIONS_PATH, 'sample.beats_blstm_mm.npz'), pj(ACTIVATIONS_PATH, 'sample.beats_lstm.npz'), pj(ACTIVATIONS_PATH, 'sample.beats_tcn_beats.npz'), + pj(ACTIVATIONS_PATH, 'sample.beats_tcn_tempo.npz'), pj(ACTIVATIONS_PATH, 'sample.cnn_chord_features.npz'), pj(ACTIVATIONS_PATH, 'sample.downbeats_blstm.npz'), pj(ACTIVATIONS_PATH, 'sample.deep_chroma.npz'), @@ -90,6 +91,7 @@ pj(DETECTIONS_PATH, 'sample.super_flux.txt'), pj(DETECTIONS_PATH, 'sample.super_flux_nn.txt'), pj(DETECTIONS_PATH, 'sample.tcn_beat_tracker.txt'), + pj(DETECTIONS_PATH, 'sample.tcn_tempo_detector.txt'), pj(DETECTIONS_PATH, 'sample.tempo_detector.txt'), pj(DETECTIONS_PATH, 'sample2.cnn_chord_recognition.txt'), pj(DETECTIONS_PATH, 'sample2.dc_chord_recognition.txt'),