diff --git a/madmom/evaluation/key.py b/madmom/evaluation/key.py index b728b7fc..18f573c8 100644 --- a/madmom/evaluation/key.py +++ b/madmom/evaluation/key.py @@ -59,19 +59,95 @@ def key_label_to_class(key_label): return key_class -def error_type(det_key, ann_key, strict_fifth=False): +def key_class_to_root_and_mode(key_class): """ - Compute the evaluation score and error category for a predicted key - compared to the annotated key. + Extract the root and mode from a key class id + :param key_class: number + :type key_class: int + :return: root id in terms of semi-tones apart from C and + mode id (0: major; 1: minor) + :rtype: tuple(int, int) + """ + if 0 <= key_class <= 23: + root = key_class % 12 + mode = key_class // 12 + else: + raise ValueError("{} is outside the [0; 23] range]".format(key_class)) + return root, mode + + +def _compute_root_distance(det_root, ann_root): + return (det_root - ann_root) % 12 + + +def _is_correct(det_key, ann_key): + det_root, det_mode = key_class_to_root_and_mode(det_key) + ann_root, ann_mode = key_class_to_root_and_mode(ann_key) + return det_mode == ann_mode and det_root == ann_root + + +def _is_fifth(det_key, ann_key, strict_fifth): + det_root, det_mode = key_class_to_root_and_mode(det_key) + ann_root, ann_mode = key_class_to_root_and_mode(ann_key) + root_distance = _compute_root_distance(det_root, ann_root) + return det_mode == ann_mode and (root_distance == 7 or + (root_distance == 5 and not strict_fifth)) + + +def _is_parallel(det_key, ann_key): + det_root, det_mode = key_class_to_root_and_mode(det_key) + ann_root, ann_mode = key_class_to_root_and_mode(ann_key) + return det_root == ann_root and det_mode != ann_mode + + +def _is_relative(det_key, ann_key, major, minor): + det_root, det_mode = key_class_to_root_and_mode(det_key) + ann_root, ann_mode = key_class_to_root_and_mode(ann_key) + root_distance = _compute_root_distance(det_root, ann_root) + ann_mode_is_major = (ann_mode == major and root_distance == 9) + ann_mode_is_minor = (ann_mode == minor and root_distance == 3) + return det_mode != ann_mode and (ann_mode_is_major or ann_mode_is_minor) + + +def _is_relative_of_fifth(det_key, ann_key, major, minor, strict_fifth, + relative_of_fifth): + det_root, det_mode = key_class_to_root_and_mode(det_key) + ann_root, ann_mode = key_class_to_root_and_mode(ann_key) + root_distance = _compute_root_distance(det_root, ann_root) + distance_criterion = _compute_relative_of_fifth_distance_criterion( + root_distance, ann_mode, strict_fifth, minor, major) + return ann_mode != det_mode and relative_of_fifth and distance_criterion + + +def _compute_relative_of_fifth_distance_criterion(root_distance, ann_mode, + strict_fifth, minor, major): + ann_mode_is_major = ann_mode == major and ((root_distance == 4) or + (root_distance == 2 and + not strict_fifth)) + ann_mode_is_minor = ann_mode == minor and ((root_distance == 10) or + (root_distance == 8 and + not strict_fifth)) + return ann_mode_is_major or ann_mode_is_minor - Categories and evaluation scores follow the evaluation strategy used + +def error_type(det_key, ann_key, strict_fifth=False, relative_of_fifth=False): + """ + Compute the error category for a predicted key compared to + the annotated key. + + Categories follow the evaluation strategy used for MIREX (see http://music-ir.org/mirex/wiki/2017:Audio_Key_Detection). + There are two evaluation modes for the 'fifth' category: by default, a detection falls into the 'fifth' category if it is the fifth of the annotation, or the annotation is the fifth of the detection. If `strict_fifth` is `True`, only the former case is considered. This is the mode used for MIREX. + There is an optional category: 'relative of fifth'. This allows to separate + keys that are closely related to the annotated key on the circle of fifth + from the 'other' error category. + Parameters ---------- det_key : int @@ -80,36 +156,69 @@ def error_type(det_key, ann_key, strict_fifth=False): Annotated key class. strict_fifth: bool Use strict interpretation of the 'fifth' category, as in MIREX. + relative_of_fifth: bool + Differentiate relative keys of the fifth wrt the annotated key. + Is coherent with strict_fifth in the sense that it only considers the + relative key of the strict fifth. Returns ------- - score, category : float, str - Evaluation score and error category. + category : str + Error category. + Examples + -------- + >>> from madmom.evaluation.key import error_type + + # annotated: 'C major' / detected: 'C major' + >>> error_type(0, 0) + 'correct' + + # annotated: 'C major' / detected: 'G major': +7 semitones + >>> error_type(7, 0) + 'fifth' + + # annotated: 'C major' / detected: 'F major': -7 semitones (modulo 12) + >>> error_type(5, 0) + 'fifth' + + # annotated: 'C major' / detected: 'F major': -7 semitones (modulo 12), + # the MIREX way + >>> error_type(5, 0, strict_fifth=True) + 'other' + + # annotated: 'C major' / detected: 'E minor': E minor is the relative key + # of G Major, which is the fifth of C Major + >>> error_type(16, 0, relative_of_fifth=True) + 'relative_of_fifth' + + # annotated: 'C major' / detected: 'D minor': D minor is the relative key + # of F Major, of which C Major is the fifth + >>> error_type(14, 0, relative_of_fifth=True) + 'relative_of_fifth' + + # annotated: 'C major' / detected: 'D minor': D minor is the relative key + # of F Major, of which C Major is the fifth + # - using MIREX definition of 'fifth' + >>> error_type(14, 0, relative_of_fifth=True, strict_fifth=True) + 'other' """ - ann_root = ann_key % 12 - ann_mode = ann_key // 12 - det_root = det_key % 12 - det_mode = det_key // 12 major, minor = 0, 1 - if det_root == ann_root and det_mode == ann_mode: - return 1.0, 'correct' - if det_mode == ann_mode and ((det_root - ann_root) % 12 == 7): - return 0.5, 'fifth' - if not strict_fifth and (det_mode == ann_mode and - ((det_root - ann_root) % 12 == 5)): - return 0.5, 'fifth' - if (ann_mode == major and det_mode != ann_mode and ( - (det_root - ann_root) % 12 == 9)): - return 0.3, 'relative' - if (ann_mode == minor and det_mode != ann_mode and ( - (det_root - ann_root) % 12 == 3)): - return 0.3, 'relative' - if det_mode != ann_mode and det_root == ann_root: - return 0.2, 'parallel' + if _is_correct(det_key, ann_key): + error_type = 'correct' + elif _is_fifth(det_key, ann_key, strict_fifth): + error_type = 'fifth' + elif _is_parallel(det_key, ann_key): + error_type = 'parallel' + elif _is_relative(det_key, ann_key, major, minor): + error_type = 'relative' + elif _is_relative_of_fifth(det_key, ann_key, major, minor, strict_fifth, + relative_of_fifth): + error_type = 'relative_of_fifth' else: - return 0.0, 'other' + error_type = 'other' + return error_type class KeyEvaluation(EvaluationMixin): @@ -119,11 +228,13 @@ class KeyEvaluation(EvaluationMixin): Parameters ---------- detection : str - File containing detected key + File containing detected key. annotation : str - File containing annotated key + File containing annotated key. strict_fifth : bool, optional Use strict interpretation of the 'fifth' category, as in MIREX. + relative_of_fifth: bool, optional + Consider relative of the fifth in the evaluation. name : str, optional Name of the evaluation object (e.g., the name of the song). @@ -134,14 +245,29 @@ class KeyEvaluation(EvaluationMixin): ('error_category', 'Error Category') ] - def __init__(self, detection, annotation, strict_fifth=False, name=None, + error_scores = {'correct': 1.0, + 'fifth': 0.5, + 'relative': 0.3, + 'relative_of_fifth': 0.0, + 'parallel': 0.2, + 'other': 0.0} + + def __init__(self, detection, + annotation, + strict_fifth=False, + name=None, + relative_of_fifth=False, **kwargs): self.name = name or '' self.detection = key_label_to_class(detection) self.annotation = key_label_to_class(annotation) - self.score, self.error_category = error_type( - self.detection, self.annotation, strict_fifth - ) + self.strict_fifth = strict_fifth + self.relative_of_fifth = relative_of_fifth + self.error_category = error_type(self.detection, + self.annotation, + self.strict_fifth, + self.relative_of_fifth) + self.score = self.error_scores[self.error_category] def tostring(self, **kwargs): """ @@ -175,6 +301,7 @@ class KeyMeanEvaluation(EvaluationMixin): ('correct', 'Correct'), ('fifth', 'Fifth'), ('relative', 'Relative'), + ('relative_of_fifth', 'Relative of Fifth'), ('parallel', 'Parallel'), ('other', 'Other'), ('weighted', 'Weighted'), @@ -182,22 +309,62 @@ class KeyMeanEvaluation(EvaluationMixin): def __init__(self, eval_objects, name=None): self.name = name or 'mean for {:d} files'.format(len(eval_objects)) + if _check_key_eval_objects(eval_objects): + self._count_evaluations(eval_objects) + else: + raise ValueError('The KeyEvaluation objects are not ' + 'all the same.') + def _count_evaluations(self, eval_objects): n = len(eval_objects) c = Counter(e.error_category for e in eval_objects) - self.correct = float(c['correct']) / n self.fifth = float(c['fifth']) / n self.relative = float(c['relative']) / n self.parallel = float(c['parallel']) / n self.other = float(c['other']) / n + if 'relative_of_fifth' in c.keys(): + self.relative_of_fifth = float(c['relative_of_fifth']) / n + else: + self.relative_of_fifth = None self.weighted = sum(e.score for e in eval_objects) / n def tostring(self, **kwargs): - return ('{}\n Weighted: {:.3f} Correct: {:.3f} Fifth: {:.3f} ' - 'Relative: {:.3f} Parallel: {:.3f} Other: {:.3f}'.format( - self.name, self.weighted, self.correct, self.fifth, - self.relative, self.parallel, self.other)) + ret = '' + spacing = ' ' * 2 + ret += '{}\n'.format(self.name) + spacing + ret += 'Weighted: {:.3f}'.format(self.weighted) + spacing + ret += 'Correct: {:.3f}'.format(self.correct) + spacing + ret += 'Fifth: {:.3f}'.format(self.fifth) + spacing + ret += 'Relative: {:.3f}'.format(self.relative) + spacing + if self.relative_of_fifth: + ret += 'Relative of fifth: {:.3f}'.format(self.relative_of_fifth) \ + + spacing + ret += 'Parallel: {:.3f}'.format(self.parallel) + spacing + ret += 'Other: {:.3f}'.format(self.other) + return ret + + +def _check_key_eval_objects(key_eval_objects): + """ + Check whether all the key evaluation objects in a list have the same way of + scoring errors + + Parameters: + ---------- + key_eval_objects: list + Key evaluation objects + """ + if len(key_eval_objects) > 0: + e = key_eval_objects[0] + for ke in key_eval_objects[1:]: + if (ke.error_scores != e.error_scores or + ke.strict_fifth != e.strict_fifth or + ke.relative_of_fifth != e.relative_of_fifth): + return False + return True + else: + raise ValueError('No KeyEvaluation objects to check.') def add_parser(parser): diff --git a/tests/data/detections/dummy.correct.key.txt b/tests/data/detections/dummy.correct.key.txt new file mode 100644 index 00000000..700bb93b --- /dev/null +++ b/tests/data/detections/dummy.correct.key.txt @@ -0,0 +1 @@ +f# minor \ No newline at end of file diff --git a/tests/data/detections/dummy.fifth.key.txt b/tests/data/detections/dummy.fifth.key.txt new file mode 100644 index 00000000..07413433 --- /dev/null +++ b/tests/data/detections/dummy.fifth.key.txt @@ -0,0 +1 @@ +B minor \ No newline at end of file diff --git a/tests/data/detections/dummy.other.key.txt b/tests/data/detections/dummy.other.key.txt new file mode 100644 index 00000000..45f4b694 --- /dev/null +++ b/tests/data/detections/dummy.other.key.txt @@ -0,0 +1 @@ +G# minor \ No newline at end of file diff --git a/tests/data/detections/dummy.parallel.key.txt b/tests/data/detections/dummy.parallel.key.txt new file mode 100644 index 00000000..d1199919 --- /dev/null +++ b/tests/data/detections/dummy.parallel.key.txt @@ -0,0 +1 @@ +F# maj \ No newline at end of file diff --git a/tests/data/detections/dummy.relative_of_fifth.key.txt b/tests/data/detections/dummy.relative_of_fifth.key.txt new file mode 100644 index 00000000..db05a7ec --- /dev/null +++ b/tests/data/detections/dummy.relative_of_fifth.key.txt @@ -0,0 +1 @@ +e maj \ No newline at end of file diff --git a/tests/test_evaluation_key.py b/tests/test_evaluation_key.py index f0dc0fb1..534e3930 100644 --- a/tests/test_evaluation_key.py +++ b/tests/test_evaluation_key.py @@ -59,69 +59,126 @@ def test_values(self): key_label_to_class('F# major')) +class TestKeyClassToRootAndModeFunction(unittest.TestCase): + + def test_values(self): + self.assertEqual(key_class_to_root_and_mode(0), (0, 0)) + self.assertEqual(key_class_to_root_and_mode(1), (1, 0)) + self.assertEqual(key_class_to_root_and_mode(2), (2, 0)) + self.assertEqual(key_class_to_root_and_mode(3), (3, 0)) + self.assertEqual(key_class_to_root_and_mode(4), (4, 0)) + self.assertEqual(key_class_to_root_and_mode(5), (5, 0)) + self.assertEqual(key_class_to_root_and_mode(6), (6, 0)) + self.assertEqual(key_class_to_root_and_mode(7), (7, 0)) + self.assertEqual(key_class_to_root_and_mode(8), (8, 0)) + self.assertEqual(key_class_to_root_and_mode(9), (9, 0)) + self.assertEqual(key_class_to_root_and_mode(10), (10, 0)) + self.assertEqual(key_class_to_root_and_mode(11), (11, 0)) + self.assertEqual(key_class_to_root_and_mode(12), (0, 1)) + self.assertEqual(key_class_to_root_and_mode(13), (1, 1)) + self.assertEqual(key_class_to_root_and_mode(14), (2, 1)) + self.assertEqual(key_class_to_root_and_mode(15), (3, 1)) + self.assertEqual(key_class_to_root_and_mode(16), (4, 1)) + self.assertEqual(key_class_to_root_and_mode(17), (5, 1)) + self.assertEqual(key_class_to_root_and_mode(18), (6, 1)) + self.assertEqual(key_class_to_root_and_mode(19), (7, 1)) + self.assertEqual(key_class_to_root_and_mode(20), (8, 1)) + self.assertEqual(key_class_to_root_and_mode(21), (9, 1)) + self.assertEqual(key_class_to_root_and_mode(22), (10, 1)) + self.assertEqual(key_class_to_root_and_mode(23), (11, 1)) + with self.assertRaises(ValueError): + key_class_to_root_and_mode(-4) + with self.assertRaises(ValueError): + key_class_to_root_and_mode(24) + + class TestErrorTypeFunction(unittest.TestCase): - def _compare_scores(self, correct, fifth_strict, fifth_lax, relative, - parallel): + def _compare_error_types(self, correct, fifth_strict, fifth_lax, relative, + relative_of_fifth_up, relative_of_fifth_down, + parallel): for det_key in range(24): - score, cat = error_type(det_key, correct) - score_st, cat_st = error_type(det_key, correct, strict_fifth=True) + cat = error_type(det_key, correct) + cat_st = error_type(det_key, correct, strict_fifth=True) + cat_rf = error_type(det_key, correct, relative_of_fifth=True) + cat_st_rf = error_type(det_key, + correct, + strict_fifth=True, + relative_of_fifth=True) if det_key == correct: self.assertEqual(cat, 'correct') - self.assertEqual(score, 1.0) self.assertEqual(cat_st, cat) - self.assertEqual(score_st, score) - if det_key == fifth_strict: + self.assertEqual(cat_rf, cat) + self.assertEqual(cat_st_rf, cat) + elif det_key == fifth_strict: self.assertEqual(cat, 'fifth') - self.assertEqual(score, 0.5) self.assertEqual(cat_st, cat) - self.assertEqual(score_st, score) - if det_key == fifth_lax: + self.assertEqual(cat_rf, cat) + self.assertEqual(cat_st_rf, cat) + elif det_key == fifth_lax: self.assertEqual(cat, 'fifth') - self.assertEqual(score, 0.5) self.assertEqual(cat_st, 'other') - self.assertEqual(score_st, 0.0) - if det_key == relative: + self.assertEqual(cat_rf, 'fifth') + self.assertEqual(cat_st_rf, 'other') + elif det_key == relative: self.assertEqual(cat, 'relative') - self.assertEqual(score, 0.3) self.assertEqual(cat_st, cat) - self.assertEqual(score_st, score) - if det_key == parallel: + self.assertEqual(cat_rf, cat) + self.assertEqual(cat_st_rf, cat) + elif det_key == relative_of_fifth_down: + self.assertEqual(cat, 'other') + self.assertEqual(cat_st, cat) + self.assertEqual(cat_rf, 'relative_of_fifth') + self.assertEqual(cat_st_rf, cat) + elif det_key == relative_of_fifth_up: + self.assertEqual(cat, 'other') + self.assertEqual(cat_st, cat) + self.assertEqual(cat_rf, 'relative_of_fifth') + self.assertEqual(cat_st_rf, 'relative_of_fifth') + elif det_key == parallel: self.assertEqual(cat, 'parallel') - self.assertEqual(score, 0.2) self.assertEqual(cat_st, cat) - self.assertEqual(score_st, score) + self.assertEqual(cat_rf, cat) + self.assertEqual(cat_st_rf, cat) def test_values(self): - self._compare_scores( + self._compare_error_types( correct=key_label_to_class('c maj'), fifth_strict=key_label_to_class('g maj'), fifth_lax=key_label_to_class('f maj'), relative=key_label_to_class('a min'), + relative_of_fifth_up=key_label_to_class('e min'), + relative_of_fifth_down=key_label_to_class('d min'), parallel=key_label_to_class('c min') ) - self._compare_scores( + self._compare_error_types( correct=key_label_to_class('eb maj'), fifth_strict=key_label_to_class('bb maj'), fifth_lax=key_label_to_class('ab maj'), relative=key_label_to_class('c min'), + relative_of_fifth_up=key_label_to_class('g min'), + relative_of_fifth_down=key_label_to_class('f min'), parallel=key_label_to_class('eb min') ) - self._compare_scores( + self._compare_error_types( correct=key_label_to_class('a min'), fifth_strict=key_label_to_class('e min'), fifth_lax=key_label_to_class('d min'), relative=key_label_to_class('c maj'), + relative_of_fifth_up=key_label_to_class('g maj'), + relative_of_fifth_down=key_label_to_class('f maj'), parallel=key_label_to_class('a maj') ) - self._compare_scores( + self._compare_error_types( correct=key_label_to_class('b min'), fifth_strict=key_label_to_class('gb min'), fifth_lax=key_label_to_class('e min'), relative=key_label_to_class('d maj'), + relative_of_fifth_up=key_label_to_class('a maj'), + relative_of_fifth_down=key_label_to_class('g maj'), parallel=key_label_to_class('b maj') ) @@ -129,46 +186,179 @@ def test_values(self): class TestKeyEvaluationClass(unittest.TestCase): def setUp(self): - self.eval = KeyEvaluation( + # this one should have a score of 1 + self.eval_correct = KeyEvaluation( + load_key(join(DETECTIONS_PATH, 'dummy.correct.key.txt')), + load_key(join(ANNOTATIONS_PATH, 'dummy.key')), + name='eval_correct' + ) + # this one should have a score of 0.5 + self.eval_fifth = KeyEvaluation( + load_key(join(DETECTIONS_PATH, 'dummy.fifth.key.txt')), + load_key(join(ANNOTATIONS_PATH, 'dummy.key')), + name='eval_fifth' + ) + + # this one should have a score of 0.3 + self.eval_relative = KeyEvaluation( load_key(join(DETECTIONS_PATH, 'dummy.key.txt')), load_key(join(ANNOTATIONS_PATH, 'dummy.key')), - name='TestEval' + name='eval_relative' + ) + # this one should have a score of 0.2 + self.eval_parallel = KeyEvaluation( + load_key(join(DETECTIONS_PATH, 'dummy.parallel.key.txt')), + load_key(join(ANNOTATIONS_PATH, 'dummy.key')), + name='eval_parallel' + ) + # this one should have a score of 0.0 + self.eval_relative_of_fifth = KeyEvaluation( + load_key(join(DETECTIONS_PATH, 'dummy.relative_of_fifth.key.txt')), + load_key(join(ANNOTATIONS_PATH, 'dummy.key')), + relative_of_fifth=True, + name='eval_relative_of_fifth' + ) + # this one should have a score of 0.0 + self.eval_other = KeyEvaluation( + load_key(join(DETECTIONS_PATH, 'dummy.other.key.txt')), + load_key(join(ANNOTATIONS_PATH, 'dummy.key')), + name='eval_other' ) def test_init(self): - self.assertTrue(self.eval.name == 'TestEval') - self.assertTrue(self.eval.detection, 9) - self.assertTrue(self.eval.annotation, 18) + self.assertTrue(self.eval_relative.name == 'eval_relative') + self.assertTrue(self.eval_relative.detection, 9) + self.assertTrue(self.eval_relative.annotation, 18) def test_results(self): - self.assertEqual(self.eval.error_category, 'relative') - self.assertEqual(self.eval.score, 0.3) + # Correct + self.assertEqual(self.eval_correct.error_category, 'correct') + self.assertEqual(self.eval_correct.score, 1.0) + # Fifth + self.assertEqual(self.eval_fifth.error_category, 'fifth') + self.assertEqual(self.eval_fifth.score, 0.5) + # Relative + self.assertEqual(self.eval_relative.error_category, 'relative') + self.assertEqual(self.eval_relative.score, 0.3) + # Relative of Fifth + self.assertEqual(self.eval_relative_of_fifth.error_category, + 'relative_of_fifth') + self.assertEqual(self.eval_relative_of_fifth.score, 0.0) + # Parallel + self.assertEqual(self.eval_parallel.error_category, 'parallel') + self.assertEqual(self.eval_parallel.score, 0.2) + # Other + self.assertEqual(self.eval_other.error_category, 'other') + self.assertEqual(self.eval_other.score, 0.0) class TestKeyMeanEvaluation(unittest.TestCase): def setUp(self): # this one should have a score of 1 - self.eval1 = KeyEvaluation( - load_key(join(DETECTIONS_PATH, 'dummy.key.txt')), - load_key(join(DETECTIONS_PATH, 'dummy.key.txt')), - name='eval1' + self.eval_correct = KeyEvaluation( + load_key(join(DETECTIONS_PATH, 'dummy.correct.key.txt')), + load_key(join(ANNOTATIONS_PATH, 'dummy.key')), + name='eval_correct' ) - # this one should have a score of 0.3 - self.eval2 = KeyEvaluation( + # this one should have a score of 0.2 + self.eval_parallel = KeyEvaluation( + load_key(join(DETECTIONS_PATH, 'dummy.parallel.key.txt')), + load_key(join(ANNOTATIONS_PATH, 'dummy.key')), + name='eval_parallel' + ) + # this one should have a score of 0.0 + self.eval_relative = KeyEvaluation( load_key(join(DETECTIONS_PATH, 'dummy.key.txt')), load_key(join(ANNOTATIONS_PATH, 'dummy.key')), - name='eval2' + name='eval_relative' + ) + # this one should have a score of 0.0 + self.eval_other = KeyEvaluation( + load_key(join(DETECTIONS_PATH, 'dummy.other.key.txt')), + load_key(join(ANNOTATIONS_PATH, 'dummy.key')), + name='eval_other' + ) + # this one has has the same key BUT a different set of error scores + self.eval_different_scores = KeyEvaluation( + load_key(join(DETECTIONS_PATH, 'dummy.correct.key.txt')), + load_key(join(ANNOTATIONS_PATH, 'dummy.key')), + name='eval_correct_different_scores' + ) + self.eval_different_scores.error_scores = {'correct': 0.5} + + self.eval_correct_w_rel_of_fifth = KeyEvaluation( + load_key(join(DETECTIONS_PATH, 'dummy.correct.key.txt')), + load_key(join(ANNOTATIONS_PATH, 'dummy.key')), + relative_of_fifth=True, + name='eval_correct_w_rel_of_fifth' + ) + + self.eval_rel_of_fifth = KeyEvaluation( + load_key(join(DETECTIONS_PATH, 'dummy.relative_of_fifth.key.txt')), + load_key(join(ANNOTATIONS_PATH, 'dummy.key')), + relative_of_fifth=True, + name='eval_rel_of_fifth' ) + def test_check_key_eval_objects(self): + evals = [self.eval_correct, self.eval_parallel, + self.eval_different_scores, self.eval_other] + with self.assertRaises(ValueError): + KeyMeanEvaluation(evals) + + evals = [self.eval_correct, self.eval_parallel, + self.eval_rel_of_fifth] + with self.assertRaises(ValueError): + KeyMeanEvaluation(evals) + + def test_empty_eval_list(self): + with self.assertRaises(ValueError): + KeyMeanEvaluation([]) + def test_mean_results(self): - mean_eval = KeyMeanEvaluation([self.eval1, self.eval2]) - self.assertAlmostEqual(mean_eval.correct, 0.5) - self.assertAlmostEqual(mean_eval.fifth, 0.) - self.assertAlmostEqual(mean_eval.relative, 0.5) - self.assertAlmostEqual(mean_eval.parallel, 0.0) - self.assertAlmostEqual(mean_eval.other, 0.0) - self.assertAlmostEqual(mean_eval.weighted, 0.65) + evals = [self.eval_correct, self.eval_parallel, self.eval_relative, + self.eval_other] + + mean_eval = KeyMeanEvaluation(evals) + + self.assertAlmostEqual(mean_eval.correct, 1.0 / len(evals)) + self.assertAlmostEqual(mean_eval.fifth, 0.0) + self.assertAlmostEqual(mean_eval.relative, 1.0 / len(evals)) + self.assertAlmostEqual(mean_eval.parallel, 1.0 / len(evals)) + self.assertAlmostEqual(mean_eval.other, 1.0 / len(evals)) + self.assertAlmostEqual(mean_eval.weighted, 0.375) + self.assertEqual(mean_eval.tostring(), + 'mean for 4 files\n ' + 'Weighted: 0.375 ' + 'Correct: 0.250 ' + 'Fifth: 0.000 ' + 'Relative: 0.250 ' + 'Parallel: 0.250 ' + 'Other: 0.250') + + def test_mean_results_w_rel_of_fifth(self): + evals = [self.eval_correct_w_rel_of_fifth, + self.eval_rel_of_fifth] + + mean_eval = KeyMeanEvaluation(evals, name='Jean-Guy') + + self.assertAlmostEqual(mean_eval.correct, 1.0 / len(evals)) + self.assertAlmostEqual(mean_eval.fifth, 0.0) + self.assertAlmostEqual(mean_eval.relative, 0.0) + self.assertAlmostEqual(mean_eval.relative_of_fifth, 1.0 / len(evals)) + self.assertAlmostEqual(mean_eval.parallel, 0.0 / len(evals)) + self.assertAlmostEqual(mean_eval.other, 0.0 / len(evals)) + self.assertAlmostEqual(mean_eval.weighted, 0.5) + self.assertEqual(mean_eval.tostring(), + 'Jean-Guy\n ' + 'Weighted: 0.500 ' + 'Correct: 0.500 ' + 'Fifth: 0.000 ' + 'Relative: 0.000 ' + 'Relative of fifth: 0.500 ' + 'Parallel: 0.000 ' + 'Other: 0.000') class TestAddParserFunction(unittest.TestCase): diff --git a/tests/test_utils.py b/tests/test_utils.py index d9b8becb..e0218463 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -71,6 +71,11 @@ DETECTION_FILES = [pj(DETECTIONS_PATH, 'dummy.chords.txt'), pj(DETECTIONS_PATH, 'dummy.key.txt'), + pj(DETECTIONS_PATH, 'dummy.correct.key.txt'), + pj(DETECTIONS_PATH, 'dummy.fifth.key.txt'), + pj(DETECTIONS_PATH, 'dummy.other.key.txt'), + pj(DETECTIONS_PATH, 'dummy.parallel.key.txt'), + pj(DETECTIONS_PATH, 'dummy.relative_of_fifth.key.txt'), pj(DETECTIONS_PATH, 'sample.beat_detector.txt'), pj(DETECTIONS_PATH, 'sample.beat_tracker.txt'), pj(DETECTIONS_PATH, 'sample.cnn_chord_recognition.txt'),